Skip to content

Conversation

@klei22
Copy link
Collaborator

@klei22 klei22 commented Jan 15, 2026

This pull request introduces a new utility script for merging two nanoGPT model checkpoints with flexible options for normalization and averaging. The main addition is the model_merge.py script, which supports L2-normalized merging, skipping normalization for specific layers, and a simple averaging mode. A demo shell script is also provided to illustrate usage.

New model merging functionality:

  • Added model_merge.py, a utility script for merging two nanoGPT checkpoints with options for L2 normalization, skipping final normalization for wte/lm_head weights, and simple averaging without normalization. The script handles key mismatches, shape validation, and preserves metadata.

Demo and usage examples:

  • Added demos/model_merge_demo.sh, a shell script demonstrating typical usage patterns for model_merge.py, including L2-normalized merge, skipping final normalization for specific layers, and simple averaging.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request introduces a utility script for merging two nanoGPT model checkpoints with flexible normalization options. The script supports L2-normalized merging (the default behavior), simple averaging without normalization, and an option to skip final normalization for embedding and language model head weights.

Changes:

  • Added model_merge.py utility script that merges two checkpoint files with configurable L2 normalization and averaging strategies
  • Added demos/model_merge_demo.sh demonstrating the three main usage patterns of the merge script

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
model_merge.py New utility for merging two nanoGPT checkpoints with L2 normalization, simple averaging, and selective layer normalization skipping
demos/model_merge_demo.sh Demo script illustrating typical merge operations with different normalization options

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

if isinstance(checkpoint_a, dict):
checkpoint_a["iter_num"] = 0
checkpoint_a["best_val_loss"] = 1e9
checkpoint_a["best_iter"] = 0
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code resets best_iter to 0 but doesn't reset best_tokens which is also tracked in checkpoints according to train.py. This creates an inconsistency where best_tokens from checkpoint_a would be preserved while best_iter and best_val_loss are reset. Either add checkpoint_a['best_tokens'] = 0 after line 142, or if this field is not guaranteed to exist, use checkpoint_a.pop('best_tokens', None) after line 138.

Suggested change
checkpoint_a["best_iter"] = 0
checkpoint_a["best_iter"] = 0
checkpoint_a["best_tokens"] = 0

Copilot uses AI. Check for mistakes.
Comment on lines +137 to +138
checkpoint_a.pop("optimizer", None)
checkpoint_a.pop("scheduler", None)
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code mutates checkpoint_a directly after lines 132-135 potentially create a new dictionary assignment. If checkpoint_a is reassigned to be just the state_dict (line 135), the .pop() calls will fail since a state_dict doesn't have 'optimizer' or 'scheduler' keys. Consider handling this more explicitly by checking isinstance(checkpoint_a, dict) before attempting to pop these keys, similar to the check on line 139.

Copilot uses AI. Check for mistakes.
Comment on lines +60 to +64
def l2_normalize(tensor: torch.Tensor, dim: int = L2_NORM_DIM) -> torch.Tensor:
if tensor.ndim == 0:
return tensor
if tensor.ndim == 1:
dim = 0
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The l2_normalize function has special handling for scalar (0-dim) and vector (1-dim) tensors but lacks documentation explaining this behavior. Add a docstring documenting that: (1) scalar tensors are returned unchanged, (2) 1-dim tensors normalize along dim=0, and (3) higher-dim tensors use the provided dim parameter (default=-1).

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant