-
Notifications
You must be signed in to change notification settings - Fork 65
Description
🐛 Bug Description
When I run the provided example script
examples/ttrl/Qwen2.5-Math/math.sh
on the latest TTRL setup, the job fails with:
Traceback (most recent call last):
File "/home/tjy/TTRL/verl/verl/trainer/main_ppo.py", line 31, in main
run_ppo(config)
File "/home/tjy/TTRL/verl/verl/trainer/main_ppo.py", line 65, in run_ppo
ray.get(runner.run.remote(config))
File "/home/tjy/miniconda3/envs/ttrl/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
return fn(*args, **kwargs)
File "/home/tjy/miniconda3/envs/ttrl/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 104, in wrapper
return func(*args, **kwargs)
File "/home/tjy/miniconda3/envs/ttrl/lib/python3.10/site-packages/ray/_private/worker.py", line 2961, in get
values, debugger_breakpoint = worker.get_objects(
File "/home/tjy/miniconda3/envs/ttrl/lib/python3.10/site-packages/ray/_private/worker.py", line 1026, in get_objects
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(ActorDiedError): ray::TaskRunner.run() (pid=106666, ip=144.214.210.20, actor_id=dbc11a20aca6cccf44a05b0301000000, repr=<main_ppo.TaskRunner object at 0x7ecd4488b6d0>)
File "/home/tjy/TTRL/verl/verl/trainer/main_ppo.py", line 213, in run
trainer.init_workers()
File "/home/tjy/TTRL/verl/verl/trainer/ppo/ray_trainer.py", line 884, in init_workers
self.ref_policy_wg.init_model()
File "/home/tjy/TTRL/verl/verl/single_controller/ray/base.py", line 51, in __call__
output = ray.get(output)
ray.exceptions.ActorDiedError: The actor died because of an error raised in its creation task, ray::DkyUZOWorkerDict_0:1:WorkerDict.__init__() (pid=107888, ip=144.214.210.20, actor_id=ee8a699f30b92b905b911c6501000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f469ceace50>)
File "/home/tjy/TTRL/verl/verl/single_controller/ray/base.py", line 789, in __init__
self.worker_dict[key] = user_defined_cls(
File "/home/tjy/TTRL/verl/verl/workers/fsdp_workers.py", line 176, in __init__
assert self.config.actor.ppo_mini_batch_size > 0, (
AssertionError: ppo_mini_batch_size 0 should be larger than 0 after normalization
📜 Analysis
This seems to come from the normalization step in
verl/workers/fsdp_workers.py:
self.config.actor.ppo_mini_batch_size *= self.config.rollout.n
self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size
assert self.config.actor.ppo_mini_batch_size > 0, (
f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after "
f"normalization"
)In my environment:
- default
self.config.rollout.n = 1in verl/trainer/config/ppo_trainer.yaml - default
self.ulysses_sequence_parallel_size = 1in verl/trainer/config/ppo_trainer.yaml - default
ppo_mini_batch_size = 1in examples/ttrl/Qwen2.5-Math/math.sh self.device_mesh.size()equals the number of GPUs on the single machine
(I’m running single-node multi-GPU, and I have tried both 2 GPUs and 3 GPUs, both trigger the same error)
So the calculation becomes:
1 * 1 // (2 // 1) = 0 # or 1 * 1 // 3 = 0
which directly triggers the assertion.
💡 Expected Behavior
The example script examples/ttrl/Qwen2.5-Math/math.sh should run successfully with default settings.
🧩 Environment
- TTRL commit:
git -C /home/tjy/TTRL rev-parse HEAD26413fa664c4cf1ef622ebdb265740645b3c7831 - verl version: v0.4.1.dev0
- Python version:
Python 3.10.0 - OS:
Linux cs658a 5.15.0-136-generic - CUDA / GPUs: single node with 3 GPUs (tested both 2 and 3)
nvidia-smi
Sat Oct 25 23:07:46 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.133.07 Driver Version: 570.133.07 CUDA Version: 12.8 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX 5880 Ada Gene... Off | 00000000:27:00.0 Off | Off |
| 30% 30C P0 60W / 285W | 0MiB / 49140MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA RTX 5880 Ada Gene... Off | 00000000:98:00.0 Off | Off |
| 30% 30C P0 54W / 285W | 0MiB / 49140MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA RTX 5880 Ada Gene... Off | 00000000:B8:00.0 Off | Off |
| 30% 29C P0 55W / 285W | 0MiB / 49140MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
-
Command run:
bash examples/ttrl/Qwen2.5-Math/math.sh
🔁 Reproduction Steps
-
Clone the repo
-
Install dependencies
-
copy this math.sh, but use own DATA_LOCAL_DIR BACKBONE_PATH(only change gpu number from 8 to 2)
#!/bin/bash
#export VLLM_ATTENTION_BACKEND=XFORMERS
unset VLLM_ATTENTION_BACKEND
export VLLM_USE_V1=1
# ------------------------------------------------------------
DATE=$(date +%m%d)
TIME_TAG=$(date +%H%M%S)
TASK="MATH-TTT"
BACKBONE="Qwen2.5-Math-1.5B"
ADVANTAGE="grpo"
K=3
MAX_PROMPT_LENGTH=1024
MAX_RESPONSE_LENGTH=$((1024 * $K))
if [ "$K" -gt 8 ]; then
N=4
else
N=16
fi
EPISODE=10
DATA_TRAIN_BATCH_SIZE=32
N_VOTES_PER_PROMPT=64
N_SAMPLES_PER_PROMPT=32
MINI_BATCH_SIZE=1
MICRO_BATCH_SIZE=2
DATA_LOCAL_DIR="/home/tjy/TTRL/verl/data"
BACKBONE_PATH="/home/tjy/TTRL/verl/llms/${BACKBONE}"
MODEL="${TASK}-${BACKBONE}"
EXPERIMENT="TTRL-Len@${K}k"
WANDB_PROJECT="TTRL-verl"
LOG_NAME="${DATE}-${EXPERIMENT}-${MODEL}-${ADVANTAGE}"
OUTPUT_DIR="checkpoints/${WANDB_PROJECT}/${MODEL}/${DATE}/${EXPERIMENT}-${ADVANTAGE}-${TIME_TAG}"
# ------------------------------------------------------------
python -m verl.trainer.main_ppo \
--config-name='ppo_trainer_ttrl.yaml'\
data.train_files=["$DATA_LOCAL_DIR/$TASK/train.parquet"] \
data.val_files=["$DATA_LOCAL_DIR/$TASK/test.parquet"] \
data.max_prompt_length=$MAX_PROMPT_LENGTH \
data.max_response_length=$MAX_RESPONSE_LENGTH \
data.train_batch_size=$DATA_TRAIN_BATCH_SIZE \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=$BACKBONE_PATH \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=$MINI_BATCH_SIZE \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.optim.lr=5e-7 \
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.03 \
actor_rollout_ref.actor.optim.warmup_style='cosine' \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH)) \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.temperature=1.0 \
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=$N \
actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
actor_rollout_ref.rollout.max_model_len=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH)) \
actor_rollout_ref.rollout.max_num_batched_tokens=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH)) \
critic.optim.lr=9e-6 \
critic.model.use_remove_padding=True \
critic.model.path=$BACKBONE_PATH \
critic.model.enable_gradient_checkpointing=True \
critic.ppo_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \
critic.model.fsdp_config.param_offload=False \
critic.model.fsdp_config.optimizer_offload=False \
algorithm.kl_ctrl.kl_coef=0.00 \
algorithm.adv_estimator=$ADVANTAGE \
custom_reward_function.path="./verl/utils/reward_score/ttrl_math/__init__.py" \
custom_reward_function.name=reward_func \
ttrl.enable=True \
ttrl.n_votes_per_prompt=$N_VOTES_PER_PROMPT \
ttrl.n_samples_per_prompt=$N_SAMPLES_PER_PROMPT \
trainer.logger=['console','wandb'] \
trainer.project_name=$WANDB_PROJECT \
trainer.experiment_name=$LOG_NAME \
trainer.n_gpus_per_node=2 \
trainer.nnodes=1 \
trainer.save_freq=2000000 \
trainer.test_freq=2 \
trainer.max_actor_ckpt_to_keep=0 \
trainer.max_critic_ckpt_to_keep=0 \
trainer.default_local_dir=$OUTPUT_DIR \
trainer.total_epochs=$EPISODE "$@"
echo "Output directory: $OUTPUT_DIR"
-
Run
bash examples/ttrl/Qwen2.5-Math/math.sh
-
Observe the error above.
📎 Additional Context
I wanted to reproduce the example exactly as provided (without changing config values like rollout.n or ppo_mini_batch_size),
so this issue might affect others trying to run the example out-of-the-box.
🙏 Thank you for the amazing work on TTRL and verl!
Looking forward to your advice or a config fix.