Skip to content

Setting compile.enabled=true for Qwen3-32B will trigger error: Dynamo failed to run FX node with fake tensors #1083

@linmuchuiyang

Description

@linmuchuiyang

Describe the bug

If you set compile.enabled=true for Qwen3-32B, you will get the error:

Dynamo failed to run FX node with fake tensors

Steps/Code to reproduce bug

docker image: nvcr.io/nvidia/nemo-automodel:25.11

yaml file:

step_scheduler:
  global_batch_size: 32
  local_batch_size: 4 #2
  ckpt_every_steps: 100
  val_every_steps: 100  # will run every x number of gradient steps
  num_epochs: 1

dist_env:
  backend: nccl
  timeout_minutes: 1

rng:
  _target_: nemo_automodel.components.training.rng.StatefulRNG
  seed: 1111
  ranked: true

model:
  _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
  pretrained_model_name_or_path: Qwen/Qwen3-32B

checkpoint:
  enabled: true
  checkpoint_dir: ./checkpoints/Qwen3_32B/
  model_save_format: safetensors
  save_consolidated: true 

# torch.compile configuration
compile:
  enabled: true
  mode: "default" # Options: "default", "reduce-overhead", "max-autotune"
  fullgraph: false
  dynamic: false  # Set to false for better performance with fixed shapes
  backend: null  # Use default backend (inductor)

distributed:
  _target_: nemo_automodel.components.distributed.fsdp2.FSDP2Manager
  #_target_: nemo_automodel.components.distributed.megatron_fsdp.MegatronFSDPManager
  dp_size: none
  dp_replicate_size: 1 # dp_shard_size = dp_size / dp_replicate_size and dp_shard_size < dp_size. For DDP usecase, use DDPManager
  tp_size: 4
  pp_size: 1 # No PP size in megatron_fsdp.py #https://github.com/NVIDIA-NeMo/Automodel/blob/main/nemo_automodel/components/distributed/megatron_fsdp.py
  cp_size: 1
  sequence_parallel: false

loss_fn:
  #_target_: nemo_automodel.components.loss.te_parallel_ce.TEParallelCrossEntropy # TP will have errors
  _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy

dataset:
  _target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset
  dataset_name: rajpurkar/squad
  split: train

packed_sequence:
  packed_sequence_size: 1024

dataloader:
  _target_: torchdata.stateful_dataloader.StatefulDataLoader
  collate_fn: nemo_automodel.components.datasets.utils.default_collater
  shuffle: false

validation_dataset:
  _target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset
  dataset_name: rajpurkar/squad
  split: validation
  limit_dataset_samples: 64

validation_dataloader:
  _target_: torchdata.stateful_dataloader.StatefulDataLoader
  collate_fn: nemo_automodel.components.datasets.utils.default_collater
  #batch_size: 1

optimizer:
  _target_: torch.optim.Adam
  betas: [0.9, 0.999]
  eps: 1e-8
  lr: 1.0e-5
  weight_decay: 0
  fused: True

lr_scheduler:
  lr_decay_style: cosine
  min_lr: 1.0e-6

# Uncomment and configure for W&B logging
# wandb:
#   project: <your_wandb_project>
#   entity: <your_wandb_entity>
#   name: <your_wandb_exp_name>
#   save_dir: <your_wandb_save_dir>

The errors will output like the following:

