Skip to content

Conversation

@speediedan
Copy link

@speediedan speediedan commented Jan 29, 2026

Add transformers v5.0.0 and huggingface_hub v1.3.4 compatibility

Firstly, thank you so much for building and maintaining TransformerLens - it's a seminal, foundationally valuable contribution to the open-source interpretability ecosystem!

This PR adds compatibility with transformers>=5.0.0 and huggingface_hub>=1.3.4, addressing behavioral changes in the tokenizer API, model configuration formats, and MoE layer outputs while retaining backward compatibility with transformers v4.x.

The Issues

1. batch_decode Behavior Change

In transformers v5, tokenizer.batch_decode() expects sequences (list of lists), not individual token IDs. When passing a 1D tensor of token IDs, it decodes them together into a single string rather than returning individual token strings.

2. rotary_pct Configuration Rename

The rotary_pct attribute on HuggingFace configs (e.g., GPTNeoXConfig) has been moved to rope_parameters['partial_rotary_factor'] in v5.

3. T5 Tokenizer BOS Token Error

Setting add_bos_token=True when loading a tokenizer now fails if the tokenizer has no BOS token (e.g., T5), causing ValueError: add_bos_token = True but bos_token = None.

4. MoE Router Scores Shape Change

In transformers v5, the PR huggingface/transformers#42456 (https://github.com/huggingface/transformers/blame/6316a9e176d22f8f09d44ef72ec0aaa2ce7b8780/src/transformers/models/gpt_oss/modular_gpt_oss.py#L163-L164) changed the GptOssTopKRouter.forward() logic to remove the zeros_like and scatter operation, returning compact [seq_len, top_k] shaped router scores instead of scattered [seq_len, num_experts].

v4.57.1 behavior (router returns shape [seq_len, num_experts]):

# v4.57.1 GptOssTopKRouter.forward()
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices  # router_scores: [seq_len, num_experts]

v5.0.0 behavior (router returns shape [seq_len, num_experts_per_tok]):

# v5.0.0 GptOssTopKRouter.forward()
router_scores = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
return router_logits, router_scores, router_indices  # router_scores: [seq_len, top_k]

The Solutions

1. Wrap tokens for batch_decode (transformer_lens/HookedTransformer.py, transformer_lens/model_bridge/bridge.py)

In v5, decode() is unified to handle batched input (detecting batches by checking if the first element is a list). However, v4's decode() only handles single sequences. Using batch_decode() with wrapped tokens provides cross-version compatibility:

  • v4: iterates over outer list, calling decode([token]) for each
  • v5: forwards to unified decode(), detects batch format, processes correctly
# Wrap each token ID in a list for v4/v5 compatibility
tokens_list = [[int(t)] for t in tokens.tolist()]
str_tokens = self.tokenizer.batch_decode(tokens_list, clean_up_tokenization_spaces=False)

2. Shared config compatibility utility (transformer_lens/utilities/hf_utils.py)

Added get_rotary_pct_from_config() helper function that handles both v4 and v5 config formats:

def get_rotary_pct_from_config(config: Any) -> float:
    """Get the rotary percentage from a config object.
    Handles both transformers v4 (rotary_pct) and v5 (rope_parameters['partial_rotary_factor']).
    """
    if hasattr(config, "rotary_pct"):
        return getattr(config, "rotary_pct", 1.0)
    if hasattr(config, "rope_parameters"):
        rope_params = getattr(config, "rope_parameters", None)
        if isinstance(rope_params, dict) and "partial_rotary_factor" in rope_params:
            return rope_params["partial_rotary_factor"]
    return 1.0

3. BOS token existence check (transformer_lens/utilities/tokenize_utils.py, transformer_lens/model_bridge/sources/transformers.py)

# Only set add_bos_token=True if tokenizer actually has a BOS token
if tokenizer.bos_token is None:
    return tokenizer

4. Update MoE test expectations (tests/unit/model_bridge/test_gpt_oss_moe.py)

Updated router scores shape expectation from (5, 32) to (5, 4) to match the new compact top-k format alluded to above. Note this test will only pass with transformers v5.x. We can consider version-conditional testing if v4 support is still needed here.

5. Type cast decode() return values (transformer_lens/HookedTransformer.py, transformer_lens/HookedEncoderDecoder.py)

Added cast(str, ...) around tokenizer.decode() calls in generate() methods to satisfy mypy, since single-sequence decoding always returns str:

from typing import cast
# ...
return cast(str, self.tokenizer.decode(decoder_input[0], skip_special_tokens=True))

6. Update notebook batch_decode usage (demos/ARENA_Content.ipynb)

Wrapped tokens for v5 compatibility in the ARENA demo notebook:

# Wrap each token in a list for v4/v5 batch_decode compatibility
reference_gpt2.tokenizer.batch_decode([[int(t)] for t in logits.argmax(dim=-1)[0].tolist()])

7. Tokenizer decode() return type (transformer_lens/HookedTransformer.py, transformer_lens/HookedEncoderDecoder.py)

In transformers v5, tokenizer.decode() has a union return type str | list[str] to support both single and batched inputs. When decoding a single sequence, it returns str, but mypy cannot infer this statically, causing type errors in the generate() methods.

6. Demo notebook batch_decode usage (demos/ARENA_Content.ipynb)

The ARENA_Content notebook passes a 1D tensor directly to batch_decode(), which works in v4 but fails in v5 due to the same behavior change described above.

7. Flaky test handling (pyproject.toml, makefile)

Added pytest-rerunfailures dependency to handle intermittent httpx network timeouts when downloading from HuggingFace Hub. Added RERUN_ARGS variable to makefile, applied to all test targets (unit, integration, acceptance, benchmark, coverage, and notebook tests):

