# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Core functions to implement PPO algorithms.
The function implemented in this file should be used by trainer with different distributed strategies to
implement PPO-like algorithms.
"""
__all__ = ["register_adv_est", "get_adv_estimator_fn", "AdvantageEstimator"]
from collections import defaultdict
from enum import Enum
from typing import Optional
import numpy as np
import torch
import verl.utils.torch_functional as verl_F
from verl.trainer.config import AlgoConfig
POLICY_LOSS_REGISTRY = {}
def register_policy_loss(name):
"""Register a policy loss function with the given name.
Args:
name (str): The name to register the policy loss function under.
Returns:
function: Decorator function that registers the policy loss function.
"""
def decorator(func):
POLICY_LOSS_REGISTRY[name] = func
return func
return decorator
def get_policy_loss_fn(name):
"""Get the policy loss with a given name.
Args:
name: `(str)`
The name of the policy loss.
Returns:
`(callable)`: The policy loss function.
"""
loss_name = name
if loss_name not in POLICY_LOSS_REGISTRY:
raise ValueError(
f"Unsupported loss mode: {loss_name}. Supported modes are: {list(POLICY_LOSS_REGISTRY.keys())}"
)
return POLICY_LOSS_REGISTRY[loss_name]
ADV_ESTIMATOR_REGISTRY = {}
def register_adv_est(name_or_enum):
"""Decorator to register a advantage estimator function with a given name.
Args:
name_or_enum: `(str)` or `(AdvantageEstimator)`
The name or enum of the advantage estimator.
"""
def decorator(fn):
name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum
if name in ADV_ESTIMATOR_REGISTRY and ADV_ESTIMATOR_REGISTRY[name] != fn:
raise ValueError(
f"Adv estimator {name} has already been registered: {ADV_ESTIMATOR_REGISTRY[name]} vs {fn}"
)
ADV_ESTIMATOR_REGISTRY[name] = fn
return fn
return decorator
def get_adv_estimator_fn(name_or_enum):
"""Get the advantage estimator function with a given name.
Args:
name_or_enum: `(str)` or `(AdvantageEstimator)`
The name or enum of the advantage estimator.
Returns:
`(callable)`: The advantage estimator function.
"""
name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum
if name not in ADV_ESTIMATOR_REGISTRY:
raise ValueError(f"Unknown advantage estimator simply: {name}")
return ADV_ESTIMATOR_REGISTRY[name]
class AdvantageEstimator(str, Enum):
"""Using an enumeration class to avoid spelling errors in adv_estimator.
Note(haibin.lin): this enum class is immutable after creation. Extending this
enum for new estimators may not be necessary since users can always just call
`verl.trainer.ppo.core_algos.register` with string name for a custom advantage
estimator instead.
"""
GAE = "gae"
GRPO = "grpo"
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
REMAX = "remax"
RLOO = "rloo"
OPO = "opo"
GRPO_PASSK = "grpo_passk"
GPG = "gpg"
class AdaptiveKLController:
"""
Adaptive KL controller described in the paper:
https://arxiv.org/pdf/1909.08593.pdf
"""
def __init__(self, init_kl_coef, target_kl, horizon):
self.value = init_kl_coef
self.target = target_kl
self.horizon = horizon
def update(self, current_kl, n_steps):
"""Update the KL coefficient based on current KL divergence.
Args:
current_kl (float): Current KL divergence value.
n_steps (int): Number of steps taken.
"""
target = self.target
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult
class FixedKLController:
"""Fixed KL controller."""
def __init__(self, kl_coef):
self.value = kl_coef
def update(self, current_kl, n_steps):
"""Update method for fixed KL controller (no-op).
Args:
current_kl (float): Current KL divergence value (unused).
n_steps (int): Number of steps taken (unused).
"""
pass
def get_kl_controller(kl_ctrl):
"""Factory function to create appropriate KL controller based on configuration.
Args:
kl_ctrl: Configuration object containing KL controller settings.
Returns:
KL controller instance (FixedKLController or AdaptiveKLController).
Raises:
NotImplementedError: If controller type is not supported.
AssertionError: If adaptive controller horizon is not positive.
"""
if kl_ctrl.type == "fixed":
return FixedKLController(kl_coef=kl_ctrl.kl_coef)
elif kl_ctrl.type == "adaptive":
assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}"
return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)
else:
raise NotImplementedError
@register_adv_est(AdvantageEstimator.GAE) # or simply: @register_adv_est("gae")
def compute_gae_advantage_return(
token_level_rewards: torch.Tensor,
values: torch.Tensor,
response_mask: torch.Tensor,
gamma: torch.Tensor,
lam: torch.Tensor,
):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
Args:
token_level_rewards: `(torch.Tensor)`
shape is (bs, response_length)
values: `(torch.Tensor)`
shape is (bs, response_length)
response_mask: `(torch.Tensor)`
shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
gamma is `(float)`
discounted factor used in RL
lam: `(float)`
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
with torch.no_grad():
nextvalues = 0
lastgaelam = 0
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]
for t in reversed(range(gen_len)):
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
lastgaelam_ = delta + gamma * lam * lastgaelam
# skip values and TD-error on observation tokens
nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues
lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values
advantages = verl_F.masked_whiten(advantages, response_mask)
return advantages, returns
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
@register_adv_est(AdvantageEstimator.GRPO) # or simply: @register_adv_est("grpo")
def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for GRPO, operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape is (bs, response_length)
response_mask: `(torch.Tensor)`
shape is (bs, response_length)
index: `(np.ndarray)`
index array for grouping
epsilon: `(float)`
small value to avoid division by zero
norm_adv_by_std_in_grpo: `(bool)`
whether to scale the GRPO advantage
config: `(Optional[AlgoConfig])`
algorithm configuration object
Note:
If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO.
If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).
Returns:
advantages: `(torch.Tensor)`
shape is (bs, response_length)
Returns: `(torch.Tensor)`
shape is (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
if norm_adv_by_std_in_grpo:
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
else:
scores[i] = scores[i] - id2mean[index[i]]
scores = scores.unsqueeze(-1) * response_mask
return scores, scores
@register_adv_est(AdvantageEstimator.GRPO_PASSK) # or simply: @register_adv_est("grpo_passk")
def compute_grpo_passk_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for Pass@k using a GRPO-style outcome reward formulation.
Only the best response per group gets a non-zero advantage: r_max - r_second_max.
Implemented as described in https://arxiv.org/abs/2503.19595.
Args:
token_level_rewards: (bs, response_length)
response_mask: (bs, response_length)
index: (bs,) → group ID per sample
epsilon: float for numerical stability
config: (AlgoConfig) algorithm settings, which contains "norm_adv_by_std_in_grpo"
Returns:
advantages: (bs, response_length)
returns: (bs, response_length)
"""
assert config is not None
# if True, normalize advantage by std within group
norm_adv_by_std_in_grpo = config.get("norm_adv_by_std_in_grpo", True)
scores = token_level_rewards.sum(dim=-1) # (bs,)
advantages = torch.zeros_like(scores)
id2scores = defaultdict(list)
id2indices = defaultdict(list)
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
idx = index[i]
id2scores[idx].append(scores[i])
id2indices[idx].append(i)
for idx in id2scores:
rewards = torch.stack(id2scores[idx]) # (k,)
if rewards.numel() < 2:
raise ValueError(
f"Pass@k requires at least 2 samples per group. Got {rewards.numel()} for group {idx}."
)
topk, topk_idx = torch.topk(rewards, 2)
r_max, r_second_max = topk[0], topk[1]
i_max = id2indices[idx][topk_idx[0].item()]
advantage = r_max - r_second_max
if norm_adv_by_std_in_grpo:
std = torch.std(rewards)
advantage = advantage / (std + epsilon)
advantages[i_max] = advantage
advantages = advantages.unsqueeze(-1) * response_mask
return advantages, advantages
@register_adv_est(
AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE
) # or simply: @register_adv_est("reinforce_plus_plus_baseline")
def compute_reinforce_plus_plus_baseline_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: torch.Tensor,
epsilon: float = 1e-6,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = scores[i] - id2mean[index[i]]
scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask
scores = verl_F.masked_whiten(scores, response_mask) * response_mask
return scores, scores
@register_adv_est(AdvantageEstimator.RLOO) # or simply: @register_adv_est("rloo")
def compute_rloo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
response_num = len(id2score[index[i]])
if response_num > 1:
scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (
response_num - 1
)
scores = scores.unsqueeze(-1) * response_mask
return scores, scores
@register_adv_est(AdvantageEstimator.OPO) # or simply: @register_adv_est("opo")
def compute_opo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for OPO based on https://arxiv.org/pdf/2505.23585
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = response_mask.sum(dim=-1)
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2len = defaultdict(list)
id2bsl = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
id2len[index[i]].append(response_length[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2bsl[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
score_tensor = torch.tensor(id2score[idx])
len_tensor = torch.tensor(id2len[idx])
id2bsl[idx] = (len_tensor * score_tensor).sum() / len_tensor.sum()
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = scores[i] - id2bsl[index[i]]
scores = scores.unsqueeze(-1) * response_mask
return scores, scores
@register_adv_est(AdvantageEstimator.REINFORCE_PLUS_PLUS) # or simply: @register_adv_est("reinforce_plus_plus")
def compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards: torch.Tensor, response_mask: torch.Tensor, config: Optional[AlgoConfig] = None, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for REINFORCE++.
This implementation is based on the paper: https://arxiv.org/abs/2501.03262
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
assert config is not None
gamma = config.gamma
with torch.no_grad():
returns = torch.zeros_like(token_level_rewards)
running_return = 0
for t in reversed(range(token_level_rewards.shape[1])):
running_return = token_level_rewards[:, t] + gamma * running_return
returns[:, t] = running_return
# Reset after EOS
running_return = running_return * response_mask[:, t]
advantages = verl_F.masked_whiten(returns, response_mask)
advantages = advantages * response_mask
return advantages, returns
@register_adv_est(AdvantageEstimator.REMAX) # or simply: @register_adv_est("remax")
def compute_remax_outcome_advantage(
token_level_rewards: torch.Tensor,
reward_baselines: torch.Tensor,
response_mask: torch.Tensor,
config: Optional[AlgoConfig] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for ReMax, operating only on Outcome reward
This implementation is based on the paper: https://arxiv.org/abs/2310.10505
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
reward_baselines: `(torch.Tensor)`
shape: (bs,)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
config: (AlgoConfig) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
with torch.no_grad():
returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
advantages = returns - reward_baselines.unsqueeze(-1) * response_mask
return advantages, returns
@register_adv_est(AdvantageEstimator.GPG) # or simply: @register_adv_est("gpg")
def compute_gpg_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
f_norm: float = 1.0,
alpha: float = 1.0,
config=None,
**kwargs,
):
"""
Compute advantage for GPG, operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
index: `(np.ndarray)`
shape: (bs,)
epsilon: (float)
f_norm: (float)
alpha: (float)
config: (dict) algorithm config
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
m = torch.count_nonzero(scores)
alpha = bsz / m.clamp(min=1)
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = alpha * (scores[i] - id2mean[index[i]]) / (f_norm)
scores = scores.unsqueeze(-1) * response_mask
return scores, scores
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
"""Compute token-level rewards with KL penalty.
Args:
token_level_scores (torch.Tensor): Token-level reward scores.
old_log_prob (torch.Tensor): Log probabilities from current policy.
ref_log_prob (torch.Tensor): Log probabilities from reference policy.
kl_ratio (float): KL penalty coefficient.
Returns:
torch.Tensor: Token-level rewards with KL penalty applied.
"""
kl = old_log_prob - ref_log_prob
return token_level_scores - kl * kl_ratio
[docs]
def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str):
"""
Aggregate the loss matrix into a scalar.
Args:
loss_mat: `(torch.Tensor)`:
shape: (bs, response_length)
loss_mask: `(torch.Tensor)`:
shape: (bs, response_length)
loss_agg_mode: (str) choices:
method to aggregate the loss matrix into a scalar.
Returns:
loss: `a scalar torch.Tensor`
aggregated loss
"""
if loss_agg_mode == "token-mean":
loss = verl_F.masked_mean(loss_mat, loss_mask)
elif loss_agg_mode == "seq-mean-token-sum":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum
loss = torch.mean(seq_losses) # seq-mean
elif loss_agg_mode == "seq-mean-token-mean":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean
loss = torch.mean(seq_losses) # seq-mean
elif loss_agg_mode == "seq-mean-token-sum-norm":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
loss = torch.sum(seq_losses) / loss_mask.shape[-1] # The divisor
# (loss_mask.shape[-1]) should ideally be constant
# throughout training to well-replicate the DrGRPO paper.
# TODO: Perhaps add user-defined normalizer argument to
# agg_loss to ensure divisor stays constant throughout.
else:
raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")
return loss
[docs]
def compute_policy_loss(
old_log_prob,
log_prob,
advantages,
response_mask,
cliprange=None,
cliprange_low=None,
cliprange_high=None,
clip_ratio_c=3.0,
loss_agg_mode: str = "token-mean",
):
"""
Compute the clipped policy objective and related metrics for PPO.
Adapted from
https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
cliprange (float, optional):
Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
Defaults to None (must be provided).
cliprange_low (float, optional):
Lower clip range for dual-clip PPO. Defaults to same as `cliprange`.
cliprange_high (float, optional):
Upper clip range for dual-clip PPO. Defaults to same as `cliprange`.
clip_ratio_c (float, optional):
Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.
Defaults to 3.0.
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
"""
assert clip_ratio_c > 1.0, (
"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,"
+ f" but get the value: {clip_ratio_c}."
)
negative_approx_kl = log_prob - old_log_prob
# Clamp negative_approx_kl for stability
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
pg_losses1 = -advantages * ratio
if cliprange_low is None:
cliprange_low = cliprange
if cliprange_high is None:
cliprange_high = cliprange
pg_losses2 = -advantages * torch.clamp(
ratio, 1 - cliprange_low, 1 + cliprange_high
) # - clip(ratio, 1-cliprange, 1+cliprange) * A
clip_pg_losses1 = torch.maximum(
pg_losses1, pg_losses2
) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
pg_losses3 = -advantages * clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
pg_clipfrac_lower = verl_F.masked_mean(
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask
)
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
@register_policy_loss("gpg")
def compute_policy_loss_gpg(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode="token-mean", config=None):
"""Adapted from
https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495
Args:
log_prob: `(torch.Tensor)`
shape: (bs, response_length)
advantages: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
return:
pg_loss: `a scalar torch.Tensor`
policy gradient loss computed via GPG
"""
pg_losses = -log_prob * advantages
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
@register_policy_loss("clip_cov")
def compute_policy_loss_clip_cov(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for Clip-Cov.
Adapted from
https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py
Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
cliprange (float, optional):
Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
Defaults to None (must be provided).
cliprange_low (float, optional):
Lower clip range for dual-clip PPO. Defaults to same as `cliprange`.
cliprange_high (float, optional):
Upper clip range for dual-clip PPO. Defaults to same as `cliprange`.
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
clip_cvo_ratio (float, optional):
Ratio for clipping the covariance. Defaults to 0.0002.
clip_cov_lb (float, optional):
Lower bound for clipping covariance. Defaults to 1.0.
clip_cov_ub (float, optional):
Upper bound for clipping covariance. Defaults to 5.0.
"""
clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002
cliprange = config.clip_ratio
cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange
cliprange_high = config.clip_ratio_high if config.clip_ratio_high is not None else cliprange
clip_cov_ub = config.policy_loss.clip_cov_ub if config.policy_loss.clip_cov_ub is not None else 5.0
clip_cov_lb = config.policy_loss.clip_cov_lb if config.policy_loss.clip_cov_lb is not None else 1.0
assert clip_cov_ratio > 0, "clip_ratio should be larger than 0."
negative_approx_kl = log_prob - old_log_prob
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
pg_losses1 = -advantages * ratio
if cliprange_low is None:
cliprange_low = cliprange
if cliprange_high is None:
cliprange_high = cliprange
corr = torch.ones_like(advantages)
pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
clip_by_origin = (pg_losses2 > pg_losses1) & (response_mask > 0)
cov_all = (advantages - verl_F.masked_mean(advantages, response_mask)) * (
log_prob - verl_F.masked_mean(log_prob.detach(), response_mask)
)
cov_all[response_mask == 0] = -torch.inf
cov_all[clip_by_origin] = -torch.inf
clip_num = max(int(clip_cov_ratio * response_mask.sum().item()), 1)
top_k_idx = (cov_all < clip_cov_ub) & (cov_all > clip_cov_lb) & (response_mask > 0)
top_k_idx = torch.nonzero(top_k_idx)
if len(top_k_idx) > 0:
perm = torch.randperm(len(top_k_idx))
top_k_idx = top_k_idx[perm[: min(clip_num, len(top_k_idx))]]
else:
top_k_idx = torch.empty((0, 2), device=cov_all.device, dtype=torch.long)
corr[top_k_idx[:, 0], top_k_idx[:, 1]] = 0
pg_clipfrac = verl_F.masked_mean((corr == 0).float(), response_mask)
pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0)
@register_policy_loss("kl_cov")
def compute_policy_loss_kl_cov(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for Clip-Cov.
Adapted from
https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py
Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
kl_cov_ratio (float, optional):
Ratio for selecting the top-k covariance values. Defaults to 0.0002.
ppo_kl_coef (float, optional):
Coefficient for the KL penalty term in the loss. Defaults to 1.
"""
kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002
ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0
assert kl_cov_ratio > 0, "kl_cov_ratio should be larger than 0."
negative_approx_kl = log_prob - old_log_prob
abs_kl = negative_approx_kl.abs()
ratio = torch.exp(negative_approx_kl)
ppo_kl_abs = verl_F.masked_mean(negative_approx_kl.abs(), response_mask)
pg_losses1 = -advantages * ratio
pg_losses_kl = -advantages * ratio + ppo_kl_coef * abs_kl
pg_losses = pg_losses1
all_valid = response_mask > 0
all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0]
all_valid_adv = advantages[all_valid].detach().reshape(-1).cpu()
all_valid_logp = log_prob[all_valid].detach().reshape(-1).cpu()
k = min(kl_cov_ratio, len(all_valid_adv))
if k != 0:
cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean())
k_percent_nums = max(1, int(len(cov_lst_all) * kl_cov_ratio))
large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices
if len(large_cov_idxs) != 0:
large_cov_idxs = all_valid_idx[large_cov_idxs]
pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[
large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]
]
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0)
def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean"):
"""Compute categorical entropy loss (For backward compatibility)
Args:
logits (torch.Tensor): shape is (bs, response_length, vocab_size)
response_mask (torch.Tensor): shape is (bs, response_length)
Returns:
entropy: a scalar torch.Tensor
"""
# compute entropy
token_entropy = verl_F.entropy_from_logits(logits) # (bs, response_len)
entropy_loss = agg_loss(loss_mat=token_entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return entropy_loss
def compute_value_loss(
vpreds: torch.Tensor,
returns: torch.Tensor,
values: torch.Tensor,
response_mask: torch.Tensor,
cliprange_value: float,
loss_agg_mode: str = "token-mean",
):
"""
Compute the clipped value-function loss for PPO.
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151
Args:
vpreds (torch.FloatTensor):
Predicted values from the value head, shape (batch_size, response_length).
values (torch.FloatTensor):
Old (baseline) values from the value head, shape (batch_size, response_length).
returns (torch.FloatTensor):
Ground-truth returns, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the value loss calculation.
cliprange_value (float):
Clip range for value prediction updates.
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
Returns:
vf_loss (torch.FloatTensor):
A scalar tensor containing the aggregated value-function loss.
vf_clipfrac (float):
Fraction of elements where the clipped loss was used.
"""
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
vf_losses1 = (vpreds - returns) ** 2
vf_losses2 = (vpredclipped - returns) ** 2
clipped_vf_losses = torch.max(vf_losses1, vf_losses2)
vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)
return vf_loss, vf_clipfrac
[docs]
def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
"""Compute KL divergence given logprob and ref_logprob.
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
See more description in http://joschu.net/blog/kl-approx.html
Args:
logprob:
ref_logprob:
Returns:
"""
if kl_penalty in ("kl", "k1"):
return logprob - ref_logprob
if kl_penalty == "abs":
return (logprob - ref_logprob).abs()
if kl_penalty in ("mse", "k2"):
return 0.5 * (logprob - ref_logprob).square()
# J. Schulman. Approximating kl divergence, 2020.
# # URL http://joschu.net/blog/kl-approx.html.
if kl_penalty in ("low_var_kl", "k3"):
kl = ref_logprob - logprob
# For numerical stability
kl = torch.clamp(kl, min=-20, max=20)
ratio = torch.exp(kl)
kld = (ratio - kl - 1).contiguous()
return torch.clamp(kld, min=-10, max=10)
if kl_penalty == "full":
# so, here logprob and ref_logprob should contain the logits for every token in vocabulary
raise NotImplementedError
raise NotImplementedError
def compute_pf_ppo_reweight_data(
data,
reweight_method: str = "pow",
weight_pow: float = 2.0,
):
"""Reweight the data based on the token_level_scores.
Args:
data: DataProto object, containing batch, non_tensor_batch and meta_info
reweight_method: str, choices: "pow", "max_min", "max_random"
weight_pow: float, the power of the weight
Returns:
"""
@torch.no_grad()
def compute_weights(scores: torch.Tensor, reweight_method: str, weight_pow: float) -> torch.Tensor:
"""Compute importance weights for resampling based on scores.
Args:
scores (torch.Tensor): Tensor of scores to compute weights from.
reweight_method (str): Method for computing weights ('pow', 'max_min', 'max_random').
weight_pow (float): Power exponent for 'pow' method.
Returns:
torch.Tensor: Computed importance weights.
Raises:
ValueError: If reweight_method is not supported.
"""
if reweight_method == "pow":
weights = torch.pow(torch.abs(scores), weight_pow)
elif reweight_method == "max_min":
max_score = torch.max(scores)
min_score = torch.min(scores)
weights = torch.where((scores == max_score) | (scores == min_score), 1.0, 0.0)
elif reweight_method == "max_random":
max_score = torch.max(scores)
weights = torch.where(scores == max_score, 0.4, 0.1)
else:
raise ValueError(f"Unsupported reweight_method: {reweight_method}")
return weights
scores = data.batch["token_level_scores"].sum(dim=-1)
weights = compute_weights(scores, reweight_method, weight_pow)
weights = torch.clamp(weights + 1e-8, min=1e-8)
batch_size = scores.shape[0]
sample_indices = torch.multinomial(weights, batch_size, replacement=True)
resampled_batch = {key: tensor[sample_indices] for key, tensor in data.batch.items()}
sample_indices_np = sample_indices.numpy()
resampled_non_tensor_batch = {}
for key, array in data.non_tensor_batch.items():
if isinstance(array, np.ndarray):
resampled_non_tensor_batch[key] = array[sample_indices_np]
else:
resampled_non_tensor_batch[key] = [array[i] for i in sample_indices_np]
resampled_meta_info = {}
for key, value in data.meta_info.items():
if isinstance(value, list) and len(value) == batch_size:
resampled_meta_info[key] = [value[i] for i in sample_indices_np]
else:
resampled_meta_info[key] = value
from copy import deepcopy
resampled_data = deepcopy(data)
resampled_data.batch = type(data.batch)(resampled_batch)
resampled_data.batch.batch_size = data.batch.batch_size
resampled_data.non_tensor_batch = resampled_non_tensor_batch
resampled_data.meta_info = resampled_meta_info
return resampled_data