-
Notifications
You must be signed in to change notification settings - Fork 503
transformers v5 support #1164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
speediedan
wants to merge
12
commits into
TransformerLensOrg:dev-3.x
Choose a base branch
from
speediedan:basic-transformers-v5-support
base: dev-3.x
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
transformers v5 support #1164
speediedan
wants to merge
12
commits into
TransformerLensOrg:dev-3.x
from
speediedan:basic-transformers-v5-support
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…xts.
We also update .gitignore to exclude .env (commonly used local file exclution), e.g. to allow collaborators to add their on HF_TOKEN for test suite
Core Fixes:
-----------
transformer_lens/components/abstract_attention.py:
- Replace pattern.to(self.cfg.dtype) with pattern.to(v.dtype) to handle cases
where tensors are upcast to float32 for numerical stability while cfg.dtype
remains float16/bfloat16
- Add explicit device/dtype synchronization for output projection:
* Move weights (W_O) and bias (b_O) to match input device (z.device)
* Ensure z matches weight dtype before final linear operation
transformer_lens/model_bridge/bridge.py:
- Replace direct original_model.to() call with move_to_and_update_config()
utility to ensure:
* All bridge components (not just original_model) are moved to target device
* cfg.device and cfg.dtype stay synchronized with actual model state
* Multi-GPU cache tensors remain on correct devices
Test Fixes:
-----------
tests/acceptance/test_hooked_encoder.py:
- Fix test_cuda() to use correct fixture name 'tokens' instead of 'mlm_tokens'
tests/acceptance/test_multi_gpu.py:
- Update test_cache_device() to pass torch.device("cpu") instead of string
"cpu" for proper device type validation
tests/unit/components/test_attention.py:
- Add test_attention_forward_half_precisions() to validate attention works
correctly with bfloat16/float16 dtypes on CUDA devices
tests/unit/factored_matrix/test_multiply_by_scalar.py:
- Add test IDs to parametrize decorators to avoid pytest cache issues when
random numbers appear in test names
Tests Fixed by This Commit:
---------------------------
- tests/acceptance/test_multi_gpu.py::test_cache_device
- tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_memory_efficiency[gpt2]
- tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_consistent_outputs[gpt2]
- tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype0]
- tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype1]
- tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype0]
- tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype1]
- tests/unit/model_bridge/compatibility/test_utils.py::TestUtilsWithTransformerBridge::test_device_compatibility[gpt2]
Enhance to() method to properly handle both device and dtype arguments in all supported PyTorch formats (positional, keyword, combined). Separately invoke move_to_and_update_config for device/dtype to update cfg while delegating the actual tensor movement to original_model.to() with original args/kwargs. This ensures TransformerBridge respects standard PyTorch behavior for model.to() calls.
Compatibility for transformers v5 and huggingface_hub v1.3.4 while maintaining backward compatibility with v4. **Handle API/Behavioral Changes:** - Handle batch_decode behavior change (wraps tokens for v4/v5 compatibility) - Add rotary_pct → rope_parameters['partial_rotary_factor'] migration helper - Fix BOS token handling for tokenizers without BOS (e.g., T5) - Update MoE router_scores shape expectations for compact top-k format - Add type casts for tokenizer.decode() return values **Code Changes:** - Add get_rotary_pct_from_config() utility for config v4/v5 compatibility - Wrap tokens for batch_decode in HookedTransformer, bridge, and notebooks - Add cast(str, ...) for decode() calls in generate() methods - Update test expectations for new router_scores shape - Add BOS token checks before setting add_bos_token=True **Infrastructure:** - Add pytest-rerunfailures dependency for flaky network tests (can be removed later once hub-related httpx read timeout issues are resolved) - Update dependencies: transformers 5.0.0, huggingface_hub 1.3.4 - Change HF cache to use HF_HUB_CACHE (TRANSFORMERS_CACHE removed in v5) - Update doctest to use range checks for numerical stability
…e httpx hub read timeouts affect both local and CI testing
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Add transformers v5.0.0 and huggingface_hub v1.3.4 compatibility
Firstly, thank you so much for building and maintaining TransformerLens - it's a seminal, foundationally valuable contribution to the open-source interpretability ecosystem!
This PR adds compatibility with
transformers>=5.0.0andhuggingface_hub>=1.3.4, addressing behavioral changes in the tokenizer API, model configuration formats, and MoE layer outputs while retaining backward compatibility withtransformersv4.x.The Issues
1.
batch_decodeBehavior ChangeIn transformers v5,
tokenizer.batch_decode()expects sequences (list of lists), not individual token IDs. When passing a 1D tensor of token IDs, it decodes them together into a single string rather than returning individual token strings.2.
rotary_pctConfiguration RenameThe
rotary_pctattribute on HuggingFace configs (e.g.,GPTNeoXConfig) has been moved torope_parameters['partial_rotary_factor']in v5.3. T5 Tokenizer BOS Token Error
Setting
add_bos_token=Truewhen loading a tokenizer now fails if the tokenizer has no BOS token (e.g., T5), causingValueError: add_bos_token = True but bos_token = None.4. MoE Router Scores Shape Change
In transformers v5, the PR huggingface/transformers#42456 (https://github.com/huggingface/transformers/blame/6316a9e176d22f8f09d44ef72ec0aaa2ce7b8780/src/transformers/models/gpt_oss/modular_gpt_oss.py#L163-L164) changed the
GptOssTopKRouter.forward()logic to remove the zeros_like and scatter operation, returning compact[seq_len, top_k]shaped router scores instead of scattered[seq_len, num_experts].v4.57.1 behavior (router returns shape
[seq_len, num_experts]):v5.0.0 behavior (router returns shape
[seq_len, num_experts_per_tok]):The Solutions
1. Wrap tokens for batch_decode (
transformer_lens/HookedTransformer.py,transformer_lens/model_bridge/bridge.py)In v5,
decode()is unified to handle batched input (detecting batches by checking if the first element is a list). However, v4'sdecode()only handles single sequences. Usingbatch_decode()with wrapped tokens provides cross-version compatibility:decode([token])for eachdecode(), detects batch format, processes correctly2. Shared config compatibility utility (
transformer_lens/utilities/hf_utils.py)Added
get_rotary_pct_from_config()helper function that handles both v4 and v5 config formats:3. BOS token existence check (
transformer_lens/utilities/tokenize_utils.py,transformer_lens/model_bridge/sources/transformers.py)4. Update MoE test expectations (
tests/unit/model_bridge/test_gpt_oss_moe.py)Updated router scores shape expectation from
(5, 32)to(5, 4)to match the new compact top-k format alluded to above. Note this test will only pass with transformers v5.x. We can consider version-conditional testing if v4 support is still needed here.5. Type cast decode() return values (
transformer_lens/HookedTransformer.py,transformer_lens/HookedEncoderDecoder.py)Added
cast(str, ...)aroundtokenizer.decode()calls ingenerate()methods to satisfy mypy, since single-sequence decoding always returnsstr:6. Update notebook batch_decode usage (
demos/ARENA_Content.ipynb)Wrapped tokens for v5 compatibility in the ARENA demo notebook:
7. Tokenizer decode() return type (
transformer_lens/HookedTransformer.py,transformer_lens/HookedEncoderDecoder.py)In transformers v5,
tokenizer.decode()has a union return typestr | list[str]to support both single and batched inputs. When decoding a single sequence, it returnsstr, but mypy cannot infer this statically, causing type errors in thegenerate()methods.6. Demo notebook batch_decode usage (
demos/ARENA_Content.ipynb)The ARENA_Content notebook passes a 1D tensor directly to
batch_decode(), which works in v4 but fails in v5 due to the same behavior change described above.7. Flaky test handling (
pyproject.toml,makefile)Added
pytest-rerunfailuresdependency to handle intermittent httpx network timeouts when downloading from HuggingFace Hub. AddedRERUN_ARGSvariable to makefile, applied to all test targets (unit, integration, acceptance, benchmark, coverage, and notebook tests):Files Changed
transformer_lens/utilities/hf_utils.pyget_rotary_pct_from_config()for HF config v4/v5 compatibilitytransformer_lens/utilities/__init__.pytransformer_lens/loading_from_pretrained.pytransformer_lens/model_bridge/generalized_components/attention.pytransformer_lens/HookedTransformer.pytransformer_lens/HookedEncoderDecoder.pytransformer_lens/model_bridge/bridge.pytransformer_lens/utilities/tokenize_utils.pytransformer_lens/model_bridge/sources/transformers.pytransformer_lens/evals.pytests/unit/model_bridge/test_gpt_oss_moe.pydemos/ARENA_Content.ipynbpyproject.tomlmakefileTesting
pytest-rerunfailuresplugin confirmed working (1 flaky test auto-retried)Key Migration References
Checklist