Skip to content

AssertionError: ppo_mini_batch_size 0 should be larger than 0 after normalization when running examples/ttrl/Qwen2.5-Math/math.sh #43

@childofcuriosity

Description

@childofcuriosity

🐛 Bug Description

When I run the provided example script
examples/ttrl/Qwen2.5-Math/math.sh
on the latest TTRL setup, the job fails with:


Traceback (most recent call last):
  File "/home/tjy/TTRL/verl/verl/trainer/main_ppo.py", line 31, in main
    run_ppo(config)
  File "/home/tjy/TTRL/verl/verl/trainer/main_ppo.py", line 65, in run_ppo
    ray.get(runner.run.remote(config))
  File "/home/tjy/miniconda3/envs/ttrl/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/home/tjy/miniconda3/envs/ttrl/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 104, in wrapper
    return func(*args, **kwargs)
  File "/home/tjy/miniconda3/envs/ttrl/lib/python3.10/site-packages/ray/_private/worker.py", line 2961, in get
    values, debugger_breakpoint = worker.get_objects(
  File "/home/tjy/miniconda3/envs/ttrl/lib/python3.10/site-packages/ray/_private/worker.py", line 1026, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(ActorDiedError): ray::TaskRunner.run() (pid=106666, ip=144.214.210.20, actor_id=dbc11a20aca6cccf44a05b0301000000, repr=<main_ppo.TaskRunner object at 0x7ecd4488b6d0>)
  File "/home/tjy/TTRL/verl/verl/trainer/main_ppo.py", line 213, in run
    trainer.init_workers()
  File "/home/tjy/TTRL/verl/verl/trainer/ppo/ray_trainer.py", line 884, in init_workers
    self.ref_policy_wg.init_model()
  File "/home/tjy/TTRL/verl/verl/single_controller/ray/base.py", line 51, in __call__
    output = ray.get(output)
ray.exceptions.ActorDiedError: The actor died because of an error raised in its creation task, ray::DkyUZOWorkerDict_0:1:WorkerDict.__init__() (pid=107888, ip=144.214.210.20, actor_id=ee8a699f30b92b905b911c6501000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f469ceace50>)
  File "/home/tjy/TTRL/verl/verl/single_controller/ray/base.py", line 789, in __init__
    self.worker_dict[key] = user_defined_cls(
  File "/home/tjy/TTRL/verl/verl/workers/fsdp_workers.py", line 176, in __init__
    assert self.config.actor.ppo_mini_batch_size > 0, (
AssertionError: ppo_mini_batch_size 0 should be larger than 0 after normalization

📜 Analysis

This seems to come from the normalization step in
verl/workers/fsdp_workers.py:

self.config.actor.ppo_mini_batch_size *= self.config.rollout.n
self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size
assert self.config.actor.ppo_mini_batch_size > 0, (
    f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after "
    f"normalization"
)

In my environment:

  • default self.config.rollout.n = 1 in verl/trainer/config/ppo_trainer.yaml
  • default self.ulysses_sequence_parallel_size = 1 in verl/trainer/config/ppo_trainer.yaml
  • default ppo_mini_batch_size = 1 in examples/ttrl/Qwen2.5-Math/math.sh
  • self.device_mesh.size() equals the number of GPUs on the single machine
    (I’m running single-node multi-GPU, and I have tried both 2 GPUs and 3 GPUs, both trigger the same error)

So the calculation becomes:

1 * 1 // (2 // 1) = 0    # or 1 * 1 // 3 = 0

which directly triggers the assertion.


💡 Expected Behavior

The example script examples/ttrl/Qwen2.5-Math/math.sh should run successfully with default settings.


🧩 Environment

  • TTRL commit: git -C /home/tjy/TTRL rev-parse HEAD 26413fa664c4cf1ef622ebdb265740645b3c7831
  • verl version: v0.4.1.dev0
  • Python version: Python 3.10.0
  • OS: Linux cs658a 5.15.0-136-generic
  • CUDA / GPUs: single node with 3 GPUs (tested both 2 and 3)
 nvidia-smi
Sat Oct 25 23:07:46 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.133.07             Driver Version: 570.133.07     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTX 5880 Ada Gene...    Off |   00000000:27:00.0 Off |                  Off |
| 30%   30C    P0             60W /  285W |       0MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX 5880 Ada Gene...    Off |   00000000:98:00.0 Off |                  Off |
| 30%   30C    P0             54W /  285W |       0MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA RTX 5880 Ada Gene...    Off |   00000000:B8:00.0 Off |                  Off |
| 30%   29C    P0             55W /  285W |       0MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+
  • Command run:

    bash examples/ttrl/Qwen2.5-Math/math.sh

🔁 Reproduction Steps

  1. Clone the repo

  2. Install dependencies

  3. copy this math.sh, but use own DATA_LOCAL_DIR BACKBONE_PATH(only change gpu number from 8 to 2)

#!/bin/bash
#export VLLM_ATTENTION_BACKEND=XFORMERS
unset VLLM_ATTENTION_BACKEND
export VLLM_USE_V1=1

# ------------------------------------------------------------

DATE=$(date +%m%d)
TIME_TAG=$(date +%H%M%S)

TASK="MATH-TTT"
BACKBONE="Qwen2.5-Math-1.5B"
ADVANTAGE="grpo"

K=3
MAX_PROMPT_LENGTH=1024
MAX_RESPONSE_LENGTH=$((1024 * $K))
if [ "$K" -gt 8 ]; then
  N=4
else
  N=16
fi

EPISODE=10
DATA_TRAIN_BATCH_SIZE=32
N_VOTES_PER_PROMPT=64
N_SAMPLES_PER_PROMPT=32
MINI_BATCH_SIZE=1
MICRO_BATCH_SIZE=2

DATA_LOCAL_DIR="/home/tjy/TTRL/verl/data"
BACKBONE_PATH="/home/tjy/TTRL/verl/llms/${BACKBONE}"

MODEL="${TASK}-${BACKBONE}"
EXPERIMENT="TTRL-Len@${K}k"

WANDB_PROJECT="TTRL-verl"
LOG_NAME="${DATE}-${EXPERIMENT}-${MODEL}-${ADVANTAGE}"
OUTPUT_DIR="checkpoints/${WANDB_PROJECT}/${MODEL}/${DATE}/${EXPERIMENT}-${ADVANTAGE}-${TIME_TAG}"

# ------------------------------------------------------------
python -m verl.trainer.main_ppo \
--config-name='ppo_trainer_ttrl.yaml'\
  data.train_files=["$DATA_LOCAL_DIR/$TASK/train.parquet"] \
  data.val_files=["$DATA_LOCAL_DIR/$TASK/test.parquet"] \
  data.max_prompt_length=$MAX_PROMPT_LENGTH \
  data.max_response_length=$MAX_RESPONSE_LENGTH \
  data.train_batch_size=$DATA_TRAIN_BATCH_SIZE \
  data.filter_overlong_prompts=True \
  data.truncation='error' \
  actor_rollout_ref.model.path=$BACKBONE_PATH \
  actor_rollout_ref.model.enable_gradient_checkpointing=True \
  actor_rollout_ref.model.use_remove_padding=True \
  actor_rollout_ref.actor.ppo_mini_batch_size=$MINI_BATCH_SIZE \
  actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \
  actor_rollout_ref.actor.use_kl_loss=True \
  actor_rollout_ref.actor.optim.lr=5e-7 \
  actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.03 \
  actor_rollout_ref.actor.optim.warmup_style='cosine' \
  actor_rollout_ref.actor.fsdp_config.param_offload=False \
  actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
  actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH)) \
  actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \
  actor_rollout_ref.ref.fsdp_config.param_offload=True \
  actor_rollout_ref.rollout.name=vllm \
  actor_rollout_ref.rollout.temperature=1.0 \
  actor_rollout_ref.rollout.enforce_eager=False \
  actor_rollout_ref.rollout.free_cache_engine=False \
  actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \
  actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
  actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
  actor_rollout_ref.rollout.val_kwargs.do_sample=True \
  actor_rollout_ref.rollout.val_kwargs.n=$N \
  actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
  actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
  actor_rollout_ref.rollout.max_model_len=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH)) \
  actor_rollout_ref.rollout.max_num_batched_tokens=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH)) \
  critic.optim.lr=9e-6 \
  critic.model.use_remove_padding=True \
  critic.model.path=$BACKBONE_PATH \
  critic.model.enable_gradient_checkpointing=True \
  critic.ppo_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \
  critic.model.fsdp_config.param_offload=False \
  critic.model.fsdp_config.optimizer_offload=False \
  algorithm.kl_ctrl.kl_coef=0.00 \
  algorithm.adv_estimator=$ADVANTAGE \
  custom_reward_function.path="./verl/utils/reward_score/ttrl_math/__init__.py" \
  custom_reward_function.name=reward_func \
  ttrl.enable=True \
  ttrl.n_votes_per_prompt=$N_VOTES_PER_PROMPT \
  ttrl.n_samples_per_prompt=$N_SAMPLES_PER_PROMPT \
  trainer.logger=['console','wandb'] \
  trainer.project_name=$WANDB_PROJECT \
  trainer.experiment_name=$LOG_NAME \
  trainer.n_gpus_per_node=2 \
  trainer.nnodes=1 \
  trainer.save_freq=2000000 \
  trainer.test_freq=2 \
  trainer.max_actor_ckpt_to_keep=0 \
  trainer.max_critic_ckpt_to_keep=0 \
  trainer.default_local_dir=$OUTPUT_DIR \
  trainer.total_epochs=$EPISODE "$@"

echo "Output directory: $OUTPUT_DIR"
  1. Run

    bash examples/ttrl/Qwen2.5-Math/math.sh
  2. Observe the error above.


📎 Additional Context

I wanted to reproduce the example exactly as provided (without changing config values like rollout.n or ppo_mini_batch_size),
so this issue might affect others trying to run the example out-of-the-box.


🙏 Thank you for the amazing work on TTRL and verl!
Looking forward to your advice or a config fix.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions