-
Notifications
You must be signed in to change notification settings - Fork 27
Allow multidataset and multitokenization training to co-occur within a single iteration mixture #658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces scheduled per-batch mixing for multi-dataset training, enabling each batch to contain samples from multiple datasets according to configurable sampling probabilities that can transition over time.
Key changes:
- Added
--dataset_mixing_per_batchargument to enable per-batch dataset mixing - Implemented per-batch mixing logic with proportional sampling, rounding, and shuffling
- Added tracking of per-dataset sample counts within each batch for accurate progress metrics
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| train_args.py | Adds new command-line argument to enable per-batch dataset mixing |
| train.py | Implements core per-batch mixing logic and tracking infrastructure |
| explorations/multidataset_per_batch_mixing.yaml | Provides configuration template for per-batch mixing experiments |
| demos/multidataset_per_batch_mixing.sh | Demo script showing per-batch mixing with two datasets and linear probability transition |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| def get_batch(self, split, target_dataset=None): | ||
| dataset = None | ||
| data = None | ||
| self.current_batch_dataset_counts = None |
Copilot
AI
Oct 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting self.current_batch_dataset_counts to None at the start of get_batch() means it will be None for all non-per-batch-mixing code paths. This causes the condition if not batch_counts: at line 1869 to fail when batch_counts is None (it should check if batch_counts is None:), potentially leading to an AttributeError when trying to iterate over None in the else block at line 1860.
| # Compute epoch for logging: | ||
| if self.args.dataset_list: | ||
| if self.args.dataset_list: | ||
| if not batch_counts: |
Copilot
AI
Oct 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This condition evaluates to True when batch_counts is None or an empty dict, but the code at line 1850 checks if batch_counts: which only executes when batch_counts is truthy (non-None and non-empty). These conditions should use explicit is None checks for clarity and correctness. Change to if batch_counts is None: to properly handle the case when per-batch mixing is not used.
| if not batch_counts: | |
| if batch_counts is None: |
| elif self.args.sampling_method == "sequential": | ||
| if self.dataset_ptr[dataset_name] + count > available: | ||
| self.dataset_ptr[dataset_name] = 0 | ||
| start = self.dataset_ptr[dataset_name] | ||
| end = start + count | ||
| indices = self.dataset_perm[dataset_name][start:end] | ||
| self.dataset_ptr[dataset_name] = end | ||
| return torch.tensor(indices) | ||
| elif self.args.sampling_method == "without_replacement": | ||
| if self.dataset_ptr[dataset_name] + count > available: | ||
| self.dataset_perm[dataset_name] = np.random.permutation(available) | ||
| self.dataset_ptr[dataset_name] = 0 | ||
| start = self.dataset_ptr[dataset_name] | ||
| end = start + count | ||
| indices = self.dataset_perm[dataset_name][start:end] | ||
| self.dataset_ptr[dataset_name] = end | ||
| return torch.tensor(indices) |
Copilot
AI
Oct 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function duplicates logic that likely exists elsewhere in the codebase for sampling indices. The sequential and without_replacement branches have identical implementations except for line 786 (permutation regeneration). Consider extracting and reusing existing sampling logic or consolidating the duplicated code within this function.
| elif self.args.sampling_method == "sequential": | |
| if self.dataset_ptr[dataset_name] + count > available: | |
| self.dataset_ptr[dataset_name] = 0 | |
| start = self.dataset_ptr[dataset_name] | |
| end = start + count | |
| indices = self.dataset_perm[dataset_name][start:end] | |
| self.dataset_ptr[dataset_name] = end | |
| return torch.tensor(indices) | |
| elif self.args.sampling_method == "without_replacement": | |
| if self.dataset_ptr[dataset_name] + count > available: | |
| self.dataset_perm[dataset_name] = np.random.permutation(available) | |
| self.dataset_ptr[dataset_name] = 0 | |
| start = self.dataset_ptr[dataset_name] | |
| end = start + count | |
| indices = self.dataset_perm[dataset_name][start:end] | |
| self.dataset_ptr[dataset_name] = end | |
| return torch.tensor(indices) | |
| elif self.args.sampling_method in ("sequential", "without_replacement"): | |
| if self.dataset_ptr[dataset_name] + count > available: | |
| if self.args.sampling_method == "without_replacement": | |
| self.dataset_perm[dataset_name] = np.random.permutation(available) | |
| self.dataset_ptr[dataset_name] = 0 | |
| start = self.dataset_ptr[dataset_name] | |
| end = start + count | |
| indices = self.dataset_perm[dataset_name][start:end] | |
| self.dataset_ptr[dataset_name] = end | |
| return torch.tensor(indices) |
| def build_tensors(data_array, indices): | ||
| x_local = torch.stack([ | ||
| torch.from_numpy(data_array[i:i+self.args.block_size].astype(np.int64)) | ||
| for i in indices | ||
| ]) | ||
| y_local = torch.stack([ | ||
| torch.from_numpy(data_array[i+1:i+1+self.args.block_size].astype(np.int64)) | ||
| for i in indices | ||
| ]) | ||
| return x_local, y_local |
Copilot
AI
Oct 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function duplicates tensor construction logic that likely exists elsewhere for batch creation. Consider extracting this into a reusable helper method to avoid code duplication and ensure consistency across different batching paths.
This pull request introduces scheduled per-batch mixing for multi-dataset training, allowing each batch to be composed of samples from multiple datasets according to configurable sampling probabilities. This enables more flexible and fine-grained control over dataset mixing during training, including smooth transitions between different proportions. Key changes include new demo scripts and config files, updates to argument parsing, and major modifications to the batching logic in
train.pyto support per-batch mixing and accurate tracking of training progress per dataset.Per-batch mixing feature implementation:
--dataset_mixing_per_batchargument totrain_args.py, enabling per-batch mixing in multidataset mode.train.pyto support scheduled per-batch mixing: each batch is now constructed from multiple datasets according to current sampling probabilities, with careful handling of rounding, sampling methods, and tensor assembly.self.current_batch_dataset_counts), ensuring accurate token and epoch accounting for each dataset during training. [1] [2] [3] [4]Demo and configuration additions:
demos/multidataset_per_batch_mixing.shto demonstrate scheduled per-batch mixing with two datasets and a linear transition of sampling probabilities.explorations/multidataset_per_batch_mixing.yamlas a config file for quick exploration runs using the new per-batch mixing feature.Type and import updates:
train.pyto support new features and improve code clarity.