Rollout Correction
Author: Yingru Li
Last updated: 10/30/2025.
📖 Documentation Structure
This document - Practical usage guide: configurations, presets, troubleshooting
Mathematical Formulations - Theoretical foundations, derivations, and algorithmic details
Start here for implementation, refer to the math doc for theory and design rationale.
This document provides a comprehensive overview of the Rollout Correction implementation in verl.
Note on Naming: This feature is called “Rollout Correction” to reflect the complete functionality: importance sampling (IS) weights and rejection sampling (RS). The internal variable rollout_is_weights retains its name as it specifically refers to the IS weights component.
BibTeX Citation
@online{liu-li-2025-rl-collapse,
title = {When Speed Kills Stability: Demystifying {RL} Collapse from the Training-Inference Mismatch},
author = {Liu, Jiacai and Li, Yingru and Fu, Yuqian and Wang, Jiawei and Liu, Qian and Shen, Yu},
year = {2025},
month = sep,
url = {https://richardli.xyz/rl-collapse}
}
@article{li2025trust,
title={Trust Region Masking for Long-Horizon LLM Reinforcement Learning},
author={Li, Yingru and Liu, Jiacai and Xu, Jiawei and Tong, Yuxuan and Li, Ziniu and Liu, Qian and Wang, Baoxiang},
journal={arXiv preprint arXiv:2512.23075},
year={2025}
}
Blog Series
Main blog post: https://richardli.xyz/rl-collapse
Part 1: Why Mismatch Breaks LLM-RL (analytical framework using TV distance for bias and χ²-divergence for variance)
Part 2: The Gradient Estimator Trials (token-level vs sequence-level correction bias-variance tradeoff)
Part 3: When Math Meets Reality—Toxic Tails and Length Traps (why rejection over clipping, and geometric-level RS)
Latest Paper: https://arxiv.org/abs/2512.23075
Overview
Rollout Correction provides a unified framework to handle general off-policy problems in RL training. Any scenario where the data collection distribution differs from the training distribution can benefit from these methods.
Common off-policy scenarios:
Policy Mismatch (Implementation Differences)
Different precision: FP8 vs FP16 vs BF16 vs FP32
Different backends: vLLM vs SGLang vs FSDP vs Megatron
Different implementations even with identical weights
Temporal Lag (Model Staleness)
Rollout uses older checkpoint while training has progressed
Asynchronous rollout workers with stale parameters
Common in distributed/async RL systems
Replay Buffers
Training on historical trajectories from earlier iterations
Experience replay from different policy versions
Data augmentation or resampling strategies
Off-Policy Algorithms
Behavioral cloning from expert demonstrations
DAPO (data from auxiliary policies)
Any algorithm using trajectories from a different policy
Data Quality Filtering
Reweighting or filtering collected data
Preference learning with modified distributions
Curriculum learning with distribution shifts
These off-policy gaps can cause training instability and policy collapse. Rollout Correction uses importance sampling (IS) weights and rejection sampling (RS) to correct for any distribution shift between data collection and training.
Important Note on Common Implementation Mistakes:
Many LLM-RL implementations incorrectly apply PPO by ignoring the actual rollout policy π_rollout and assuming the training reference policy π_old is the behavior policy. This is mathematically incorrect when π_rollout ≠ π_old (which is typical in LLM-RL due to precision/backend differences between rollout and training).
This is not PPO’s fault - PPO itself is mathematically correct. The issue is the incorrect assumption that π_old = π_rollout in naive implementations.
This critical implementation mistake that leads to RL training collapse was identified in the blog post “When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch” and motivated the development of this rollout correction framework.
Mathematically correct approaches:
Decoupled mode: Three policies (π_rollout, π_old, π_θ) with IS correction from π_rollout to π_old
Bypass mode: Two policies (π_rollout = π_old, π_θ) using actual rollout policy as PPO anchor
Bypass + Policy Gradient mode: Two policies (π_rollout, π_θ) with IS/RS correction and no PPO clipping
See Mathematical Formulations for detailed explanation.
Key Design Principle: Separation of IS Weights and Rejection Sampling
The implementation cleanly separates two orthogonal mechanisms:
IS Weights (
rollout_is_weights): Continuous reweighting for gradient correctionPolicy ratio: π_old/π_rollout (decoupled) or π_θ/π_rollout (bypass)
Safety-bounded: Clamped to [exp(-20), exp(20)] ≈ [2e-9, 5e8] to prevent overflow
Token level: Bounds per-token ratios
Sequence level: Bounds product of ratios (broadcast to all tokens)
Truncated: Upper clamped via
.clamp(max=rollout_is_threshold)(TIS: Truncated Importance Sampling)Zeroed at padding: Multiplied by response_mask to zero out padding positions
Used to weight policy gradients (variance reduction)
Rejection Sampling (
modified_response_mask): Binary filtering for outlier exclusionCreates binary mask: 1 = keep, 0 = reject
Rejects tokens/sequences with IS ratios outside [lower_threshold, upper_threshold]
Modifies response_mask to exclude rejected samples from training
This separation ensures:
✅ IS weights provide continuous reweighting (reduce variance)
✅ Rejection sampling provides hard filtering (remove extreme outliers)
✅ Both mechanisms can be enabled independently or together
✅ Safety bounds prevent numerical overflow in all cases
Quick Start: Using Verified Presets
NEW: We now provide typed configuration with verified presets for common scenarios. These presets have been validated with tens of thousands of GPU hours across various models and training scenarios.
Python API
from verl.trainer.config.algorithm import RolloutCorrectionConfig
# === Decoupled PPO mode (3 policies: π_rollout, π_old, π_θ) ===
# IS weights correct for gap between π_old and π_rollout
config = RolloutCorrectionConfig.decoupled_token_is() # Token-TIS
config = RolloutCorrectionConfig.decoupled_seq_is() # Seq-TIS
config = RolloutCorrectionConfig.decoupled_seq_is_rs() # Seq-MIS
config = RolloutCorrectionConfig.decoupled_geo_rs() # Geo-RS (ratio mode)
config = RolloutCorrectionConfig.decoupled_geo_rs_token_tis() # Geo-RS + Token-TIS
# === K3 KL Estimator presets (more stable for small KL) ===
config = RolloutCorrectionConfig.decoupled_k3_rs() # K3-RS only
config = RolloutCorrectionConfig.decoupled_k3_rs_token_tis() # K3-RS + Token-TIS
# === Bypass PPO mode (2 policies: π_rollout = π_old, π_θ) - fast ===
# PPO ratio handles IS, so no explicit IS weights needed
config = RolloutCorrectionConfig.bypass_ppo_clip() # PPO-clip only
config = RolloutCorrectionConfig.bypass_ppo_clip_geo_rs() # PPO-clip + Geo-RS
config = RolloutCorrectionConfig.bypass_ppo_clip_k3_rs() # PPO-clip + K3-RS
# === Bypass PG mode (2 policies, no PPO clipping) - fast ===
# IS weights computed on-the-fly as π_θ / π_rollout
config = RolloutCorrectionConfig.bypass_pg_is() # Seq-TIS + PG
config = RolloutCorrectionConfig.bypass_pg_geo_rs() # Geo-RS + PG
config = RolloutCorrectionConfig.bypass_pg_geo_rs_token_tis() # Geo-RS + Token-TIS + PG
# === Other ===
config = RolloutCorrectionConfig.disabled() # Metrics only (no correction)
YAML Configuration (Advanced)
For advanced customization or YAML-based configs:
algorithm:
rollout_correction:
rollout_is: token # IS weights: "token", "sequence", or null
rollout_is_threshold: 2.0 # TIS upper bound, or "0.5_5.0" for IcePop
rollout_is_batch_normalize: false # Batch normalize IS weights to mean=1.0
rollout_rs: null # Rejection sampling: comma-separated canonical options (e.g. "token_k1,seq_max_k2")
rollout_rs_threshold: null # Threshold spec: float(s) or "lower_upper" string(s)
bypass_mode: false # Skip old_log_prob computation (sets π_old = π_rollout)
loss_type: ppo_clip # Loss type in bypass mode: "ppo_clip" (default) or "reinforce"
# REQUIRED: Enable log prob calculation
actor_rollout_ref:
rollout:
calculate_log_probs: true
Files
Core Implementation
verl/trainer/ppo/rollout_corr_helper.py- Containscompute_rollout_correction_and_rejection_mask()andcompute_offpolicy_metrics()verl/trainer/ppo/core_algos.py- Rollout Correction integration with PPO and REINFORCE modes (compute_policy_loss_bypass_mode(),compute_policy_loss_reinforce())verl/trainer/ppo/ray_trainer.py- Bypass mode implementation (skipsold_log_probcomputation)verl/workers/utils/losses.py-ppo_lossloss function wired to actorTrainingWorkerviaverl.workers.engine_workers.ActorRolloutRefWorker.init_modelverl/trainer/ppo/core_algos.py-@register_policy_loss("bypass_mode")policy loss that invokescompute_rollout_correction_and_rejection_maskand emits off-policy metrics
Configuration Files
verl/trainer/config/algorithm.py- Rollout Correction parameters inRolloutCorrectionConfigverl/workers/config/actor.py- Rollout Correction parameters inPolicyLossConfigverl/trainer/config/actor/actor.yaml- Rollout Correction configuration sectionverl/trainer/config/ppo_trainer.yaml- Algorithm config with Rollout Correction
Documentation
docs/examples/config.rst- Configuration parameter descriptions
Example Scripts
recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh- DAPO example with Rollout Correctionexamples/rollout_correction/run_qwen2_5_7b_fsdp.sh- Basic exampleexamples/rollout_correction/run_qwen2_5_7b_fsdp_multi_rs.sh- Multi-RS example
Tests
tests/trainer/ppo/test_rollout_corr.py- Unit tests for IS/RS mechanismstests/trainer/ppo/test_rollout_corr_integration.py- Integration tests
Configuration Parameters
All parameters are under algorithm.rollout_correction:
rollout_is (str or null)
Importance sampling weights aggregation level:
null= No IS weights computed (metrics-only mode)"token": Per-token IS weightsDecoupled mode: ρ_t = π_old(t)/π_rollout(t)
Bypass/Pure IS mode: ρ_t = π_θ(t)/π_rollout(t)
Independent truncation per token
Typical threshold: 1.5 - 5.0
"sequence": Per-sequence weight ρ_seq = ∏_t ρ_tMultiplicative aggregation across sequence
Typical threshold: 2.0 - 10.0
All IS weights are safety-bounded to [exp(-20), exp(20)] ≈ [2e-9, 5e8]
rollout_is_threshold (str or float)
Threshold specification for IS weighting. Default: 2.0
Single float or float-like string: TIS via
.clamp(max=rollout_is_threshold)"lower_upper"string such as"0.5_5.0": IcePop, zero weights outside[lower, upper]Applied to IS weights for variance reduction
Separate from rejection sampling (controlled by
rollout_rsparameters)Unlike
rollout_rs, IcePop does not modifyresponse_mask; it only changes the IS coefficients
rollout_is_batch_normalize (bool)
Apply batch normalization to IS weights. Default: False
True: Normalize IS weights to have mean=1.0 within each batchToken-level IS: Normalizes over all token weights
Sequence-level IS: Normalizes over sequence means (one weight per sequence)
False: Use raw (truncated) IS weightsReduces variance by ensuring average weight is 1.0 per batch
Applied AFTER truncation to preserve truncation semantics
Only affects IS weight values, not rejection sampling
rollout_rs (str or null)
Rejection sampling aggregation modes. Supply a comma-separated string (spaces optional) using the canonical options implemented in rollout_corr_helper:
token_k1: Token-level rejection with-log rbounds (ratio thresholds supplied aslower_upper). Example:"0.6_1.4"token_k2: Token-level rejection with0.5 * (log r)^2(upper bound only)token_k3: Token-level rejection withexp(log r) - 1 - log r(upper bound only)seq_sum_k1: Sequence-level rejection with sum of-log r(ratio bounds)seq_sum_k2: Sequence-level rejection with sum of0.5 * (log r)^2(upper bound only)seq_sum_k3: Sequence-level rejection with sum ofexp(log r) - 1 - log r(upper bound only)seq_mean_k1: Sequence-level rejection with mean of-log r(ratio bounds)seq_mean_k2: Sequence-level rejection with mean of0.5 * (log r)^2(upper bound only)seq_mean_k3: Sequence-level rejection with mean ofexp(log r) - 1 - log r(upper bound only)seq_max_k2: Sequence-level rejection with max of0.5 * (log r)^2(upper bound only)seq_max_k3: Sequence-level rejection with max ofexp(log r) - 1 - log r(upper bound only)
rollout_rs_threshold (str, float, or null)
Threshold specification for rejection sampling.
Provide one entry per option, separated by commas. A single entry is broadcast to every option.
K1 KL modes (
*k1): Use"lower_upper"strings (e.g."0.7_1.3"). Supplying a float implies only the upper bound; the lower bound defaults to its reciprocal.K2/K3 KL modes (
*k2/*k3): Supply positive upper bounds (float or numeric string).Set to
nullto disable thresholds entirely (only valid whenrollout_rsis null).
Understanding the Framework: Components and Combinations
The rollout correction framework is built from orthogonal components that can be combined flexibly. Understanding these components helps you choose the right configuration for your scenario.
Key Components
Operating Mode (Section: Operation Modes)
Decoupled: Three policies (π_rollout, π_old, π_θ) with separate π_old computation
Bypass: Two policies (π_rollout = π_old, π_θ), skips π_old computation
Loss Function (in bypass mode, controlled by
loss_type)PPO-clip (
loss_type="ppo_clip", default): PPO clipped objective (IS handled by ratio)REINFORCE (
loss_type="reinforce"): Policy gradient with explicit IS weights (no clipping)
IS/RS Aggregation Level
Token: Per-token IS weights/rejection
Sequence: Sequence-level IS weights/rejection
See Mathematical Formulations for detailed theory.
Preset Configuration Guide
This section provides detailed guidance on choosing and using the verified presets. Each preset is a specific combination of components optimized for common scenarios.
Understanding the Presets
Available Preset Methods
Preset Method |
Estimator |
Mode |
IS Level |
RS Level |
Properties |
|---|---|---|---|---|---|
Decoupled PPO Mode (3 policies: π_rollout, π_old, π_θ) |
|||||
|
Token-TIS |
Decoupled |
token |
- |
Token-level IS weights |
|
Seq-TIS |
Decoupled |
sequence |
- |
Sequence-level IS weights |
|
Seq-MIS |
Decoupled |
sequence |
sequence |
Sequence IS + seq_sum_k1 RS |
|
Geo-RS |
Decoupled |
- |
sequence |
Geometric RS (seq_mean_k1) |
|
Geo-RS-Token-TIS |
Decoupled |
token |
sequence |
Geometric RS + token IS |
K3 KL Estimator (more stable for small KL values) |
|||||
|
K3-RS |
Decoupled |
- |
sequence |
seq_mean_k3 RS |
|
K3-RS-Token-TIS |
Decoupled |
token |
sequence |
seq_mean_k3 RS + token IS |
Bypass Mode (PPO-clip) (2 policies; ratio handles IS, RS masks outliers) |
|||||
|
- |
Bypass (PPO-clip) |
- |
- |
PPO-clip only |
|
Geo-RS |
Bypass (PPO-clip) |
- |
sequence |
PPO-clip + Geo-RS |
|
K3-RS |
Bypass (PPO-clip) |
- |
sequence |
PPO-clip + K3-RS |
Bypass Mode (REINFORCE) (2 policies; explicit IS weights, no PPO clipping) |
|||||
|
Seq-TIS |
Bypass (REINFORCE) |
sequence |
- |
REINFORCE with explicit IS |
|
Geo-RS |
Bypass (REINFORCE) |
- |
sequence |
REINFORCE with Geo-RS |
|
Geo-RS-Token-TIS |
Bypass (REINFORCE) |
token |
sequence |
REINFORCE + Geo-RS + token IS |
Other |
|||||
|
- |
- |
- |
- |
Metrics only, no correction |
Note:
Bypass mode sets π_old = π_rollout and uses
loss_typeto select the loss function:"ppo_clip"(default): PPO clipped objective where ratio = π_θ/π_rollout already handles IS"reinforce": REINFORCE with explicit IS weights as π_θ/π_rollout
Both loss types benefit from rejection sampling (RS) which masks out-of-distribution samples.
All estimators (Token-TIS, Seq-TIS, Seq-MIS, Geo-RS, …) are compatible with Decoupled and Bypass modes.
Other Supported Combinations (Manual Configuration Required)
Other supported combinations without preset methods:
Token IS + Token RS: Token-level IS weights + Token-level RS mask
Pure token RS: Token-level RS only, no IS weights
Pure sequence RS: Sequence-level RS only, no IS weights
See detailed configuration examples below for manual configurations.
Key properties:
Any aggregation level (token/sequence) works in either decoupled or bypass mode
All combinations are fully supported by the implementation
Rejection sampling is independent of IS weighting
Pure RS (
bypass_pg_rs) uses bypass + geometric RS withloss_type="reinforce"(no IS weights)
1. Decoupled Mode with Token-level Importance Sampling (decoupled_token_is)
Configuration:
config = RolloutCorrectionConfig.decoupled_token_is(threshold=2.0)
Components:
Operating Mode: Decoupled (3 policies)
Loss: PPO with clipping (only for the second drift correction)
IS Aggregation: Token-level
RS: None (can be added separately)
Equivalent YAML:
algorithm:
rollout_correction:
rollout_is: token
rollout_is_threshold: 2.0
rollout_rs: null
bypass_mode: false # Decoupled mode
Properties:
Independent truncation per token
Lower variance than sequence-level (product of ratios bounded individually)
Typical threshold: 1.5 - 5.0
Theory: See rollout_corr_math.md §3.3.1
2. Decoupled Mode with Sequence-level Importance Sampling (decoupled_seq_is)
Also known as: Seq-TIS (Sequence-Level Truncated IS)
Configuration:
config = RolloutCorrectionConfig.decoupled_seq_is(threshold=2.0)
Components:
Operating Mode: Decoupled (3 policies)
Loss: PPO with clipping (only for the second drift correction)
IS Aggregation: Sequence-level (Seq-TIS)
RS: None (can be added separately)
Equivalent YAML:
algorithm:
rollout_correction:
rollout_is: sequence
rollout_is_threshold: 2.0
rollout_rs: null
bypass_mode: false # Decoupled mode
Properties:
Multiplicative aggregation across sequence
More sensitive to outliers than token-level
Typical threshold: 2.0 - 10.0 (higher than token-level)
Theory: See rollout_corr_math.md §3.3.2
3. Decoupled Mode with Sequence-level IS + Rejection Sampling (decoupled_seq_is_rs)
Also known as: Seq-MIS (Sequence-Level Masked IS)
Configuration:
config = RolloutCorrectionConfig.decoupled_seq_is_rs(is_threshold=2.0, rs_threshold="0.5_2.0")
Components:
Operating Mode: Decoupled (3 policies)
Loss: PPO with clipping (only for the second drift correction)
IS Aggregation: Sequence-level (Seq-TIS)
RS: Sequence-level rejection (Seq-MIS)
Equivalent YAML:
algorithm:
rollout_correction:
rollout_is: sequence
rollout_is_threshold: 2.0
rollout_rs: seq_sum_k1
rollout_rs_threshold: 0.5_2.0
bypass_mode: false # Decoupled mode
Properties:
Double mechanism: IS reweighting (Seq-TIS) + rejection filtering (Seq-MIS)
Lower effective sample size (rejects outliers)
For severe off-policy gaps or when the distribution tail is “toxic” (garbage/adversarial samples)
When to use Seq-MIS over Seq-TIS:
Seq-TIS (clipping only): Maximizes information efficiency; extracts signal from all samples. Use when data is clean and mismatch is moderate.
Seq-MIS (rejection): Maximizes safety; acts as a hard trust region filter. Use when mismatch is severe or when high-weight samples are likely garbage rather than signal.
Theory: See rollout_corr_math.md §3.5
6. Bypass Mode with PPO-clip (bypass_ppo_clip)
Configuration:
config = RolloutCorrectionConfig.bypass_ppo_clip()
Components:
Operating Mode: Bypass (2 policies: π_rollout = π_old, π_θ)
Loss: PPO-clip (IS handled by ratio, no explicit IS weights)
IS Aggregation: None (PPO ratio handles it)
RS: None
Equivalent YAML:
rollout_correction:
rollout_is: null
rollout_rs: null
bypass_mode: true
loss_type: ppo_clip
Properties:
PPO clipped objective in bypass mode
The PPO ratio = π_θ/π_rollout already handles IS (no explicit IS weights needed)
Skips
actor.compute_log_prob()forward pass (2 policies instead of 3)No rejection sampling - use
bypass_ppo_clip_geo_rs()for RS
Configuration requirement:
Set
actor_rollout_ref.rollout.calculate_log_probs: true
Additional requirements for bypass mode:
Set
actor_rollout_ref.actor.use_rollout_log_probs: trueSet
actor_rollout_ref.actor.policy_loss.loss_mode: bypass_modeSet rollout correction config via
actor_rollout_ref.actor.policy_loss.rollout_correction
Theory: See rollout_corr_math.md §3.1.2
7. REINFORCE with IS (bypass_pg_is)
Configuration:
config = RolloutCorrectionConfig.bypass_pg_is(threshold=2.0)
Components:
Operating Mode: Bypass (2 policies: π_rollout, π_θ)
Loss: REINFORCE (policy gradient with explicit IS weights, no PPO clipping)
IS Aggregation: Sequence-level
RS: None
Equivalent YAML:
rollout_correction:
rollout_is: sequence
rollout_is_threshold: 2.0
rollout_rs: null
bypass_mode: true
loss_type: reinforce # REINFORCE with explicit IS weights
Properties:
REINFORCE loss with explicit IS weights (no PPO clipping)
Single forward pass (skips old_log_prob computation)
IS weights computed on-the-fly in loss function
Theory: See rollout_corr_math.md §3.2.2
Additional Useful Configurations (Not Exposed as Presets)
These configurations are fully supported but don’t have convenience preset methods yet.
1. Token IS + Token RS (token_is_rs)
Token-level IS weights with token-level RS mask.
Python:
config = RolloutCorrectionConfig(
rollout_is="token",
rollout_is_threshold=2.0,
rollout_rs="token_k1",
rollout_rs_threshold=2.0,
)
Properties: Per-token IS weights + per-token RS mask.
2. Pure Token RS (token_rs)
Token-level RS only, no IS weights.
Python:
config = RolloutCorrectionConfig(
rollout_is=None,
rollout_rs="token_k1",
rollout_rs_threshold=2.0,
)
Properties: Token-level RS mask, no IS reweighting.
3. Pure Sequence RS (seq_rs)
Sequence-level RS only, no IS weights.
Python:
config = RolloutCorrectionConfig(
rollout_is=None,
rollout_rs="seq_sum_k1",
rollout_rs_threshold="0.5_2.0",
)
Properties: Sequence-level RS mask, no IS reweighting.
Summary: How IS Weights are Processed
IS weights (rollout_is_weights) go through a fixed processing pipeline:
Stage 1: Safety Bound (Prevent Overflow)
Token level:
exp(clamp(log_ratio, -20, 20))per token → bounds each token to [2e-9, 5e8]Sequence level:
exp(clamp(sum(log_ratio), -20, 20))→ bounds product to [2e-9, 5e8], broadcast to all tokens
Stage 2: Truncation (Reduce Variance)
.clamp(max=rollout_is_threshold)→ caps weights at upper threshold (TIS: Truncated Importance Sampling)No lower truncation (preserves unbiasedness for small weights)
Stage 3: Padding Zeroing (Correct Aggregation)
weights * response_mask→ zeros out padding positions
Stage 4: Optional Batch Normalization
If
rollout_is_batch_normalize=True: Normalize weights to mean=1.0 within batchApplied after truncation to preserve truncation semantics
Rejection Sampling (Separate Mechanism)
Rejection sampling modifies response_mask (NOT weights) through compute_rollout_rejection_mask():
Computes safety-bounded ratios independently
Creates binary mask: tokens/sequences outside [lower_threshold, upper_threshold] → 0 (rejected)
Modified mask used for loss aggregation
Operation Modes
The framework provides two operating modes for computing π_old, which can be combined with different loss functions.
Operating Modes and Configuration
Configuration |
|
|
Operating Mode |
Loss Function |
Description |
|---|---|---|---|---|---|
Decoupled |
|
N/A |
Decoupled |
PPO |
Computes |
Bypass + PPO-clip |
|
|
Bypass |
PPO-clip |
PPO clipped objective (IS handled by ratio) |
Bypass + REINFORCE |
|
|
Bypass |
REINFORCE |
Policy gradient with explicit IS weights (no PPO clipping) |
Operating Mode Details
Decoupled Mode (Three Policies)
Policy setup:
π_rollout: Behavior policy (data collection)
π_old: Proximal policy (computed via
actor.compute_log_prob()at start of training epoch)π_θ: Current policy (being updated)
Configuration: bypass_mode = false
Properties:
✅ Achieves batch size invariance
✅ Separately corrects Drift 1 (rollout→old) and Drift 2 (old→current)
✅ Efficient stale data utilization
❌ Extra forward pass needed (
actor.compute_log_prob())
Theory: See rollout_corr_math.md §3.1.1
Bypass Mode (Two Policies)
Policy setup:
π_rollout: Behavior policy (data collection)
π_old = π_rollout: Proximal policy equals behavior policy
π_θ: Current policy (being updated)
Configuration: bypass_mode = true
Properties:
✅ Skips
actor.compute_log_prob()call (faster)✅ Handles off-policy correction via IS/RS (when using policy gradient with IS/RS)
✅ Uses two policies instead of three (π_rollout = π_old)
⚠️ Does not separate proximal policy from behavior policy (unlike decoupled mode)
Theory: See rollout_corr_math.md §3.1.2
IS/RS Aggregation Levels (Orthogonal to Operating Mode)
The aggregation level can be chosen independently of the operating mode. Any aggregation level works in either decoupled or bypass mode.
|
|
Behavior |
|---|---|---|
|
|
Disabled: No computation, no metrics, no rejection |
|
|
Rejection only: Compute metrics, NO weight correction, YES rejection sampling |
|
|
IS weights only: Weight correction enabled, NO rejection sampling |
|
|
Full correction: Both weight correction and rejection sampling enabled |
Key Insights
✅ Any IS/RS aggregation level (token/sequence/geometric) can be used in either decoupled or bypass mode
✅ You can use rejection sampling alone without IS weight correction (
rollout_is=null, rollout_rs="token_k1")✅ You can use IS weights alone without outlier rejection (
rollout_is="token", rollout_rs=null)✅ You can use both together (
rollout_is="token", rollout_rs="token_k1")✅ You can monitor metrics only without any correction by setting both to
nullbut still providing rollout_log_probs
Theory: See rollout_corr_math.md §3.3 for details on aggregation levels.
Example Workflow
Recommended: Bypass Mode
This workflow uses bypass mode for efficiency.
Start with metrics only to understand the off-policy gap:
rollout_correction: rollout_is: null rollout_rs: null bypass_mode: true # Bypass mode (recommended) loss_type: ppo_clip # Default: PPO clipped objective
Monitor
rollout_corr/kl,rollout_corr/log_ppl_abs_diff,rollout_corr/chi2_tokento assess off-policy gap.Enable rejection sampling if you see high outlier fractions:
rollout_correction: rollout_is: null rollout_rs: sequence # or "geometric" for higher sensitivity rollout_rs_threshold: 2.0 bypass_mode: true # Bypass mode loss_type: ppo_clip # or "reinforce" for explicit IS weights
This excludes outliers from training without modifying gradients.
Enable full IS correction (with REINFORCE loss) once comfortable with metrics:
rollout_correction: rollout_is: sequence # Recommended: unbiased, suitable for most cases rollout_is_threshold: 2.0 rollout_rs: sequence # or "geometric" for more aggressive filtering rollout_rs_threshold: 2.0 bypass_mode: true # Bypass mode loss_type: reinforce # REINFORCE with explicit IS weights
Benefits of bypass mode:
✅ Skips expensive
actor.compute_log_prob()forward pass (faster)✅
loss_typecontrols the loss function: “ppo_clip” (default) or “reinforce”✅ PPO-clip: IS handled by ratio (no explicit weights), RS mask applied
✅ REINFORCE: Explicit IS weights computed on-the-fly (π_θ / π_rollout)
✅ Both loss types work with all IS/RS combinations
Usage
Basic Setup
algorithm:
rollout_correction:
rollout_is: token # Enable IS weights at token level
rollout_is_threshold: 2.0 # Threshold for IS weights
rollout_rs: null # No rejection sampling
actor_rollout_ref:
rollout:
calculate_log_probs: true # Required!
Additional Configurations for Bypass Mode
Set
actor_rollout_ref.actor.use_rollout_log_probs: trueSet
actor_rollout_ref.actor.policy_loss.loss_mode: bypass_modeSet rollout correction config via
actor_rollout_ref.actor.policy_loss.rollout_correction
Metrics
All metrics are prefixed with rollout_corr/ in logs. For example, rollout_is_mean appears as rollout_corr/rollout_is_mean.
These metrics cover both:
Diagnostic metrics: KL divergence, perplexity differences (measuring off-policy gap)
Correction statistics: IS weights, rejection rates (measuring correction applied)
Core IS Weight Metrics
rollout_is_mean: Mean importance sampling weight across all valid tokensValue close to 1.0 indicates minimal off-policy gap
rollout_is_std: Standard deviation of IS weightsHigher values indicate greater variance in IS weights
rollout_is_min: Minimum IS weight observedShows the most underweighted token/sequence
For sequence/geometric: computed from unclamped log-space ratios (true minimum)
For token: computed from safety-bounded weights
rollout_is_max: Maximum IS weight observedShows the most overweighted token/sequence
For sequence/geometric: computed from unclamped log-space ratios (true maximum before safety bound)
For token: computed from safety-bounded weights (before threshold clamping)
Compare with
rollout_is_thresholdto see truncation impact
Effective Sample Size
rollout_is_eff_sample_size: Effective sample size after IS weightingFormula:
1 / mean(weights²)where weights are normalizedRange: 0.0 to 1.0 (as fraction of original batch)
Lower values indicate weight concentration on fewer samples
Threshold Exceedance Metrics
rollout_is_ratio_fraction_high: Fraction of weights exceeding upper thresholdShows how often truncation/masking occurs on high end
For sequence/geometric: computed from unclamped log-space ratios (true exceedance)
For token: computed from safety-bounded weights (before threshold clamping)
rollout_is_ratio_fraction_low: Fraction of weights below lower threshold (1/upper_threshold)Diagnostic metric showing how many weights are below the reciprocal threshold
For sequence/geometric: computed from unclamped log-space ratios (true exceedance)
For token: computed from safety-bounded weights (before truncation)
Sequence-Level Metrics (for sequence aggregation)
rollout_is_seq_mean: Mean IS weight at sequence levelShould match
rollout_is_meanfor sequence-level aggregation
rollout_is_seq_std: Standard deviation of sequence-level IS weightsrollout_is_seq_min: Minimum sequence-level IS weightrollout_is_seq_max: Maximum sequence-level IS weightrollout_is_seq_max_deviation: Maximum absolute deviation from 1.0 at sequence levelShows worst-case sequence off-policy gap
rollout_is_seq_fraction_high: Fraction of sequences exceeding upper thresholdrollout_is_seq_fraction_low: Fraction of sequences below lower threshold
Rejection Sampling Metrics (when rollout_rs is enabled)
rollout_rs_masked_fraction: Fraction of tokens rejected via rejection samplingImportant: Rejection sampling modifies
response_mask(sets rejected tokens to 0)Separate from IS weights: IS weights are still truncated; rejection is an independent filtering step
Only present when
rollout_rsis enabled (token/sequence/geometric)
rollout_rs_seq_masked_fraction: Fraction of sequences with at least one rejected tokenShows sequence-level impact of rejection sampling
Token-level RS: sequence rejected if ANY token is outside [lower, upper]
Sequence-level RS: entire sequence rejected or accepted based on sequence-level ratio
Geometric RS: entire sequence rejected or accepted based on geometric mean
Off-Policy Diagnostic Metrics (Training vs Rollout Policy)
Note on terminology: These metrics use “training” to refer to the training reference policy and “rollout” to refer to π_rollout (the behavior policy used for data collection).
Decoupled mode: “training” = π_old (computed at start of training epoch)
Bypass/Pure IS mode: “training” = π_θ (current policy being trained)
In bypass/pure IS mode, metrics measure the drift between π_θ and π_rollout directly.
training_ppl: Perplexity of training reference policy (π_old in decoupled mode, π_θ in bypass/pure IS mode)Formula:
exp(-mean(log_probs))Lower values indicate higher model confidence
rollout_ppl: Perplexity of rollout policy π_rollout (e.g., vLLM BF16)ppl_ratio: Ratio of training PPL to rollout PPLFormula:
exp(mean(log(training_ppl / rollout_ppl)))Meaning: > 1.0 means training is less confident than rollout
training_log_ppl: Log perplexity of training policyUseful for identifying trends (linear scale)
rollout_log_ppl: Log perplexity of rollout policylog_ppl_diff: Mean difference in log perplexitiesFormula:
mean(log_ppl_rollout - log_ppl_training)Sign indicates which policy is more confident
log_ppl_abs_diff: Mean absolute log perplexity differenceMagnitude of off-policy gap regardless of direction
log_ppl_diff_max: Maximum log perplexity difference across sequencesIdentifies worst-case sequence
log_ppl_diff_min: Minimum log perplexity difference across sequenceskl: KL divergence KL(π_rollout || π_training)Formula:
mean(log_prob_rollout - log_prob_training)Note: Can be negative (rollout is less confident)
k3_kl: K3 divergence (equals KL(π_rollout || π_training) in expectation)Formula:
mean(exp(log_ratio) - log_ratio - 1)More stable than direct KL (non-negative per token)
Always >= 0
chi2_token: Chi-squared divergence at token levelFormula:
mean(ratio²) - 1where ratio = π_training/π_rolloutMeasures second moment of IS weight distribution
Always non-negative
chi2_seq: Chi-squared divergence at sequence levelFormula:
mean((∏_t ratio_t)²) - 1Sequence-level second moment of IS weights
More sensitive than token-level chi-squared
Example: Accessing Metrics in Code
# Metrics are returned from compute_rollout_correction_and_rejection_mask
from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_rejection_mask
# Returns 3 values (weights, modified_response_mask, metrics)
weights_proto, modified_response_mask, metrics = compute_rollout_correction_and_rejection_mask(
old_log_prob=training_log_probs, # from training policy
rollout_log_prob=rollout_log_probs, # from rollout policy
response_mask=response_mask,
rollout_is="token", # Enable IS weights at token level
rollout_is_threshold=2.0,
rollout_rs="token_k1",
rollout_rs_threshold="0.5_2.0",
)
# Extract IS weights (processed, zeroed at padding)
is_weights = weights_proto.batch["rollout_is_weights"]
# IS weights processing (with IS enabled at token level):
# 1. Safety-bounded: exp(clamp(log_ratio, -20, 20)) per token
# 2. Truncated: .clamp(max=2.0) to cap extreme weights
# 3. Zeroed at padding positions
# Note: Truncation is ALWAYS applied to IS weights (TIS: Truncated Importance Sampling)
# modified_response_mask has rejection applied (since rollout_rs="token_k1"):
# 1. RS rejection: tokens outside [0.5, 2.0] masked to 0 via response_mask
# Note: RS and IS are separate mechanisms - both can be enabled independently
# All metrics have 'rollout_corr/' prefix
print(f"Mean IS weight: {metrics['rollout_corr/rollout_is_mean']:.3f}")
print(f"Effective sample size: {metrics['rollout_corr/rollout_is_eff_sample_size']:.3f}")
print(f"RS masked fraction: {metrics['rollout_corr/rollout_rs_masked_fraction']:.3f}")
print(f"KL divergence: {metrics['rollout_corr/kl']:.3f}")
# Check IS weights for valid tokens (non-padding)
valid_weights = is_weights[response_mask.bool()]
print(f"\n✓ IS weights min (valid tokens): {valid_weights.min():.4f}")
print(f"✓ IS weights max (valid tokens): {valid_weights.max():.4f}")
print(f"✓ All valid IS weights > 0: {(valid_weights > 0).all()}")
print(f"✓ IS weights are capped at threshold: {(valid_weights <= 2.0).all()}")
# Check rejection via response_mask
rejected_tokens = (response_mask == 1) & (modified_response_mask == 0)
print(f"\n✓ Rejected {rejected_tokens.sum()} tokens via response_mask")
print(f"✓ Rejection sampling modifies response_mask (separate from IS weight truncation)")
print(f"✓ IS weights are always truncated to [0, threshold] after safety bounding")
# Check for warning conditions
if metrics['rollout_corr/rollout_is_mean'] < 0.5 or metrics['rollout_corr/rollout_is_mean'] > 2.0:
print("⚠️ Warning: Mean IS weight far from 1.0, significant off-policy gap detected")
if metrics['rollout_corr/rollout_is_eff_sample_size'] < 0.3:
print("⚠️ Warning: Low effective sample size, high weight concentration")
Example: Monitoring Metrics During Training
# In your training loop
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(dataloader):
# ... rollout phase ...
# Compute IS weights and get metrics
rollout_corr_config = config.algorithm.get("rollout_correction", None)
if rollout_corr_config is not None:
weights_proto, modified_response_mask, metrics = compute_rollout_correction_and_rejection_mask(
old_log_prob=batch.old_log_prob,
rollout_log_prob=batch.rollout_log_prob,
response_mask=batch.response_mask,
rollout_is=rollout_corr_config.get("rollout_is", None),
rollout_is_threshold=rollout_corr_config.get("rollout_is_threshold", 2.0),
rollout_rs=rollout_corr_config.get("rollout_rs", None),
rollout_rs_threshold=rollout_corr_config.get("rollout_rs_threshold", None),
)
# Log to tensorboard/wandb
for metric_name, metric_value in metrics.items():
logger.log_scalar(metric_name, metric_value, step=global_step)
# IMPORTANT: Update batch response_mask with rejection applied
batch.response_mask = modified_response_mask
# Use IS weights in training (always safety-bounded, zeroed at padding)
is_weights = weights_proto.batch["rollout_is_weights"]
# ... apply weights to policy gradient ...
Example: Conditional Alerting Based on Metrics
def check_rollout_correction_health(metrics, config):
"""Check if Rollout Correction metrics indicate healthy training."""
warnings = []
# Check mean IS weight
mean_weight = metrics['rollout_corr/rollout_is_mean']
if mean_weight < 0.5 or mean_weight > 2.0:
warnings.append(f"Mean IS weight {mean_weight:.3f} is far from 1.0")
# Check effective sample size
ess = metrics['rollout_corr/rollout_is_eff_sample_size']
if ess < 0.3:
warnings.append(f"Effective sample size {ess:.3f} is too low")
# Check standard deviation
std = metrics['rollout_corr/rollout_is_std']
if std > 1.0:
warnings.append(f"IS weight std {std:.3f} is too high")
# Check KL divergence
kl = metrics['rollout_corr/kl']
if abs(kl) > 0.1:
warnings.append(f"KL divergence {kl:.3f} indicates significant off-policy gap")
# Check chi-squared divergence
if 'rollout_corr/chi2_token' in metrics:
chi2_token = metrics['rollout_corr/chi2_token']
if chi2_token > 1.0:
warnings.append(f"Chi-squared divergence (token) {chi2_token:.3f} indicates severe distribution shift")
if warnings:
print("⚠️ Rollout Correction Health Warnings:")
for warning in warnings:
print(f" - {warning}")
return False
else:
print("✅ Rollout Correction metrics look healthy")
return True
# Use in training
_, _, metrics = compute_rollout_correction_and_rejection_mask(...)
is_healthy = check_rollout_correction_health(metrics, config)
if not is_healthy:
# Consider adjusting config or investigating issues
print("Consider:")
print(" - Tightening rollout_is_threshold")
print(" - Switching to geometric aggregation level")
print(" - Checking if rollout and training policies are too different")
Running Examples
Start with the basic token-level truncate configuration:
bash examples/rollout_correction/run_qwen2_5_7b_fsdp.sh
Monitor metrics for 1-2 epochs before adjusting parameters.
Configuration Examples
Example 1: IS Weights Only (Token Level)
algorithm:
rollout_correction:
rollout_is: token
rollout_is_threshold: 2.0
rollout_rs: null # No rejection sampling
Example 2: Rejection Sampling Only (No IS Weights)
algorithm:
rollout_correction:
rollout_is: null # No IS weights
rollout_rs: token_k1
rollout_rs_threshold: "0.5_2.0"
Example 3: Both IS and RS (Token RS)
algorithm:
rollout_correction:
rollout_is: token
rollout_is_threshold: 2.0
rollout_rs: token_k1
rollout_rs_threshold: "0.5_2.0"
Example 5: Bypass Mode with PPO-clip (Default)
algorithm:
rollout_correction:
rollout_is: token
rollout_is_threshold: 2.0
rollout_rs: token_k1
rollout_rs_threshold: "0.5_2.0"
bypass_mode: true # Skip old_log_prob computation
loss_type: ppo_clip # PPO clipped objective (default)
Skips expensive actor.compute_log_prob() forward pass. PPO ratio = π_θ/π_rollout handles IS.
Example 6: Bypass Mode with REINFORCE
rollout_correction:
rollout_is: sequence # Explicit IS correction in loss
rollout_is_threshold: 2.0
rollout_rs: null # Optional: can add rejection sampling
bypass_mode: true
loss_type: reinforce # REINFORCE with explicit IS weights
No PPO clipping, pure policy gradient with IS correction
Example 7: Bypass Mode with PPO-clip + Rejection Sampling
rollout_correction:
rollout_is: sequence # Computed for metrics
rollout_is_threshold: 2.0
rollout_rs: seq_max_k2 # Sequence max χ²/2 guard
rollout_rs_threshold: 2.5
bypass_mode: true
loss_type: ppo_clip # PPO clipped objective (IS handled by ratio)
PPO clipping with rejection sampling. IS handled by PPO ratio (no explicit IS weights).
Troubleshooting
Issue: High spread in IS weights
Symptoms: rollout_is_std > 1.0, rollout_is_eff_sample_size < 0.3
Solutions:
Switch from
sequencetogeometriclevelTighten thresholds
Verify rollout and training aren’t too different
Issue: Mean IS weight far from 1.0
Symptoms: rollout_is_mean < 0.5 or > 2.0
Solutions:
Verify
calculate_log_probs=Trueis setCheck rollout_log_probs are correctly passed
Check for systematic distribution shift
Debugging: Visualizing Metrics
Example: Plot IS weight distribution
import matplotlib.pyplot as plt
import numpy as np
def plot_is_metrics(metrics_history):
"""Plot rollout IS metrics over training steps."""
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
# Plot 1: Mean IS weight over time
axes[0, 0].plot(metrics_history['rollout_corr/rollout_is_mean'])
axes[0, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
axes[0, 0].set_title('Mean IS Weight')
axes[0, 0].set_xlabel('Step')
axes[0, 0].legend()
# Plot 2: Effective sample size
axes[0, 1].plot(metrics_history['rollout_corr/rollout_is_eff_sample_size'])
axes[0, 1].axhline(y=0.5, color='g', linestyle='--', label='Good')
axes[0, 1].axhline(y=0.3, color='r', linestyle='--', label='Warning')
axes[0, 1].set_title('Effective Sample Size')
axes[0, 1].set_xlabel('Step')
axes[0, 1].legend()
# Plot 3: KL divergence over time
axes[1, 0].plot(metrics_history['rollout_corr/kl'], label='KL')
axes[1, 0].plot(metrics_history['rollout_corr/k3_kl'], label='K3 KL')
axes[1, 0].axhline(y=0, color='g', linestyle='--', alpha=0.3)
axes[1, 0].set_title('KL Divergence')
axes[1, 0].set_xlabel('Step')
axes[1, 0].legend()
# Plot 4: PPL ratio over time
axes[1, 1].plot(metrics_history['rollout_corr/ppl_ratio'])
axes[1, 1].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
axes[1, 1].set_title('PPL Ratio (Training/Rollout)')
axes[1, 1].set_xlabel('Step')
axes[1, 1].legend()
# Plot 5: Chi-squared divergence
if 'rollout_corr/chi2_token' in metrics_history:
axes[1, 2].plot(metrics_history['rollout_corr/chi2_token'], label='Token-level')
if 'rollout_corr/chi2_seq' in metrics_history:
axes[1, 2].plot(metrics_history['rollout_corr/chi2_seq'], label='Seq-level')
axes[1, 2].axhline(y=1.0, color='r', linestyle='--', label='Warning')
axes[1, 2].set_title('Chi-squared Divergence')
axes[1, 2].set_xlabel('Step')
axes[1, 2].legend()
else:
axes[1, 2].axis('off')
plt.tight_layout()
plt.savefig('rollout_is_metrics.png', dpi=150)
print("Saved plot to rollout_is_metrics.png")
Example: Metric collection during training
# Collect metrics over time
metrics_history = {
'rollout_corr/rollout_is_mean': [],
'rollout_corr/rollout_is_eff_sample_size': [],
'rollout_corr/kl': [],
'rollout_corr/k3_kl': [],
'rollout_corr/ppl_ratio': [],
'rollout_corr/chi2_token': [],
'rollout_corr/chi2_seq': [],
}
# In training loop
for step in range(num_steps):
# ... compute IS weights and rejection mask ...
_, _, metrics = compute_rollout_correction_and_rejection_mask(...)
# Store metrics
for key in metrics_history.keys():
if key in metrics:
metrics_history[key].append(metrics[key])
# Plot every 100 steps
if step % 100 == 0:
plot_is_metrics(metrics_history)
Performance Impact
Memory overhead: ~1% of model memory
Computational overhead: 1-3% depending on level
Training stability: Significantly improved when off-policy gap exists
Testing
Run the test suite to verify everything works:
# Basic unit tests
python tests/trainer/ppo/test_rollout_corr.py
# Integration tests (if pytest is available)
pytest tests/trainer/ppo/test_rollout_corr_integration.py -v
Expected output: All tests pass ✓
Additional Resources
Implementation:
verl/trainer/ppo/rollout_corr_helper.pyExamples:
examples/rollout_correction/DAPO Example:
recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh
Summary
Rollout Correction provides a unified framework for handling general off-policy problems in RL:
✅ Corrects ANY distribution shift between data collection and training
✅ Supports diverse scenarios: policy mismatch, staleness, replay buffers, off-policy algorithms
✅ Numerical stability with safety bounds and rejection mechanisms
✅ Comprehensive diagnostics: KL, perplexity, χ² divergence
✅ Flexible methods from token-level to sequence-level aggregation
✅ Memory-efficient implementation
References
Mathematical Formulations - Detailed mathematical theory and derivations for all rollout correction methods
Your Efficient RL Framework Secretly Brings You Off-Policy RL Training