-
Notifications
You must be signed in to change notification settings - Fork 27
Add model merge #721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add model merge #721
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request introduces a utility script for merging two nanoGPT model checkpoints with flexible normalization options. The script supports L2-normalized merging (the default behavior), simple averaging without normalization, and an option to skip final normalization for embedding and language model head weights.
Changes:
- Added
model_merge.pyutility script that merges two checkpoint files with configurable L2 normalization and averaging strategies - Added
demos/model_merge_demo.shdemonstrating the three main usage patterns of the merge script
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| model_merge.py | New utility for merging two nanoGPT checkpoints with L2 normalization, simple averaging, and selective layer normalization skipping |
| demos/model_merge_demo.sh | Demo script illustrating typical merge operations with different normalization options |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if isinstance(checkpoint_a, dict): | ||
| checkpoint_a["iter_num"] = 0 | ||
| checkpoint_a["best_val_loss"] = 1e9 | ||
| checkpoint_a["best_iter"] = 0 |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code resets best_iter to 0 but doesn't reset best_tokens which is also tracked in checkpoints according to train.py. This creates an inconsistency where best_tokens from checkpoint_a would be preserved while best_iter and best_val_loss are reset. Either add checkpoint_a['best_tokens'] = 0 after line 142, or if this field is not guaranteed to exist, use checkpoint_a.pop('best_tokens', None) after line 138.
| checkpoint_a["best_iter"] = 0 | |
| checkpoint_a["best_iter"] = 0 | |
| checkpoint_a["best_tokens"] = 0 |
| checkpoint_a.pop("optimizer", None) | ||
| checkpoint_a.pop("scheduler", None) |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code mutates checkpoint_a directly after lines 132-135 potentially create a new dictionary assignment. If checkpoint_a is reassigned to be just the state_dict (line 135), the .pop() calls will fail since a state_dict doesn't have 'optimizer' or 'scheduler' keys. Consider handling this more explicitly by checking isinstance(checkpoint_a, dict) before attempting to pop these keys, similar to the check on line 139.
| def l2_normalize(tensor: torch.Tensor, dim: int = L2_NORM_DIM) -> torch.Tensor: | ||
| if tensor.ndim == 0: | ||
| return tensor | ||
| if tensor.ndim == 1: | ||
| dim = 0 |
Copilot
AI
Jan 15, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The l2_normalize function has special handling for scalar (0-dim) and vector (1-dim) tensors but lacks documentation explaining this behavior. Add a docstring documenting that: (1) scalar tensors are returned unchanged, (2) 1-dim tensors normalize along dim=0, and (3) higher-dim tensors use the provided dim parameter (default=-1).
This pull request introduces a new utility script for merging two nanoGPT model checkpoints with flexible options for normalization and averaging. The main addition is the
model_merge.pyscript, which supports L2-normalized merging, skipping normalization for specific layers, and a simple averaging mode. A demo shell script is also provided to illustrate usage.New model merging functionality:
model_merge.py, a utility script for merging two nanoGPT checkpoints with options for L2 normalization, skipping final normalization forwte/lm_headweights, and simple averaging without normalization. The script handles key mismatches, shape validation, and preserves metadata.Demo and usage examples:
demos/model_merge_demo.sh, a shell script demonstrating typical usage patterns formodel_merge.py, including L2-normalized merge, skipping final normalization for specific layers, and simple averaging.