-
Notifications
You must be signed in to change notification settings - Fork 49
Open
Description
Describe the bug
For Qwen3-32B, we could successfully set FP8 training. However, when reached the max_steps, it will not like the BF16 precision to successfully save the ckpt.
Steps/Code to reproduce bug
docker image: nvcr.io/nvidia/nemo-automodel:25.11
4*H200 GPUs.
yaml file:
step_scheduler:
global_batch_size: 32
local_batch_size: 4 #2
ckpt_every_steps: 100
val_every_steps: 100 # will run every x number of gradient steps
num_epochs: 1
dist_env:
backend: nccl
timeout_minutes: 1
rng:
_target_: nemo_automodel.components.training.rng.StatefulRNG
seed: 1111
ranked: true
model:
_target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
pretrained_model_name_or_path: Qwen/Qwen3-32B
checkpoint:
enabled: true
checkpoint_dir: ./checkpoints/Qwen3_32B/
model_save_format: safetensors
save_consolidated: true
# torch.compile configuration
compile:
enabled: false
mode: "default" # Options: "default", "reduce-overhead", "max-autotune"
fullgraph: false
dynamic: false # Set to false for better performance with fixed shapes
backend: null # Use default backend (inductor)
fp8:
enabled: true
recipe_name: tensorwise # Options: tensorwise, rowwise, rowwise_with_gw_hp
enable_fsdp_float8_all_gather: true
precompute_float8_dynamic_scale_for_fsdp: true
force_recompute_fp8_weight_in_bwd: true
#filter_fqns: ["lm_head"]
emulate: false
distributed:
_target_: nemo_automodel.components.distributed.fsdp2.FSDP2Manager
#_target_: nemo_automodel.components.distributed.megatron_fsdp.MegatronFSDPManager
dp_size: none
dp_replicate_size: 1 # dp_shard_size = dp_size / dp_replicate_size and dp_shard_size < dp_size. For DDP usecase, use DDPManager
tp_size: 4
pp_size: 1 # No PP size in megatron_fsdp.py #https://github.com/NVIDIA-NeMo/Automodel/blob/main/nemo_automodel/components/distributed/megatron_fsdp.py
cp_size: 1
sequence_parallel: false
loss_fn:
#_target_: nemo_automodel.components.loss.te_parallel_ce.TEParallelCrossEntropy # TP will have errors
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
dataset:
_target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset
dataset_name: rajpurkar/squad
split: train
packed_sequence:
packed_sequence_size: 1024
dataloader:
_target_: torchdata.stateful_dataloader.StatefulDataLoader
collate_fn:
_target_: nemo_automodel.components.datasets.utils.default_collater
pad_seq_len_divisible: 16 # fp8 requires divisible by 16
shuffle: false
validation_dataset:
_target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset
dataset_name: rajpurkar/squad
split: validation
limit_dataset_samples: 64
validation_dataloader:
_target_: torchdata.stateful_dataloader.StatefulDataLoader
collate_fn: nemo_automodel.components.datasets.utils.default_collater
#batch_size: 1
optimizer:
_target_: torch.optim.Adam
betas: [0.9, 0.999]
eps: 1e-8
lr: 1.0e-5
weight_decay: 0
fused: True
lr_scheduler:
lr_decay_style: cosine
min_lr: 1.0e-6
# Uncomment and configure for W&B logging
# wandb:
# project: <your_wandb_project>
# entity: <your_wandb_entity>
# name: <your_wandb_exp_name>
# save_dir: <your_wandb_save_dir>
And after the validation stage, the ckpt saving will trigger the error like the following:
2026-01-19 10:01:54 | INFO | root | [val] name "default" | step 99 | epoch 0 | loss 0.0779 | lr 9.88e-06 | num_label_tokens 836
/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py:4814: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
warnings.warn( # warn only once
/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py:4814: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
warnings.warn( # warn only once
/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py:4814: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
warnings.warn( # warn only once
Saving checkpoint to ./checkpoints/Qwen3_32B/epoch_0_step_99
/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py:4814: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
warnings.warn( # warn only once
/opt/venv/lib/python3.12/site-packages/safetensors/torch.py:18: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
return tensor.storage().data_ptr()
/opt/venv/lib/python3.12/site-packages/safetensors/torch.py:18: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
return tensor.storage().data_ptr()
/opt/venv/lib/python3.12/site-packages/safetensors/torch.py:18: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
return tensor.storage().data_ptr()
/opt/venv/lib/python3.12/site-packages/safetensors/torch.py:18: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
return tensor.storage().data_ptr()
[rank1]: Traceback (most recent call last):
[rank1]: File "/opt/Automodel/examples/llm_finetune/finetune.py", line 33, in <module>
[rank1]: main()
[rank1]: File "/opt/Automodel/examples/llm_finetune/finetune.py", line 29, in main
[rank1]: recipe.run_train_validation_loop()
[rank1]: File "/opt/Automodel/nemo_automodel/recipes/llm/train_ft.py", line 1146, in run_train_validation_loop
[rank1]: self.save_checkpoint(
[rank1]: File "/opt/Automodel/nemo_automodel/recipes/base_recipe.py", line 269, in save_checkpoint
[rank1]: self.checkpointer.save_model(model, path, peft_config=self.peft_config, tokenizer=tokenizer)
[rank1]: File "/opt/Automodel/nemo_automodel/components/checkpoint/checkpointing.py", line 223, in save_model
[rank1]: self._model_ctx.future = self._do_save(state_dict, model_dir, storage_writer)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/Automodel/nemo_automodel/components/checkpoint/checkpointing.py", line 518, in _do_save
[rank1]: dcp.save(state_dict, checkpoint_id=path, storage_writer=storage_writer, planner=planner)
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/logger.py", line 87, in wrapper
[rank1]: result = func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/utils.py", line 475, in inner_func
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/state_dict_saver.py", line 187, in save
[rank1]: return _save_state_dict(
[rank1]: ^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/state_dict_saver.py", line 441, in _save_state_dict
[rank1]: return distW.all_reduce("write", write_data, finish_checkpoint)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/utils.py", line 259, in all_reduce
[rank1]: raise final_result
[rank1]: torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0, 1, 2, 3])
[rank1]: Traceback (most recent call last): (RANK 0)
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/utils.py", line 239, in all_reduce
[rank1]: local_data = map_fun()
[rank1]: ^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/logger.py", line 87, in wrapper
[rank1]: result = func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/state_dict_saver.py", line 430, in write_data
[rank1]: all_writes = storage_writer.write_data(final_local_plan, planner)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/Automodel/nemo_automodel/components/checkpoint/_backports/hf_storage.py", line 167, in write_data
[rank1]: return super()._write_data(planner, file_queue)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/Automodel/nemo_automodel/components/checkpoint/_backports/filesystem.py", line 702, in _write_data
[rank1]: _write_files_from_queue(
[rank1]: File "/opt/Automodel/nemo_automodel/components/checkpoint/_backports/filesystem.py", line 449, in _write_files_from_queue
[rank1]: save(
[rank1]: File "/opt/venv/lib/python3.12/site-packages/safetensors/torch.py", line 316, in save
[rank1]: serialized = serialize(_flatten(tensors), metadata=metadata)
[rank1]: ^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/venv/lib/python3.12/site-packages/safetensors/torch.py", line 570, in _flatten
[rank1]: shared_pointers = _find_shared_tensors(tensors)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/venv/lib/python3.12/site-packages/safetensors/torch.py", line 77, in _find_shared_tensors
[rank1]: and storage_ptr(v) != 0
[rank1]: ^^^^^^^^^^^^^^
[rank1]: File "/opt/venv/lib/python3.12/site-packages/safetensors/torch.py", line 18, in storage_ptr
[rank1]: return tensor.storage().data_ptr()
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/storage.py", line 1247, in data_ptr
[rank1]: return self._data_ptr()
[rank1]: ^^^^^^^^^^^^^^^^
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/storage.py", line 1251, in _data_ptr
[rank1]: return self._untyped_storage.data_ptr()
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: Attempted to access the data pointer on an invalid python storage.