-
Notifications
You must be signed in to change notification settings - Fork 49
Open
Description
Describe the bug
Upon finetuning llama3.2 with simple identity dataset, the model chat-completion endpoint spits out garbage results.
Steps/Code to reproduce bug
my config is as follows :-
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# To run this recipe, please use the following command:
# torchrun --nproc-per-node=8 recipes/llm_finetune/finetune.py --config recipes/llm_finetune/llama3_2/llama3_2_1b_squad.yaml
# Adjust --nproc-per-node to the number of GPUs available on your host machine.
step_scheduler:
global_batch_size: 64
local_batch_size: 8
ckpt_every_steps: 100
val_every_steps: 10 # will run every x number of gradient steps
max_steps: 25
dist_env:
backend: nccl
timeout_minutes: 1
rng:
_target_: nemo_automodel.components.training.rng.StatefulRNG
seed: 1111
ranked: true
model:
_target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
peft:
_target_: nemo_automodel.components._peft.lora.PeftConfig
match_all_linear: True
dim: 8
alpha: 32
use_triton: True
distributed:
_target_: nemo_automodel.components.distributed.fsdp2.FSDP2Manager
dp_size: none
dp_replicate_size: 1 # dp_shard_size = dp_size / dp_replicate_size and dp_shard_size < dp_size. For DDP usecase, use DDPManager
tp_size: 1
cp_size: 1
sequence_parallel: false
loss_fn:
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
# dataset:
# _target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset
# dataset_name: rajpurkar/squad
# split: train
dataset:
_target_: nemo_automodel.components.datasets.llm.column_mapped_text_instruction_dataset.ColumnMappedTextInstructionDataset
column_mapping:
# context: definition
question: question
answer: answer
answer_only_loss_mask: False
# start_of_turn_token: "<|assistant|>"
path_or_dataset_id: /workspace/dataset-pipeline/data/nemo-dataset/train/train.jsonl
tokenizer:
_target_: nemo_automodel._transformers.auto_tokenizer.NeMoAutoTokenizer.from_pretrained
pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
trust_remote_code: true
packed_sequence:
packed_sequence_size: 0
dataloader:
_target_: torchdata.stateful_dataloader.StatefulDataLoader
collate_fn: nemo_automodel.components.datasets.utils.default_collater
shuffle: false
validation_dataset:
_target_: nemo_automodel.components.datasets.llm.column_mapped_text_instruction_dataset.ColumnMappedTextInstructionDataset
path_or_dataset_id: /workspace/dataset-pipeline/data/nemo-dataset/val/val.jsonl
column_mapping:
# context: definition
question: question
answer: answer
# answer_only_loss_mask: true
# start_of_turn_token: "<|assistant|>"
validation_dataloader:
_target_: torchdata.stateful_dataloader.StatefulDataLoader
collate_fn: nemo_automodel.components.datasets.utils.default_collater
optimizer:
_target_: torch.optim.Adam
betas: [0.9, 0.999]
eps: 1e-8
lr: 1.0e-5
weight_decay: 0
# min_lr: 1.0e-5
lr_scheduler:
lr_decay_style: cosine
min_lr: 1.0e-6
mlflow:
experiment_name: automodel-llm-llama3_2_1b_codegen-finetune
run_name: ''
tracking_uri: http://mlflow:5000
artifact_location: s3://mlflow/nemo-automodel
tags:
task: codegen-finetune
model_family: llama3.2
model_size: 1b
dataset: codegen
framework: automodel
checkpoint:
enabled: true
checkpoint_dir: /workspace/nemo-stack/codegen/checkpoints/llama-3.2-1b-instruct
model_save_format: safetensors
save_consolidated: false
I merge the adapaters using below code :-
def merge_adapter_with_base_model(
base_model_name: str,
adapter_path: str,
output_path: str,
save: bool = True
):
"""
Merge LoRA adapter with base model and optionally save merged model.
Args:
base_model_name: HuggingFace model ID or path to base model
adapter_path: Path to the LoRA adapter directory
output_path: Path where merged model will be saved
save: Whether to save the merged model to disk
Returns:
Tuple of (merged_model, tokenizer)
"""
print(f"\n{'='*80}")
print("🔧 MERGING LORA ADAPTER WITH BASE MODEL")
print(f"{'='*80}")
print(f"\n📦 Loading base model: {base_model_name}")
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True
)
print(f"📝 Loading tokenizer: {base_model_name}")
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# tokenizer.padding_side = 'left'
# if tokenizer.pad_token is None:
# tokenizer.pad_token = tokenizer.eos_token
print(f"\n🔌 Loading adapter from: {adapter_path}")
model = PeftModel.from_pretrained(model, adapter_path)
print("⚙️ Merging adapter weights with base model...")
merged_model = model.merge_and_unload()
if save:
print(f"\n💾 Saving merged model to: {output_path}")
os.makedirs(output_path, exist_ok=True)
merged_model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
print(f"✅ Merged model saved successfully!")
else:
print("⏭️ Skipping save (--no-save specified)")
return merged_model, tokenizer
Additional context
I think the issue is somehow related to tokenizer.
I expect that nemo-automodel will use the default chat-template available for llama3.2-1b
tokenizer:
_target_: nemo_automodel._transformers.auto_tokenizer.NeMoAutoTokenizer.from_pretrained
pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
trust_remote_code: true
I use the same chat_template while doing serving using VLLM.