Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions nemo_automodel/components/datasets/llm/retrieval_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _transform_func(examples, num_neg_docs, corpus_dict):
Same as _format_process_data in RetrievalMultiModalDatasetLoader.

Args:
examples: Batch of examples with question, corpus_id, pos_doc, neg_doc
examples: Batch of examples with question, corpus_id, pos_doc, neg_doc, and optional epoch for cycling through positive documents
num_neg_docs: Number of negative documents to use
corpus_dict: Dictionary mapping corpus_id to corpus objects
"""
Expand All @@ -180,6 +180,7 @@ def _transform_func(examples, num_neg_docs, corpus_dict):
corpus_ids = examples["corpus_id"]
batch_positives = examples["pos_doc"]
batch_negatives = examples["neg_doc"]
epoch = examples.get("epoch", 0) # Get epoch from examples if present, else default to 0

# Check if we have enough negatives
if num_neg_docs > len(batch_negatives[0]):
Expand All @@ -190,10 +191,10 @@ def _transform_func(examples, num_neg_docs, corpus_dict):
for i_example in range(len(questions)):
cur_pos_neg_doc = []

# Get one positive doc (take first one)
# Get one positive doc (cycle through positives based on epoch)
positives = batch_positives[i_example]
if isinstance(positives, list) and len(positives) > 0:
cur_pos_neg_doc.append(positives[0])
cur_pos_neg_doc.append(positives[epoch % len(positives)])
else:
cur_pos_neg_doc.append(positives)

Expand Down Expand Up @@ -246,10 +247,12 @@ def _transform_func(examples, num_neg_docs, corpus_dict):
return result


def _create_transform_func(num_neg_docs, corpus_dict):
def _create_transform_func(num_neg_docs, corpus_dict, epoch=0):
"""Create transform function with specified number of negative documents."""

def transform(examples):
# Inject epoch into examples so _transform_func can use it
examples["epoch"] = epoch
return _transform_func(examples, num_neg_docs=num_neg_docs, corpus_dict=corpus_dict)

return transform
Expand Down Expand Up @@ -313,6 +316,10 @@ def make_retrieval_dataset(
negative_size = train_n_passages - 1
dataset.set_transform(_create_transform_func(negative_size, corpus_dict))

# Store metadata for updating transform later
dataset.corpus_dict = corpus_dict
dataset.num_neg_docs = negative_size

elif data_type == "eval":
# Set transform for evaluation
dataset.set_transform(_create_transform_func(eval_negative_size, corpus_dict))
Expand All @@ -325,6 +332,22 @@ def make_retrieval_dataset(
return dataset


def update_dataset_epoch(dataset, epoch):
"""
Update the dataset transform to use the specified epoch for positive document selection.

Args:
dataset: The HuggingFace dataset
epoch: The new epoch number
"""
if hasattr(dataset, "corpus_dict") and hasattr(dataset, "num_neg_docs"):
dataset.set_transform(_create_transform_func(dataset.num_neg_docs, dataset.corpus_dict, epoch=epoch))
else:
# If metadata is missing (e.g. eval dataset or loaded differently), we can't update.
# This is expected for eval datasets or if make_retrieval_dataset wasn't used.
pass


if __name__ == "__main__":
import argparse

Expand Down
6 changes: 6 additions & 0 deletions nemo_automodel/recipes/biencoder/train_biencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from nemo_automodel.components.training.step_scheduler import StepScheduler
from nemo_automodel.components.training.utils import scale_grads_and_clip_grad_norm
from nemo_automodel.recipes.base_recipe import BaseRecipe
from nemo_automodel.components.datasets.llm.retrieval_dataset import update_dataset_epoch

if TYPE_CHECKING:
from nemo_automodel.components.distributed.init_utils import DistInfo
Expand Down Expand Up @@ -511,6 +512,11 @@ def run_train_validation_loop(self):

for epoch in self.step_scheduler.epochs:
self.step_scheduler.set_epoch(epoch)

# Update dataset epoch for positive document cycling
if hasattr(self.dataloader, "dataset"):
update_dataset_epoch(self.dataloader.dataset, epoch)

# The step scheduler yields a list of batches for gradient accumulation
for batches in self.step_scheduler:
train_log_data = self._run_train_optim_step(batches, 1.0)
Expand Down
65 changes: 65 additions & 0 deletions tests/unit_tests/datasets/llm/test_retrieval_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,68 @@ def test_make_retrieval_dataset_invalid_type(tmp_path, monkeypatch):
rd.make_retrieval_dataset(str(train_file), data_type="invalid")


def test_transform_func_epoch_cycling():
corpus_dict = {
"corpusA": DummyCorpus(
{
"p1": {"text": "pos1", "image": "", "nr_ocr": ""},
"p2": {"text": "pos2", "image": "", "nr_ocr": ""},
"p3": {"text": "pos3", "image": "", "nr_ocr": ""},
"n1": {"text": "neg1", "image": "", "nr_ocr": ""},
}
)
}

# Example with multiple positive docs
examples = {
"question": ["Q"],
"corpus_id": ["corpusA"],
"pos_doc": [[{"id": "p1"}, {"id": "p2"}, {"id": "p3"}]],
"neg_doc": [[{"id": "n1"}]],
}

# Epoch 0: Should select first positive (p1)
examples["epoch"] = 0
out_0 = rd._transform_func(examples, num_neg_docs=1, corpus_dict=corpus_dict)
assert out_0["doc_text"][0][0] == "pos1"

# Epoch 1: Should select second positive (p2)
examples["epoch"] = 1
out_1 = rd._transform_func(examples, num_neg_docs=1, corpus_dict=corpus_dict)
assert out_1["doc_text"][0][0] == "pos2"

# Epoch 2: Should select third positive (p3)
examples["epoch"] = 2
out_2 = rd._transform_func(examples, num_neg_docs=1, corpus_dict=corpus_dict)
assert out_2["doc_text"][0][0] == "pos3"

# Epoch 3: Should cycle back to first positive (p1)
examples["epoch"] = 3
out_3 = rd._transform_func(examples, num_neg_docs=1, corpus_dict=corpus_dict)
assert out_3["doc_text"][0][0] == "pos1"

# Test update_dataset_epoch helper
# Create a dummy dataset with metadata
dataset = Dataset.from_list([
{
"question_id": "q1",
"question": "Q",
"corpus_id": "corpusA",
"pos_doc": [{"id": "p1"}, {"id": "p2"}],
"neg_doc": [{"id": "n1"}]
}
])
dataset.corpus_dict = corpus_dict
dataset.num_neg_docs = 1

# Initial transform (epoch 0 default)
rd.update_dataset_epoch(dataset, epoch=0)
# Verify transform is set (we can't easily inspect the closure, but we can run it)
# The dataset transform is applied when accessing items
item_0 = dataset[0]
assert item_0["doc_text"][0] == "pos1"

# Update to epoch 1
rd.update_dataset_epoch(dataset, epoch=1)
item_1 = dataset[0]
assert item_1["doc_text"][0] == "pos2"
Loading