# Rerun args for flaky tests (httpx timeouts during HF Hub downloads)
# Remove this line when no longer needed
RERUN_ARGS := --reruns 2 --reruns-delay 5

Files Changed

File Change
transformer_lens/utilities/hf_utils.py Added get_rotary_pct_from_config() for HF config v4/v5 compatibility
transformer_lens/utilities/__init__.py Export new utility
transformer_lens/loading_from_pretrained.py Import shared utility
transformer_lens/model_bridge/generalized_components/attention.py Import shared utility
transformer_lens/HookedTransformer.py batch_decode fix, decode() type cast in generate()
transformer_lens/HookedEncoderDecoder.py decode() type cast in generate()
transformer_lens/model_bridge/bridge.py batch_decode fix
transformer_lens/utilities/tokenize_utils.py BOS token check
transformer_lens/model_bridge/sources/transformers.py BOS token check for tokenizer loading
transformer_lens/evals.py Doctest range checks for numerical stability
tests/unit/model_bridge/test_gpt_oss_moe.py Router scores shape fix
demos/ARENA_Content.ipynb batch_decode fix for v5 compatibility
pyproject.toml pytest-rerunfailures dependency
makefile RERUN_ARGS for httpx timeout retries

Testing

  • All existing unit tests pass locally with transformers v5.0.0 and huggingface_hub v1.3.4. All tests except the MoE router shape test pass with transformers v4.57.1 and huggingface_hub 0.34.1.
  • pytest-rerunfailures plugin confirmed working (1 flaky test auto-retried)
  • Intermittent httpx network timeouts during HF Hub model downloads are infrastructure-related, not code issues

Key Migration References


Checklist

  • New feature (non-breaking change which adds functionality)
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

…xts.

We also update .gitignore to exclude .env (commonly used local file exclution), e.g. to allow collaborators to add their on HF_TOKEN for test suite

Core Fixes:
-----------

transformer_lens/components/abstract_attention.py:
  - Replace pattern.to(self.cfg.dtype) with pattern.to(v.dtype) to handle cases
    where tensors are upcast to float32 for numerical stability while cfg.dtype
    remains float16/bfloat16
  - Add explicit device/dtype synchronization for output projection:
    * Move weights (W_O) and bias (b_O) to match input device (z.device)
    * Ensure z matches weight dtype before final linear operation

transformer_lens/model_bridge/bridge.py:
  - Replace direct original_model.to() call with move_to_and_update_config()
    utility to ensure:
    * All bridge components (not just original_model) are moved to target device
    * cfg.device and cfg.dtype stay synchronized with actual model state
    * Multi-GPU cache tensors remain on correct devices

Test Fixes:
-----------

tests/acceptance/test_hooked_encoder.py:
  - Fix test_cuda() to use correct fixture name 'tokens' instead of 'mlm_tokens'

tests/acceptance/test_multi_gpu.py:
  - Update test_cache_device() to pass torch.device("cpu") instead of string
    "cpu" for proper device type validation

tests/unit/components/test_attention.py:
  - Add test_attention_forward_half_precisions() to validate attention works
    correctly with bfloat16/float16 dtypes on CUDA devices

tests/unit/factored_matrix/test_multiply_by_scalar.py:
  - Add test IDs to parametrize decorators to avoid pytest cache issues when
    random numbers appear in test names

Tests Fixed by This Commit:
---------------------------
- tests/acceptance/test_multi_gpu.py::test_cache_device
- tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_memory_efficiency[gpt2]
- tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_consistent_outputs[gpt2]
- tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype0]
- tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype1]
- tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype0]
- tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype1]
- tests/unit/model_bridge/compatibility/test_utils.py::TestUtilsWithTransformerBridge::test_device_compatibility[gpt2]
Enhance to() method to properly handle both device and dtype arguments in
all supported PyTorch formats (positional, keyword, combined). Separately
invoke move_to_and_update_config for device/dtype to update cfg while
delegating the actual tensor movement to original_model.to() with original
args/kwargs. This ensures TransformerBridge respects standard PyTorch
behavior for model.to() calls.
Compatibility for transformers v5 and huggingface_hub v1.3.4
while maintaining backward compatibility with v4.

**Handle API/Behavioral Changes:**
- Handle batch_decode behavior change (wraps tokens for v4/v5 compatibility)
- Add rotary_pct → rope_parameters['partial_rotary_factor'] migration helper
- Fix BOS token handling for tokenizers without BOS (e.g., T5)
- Update MoE router_scores shape expectations for compact top-k format
- Add type casts for tokenizer.decode() return values

**Code Changes:**
- Add get_rotary_pct_from_config() utility for config v4/v5 compatibility
- Wrap tokens for batch_decode in HookedTransformer, bridge, and notebooks
- Add cast(str, ...) for decode() calls in generate() methods
- Update test expectations for new router_scores shape
- Add BOS token checks before setting add_bos_token=True

**Infrastructure:**
- Add pytest-rerunfailures dependency for flaky network tests (can be removed later once hub-related httpx read timeout issues are resolved)
- Update dependencies: transformers 5.0.0, huggingface_hub 1.3.4
- Change HF cache to use HF_HUB_CACHE (TRANSFORMERS_CACHE removed in v5)
- Update doctest to use range checks for numerical stability
…e httpx hub read timeouts affect both local and CI testing
@speediedan speediedan changed the title Basic transformers v5 support transformers v5 support Jan 29, 2026
@speediedan speediedan marked this pull request as ready for review January 30, 2026 00:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant