Skip to content

Conversation

@ziqi-wlb
Copy link

@ziqi-wlb ziqi-wlb commented Aug 19, 2025

What does this PR do?

This PR implements fully asynchronous RL training. The Async-RL Pipeline is a state-of-the-art implementation of asynchronous reinforcement learning training based on a fully decoupled architecture. It separates actor-train, actor-forward-logp, ref_logp, and rollout-generate components to achieve optimal performance and scalability.

Async-rl workflow:
image

Async-rl can achieve up to 50-100% performance improvement and convergence can be maintained.
Benchmark Configuration:

  • Model: Red-MoE-16B
  • Hardware: 4 machines
  • Configuration: TP1 + PP1 + EP4 + SGLang-TP2
  • Algorithm: GRPO
  • Batch Size: 256
image

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: [trainer, fsdp, vllm, recipe] feat: one step off async training recipe volcengine/verl#2231 , The implementation of PR2231 is only a one-step off-policy asynchronous one, and parameter synchronization through nccl is not scalable. This PR adds a state machine mechanism to implement asynchronous parameter updates, which can achieve separate deployment and asynchrony of any component, such as separate parallel pipelines of actor-train/param-update/logp/rollout.

  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)

    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Async RL Configuration
+actor_rollout_ref.async_pipeline=True \
 
# Resource Management
+trainer.use_nodes_ratios=[0.5,0.5,0.5,0.5] \
# means: train/logp/ref_logp use 0.5 ngpus, generate use 0.5 ngpus
 
# Performance Tuning, enable async-param-update
+actor_rollout_ref.rollout.enable_dual_buffer=True \
# The sender granularity of the actor training node during parameter update
+actor_rollout_ref.rollout.param_update_preduce_bucket_size_mb=512 \
# The receiver granularity of the rollout inference node is too large, which will cause GPU-OOM
+actor_rollout_ref.rollout.param_update_consume_bucket_size_mb=128 \
 
# The granularity of offpolicy, 2 means that generate is faster than the train node to execute 2 steps, that is, one-step-offpolicy
+trainer.generate_ahead_steps=2 \
# Add code snippet or script demonstrating how to use this
python3 -m verl.trainer.main_ppo --config-path=$ROOT_PATH/run_verl --config-name='redmoe_megatron' \
	++hydra.run.dir=outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}-${env:RANK,0} \
	algorithm.adv_estimator=grpo \
	data.train_files="$TRAIN_DATA_PATH" \
	data.val_files="$TEST_DATA_PATH" \
	data.train_batch_size=128 \
	data.max_prompt_length=$max_prompt_length \
	data.max_response_length=$max_response_length \
	data.filter_overlong_prompts=True \
	data.filter_overlong_prompts_workers=32 \
	data.truncation='error' \
	actor_rollout_ref.hybrid_engine=False \
	actor_rollout_ref.model.path=$MODEL_PATH \
	actor_rollout_ref.model.trust_remote_code=True \
	+actor_rollout_ref.model.use_fused_kernels=True \
	actor_rollout_ref.actor.optim.lr=1e-6 \
	actor_rollout_ref.actor.load_weight=True \
	actor_rollout_ref.actor.ppo_mini_batch_size=128 \
	actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
	actor_rollout_ref.actor.megatron.param_offload=False \
	actor_rollout_ref.actor.megatron.grad_offload=True \
	actor_rollout_ref.actor.megatron.optimizer_offload=True \
	actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1 \
	actor_rollout_ref.actor.megatron.tensor_model_parallel_size=1 \
	actor_rollout_ref.actor.megatron.expert_model_parallel_size=4 \
  actor_rollout_ref.ref.megatron.param_offload=True \
  actor_rollout_ref.ref.megatron.grad_offload=True \
  actor_rollout_ref.ref.megatron.optimizer_offload=True \
  actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
  actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=1 \
  actor_rollout_ref.ref.megatron.tensor_model_parallel_size=1 \
  actor_rollout_ref.ref.megatron.expert_model_parallel_size=4 \
	actor_rollout_ref.actor.use_kl_loss=True \
  actor_rollout_ref.rollout.n=16 \
  +actor_rollout_ref.rollout.enable_dual_buffer=True \
  +actor_rollout_ref.rollout.param_update_preduce_bucket_size_mb=512 \
  +actor_rollout_ref.rollout.param_update_consume_bucket_size_mb=128 \
	actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
	actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
	actor_rollout_ref.rollout.name=sglang \
	actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
	actor_rollout_ref.rollout.free_cache_engine=False \
	actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
	algorithm.use_kl_in_reward=False \
  +trainer.async_pipeline=True \
  +trainer.use_nodes_ratios=[0.5,0.5,0.5,0.5] \
  +trainer.generate_ahead_steps=2 \
	trainer.val_only=False \
	trainer.critic_warmup=0 \
  trainer.resume_mode=disable \
	trainer.logger=['console','tensorboard'] \
	trainer.project_name="verl_async_rl_redmoe16b" \
	trainer.experiment_name=$EXP_NAME \
	trainer.n_gpus_per_node=8 \
	trainer.nnodes=${WORLD_SIZE} \
	trainer.save_freq=50 \
	trainer.test_freq=5 \
	trainer.total_epochs=100

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

  1. State-machine design for async-rl: RL training workflows are inherently complex. While synchronous approaches can simply execute tasks sequentially, async-RL requires complex state transitions between different tasks. To ensure both performance and accuracy, the system employs flexible scheduling strategies that bind tasks to resources logically. Each task maintains its own production and consumption loop to prevent errors. In this context, designing RL state machines provides a friendly and manageable approach.
    The pipeline implements a sophisticated state machine design where different state transitions correspond to the entire async-RL pipeline workflow:
    dataloadergeneraterolloutlogpref_logprewardtrainparam_update

  2. Asynchronous Parameter Synchronization:
    The parameter update process is decomposed into three main components:
    2.1. Gather: Uses NCCL for parameter aggregation (must be serial)
    2.2. Send/Recv: Asynchronous CPU communication
    2.3. Load: Parameter loading without affecting GPU compute

  3. Add red-moe model for grpo

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants