-
Notifications
You must be signed in to change notification settings - Fork 49
feat: Enable cycling through all positive documents in biencoder training #907 #933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: Enable cycling through all positive documents in biencoder training #907 #933
Conversation
…s for a query by cycling through them based on the current epoch, instead of always using the first one Signed-off-by: Yuhe Zhang <yuhe@polarr.co>
|
/ok to test c4f83a9 |
@akoumpa, there was an error processing your request: See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/ |
|
/ok to test f57eaa4 |
|
Hi @akoumpa , thanks for triggering CICD test on this PR. I just noticed there is a newer draft PR (#937) working on the same multi-positive cycling behavior for biencoder training, but using a different implementation:
I wanted to mention this in case it helps decide which approach fits Automodel better. I’m happy to update this PR or align with whichever direction the team prefers. |
|
Apologies for the delay @yuhezhang-ai, will let @shan-nvidia to guide the decision further here. |
|
@yuhezhang-ai Thanks for this PR I have a suggestion to convert the transform into a stateful class to keep track of the epoch state. The current implementation calls class RetrievalTransform:
"""Transform for retrieval datasets with epoch-based positive cycling."""
def __init__(self, num_neg_docs: int, corpus_dict: dict):
self.num_neg_docs = num_neg_docs
self.corpus_dict = corpus_dict
self.epoch = 0
def __call__(self, examples):
return _transform_func(
examples,
num_neg_docs=self.num_neg_docs,
corpus_dict=self.corpus_dict,
epoch=self.epoch, # Pass as parameter, don't inject into examples
)
def set_epoch(self, epoch: int):
self.epoch = epochThen in transform = RetrievalTransform(negative_size, corpus_dict)
dataset.set_transform(transform) # Called once
dataset.set_epoch = transform.set_epoch # Expose on datasetwith the following properties:
Training loop becomes: if hasattr(self.dataloader.dataset, "set_epoch"):
self.dataloader.dataset.set_epoch(epoch) |
What does this PR do ?
This PR updates the biencoder training recipe to utilize all available positive documents for a given query. Previously, only the first positive document was used. Now, the training loop cycles through the list of positive documents across epochs using a modulo operation (e.g., Epoch 0 uses doc 0, Epoch 1 uses doc 1, etc.).
Changelog
retrieval_dataset.pyto accept an epoch argument in the transform function.update_dataset_epochhelper to update the dataset transform.train_biencoder.pyto callupdate_dataset_epochat the start of each epoch.Before your PR is "Ready for review"
Pre checks:
Additional Information