Handle pure bodies in nnx.fori_loop #5141
Open
+66
−24
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes #5113
Currently,
nnx.fori_looponly works with loop bodies that perform some kind of mutation. This is because the the call to_add_fake_index_mapping(pure_init_val)in line 1625 offlax/nnx/transforms/iteration.pyadds anouter_indexattribute that won't be added in the corresponding pure body. To fix this, this PR checks what the output of the loop function looks like first. If it has an outer index, we add a fake index mapping. Otherwise, we leave it be.