[rank3]: Traceback (most recent call last):
[rank3]:   File "/opt/Automodel/examples/llm_finetune/finetune.py", line 33, in <module>
[rank3]:     main()
[rank3]:   File "/opt/Automodel/examples/llm_finetune/finetune.py", line 29, in main
[rank3]:     recipe.run_train_validation_loop()
[rank3]:   File "/opt/Automodel/nemo_automodel/recipes/llm/train_ft.py", line 1127, in run_train_validation_loop
[rank3]:     train_log_data = self._run_train_optim_step(batches, 1.0)
[rank3]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/Automodel/nemo_automodel/recipes/llm/train_ft.py", line 1268, in _run_train_optim_step
[rank3]:     self._forward_backward_step(
[rank3]:   File "/opt/Automodel/nemo_automodel/recipes/llm/train_ft.py", line 1226, in _forward_backward_step
[rank3]:     out = model(**batch)
[rank3]:           ^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 413, in __call__
[rank3]:     return super().__call__(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 804, in compile_wrapper
[rank3]:     return fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/external_utils.py", line 68, in inner
[rank3]:     return fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1881, in _call_impl
[rank3]:     return inner()
[rank3]:            ^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1829, in inner
[rank3]:     result = forward_call(*args, **kwargs)
[rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/opt/Automodel/nemo_automodel/_transformers/auto_model.py", line 86, in wrapper
[rank3]:     return func(self, *args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1595, in __call__
[rank3]:     result = self._torchdynamo_orig_backend(
[rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1353, in __call__
[rank3]:     result = self._inner_convert(
[rank3]:              ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 682, in __call__
[rank3]:     result = _compile(
[rank3]:              ^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1172, in _compile
[rank3]:     guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank3]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_utils_internal.py", line 92, in wrapper_function
[rank3]:     return function(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 858, in compile_inner
[rank3]:     return _compile_inner(code, one_graph, hooks, transform)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 897, in _compile_inner
[rank3]:     out_code = transform_code_object(code, transform)
[rank3]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1461, in transform_code_object
[rank3]:     transformations(instructions, code_options)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 300, in _fn
[rank3]:     return fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 818, in transform
[rank3]:     tracer.run()
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 3529, in run
[rank3]:     super().run()
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1373, in run
[rank3]:     while self.step():
[rank3]:           ^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1277, in step
[rank3]:     self.dispatch_table[inst.opcode](self, inst)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 853, in wrapper
[rank3]:     return inner_fn(self, inst)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2241, in CALL_FUNCTION_EX
[rank3]:     self.call_function(fn, argsvars.items, kwargsvars)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1201, in call_function
[rank3]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank3]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
[rank3]:     return getattr(self.realize(), name)(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py", line 578, in call_function
[rank3]:     return super().call_function(tx, args, kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py", line 334, in call_function
[rank3]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1218, in inline_user_function_return
[rank3]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 3734, in inline_call
[rank3]:     return tracer.inline_call_()
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 3937, in inline_call_
[rank3]:     self.run()
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1373, in run
[rank3]:     while self.step():
[rank3]:           ^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1277, in step
[rank3]:     self.dispatch_table[inst.opcode](self, inst)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 853, in wrapper
[rank3]:     return inner_fn(self, inst)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2241, in CALL_FUNCTION_EX
[rank3]:     self.call_function(fn, argsvars.items, kwargsvars)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1201, in call_function
[rank3]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank3]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
[rank3]:     return getattr(self.realize(), name)(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/nn_module.py", line 1006, in call_function
[rank3]:     return variables.UserFunctionVariable(fn, source=source).call_function(
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py", line 578, in call_function
[rank3]:     return super().call_function(tx, args, kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py", line 334, in call_function
[rank3]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1218, in inline_user_function_return
[rank3]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 3734, in inline_call
[rank3]:     return tracer.inline_call_()
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 3937, in inline_call_
[rank3]:     self.run()
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1373, in run
[rank3]:     while self.step():
[rank3]:           ^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1277, in step
[rank3]:     self.dispatch_table[inst.opcode](self, inst)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 853, in wrapper
[rank3]:     return inner_fn(self, inst)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2241, in CALL_FUNCTION_EX
[rank3]:     self.call_function(fn, argsvars.items, kwargsvars)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1201, in call_function
[rank3]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank3]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
[rank3]:     return getattr(self.realize(), name)(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py", line 578, in call_function
[rank3]:     return super().call_function(tx, args, kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py", line 334, in call_function
[rank3]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1218, in inline_user_function_return
[rank3]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 3734, in inline_call
[rank3]:     return tracer.inline_call_()
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 3937, in inline_call_
[rank3]:     self.run()
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1373, in run
[rank3]:     while self.step():
[rank3]:           ^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1277, in step
[rank3]:     self.dispatch_table[inst.opcode](self, inst)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 853, in wrapper
[rank3]:     return inner_fn(self, inst)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2933, in CALL
[rank3]:     self._call(inst)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2927, in _call
[rank3]:     self.call_function(fn, args, kwargs)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1201, in call_function
[rank3]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank3]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
[rank3]:     return getattr(self.realize(), name)(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/nn_module.py", line 1006, in call_function
[rank3]:     return variables.UserFunctionVariable(fn, source=source).call_function(
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py", line 578, in call_function
[rank3]:     return super().call_function(tx, args, kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py", line 334, in call_function
[rank3]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1218, in inline_user_function_return
[rank3]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 3734, in inline_call
[rank3]:     return tracer.inline_call_()
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 3937, in inline_call_
[rank3]:     self.run()
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1373, in run
[rank3]:     while self.step():
[rank3]:           ^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1277, in step
[rank3]:     self.dispatch_table[inst.opcode](self, inst)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 853, in wrapper
[rank3]:     return inner_fn(self, inst)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2933, in CALL
[rank3]:     self._call(inst)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2927, in _call
[rank3]:     self.call_function(fn, args, kwargs)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1201, in call_function
[rank3]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank3]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py", line 334, in call_function
[rank3]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1218, in inline_user_function_return
[rank3]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 3734, in inline_call
[rank3]:     return tracer.inline_call_()
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 3937, in inline_call_
[rank3]:     self.run()
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1373, in run
[rank3]:     while self.step():
[rank3]:           ^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1277, in step
[rank3]:     self.dispatch_table[inst.opcode](self, inst)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 853, in wrapper
[rank3]:     return inner_fn(self, inst)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2933, in CALL
[rank3]:     self._call(inst)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2927, in _call
[rank3]:     self.call_function(fn, args, kwargs)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1201, in call_function
[rank3]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank3]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
[rank3]:     return getattr(self.realize(), name)(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py", line 578, in call_function
[rank3]:     return super().call_function(tx, args, kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py", line 334, in call_function
[rank3]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1218, in inline_user_function_return
[rank3]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 3734, in inline_call
[rank3]:     return tracer.inline_call_()
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 3937, in inline_call_
[rank3]:     self.run()
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1373, in run
[rank3]:     while self.step():
[rank3]:           ^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1277, in step
[rank3]:     self.dispatch_table[inst.opcode](self, inst)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 853, in wrapper
[rank3]:     return inner_fn(self, inst)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2933, in CALL
[rank3]:     self._call(inst)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2927, in _call
[rank3]:     self.call_function(fn, args, kwargs)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1201, in call_function
[rank3]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank3]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
[rank3]:     return getattr(self.realize(), name)(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py", line 1889, in call_function
[rank3]:     return self.func.call_function(tx, merged_args, merged_kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py", line 578, in call_function
[rank3]:     return super().call_function(tx, args, kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py", line 334, in call_function
[rank3]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1218, in inline_user_function_return
[rank3]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 3734, in inline_call
[rank3]:     return tracer.inline_call_()
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 3937, in inline_call_
[rank3]:     self.run()
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1373, in run
[rank3]:     while self.step():
[rank3]:           ^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1277, in step
[rank3]:     self.dispatch_table[inst.opcode](self, inst)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 853, in wrapper
[rank3]:     return inner_fn(self, inst)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2933, in CALL
[rank3]:     self._call(inst)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2927, in _call
[rank3]:     self.call_function(fn, args, kwargs)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1201, in call_function
[rank3]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank3]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/misc.py", line 1115, in call_function
[rank3]:     return self.obj.call_method(tx, self.name, args, kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/tensor.py", line 698, in call_method
[rank3]:     result = handler_method(*args, **kwargs)
[rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/tensor.py", line 1203, in method_redistribute
[rank3]:     return wrap_fx_proxy(
[rank3]:            ^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/builder.py", line 2667, in wrap_fx_proxy
[rank3]:     return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/builder.py", line 2733, in wrap_fx_proxy_cls
[rank3]:     return _wrap_fx_proxy(
[rank3]:            ^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/builder.py", line 2831, in _wrap_fx_proxy
[rank3]:     example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
[rank3]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3423, in get_fake_value
[rank3]:     raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3321, in get_fake_value
[rank3]:     ret_val = wrap_fake_exception(
[rank3]:               ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 2821, in wrap_fake_exception
[rank3]:     return fn()
[rank3]:            ^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3322, in <lambda>
[rank3]:     lambda: run_node(tx.output, node, args, kwargs, nnmodule)
[rank3]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3530, in run_node
[rank3]:     raise RuntimeError(make_error_message(e)).with_traceback(
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3489, in run_node
[rank3]:     return node.target(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/tensor.py", line 1196, in redistribute_fn_with_prim_types
[rank3]:     return x.redistribute(*args_as_value, **kwargs_as_value)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_api.py", line 564, in redistribute
[rank3]:     return Redistribute.apply(
[rank3]:            ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 581, in apply
[rank3]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_redistribute.py", line 321, in forward
[rank3]:     output = redistribute_local_tensor(
[rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_redistribute.py", line 208, in redistribute_local_tensor
[rank3]:     new_local_tensor = partial_spec._reduce_value(
[rank3]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_ops/_embedding_ops.py", line 117, in _reduce_value
[rank3]:     assert self.mask_buffer.data is not None
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <function TensorVariable.method_redistribute.<locals>.redistribute_fn_with_prim_types at 0x77cf4416a8e0>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:3', size=(4, 1024, 5120), dtype=torch.bfloat16), device_mesh=DeviceMesh((tp=4), device: 'cuda', stride: (1,)), placements=(_MaskPartial(offset_shape=(151936, 5120), offset_dim=0),)),), **{}): got AssertionError()

[rank3]: from user code:
[rank3]:    File "/opt/venv/lib/python3.12/site-packages/transformers/utils/generic.py", line 918, in wrapper
[rank3]:     output = func(self, *args, **kwargs)
[rank3]:   File "/opt/venv/lib/python3.12/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 480, in forward
[rank3]:     outputs: BaseModelOutputWithPast = self.model(
[rank3]:   File "/opt/venv/lib/python3.12/site-packages/transformers/utils/generic.py", line 1064, in wrapper
[rank3]:     outputs = func(self, *args, **kwargs)
[rank3]:   File "/opt/venv/lib/python3.12/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 371, in forward
[rank3]:     inputs_embeds = self.embed_tokens(input_ids)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1878, in _call_impl
[rank3]:     return inner()
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1842, in inner
[rank3]:     hook_result = hook(self, args, result)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_api.py", line 983, in <lambda>
[rank3]:     lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/parallel/style.py", line 285, in _prepare_output_fn
[rank3]:     outputs = outputs.redistribute(placements=output_layouts, async_op=True)

[rank3]: Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

W0119 08:19:51.321000 2046 torch/distributed/elastic/multiprocessing/api.py:906] Sending process 2114 closing signal SIGTERM
W0119 08:19:51.322000 2046 torch/distributed/elastic/multiprocessing/api.py:906] Sending process 2115 closing signal SIGTERM
W0119 08:19:51.322000 2046 torch/distributed/elastic/multiprocessing/api.py:906] Sending process 2116 closing signal SIGTERM
E0119 08:19:51.569000 2046 torch/distributed/elastic/multiprocessing/api.py:880] failed (exitcode: 1) local_rank: 0 (pid: 2113) of binary: /opt/venv/bin/python3
Traceback (most recent call last):
  File "/usr/local/bin/torchrun", line 7, in <module>
    sys.exit(main())
             ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 936, in main
    run(args)
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/run.py", line 927, in run
    elastic_launch(
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 151, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/launcher/api.py", line 288, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
/opt/Automodel/examples/llm_finetune/finetune.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2026-01-19_08:19:51
  host      : 7c814ccf8f4a
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 2113)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions