From c4f83a92e76f53557a80e6f741b111271b3dfada Mon Sep 17 00:00:00 2001 From: Yuhe Zhang Date: Thu, 4 Dec 2025 18:18:03 -0500 Subject: [PATCH] Updates biencoder training to utilize all available positive documents for a query by cycling through them based on the current epoch, instead of always using the first one Signed-off-by: Yuhe Zhang --- .../datasets/llm/retrieval_dataset.py | 31 +++++++-- .../recipes/biencoder/train_biencoder.py | 6 ++ .../datasets/llm/test_retrieval_dataset.py | 65 +++++++++++++++++++ 3 files changed, 98 insertions(+), 4 deletions(-) diff --git a/nemo_automodel/components/datasets/llm/retrieval_dataset.py b/nemo_automodel/components/datasets/llm/retrieval_dataset.py index 4e128a007..58a1e1752 100644 --- a/nemo_automodel/components/datasets/llm/retrieval_dataset.py +++ b/nemo_automodel/components/datasets/llm/retrieval_dataset.py @@ -165,7 +165,7 @@ def _transform_func(examples, num_neg_docs, corpus_dict): Same as _format_process_data in RetrievalMultiModalDatasetLoader. Args: - examples: Batch of examples with question, corpus_id, pos_doc, neg_doc + examples: Batch of examples with question, corpus_id, pos_doc, neg_doc, and optional epoch for cycling through positive documents num_neg_docs: Number of negative documents to use corpus_dict: Dictionary mapping corpus_id to corpus objects """ @@ -180,6 +180,7 @@ def _transform_func(examples, num_neg_docs, corpus_dict): corpus_ids = examples["corpus_id"] batch_positives = examples["pos_doc"] batch_negatives = examples["neg_doc"] + epoch = examples.get("epoch", 0) # Get epoch from examples if present, else default to 0 # Check if we have enough negatives if num_neg_docs > len(batch_negatives[0]): @@ -190,10 +191,10 @@ def _transform_func(examples, num_neg_docs, corpus_dict): for i_example in range(len(questions)): cur_pos_neg_doc = [] - # Get one positive doc (take first one) + # Get one positive doc (cycle through positives based on epoch) positives = batch_positives[i_example] if isinstance(positives, list) and len(positives) > 0: - cur_pos_neg_doc.append(positives[0]) + cur_pos_neg_doc.append(positives[epoch % len(positives)]) else: cur_pos_neg_doc.append(positives) @@ -246,10 +247,12 @@ def _transform_func(examples, num_neg_docs, corpus_dict): return result -def _create_transform_func(num_neg_docs, corpus_dict): +def _create_transform_func(num_neg_docs, corpus_dict, epoch=0): """Create transform function with specified number of negative documents.""" def transform(examples): + # Inject epoch into examples so _transform_func can use it + examples["epoch"] = epoch return _transform_func(examples, num_neg_docs=num_neg_docs, corpus_dict=corpus_dict) return transform @@ -313,6 +316,10 @@ def make_retrieval_dataset( negative_size = train_n_passages - 1 dataset.set_transform(_create_transform_func(negative_size, corpus_dict)) + # Store metadata for updating transform later + dataset.corpus_dict = corpus_dict + dataset.num_neg_docs = negative_size + elif data_type == "eval": # Set transform for evaluation dataset.set_transform(_create_transform_func(eval_negative_size, corpus_dict)) @@ -325,6 +332,22 @@ def make_retrieval_dataset( return dataset +def update_dataset_epoch(dataset, epoch): + """ + Update the dataset transform to use the specified epoch for positive document selection. + + Args: + dataset: The HuggingFace dataset + epoch: The new epoch number + """ + if hasattr(dataset, "corpus_dict") and hasattr(dataset, "num_neg_docs"): + dataset.set_transform(_create_transform_func(dataset.num_neg_docs, dataset.corpus_dict, epoch=epoch)) + else: + # If metadata is missing (e.g. eval dataset or loaded differently), we can't update. + # This is expected for eval datasets or if make_retrieval_dataset wasn't used. + pass + + if __name__ == "__main__": import argparse diff --git a/nemo_automodel/recipes/biencoder/train_biencoder.py b/nemo_automodel/recipes/biencoder/train_biencoder.py index e2f18d431..c795e6aee 100644 --- a/nemo_automodel/recipes/biencoder/train_biencoder.py +++ b/nemo_automodel/recipes/biencoder/train_biencoder.py @@ -40,6 +40,7 @@ from nemo_automodel.components.training.step_scheduler import StepScheduler from nemo_automodel.components.training.utils import scale_grads_and_clip_grad_norm from nemo_automodel.recipes.base_recipe import BaseRecipe +from nemo_automodel.components.datasets.llm.retrieval_dataset import update_dataset_epoch if TYPE_CHECKING: from nemo_automodel.components.distributed.init_utils import DistInfo @@ -511,6 +512,11 @@ def run_train_validation_loop(self): for epoch in self.step_scheduler.epochs: self.step_scheduler.set_epoch(epoch) + + # Update dataset epoch for positive document cycling + if hasattr(self.dataloader, "dataset"): + update_dataset_epoch(self.dataloader.dataset, epoch) + # The step scheduler yields a list of batches for gradient accumulation for batches in self.step_scheduler: train_log_data = self._run_train_optim_step(batches, 1.0) diff --git a/tests/unit_tests/datasets/llm/test_retrieval_dataset.py b/tests/unit_tests/datasets/llm/test_retrieval_dataset.py index 8307dd868..65c5061cd 100644 --- a/tests/unit_tests/datasets/llm/test_retrieval_dataset.py +++ b/tests/unit_tests/datasets/llm/test_retrieval_dataset.py @@ -385,3 +385,68 @@ def test_make_retrieval_dataset_invalid_type(tmp_path, monkeypatch): rd.make_retrieval_dataset(str(train_file), data_type="invalid") +def test_transform_func_epoch_cycling(): + corpus_dict = { + "corpusA": DummyCorpus( + { + "p1": {"text": "pos1", "image": "", "nr_ocr": ""}, + "p2": {"text": "pos2", "image": "", "nr_ocr": ""}, + "p3": {"text": "pos3", "image": "", "nr_ocr": ""}, + "n1": {"text": "neg1", "image": "", "nr_ocr": ""}, + } + ) + } + + # Example with multiple positive docs + examples = { + "question": ["Q"], + "corpus_id": ["corpusA"], + "pos_doc": [[{"id": "p1"}, {"id": "p2"}, {"id": "p3"}]], + "neg_doc": [[{"id": "n1"}]], + } + + # Epoch 0: Should select first positive (p1) + examples["epoch"] = 0 + out_0 = rd._transform_func(examples, num_neg_docs=1, corpus_dict=corpus_dict) + assert out_0["doc_text"][0][0] == "pos1" + + # Epoch 1: Should select second positive (p2) + examples["epoch"] = 1 + out_1 = rd._transform_func(examples, num_neg_docs=1, corpus_dict=corpus_dict) + assert out_1["doc_text"][0][0] == "pos2" + + # Epoch 2: Should select third positive (p3) + examples["epoch"] = 2 + out_2 = rd._transform_func(examples, num_neg_docs=1, corpus_dict=corpus_dict) + assert out_2["doc_text"][0][0] == "pos3" + + # Epoch 3: Should cycle back to first positive (p1) + examples["epoch"] = 3 + out_3 = rd._transform_func(examples, num_neg_docs=1, corpus_dict=corpus_dict) + assert out_3["doc_text"][0][0] == "pos1" + + # Test update_dataset_epoch helper + # Create a dummy dataset with metadata + dataset = Dataset.from_list([ + { + "question_id": "q1", + "question": "Q", + "corpus_id": "corpusA", + "pos_doc": [{"id": "p1"}, {"id": "p2"}], + "neg_doc": [{"id": "n1"}] + } + ]) + dataset.corpus_dict = corpus_dict + dataset.num_neg_docs = 1 + + # Initial transform (epoch 0 default) + rd.update_dataset_epoch(dataset, epoch=0) + # Verify transform is set (we can't easily inspect the closure, but we can run it) + # The dataset transform is applied when accessing items + item_0 = dataset[0] + assert item_0["doc_text"][0] == "pos1" + + # Update to epoch 1 + rd.update_dataset_epoch(dataset, epoch=1) + item_1 = dataset[0] + assert item_1["doc_text"][0] == "pos2"