Skip to content

fined tuned for version of llama-3.2-1b throws garbage results #1099

@mohittalele

Description

@mohittalele

Describe the bug

Upon finetuning llama3.2 with simple identity dataset, the model chat-completion endpoint spits out garbage results.

Steps/Code to reproduce bug

my config is as follows :-


# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# To run this recipe, please use the following command:
# torchrun --nproc-per-node=8 recipes/llm_finetune/finetune.py --config recipes/llm_finetune/llama3_2/llama3_2_1b_squad.yaml
# Adjust --nproc-per-node to the number of GPUs available on your host machine.


step_scheduler:
  global_batch_size: 64
  local_batch_size: 8
  ckpt_every_steps: 100
  val_every_steps: 10  # will run every x number of gradient steps
  max_steps: 25

dist_env:
  backend: nccl
  timeout_minutes: 1

rng:
  _target_: nemo_automodel.components.training.rng.StatefulRNG
  seed: 1111
  ranked: true

model:
  _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
  pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct



peft:
  _target_: nemo_automodel.components._peft.lora.PeftConfig
  match_all_linear: True
  dim: 8
  alpha: 32
  use_triton: True

distributed:
  _target_: nemo_automodel.components.distributed.fsdp2.FSDP2Manager
  dp_size: none
  dp_replicate_size: 1 # dp_shard_size = dp_size / dp_replicate_size and dp_shard_size < dp_size. For DDP usecase, use DDPManager
  tp_size: 1
  cp_size: 1
  sequence_parallel: false

loss_fn:
  _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy

# dataset:
#   _target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset
#   dataset_name: rajpurkar/squad
#   split: train

dataset:
  _target_: nemo_automodel.components.datasets.llm.column_mapped_text_instruction_dataset.ColumnMappedTextInstructionDataset
  column_mapping:
    # context: definition
    question: question
    answer: answer
  answer_only_loss_mask: False
  # start_of_turn_token: "<|assistant|>"
  path_or_dataset_id: /workspace/dataset-pipeline/data/nemo-dataset/train/train.jsonl
  tokenizer:
    _target_: nemo_automodel._transformers.auto_tokenizer.NeMoAutoTokenizer.from_pretrained
    pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
    trust_remote_code: true

packed_sequence:
  packed_sequence_size: 0

dataloader:
  _target_: torchdata.stateful_dataloader.StatefulDataLoader
  collate_fn: nemo_automodel.components.datasets.utils.default_collater
  shuffle: false

validation_dataset:
  _target_: nemo_automodel.components.datasets.llm.column_mapped_text_instruction_dataset.ColumnMappedTextInstructionDataset
  path_or_dataset_id: /workspace/dataset-pipeline/data/nemo-dataset/val/val.jsonl
  column_mapping:
    # context: definition
    question: question
    answer: answer
  # answer_only_loss_mask: true
  # start_of_turn_token: "<|assistant|>"

validation_dataloader:
  _target_: torchdata.stateful_dataloader.StatefulDataLoader
  collate_fn: nemo_automodel.components.datasets.utils.default_collater

optimizer:
  _target_: torch.optim.Adam
  betas: [0.9, 0.999]
  eps: 1e-8
  lr: 1.0e-5
  weight_decay: 0
  # min_lr: 1.0e-5

lr_scheduler:
  lr_decay_style: cosine
  min_lr: 1.0e-6

mlflow:
  experiment_name: automodel-llm-llama3_2_1b_codegen-finetune
  run_name: ''
  tracking_uri: http://mlflow:5000
  artifact_location: s3://mlflow/nemo-automodel
  tags:
    task: codegen-finetune
    model_family: llama3.2
    model_size: 1b
    dataset: codegen
    framework: automodel

checkpoint:
  enabled: true
  checkpoint_dir: /workspace/nemo-stack/codegen/checkpoints/llama-3.2-1b-instruct
  model_save_format: safetensors
  save_consolidated: false

I merge the adapaters using below code :-

def merge_adapter_with_base_model(
    base_model_name: str,
    adapter_path: str,
    output_path: str,
    save: bool = True
):
    """
    Merge LoRA adapter with base model and optionally save merged model.
    
    Args:
        base_model_name: HuggingFace model ID or path to base model
        adapter_path: Path to the LoRA adapter directory
        output_path: Path where merged model will be saved
        save: Whether to save the merged model to disk
    
    Returns:
        Tuple of (merged_model, tokenizer)
    """
    print(f"\n{'='*80}")
    print("🔧 MERGING LORA ADAPTER WITH BASE MODEL")
    print(f"{'='*80}")
    
    print(f"\n📦 Loading base model: {base_model_name}")
    model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        torch_dtype="auto",
        device_map="auto",
        trust_remote_code=True
    )
    
    print(f"📝 Loading tokenizer: {base_model_name}")
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    # tokenizer.padding_side = 'left'
    # if tokenizer.pad_token is None:
    #     tokenizer.pad_token = tokenizer.eos_token
    
    print(f"\n🔌 Loading adapter from: {adapter_path}")
    model = PeftModel.from_pretrained(model, adapter_path)
    
    print("⚙️  Merging adapter weights with base model...")
    merged_model = model.merge_and_unload()
    
    if save:
        print(f"\n💾 Saving merged model to: {output_path}")
        os.makedirs(output_path, exist_ok=True)
        merged_model.save_pretrained(output_path)
        tokenizer.save_pretrained(output_path)
        print(f"✅ Merged model saved successfully!")
    else:
        print("⏭️  Skipping save (--no-save specified)")
    
    return merged_model, tokenizer

Additional context

I think the issue is somehow related to tokenizer.

I expect that nemo-automodel will use the default chat-template available for llama3.2-1b

  tokenizer:
    _target_: nemo_automodel._transformers.auto_tokenizer.NeMoAutoTokenizer.from_pretrained
    pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
    trust_remote_code: true

I use the same chat_template while doing serving using VLLM.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions