Skip to content

Conversation

@klei22
Copy link
Collaborator

@klei22 klei22 commented Oct 19, 2025

This pull request introduces scheduled per-batch mixing for multi-dataset training, allowing each batch to be composed of samples from multiple datasets according to configurable sampling probabilities. This enables more flexible and fine-grained control over dataset mixing during training, including smooth transitions between different proportions. Key changes include new demo scripts and config files, updates to argument parsing, and major modifications to the batching logic in train.py to support per-batch mixing and accurate tracking of training progress per dataset.

Per-batch mixing feature implementation:

  • Added the --dataset_mixing_per_batch argument to train_args.py, enabling per-batch mixing in multidataset mode.
  • Overhauled the batching logic in train.py to support scheduled per-batch mixing: each batch is now constructed from multiple datasets according to current sampling probabilities, with careful handling of rounding, sampling methods, and tensor assembly.
  • Implemented tracking of the number of samples from each dataset used in the current batch (self.current_batch_dataset_counts), ensuring accurate token and epoch accounting for each dataset during training. [1] [2] [3] [4]

Demo and configuration additions:

  • Added demos/multidataset_per_batch_mixing.sh to demonstrate scheduled per-batch mixing with two datasets and a linear transition of sampling probabilities.
  • Added explorations/multidataset_per_batch_mixing.yaml as a config file for quick exploration runs using the new per-batch mixing feature.

Type and import updates:

  • Updated type annotations and imports in train.py to support new features and improve code clarity.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces scheduled per-batch mixing for multi-dataset training, enabling each batch to contain samples from multiple datasets according to configurable sampling probabilities that can transition over time.

Key changes:

  • Added --dataset_mixing_per_batch argument to enable per-batch dataset mixing
  • Implemented per-batch mixing logic with proportional sampling, rounding, and shuffling
  • Added tracking of per-dataset sample counts within each batch for accurate progress metrics

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
train_args.py Adds new command-line argument to enable per-batch dataset mixing
train.py Implements core per-batch mixing logic and tracking infrastructure
explorations/multidataset_per_batch_mixing.yaml Provides configuration template for per-batch mixing experiments
demos/multidataset_per_batch_mixing.sh Demo script showing per-batch mixing with two datasets and linear probability transition

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

def get_batch(self, split, target_dataset=None):
dataset = None
data = None
self.current_batch_dataset_counts = None
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting self.current_batch_dataset_counts to None at the start of get_batch() means it will be None for all non-per-batch-mixing code paths. This causes the condition if not batch_counts: at line 1869 to fail when batch_counts is None (it should check if batch_counts is None:), potentially leading to an AttributeError when trying to iterate over None in the else block at line 1860.

Copilot uses AI. Check for mistakes.
# Compute epoch for logging:
if self.args.dataset_list:
if self.args.dataset_list:
if not batch_counts:
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition evaluates to True when batch_counts is None or an empty dict, but the code at line 1850 checks if batch_counts: which only executes when batch_counts is truthy (non-None and non-empty). These conditions should use explicit is None checks for clarity and correctness. Change to if batch_counts is None: to properly handle the case when per-batch mixing is not used.

Suggested change
if not batch_counts:
if batch_counts is None:

Copilot uses AI. Check for mistakes.
Comment on lines +776 to +792
elif self.args.sampling_method == "sequential":
if self.dataset_ptr[dataset_name] + count > available:
self.dataset_ptr[dataset_name] = 0
start = self.dataset_ptr[dataset_name]
end = start + count
indices = self.dataset_perm[dataset_name][start:end]
self.dataset_ptr[dataset_name] = end
return torch.tensor(indices)
elif self.args.sampling_method == "without_replacement":
if self.dataset_ptr[dataset_name] + count > available:
self.dataset_perm[dataset_name] = np.random.permutation(available)
self.dataset_ptr[dataset_name] = 0
start = self.dataset_ptr[dataset_name]
end = start + count
indices = self.dataset_perm[dataset_name][start:end]
self.dataset_ptr[dataset_name] = end
return torch.tensor(indices)
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function duplicates logic that likely exists elsewhere in the codebase for sampling indices. The sequential and without_replacement branches have identical implementations except for line 786 (permutation regeneration). Consider extracting and reusing existing sampling logic or consolidating the duplicated code within this function.

Suggested change
elif self.args.sampling_method == "sequential":
if self.dataset_ptr[dataset_name] + count > available:
self.dataset_ptr[dataset_name] = 0
start = self.dataset_ptr[dataset_name]
end = start + count
indices = self.dataset_perm[dataset_name][start:end]
self.dataset_ptr[dataset_name] = end
return torch.tensor(indices)
elif self.args.sampling_method == "without_replacement":
if self.dataset_ptr[dataset_name] + count > available:
self.dataset_perm[dataset_name] = np.random.permutation(available)
self.dataset_ptr[dataset_name] = 0
start = self.dataset_ptr[dataset_name]
end = start + count
indices = self.dataset_perm[dataset_name][start:end]
self.dataset_ptr[dataset_name] = end
return torch.tensor(indices)
elif self.args.sampling_method in ("sequential", "without_replacement"):
if self.dataset_ptr[dataset_name] + count > available:
if self.args.sampling_method == "without_replacement":
self.dataset_perm[dataset_name] = np.random.permutation(available)
self.dataset_ptr[dataset_name] = 0
start = self.dataset_ptr[dataset_name]
end = start + count
indices = self.dataset_perm[dataset_name][start:end]
self.dataset_ptr[dataset_name] = end
return torch.tensor(indices)

Copilot uses AI. Check for mistakes.
Comment on lines +796 to +805
def build_tensors(data_array, indices):
x_local = torch.stack([
torch.from_numpy(data_array[i:i+self.args.block_size].astype(np.int64))
for i in indices
])
y_local = torch.stack([
torch.from_numpy(data_array[i+1:i+1+self.args.block_size].astype(np.int64))
for i in indices
])
return x_local, y_local
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function duplicates tensor construction logic that likely exists elsewhere for batch creation. Consider extracting this into a reusable helper method to avoid code duplication and ensure consistency across different batching paths.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant