Skip to content

Conversation

@xiaoxi-wangfj
Copy link
Contributor

Description

This PR fixes a NaN issue in the fused permute+pad path when handling Float8BlockwiseQTensor inputs.

Since torch.empty does not initialize memory, these buffers could contain NaN values in the padded regions.

When the permute input is a Float8BlockwiseQTensor, if the corresponding permuted_scale entries in the padded region contain NaNs, these NaNs can propagate through the subsequent dequantization and requantization path in GroupedLinear, eventually resulting in a NaN forward loss, e.g.:
ERROR:megatron.core.rerun_state_machine:Unexpected result nan on rank 1 at iteration #2 invocation #1 (message='found NaN in local forward loss calculation')

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Modify permuted_scale initialized to torch.zero in permute_with_mask_map

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 29, 2025

Greptile Summary

This PR fixes a critical NaN issue in the fused permute+pad path when handling Float8BlockwiseQTensor inputs by replacing torch.empty with alloc (which uses torch.zeros when pad_offsets is provided) for permuted_scale initialization.

Changes:

  • Modified permuted_scale initialization in permute_with_mask_map to use the existing alloc helper instead of torch.empty

Root Cause:
When pad_offsets is provided (indicating fused padding), the Triton _permute_kernel only writes to non-padded regions of permuted_scale. Since torch.empty doesn't initialize memory, padded regions could contain garbage NaN values. These NaNs propagate through subsequent dequantization/requantization in GroupedLinear, eventually causing NaN in forward loss calculations.

Solution Correctness:
The fix correctly mirrors the existing pattern used for output and permuted_probs buffers on lines 162-166, which already use torch.zeros when pad_offsets is provided. The change is minimal, targeted, and aligns with the established convention in the codebase for handling padded regions.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The fix is a one-line change that follows the existing pattern in the same function for similar buffers (output and permuted_probs). It addresses a genuine bug where uninitialized memory could contain NaN values. The change is conservative, minimal, and directly targets the root cause without introducing new complexity or breaking existing functionality.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/triton/permutation.py Changed torch.empty to alloc (which uses torch.zeros when pad_offsets is provided) for permuted_scale initialization to prevent NaN values in padded regions

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Permute as permute_with_mask_map
    participant Kernel as _permute_kernel
    participant GroupedLinear as GroupedLinear
    
    Note over User,GroupedLinear: Before Fix: NaN Propagation Path
    User->>Permute: Float8BlockwiseQTensor input
    Permute->>Permute: torch.empty(permuted_scale)<br/>[Contains NaN in padded regions]
    Permute->>Kernel: Execute permutation
    Kernel->>Kernel: Store scale to permuted_scale<br/>[NaN remains in unwritten padding]
    Kernel-->>Permute: permuted_scale with NaN
    Permute-->>GroupedLinear: Float8BlockwiseQTensor with NaN scale
    GroupedLinear->>GroupedLinear: Dequantize + Requantize
    Note over GroupedLinear: NaN propagates through operations
    GroupedLinear-->>User: NaN in forward loss
    
    Note over User,GroupedLinear: After Fix: Zero Initialization
    User->>Permute: Float8BlockwiseQTensor input<br/>with pad_offsets
    Permute->>Permute: torch.zeros(permuted_scale)<br/>[All zeros, including padding]
    Permute->>Kernel: Execute permutation
    Kernel->>Kernel: Store scale to permuted_scale<br/>[Zeros remain in unwritten padding]
    Kernel-->>Permute: permuted_scale with zeros
    Permute-->>GroupedLinear: Float8BlockwiseQTensor with valid scale
    GroupedLinear->>GroupedLinear: Dequantize + Requantize
    GroupedLinear-->>User: Valid forward loss
Loading

@xiaoxi-wangfj
Copy link
Contributor Author

@tdophung
Apologies for the oversight. In the previous change, when switching to torch.empty to initialize permuted_scale, I didn’t re-validate the code path where the permute input is a Float8BlockwiseQTensor.
Today, I reran the FP8-Flow setup (with Float8BlockwiseQTensor inputs) and was able to reproduce the issue. This PR fixes the problem, thanks to review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant