Source code for verl.trainer.ppo.core_algos

# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Core functions to implement PPO algorithms.
The function implemented in this file should be used by trainer with different distributed strategies to
implement PPO-like algorithms.
"""

__all__ = ["register_adv_est", "get_adv_estimator_fn", "AdvantageEstimator"]

from collections import defaultdict
from enum import Enum
from typing import Optional

import numpy as np
import torch

import verl.utils.torch_functional as verl_F
from verl.trainer.config import AlgoConfig

POLICY_LOSS_REGISTRY = {}


def register_policy_loss(name):
    """Register a policy loss function with the given name.

    Args:
        name (str): The name to register the policy loss function under.

    Returns:
        function: Decorator function that registers the policy loss function.
    """

    def decorator(func):
        POLICY_LOSS_REGISTRY[name] = func
        return func

    return decorator


def get_policy_loss_fn(name):
    """Get the policy loss with a given name.

    Args:
        name: `(str)`
            The name of the policy loss.

    Returns:
        `(callable)`: The policy loss function.
    """
    loss_name = name
    if loss_name not in POLICY_LOSS_REGISTRY:
        raise ValueError(
            f"Unsupported loss mode: {loss_name}. Supported modes are: {list(POLICY_LOSS_REGISTRY.keys())}"
        )
    return POLICY_LOSS_REGISTRY[loss_name]


ADV_ESTIMATOR_REGISTRY = {}


def register_adv_est(name_or_enum):
    """Decorator to register a advantage estimator function with a given name.

    Args:
        name_or_enum: `(str)` or `(AdvantageEstimator)`
            The name or enum of the advantage estimator.

    """

    def decorator(fn):
        name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum
        if name in ADV_ESTIMATOR_REGISTRY and ADV_ESTIMATOR_REGISTRY[name] != fn:
            raise ValueError(
                f"Adv estimator {name} has already been registered: {ADV_ESTIMATOR_REGISTRY[name]} vs {fn}"
            )
        ADV_ESTIMATOR_REGISTRY[name] = fn
        return fn

    return decorator


def get_adv_estimator_fn(name_or_enum):
    """Get the advantage estimator function with a given name.

    Args:
        name_or_enum: `(str)` or `(AdvantageEstimator)`
            The name or enum of the advantage estimator.

    Returns:
        `(callable)`: The advantage estimator function.
    """
    name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum
    if name not in ADV_ESTIMATOR_REGISTRY:
        raise ValueError(f"Unknown advantage estimator simply: {name}")
    return ADV_ESTIMATOR_REGISTRY[name]


class AdvantageEstimator(str, Enum):
    """Using an enumeration class to avoid spelling errors in adv_estimator.

    Note(haibin.lin): this enum class is immutable after creation. Extending this
    enum for new estimators may not be necessary since users can always just call
    `verl.trainer.ppo.core_algos.register` with string name for a custom advantage
    estimator instead.
    """

    GAE = "gae"
    GRPO = "grpo"
    REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
    REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
    REMAX = "remax"
    RLOO = "rloo"
    OPO = "opo"
    GRPO_PASSK = "grpo_passk"
    GPG = "gpg"


class AdaptiveKLController:
    """
    Adaptive KL controller described in the paper:
    https://arxiv.org/pdf/1909.08593.pdf
    """

    def __init__(self, init_kl_coef, target_kl, horizon):
        self.value = init_kl_coef
        self.target = target_kl
        self.horizon = horizon

    def update(self, current_kl, n_steps):
        """Update the KL coefficient based on current KL divergence.

        Args:
            current_kl (float): Current KL divergence value.
            n_steps (int): Number of steps taken.
        """
        target = self.target
        proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
        mult = 1 + proportional_error * n_steps / self.horizon
        self.value *= mult


class FixedKLController:
    """Fixed KL controller."""

    def __init__(self, kl_coef):
        self.value = kl_coef

    def update(self, current_kl, n_steps):
        """Update method for fixed KL controller (no-op).

        Args:
            current_kl (float): Current KL divergence value (unused).
            n_steps (int): Number of steps taken (unused).
        """
        pass


def get_kl_controller(kl_ctrl):
    """Factory function to create appropriate KL controller based on configuration.

    Args:
        kl_ctrl: Configuration object containing KL controller settings.

    Returns:
        KL controller instance (FixedKLController or AdaptiveKLController).

    Raises:
        NotImplementedError: If controller type is not supported.
        AssertionError: If adaptive controller horizon is not positive.
    """
    if kl_ctrl.type == "fixed":
        return FixedKLController(kl_coef=kl_ctrl.kl_coef)
    elif kl_ctrl.type == "adaptive":
        assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}"
        return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)
    else:
        raise NotImplementedError


@register_adv_est(AdvantageEstimator.GAE)  # or simply: @register_adv_est("gae")
def compute_gae_advantage_return(
    token_level_rewards: torch.Tensor,
    values: torch.Tensor,
    response_mask: torch.Tensor,
    gamma: torch.Tensor,
    lam: torch.Tensor,
):
    """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py

    Args:
        token_level_rewards: `(torch.Tensor)`
            shape is (bs, response_length)
        values: `(torch.Tensor)`
            shape is (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
        gamma is `(float)`
            discounted factor used in RL
        lam: `(float)`
            lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)

    """
    with torch.no_grad():
        nextvalues = 0
        lastgaelam = 0
        advantages_reversed = []
        gen_len = token_level_rewards.shape[-1]

        for t in reversed(range(gen_len)):
            delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
            lastgaelam_ = delta + gamma * lam * lastgaelam

            # skip values and TD-error on observation tokens
            nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues
            lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam

            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)

        returns = advantages + values
        advantages = verl_F.masked_whiten(advantages, response_mask)
    return advantages, returns


# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
@register_adv_est(AdvantageEstimator.GRPO)  # or simply: @register_adv_est("grpo")
def compute_grpo_outcome_advantage(
    token_level_rewards: torch.Tensor,
    response_mask: torch.Tensor,
    index: np.ndarray,
    epsilon: float = 1e-6,
    norm_adv_by_std_in_grpo: bool = True,
    config: Optional[AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Compute advantage for GRPO, operating only on Outcome reward
    (with only one scalar reward for each response).

    Args:
        token_level_rewards: `(torch.Tensor)`
            shape is (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape is (bs, response_length)
        index: `(np.ndarray)`
            index array for grouping
        epsilon: `(float)`
            small value to avoid division by zero
        norm_adv_by_std_in_grpo: `(bool)`
            whether to scale the GRPO advantage
        config: `(Optional[AlgoConfig])`
            algorithm configuration object

    Note:
        If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO.
        If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).

    Returns:
        advantages: `(torch.Tensor)`
            shape is (bs, response_length)
        Returns: `(torch.Tensor)`
            shape is (bs, response_length)
    """
    scores = token_level_rewards.sum(dim=-1)

    id2score = defaultdict(list)
    id2mean = {}
    id2std = {}

    with torch.no_grad():
        bsz = scores.shape[0]
        for i in range(bsz):
            id2score[index[i]].append(scores[i])
        for idx in id2score:
            if len(id2score[idx]) == 1:
                id2mean[idx] = torch.tensor(0.0)
                id2std[idx] = torch.tensor(1.0)
            elif len(id2score[idx]) > 1:
                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
                id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
            else:
                raise ValueError(f"no score in prompt index: {idx}")
        for i in range(bsz):
            if norm_adv_by_std_in_grpo:
                scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
            else:
                scores[i] = scores[i] - id2mean[index[i]]
        scores = scores.unsqueeze(-1) * response_mask

    return scores, scores


@register_adv_est(AdvantageEstimator.GRPO_PASSK)  # or simply: @register_adv_est("grpo_passk")
def compute_grpo_passk_outcome_advantage(
    token_level_rewards: torch.Tensor,
    response_mask: torch.Tensor,
    index: np.ndarray,
    epsilon: float = 1e-6,
    norm_adv_by_std_in_grpo: bool = True,
    config: Optional[AlgoConfig] = None,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Compute advantage for Pass@k using a GRPO-style outcome reward formulation.
    Only the best response per group gets a non-zero advantage: r_max - r_second_max.

    Implemented as described in https://arxiv.org/abs/2503.19595.

    Args:
        token_level_rewards: (bs, response_length)
        response_mask: (bs, response_length)
        index: (bs,) → group ID per sample
        epsilon: float for numerical stability
        config: (AlgoConfig) algorithm settings, which contains "norm_adv_by_std_in_grpo"

    Returns:
        advantages: (bs, response_length)
        returns: (bs, response_length)
    """
    assert config is not None
    # if True, normalize advantage by std within group
    norm_adv_by_std_in_grpo = config.get("norm_adv_by_std_in_grpo", True)
    scores = token_level_rewards.sum(dim=-1)  # (bs,)
    advantages = torch.zeros_like(scores)

    id2scores = defaultdict(list)
    id2indices = defaultdict(list)

    with torch.no_grad():
        bsz = scores.shape[0]
        for i in range(bsz):
            idx = index[i]
            id2scores[idx].append(scores[i])
            id2indices[idx].append(i)

        for idx in id2scores:
            rewards = torch.stack(id2scores[idx])  # (k,)
            if rewards.numel() < 2:
                raise ValueError(
                    f"Pass@k requires at least 2 samples per group. Got {rewards.numel()} for group {idx}."
                )
            topk, topk_idx = torch.topk(rewards, 2)
            r_max, r_second_max = topk[0], topk[1]
            i_max = id2indices[idx][topk_idx[0].item()]
            advantage = r_max - r_second_max
            if norm_adv_by_std_in_grpo:
                std = torch.std(rewards)
                advantage = advantage / (std + epsilon)
            advantages[i_max] = advantage

    advantages = advantages.unsqueeze(-1) * response_mask
    return advantages, advantages


@register_adv_est(
    AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE
)  # or simply: @register_adv_est("reinforce_plus_plus_baseline")
def compute_reinforce_plus_plus_baseline_outcome_advantage(
    token_level_rewards: torch.Tensor,
    response_mask: torch.Tensor,
    index: torch.Tensor,
    epsilon: float = 1e-6,
    config: Optional[AlgoConfig] = None,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward
    (with only one scalar reward for each response).

    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape: (bs, response_length)
        config: (AlgoConfig) algorithm config

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)
    """
    response_length = token_level_rewards.shape[-1]
    scores = token_level_rewards.sum(dim=-1)

    id2score = defaultdict(list)
    id2mean = {}

    with torch.no_grad():
        bsz = scores.shape[0]
        for i in range(bsz):
            id2score[index[i]].append(scores[i])
        for idx in id2score:
            if len(id2score[idx]) == 1:
                id2mean[idx] = torch.tensor(0.0)
            elif len(id2score[idx]) > 1:
                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
            else:
                raise ValueError(f"no score in prompt index: {idx}")
        for i in range(bsz):
            scores[i] = scores[i] - id2mean[index[i]]

        scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask
        scores = verl_F.masked_whiten(scores, response_mask) * response_mask

    return scores, scores


@register_adv_est(AdvantageEstimator.RLOO)  # or simply: @register_adv_est("rloo")
def compute_rloo_outcome_advantage(
    token_level_rewards: torch.Tensor,
    response_mask: torch.Tensor,
    index: np.ndarray,
    epsilon: float = 1e-6,
    config: Optional[AlgoConfig] = None,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740

    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape: (bs, response_length)
        config: (AlgoConfig) algorithm config

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)
    """
    scores = token_level_rewards.sum(dim=-1)

    id2score = defaultdict(list)
    id2mean = {}

    with torch.no_grad():
        bsz = scores.shape[0]
        for i in range(bsz):
            id2score[index[i]].append(scores[i])
        for idx in id2score:
            if len(id2score[idx]) == 1:
                id2mean[idx] = torch.tensor(0.0)
            elif len(id2score[idx]) > 1:
                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
            else:
                raise ValueError(f"no score in prompt index: {idx}")
        for i in range(bsz):
            response_num = len(id2score[index[i]])
            if response_num > 1:
                scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (
                    response_num - 1
                )
        scores = scores.unsqueeze(-1) * response_mask

    return scores, scores


@register_adv_est(AdvantageEstimator.OPO)  # or simply: @register_adv_est("opo")
def compute_opo_outcome_advantage(
    token_level_rewards: torch.Tensor,
    response_mask: torch.Tensor,
    index: np.ndarray,
    epsilon: float = 1e-6,
    config: Optional[AlgoConfig] = None,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Compute advantage for OPO based on https://arxiv.org/pdf/2505.23585

    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape: (bs, response_length)
        config: (AlgoConfig) algorithm config

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)
    """
    response_length = response_mask.sum(dim=-1)
    scores = token_level_rewards.sum(dim=-1)

    id2score = defaultdict(list)
    id2len = defaultdict(list)
    id2bsl = {}

    with torch.no_grad():
        bsz = scores.shape[0]
        for i in range(bsz):
            id2score[index[i]].append(scores[i])
            id2len[index[i]].append(response_length[i])

        for idx in id2score:
            if len(id2score[idx]) == 1:
                id2bsl[idx] = torch.tensor(0.0)
            elif len(id2score[idx]) > 1:
                score_tensor = torch.tensor(id2score[idx])
                len_tensor = torch.tensor(id2len[idx])
                id2bsl[idx] = (len_tensor * score_tensor).sum() / len_tensor.sum()
            else:
                raise ValueError(f"no score in prompt index: {idx}")
        for i in range(bsz):
            scores[i] = scores[i] - id2bsl[index[i]]
        scores = scores.unsqueeze(-1) * response_mask

    return scores, scores


@register_adv_est(AdvantageEstimator.REINFORCE_PLUS_PLUS)  # or simply: @register_adv_est("reinforce_plus_plus")
def compute_reinforce_plus_plus_outcome_advantage(
    token_level_rewards: torch.Tensor, response_mask: torch.Tensor, config: Optional[AlgoConfig] = None, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Compute advantage for REINFORCE++.
    This implementation is based on the paper: https://arxiv.org/abs/2501.03262

    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape: (bs, response_length)
        config: (AlgoConfig) algorithm config

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)
    """
    assert config is not None
    gamma = config.gamma
    with torch.no_grad():
        returns = torch.zeros_like(token_level_rewards)
        running_return = 0

        for t in reversed(range(token_level_rewards.shape[1])):
            running_return = token_level_rewards[:, t] + gamma * running_return
            returns[:, t] = running_return
            # Reset after EOS
            running_return = running_return * response_mask[:, t]

        advantages = verl_F.masked_whiten(returns, response_mask)
        advantages = advantages * response_mask

    return advantages, returns


@register_adv_est(AdvantageEstimator.REMAX)  # or simply: @register_adv_est("remax")
def compute_remax_outcome_advantage(
    token_level_rewards: torch.Tensor,
    reward_baselines: torch.Tensor,
    response_mask: torch.Tensor,
    config: Optional[AlgoConfig] = None,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Compute advantage for ReMax, operating only on Outcome reward
    This implementation is based on the paper: https://arxiv.org/abs/2310.10505
    (with only one scalar reward for each response).

    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        reward_baselines: `(torch.Tensor)`
            shape: (bs,)
        response_mask: `(torch.Tensor)`
            shape: (bs, response_length)
        config: (AlgoConfig) algorithm config

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)
    """

    with torch.no_grad():
        returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
        advantages = returns - reward_baselines.unsqueeze(-1) * response_mask

    return advantages, returns


@register_adv_est(AdvantageEstimator.GPG)  # or simply: @register_adv_est("gpg")
def compute_gpg_outcome_advantage(
    token_level_rewards: torch.Tensor,
    response_mask: torch.Tensor,
    index: np.ndarray,
    epsilon: float = 1e-6,
    f_norm: float = 1.0,
    alpha: float = 1.0,
    config=None,
    **kwargs,
):
    """
    Compute advantage for GPG, operating only on Outcome reward
    (with only one scalar reward for each response).
    Args:
        token_level_rewards: `(torch.Tensor)`
            shape: (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape: (bs, response_length)
        index: `(np.ndarray)`
            shape: (bs,)
        epsilon: (float)
        f_norm: (float)
        alpha: (float)
        config: (dict) algorithm config

    Returns:
        advantages: `(torch.Tensor)`
            shape: (bs, response_length)
        Returns: `(torch.Tensor)`
            shape: (bs, response_length)
    """
    scores = token_level_rewards.sum(dim=-1)

    id2score = defaultdict(list)
    id2mean = {}
    id2std = {}

    with torch.no_grad():
        bsz = scores.shape[0]
        m = torch.count_nonzero(scores)
        alpha = bsz / m.clamp(min=1)

        for i in range(bsz):
            id2score[index[i]].append(scores[i])

        for idx in id2score:
            if len(id2score[idx]) == 1:
                id2mean[idx] = torch.tensor(0.0)
                id2std[idx] = torch.tensor(1.0)
            elif len(id2score[idx]) > 1:
                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
                id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
            else:
                raise ValueError(f"no score in prompt index: {idx}")
        for i in range(bsz):
            scores[i] = alpha * (scores[i] - id2mean[index[i]]) / (f_norm)
        scores = scores.unsqueeze(-1) * response_mask

    return scores, scores


def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
    """Compute token-level rewards with KL penalty.

    Args:
        token_level_scores (torch.Tensor): Token-level reward scores.
        old_log_prob (torch.Tensor): Log probabilities from current policy.
        ref_log_prob (torch.Tensor): Log probabilities from reference policy.
        kl_ratio (float): KL penalty coefficient.

    Returns:
        torch.Tensor: Token-level rewards with KL penalty applied.
    """
    kl = old_log_prob - ref_log_prob
    return token_level_scores - kl * kl_ratio


[docs] def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str): """ Aggregate the loss matrix into a scalar. Args: loss_mat: `(torch.Tensor)`: shape: (bs, response_length) loss_mask: `(torch.Tensor)`: shape: (bs, response_length) loss_agg_mode: (str) choices: method to aggregate the loss matrix into a scalar. Returns: loss: `a scalar torch.Tensor` aggregated loss """ if loss_agg_mode == "token-mean": loss = verl_F.masked_mean(loss_mat, loss_mask) elif loss_agg_mode == "seq-mean-token-sum": seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum loss = torch.mean(seq_losses) # seq-mean elif loss_agg_mode == "seq-mean-token-mean": seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean loss = torch.mean(seq_losses) # seq-mean elif loss_agg_mode == "seq-mean-token-sum-norm": seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) loss = torch.sum(seq_losses) / loss_mask.shape[-1] # The divisor # (loss_mask.shape[-1]) should ideally be constant # throughout training to well-replicate the DrGRPO paper. # TODO: Perhaps add user-defined normalizer argument to # agg_loss to ensure divisor stays constant throughout. else: raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") return loss
[docs] def compute_policy_loss( old_log_prob, log_prob, advantages, response_mask, cliprange=None, cliprange_low=None, cliprange_high=None, clip_ratio_c=3.0, loss_agg_mode: str = "token-mean", ): """ Compute the clipped policy objective and related metrics for PPO. Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 Args: old_log_prob (torch.Tensor): Log-probabilities of actions under the old policy, shape (batch_size, response_length). log_prob (torch.Tensor): Log-probabilities of actions under the current policy, shape (batch_size, response_length). advantages (torch.Tensor): Advantage estimates for each action, shape (batch_size, response_length). response_mask (torch.Tensor): Mask indicating which tokens to include in the loss, shape (batch_size, response_length). cliprange (float, optional): Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. Defaults to None (must be provided). cliprange_low (float, optional): Lower clip range for dual-clip PPO. Defaults to same as `cliprange`. cliprange_high (float, optional): Upper clip range for dual-clip PPO. Defaults to same as `cliprange`. clip_ratio_c (float, optional): Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729. Defaults to 3.0. loss_agg_mode (str, optional): Aggregation mode for `agg_loss`. Defaults to "token-mean". """ assert clip_ratio_c > 1.0, ( "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + f" but get the value: {clip_ratio_c}." ) negative_approx_kl = log_prob - old_log_prob # Clamp negative_approx_kl for stability negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) ratio = torch.exp(negative_approx_kl) ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) pg_losses1 = -advantages * ratio if cliprange_low is None: cliprange_low = cliprange if cliprange_high is None: cliprange_high = cliprange pg_losses2 = -advantages * torch.clamp( ratio, 1 - cliprange_low, 1 + cliprange_high ) # - clip(ratio, 1-cliprange, 1+cliprange) * A clip_pg_losses1 = torch.maximum( pg_losses1, pg_losses2 ) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A) pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) pg_losses3 = -advantages * clip_ratio_c clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) pg_clipfrac_lower = verl_F.masked_mean( torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask ) pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
@register_policy_loss("gpg") def compute_policy_loss_gpg(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode="token-mean", config=None): """Adapted from https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495 Args: log_prob: `(torch.Tensor)` shape: (bs, response_length) advantages: `(torch.Tensor)` shape: (bs, response_length) response_mask: `(torch.Tensor)` shape: (bs, response_length) return: pg_loss: `a scalar torch.Tensor` policy gradient loss computed via GPG """ pg_losses = -log_prob * advantages pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) return pg_loss, torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0) @register_policy_loss("clip_cov") def compute_policy_loss_clip_cov( old_log_prob: torch.Tensor, log_prob: torch.Tensor, advantages: torch.Tensor, response_mask: torch.Tensor, loss_agg_mode: str = "token-mean", config: Optional[AlgoConfig] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute the clipped policy objective and related metrics for Clip-Cov. Adapted from https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py Args: old_log_prob (torch.Tensor): Log-probabilities of actions under the old policy, shape (batch_size, response_length). log_prob (torch.Tensor): Log-probabilities of actions under the current policy, shape (batch_size, response_length). advantages (torch.Tensor): Advantage estimates for each action, shape (batch_size, response_length). response_mask (torch.Tensor): Mask indicating which tokens to include in the loss, shape (batch_size, response_length). cliprange (float, optional): Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. Defaults to None (must be provided). cliprange_low (float, optional): Lower clip range for dual-clip PPO. Defaults to same as `cliprange`. cliprange_high (float, optional): Upper clip range for dual-clip PPO. Defaults to same as `cliprange`. loss_agg_mode (str, optional): Aggregation mode for `agg_loss`. Defaults to "token-mean". clip_cvo_ratio (float, optional): Ratio for clipping the covariance. Defaults to 0.0002. clip_cov_lb (float, optional): Lower bound for clipping covariance. Defaults to 1.0. clip_cov_ub (float, optional): Upper bound for clipping covariance. Defaults to 5.0. """ clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002 cliprange = config.clip_ratio cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange cliprange_high = config.clip_ratio_high if config.clip_ratio_high is not None else cliprange clip_cov_ub = config.policy_loss.clip_cov_ub if config.policy_loss.clip_cov_ub is not None else 5.0 clip_cov_lb = config.policy_loss.clip_cov_lb if config.policy_loss.clip_cov_lb is not None else 1.0 assert clip_cov_ratio > 0, "clip_ratio should be larger than 0." negative_approx_kl = log_prob - old_log_prob ratio = torch.exp(negative_approx_kl) ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) pg_losses1 = -advantages * ratio if cliprange_low is None: cliprange_low = cliprange if cliprange_high is None: cliprange_high = cliprange corr = torch.ones_like(advantages) pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) clip_by_origin = (pg_losses2 > pg_losses1) & (response_mask > 0) cov_all = (advantages - verl_F.masked_mean(advantages, response_mask)) * ( log_prob - verl_F.masked_mean(log_prob.detach(), response_mask) ) cov_all[response_mask == 0] = -torch.inf cov_all[clip_by_origin] = -torch.inf clip_num = max(int(clip_cov_ratio * response_mask.sum().item()), 1) top_k_idx = (cov_all < clip_cov_ub) & (cov_all > clip_cov_lb) & (response_mask > 0) top_k_idx = torch.nonzero(top_k_idx) if len(top_k_idx) > 0: perm = torch.randperm(len(top_k_idx)) top_k_idx = top_k_idx[perm[: min(clip_num, len(top_k_idx))]] else: top_k_idx = torch.empty((0, 2), device=cov_all.device, dtype=torch.long) corr[top_k_idx[:, 0], top_k_idx[:, 1]] = 0 pg_clipfrac = verl_F.masked_mean((corr == 0).float(), response_mask) pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0) @register_policy_loss("kl_cov") def compute_policy_loss_kl_cov( old_log_prob: torch.Tensor, log_prob: torch.Tensor, advantages: torch.Tensor, response_mask: torch.Tensor, loss_agg_mode: str = "token-mean", config: Optional[AlgoConfig] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute the clipped policy objective and related metrics for Clip-Cov. Adapted from https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py Args: old_log_prob (torch.Tensor): Log-probabilities of actions under the old policy, shape (batch_size, response_length). log_prob (torch.Tensor): Log-probabilities of actions under the current policy, shape (batch_size, response_length). advantages (torch.Tensor): Advantage estimates for each action, shape (batch_size, response_length). response_mask (torch.Tensor): Mask indicating which tokens to include in the loss, shape (batch_size, response_length). loss_agg_mode (str, optional): Aggregation mode for `agg_loss`. Defaults to "token-mean". kl_cov_ratio (float, optional): Ratio for selecting the top-k covariance values. Defaults to 0.0002. ppo_kl_coef (float, optional): Coefficient for the KL penalty term in the loss. Defaults to 1. """ kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002 ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0 assert kl_cov_ratio > 0, "kl_cov_ratio should be larger than 0." negative_approx_kl = log_prob - old_log_prob abs_kl = negative_approx_kl.abs() ratio = torch.exp(negative_approx_kl) ppo_kl_abs = verl_F.masked_mean(negative_approx_kl.abs(), response_mask) pg_losses1 = -advantages * ratio pg_losses_kl = -advantages * ratio + ppo_kl_coef * abs_kl pg_losses = pg_losses1 all_valid = response_mask > 0 all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0] all_valid_adv = advantages[all_valid].detach().reshape(-1).cpu() all_valid_logp = log_prob[all_valid].detach().reshape(-1).cpu() k = min(kl_cov_ratio, len(all_valid_adv)) if k != 0: cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean()) k_percent_nums = max(1, int(len(cov_lst_all) * kl_cov_ratio)) large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices if len(large_cov_idxs) != 0: large_cov_idxs = all_valid_idx[large_cov_idxs] pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[ large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1] ] pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0) def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean"): """Compute categorical entropy loss (For backward compatibility) Args: logits (torch.Tensor): shape is (bs, response_length, vocab_size) response_mask (torch.Tensor): shape is (bs, response_length) Returns: entropy: a scalar torch.Tensor """ # compute entropy token_entropy = verl_F.entropy_from_logits(logits) # (bs, response_len) entropy_loss = agg_loss(loss_mat=token_entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) return entropy_loss def compute_value_loss( vpreds: torch.Tensor, returns: torch.Tensor, values: torch.Tensor, response_mask: torch.Tensor, cliprange_value: float, loss_agg_mode: str = "token-mean", ): """ Compute the clipped value-function loss for PPO. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151 Args: vpreds (torch.FloatTensor): Predicted values from the value head, shape (batch_size, response_length). values (torch.FloatTensor): Old (baseline) values from the value head, shape (batch_size, response_length). returns (torch.FloatTensor): Ground-truth returns, shape (batch_size, response_length). response_mask (torch.Tensor): Mask indicating which tokens to include in the value loss calculation. cliprange_value (float): Clip range for value prediction updates. loss_agg_mode (str, optional): Aggregation mode for `agg_loss`. Defaults to "token-mean". Returns: vf_loss (torch.FloatTensor): A scalar tensor containing the aggregated value-function loss. vf_clipfrac (float): Fraction of elements where the clipped loss was used. """ vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value) vf_losses1 = (vpreds - returns) ** 2 vf_losses2 = (vpredclipped - returns) ** 2 clipped_vf_losses = torch.max(vf_losses1, vf_losses2) vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask) return vf_loss, vf_clipfrac
[docs] def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor: """Compute KL divergence given logprob and ref_logprob. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104 See more description in http://joschu.net/blog/kl-approx.html Args: logprob: ref_logprob: Returns: """ if kl_penalty in ("kl", "k1"): return logprob - ref_logprob if kl_penalty == "abs": return (logprob - ref_logprob).abs() if kl_penalty in ("mse", "k2"): return 0.5 * (logprob - ref_logprob).square() # J. Schulman. Approximating kl divergence, 2020. # # URL http://joschu.net/blog/kl-approx.html. if kl_penalty in ("low_var_kl", "k3"): kl = ref_logprob - logprob # For numerical stability kl = torch.clamp(kl, min=-20, max=20) ratio = torch.exp(kl) kld = (ratio - kl - 1).contiguous() return torch.clamp(kld, min=-10, max=10) if kl_penalty == "full": # so, here logprob and ref_logprob should contain the logits for every token in vocabulary raise NotImplementedError raise NotImplementedError
def compute_pf_ppo_reweight_data( data, reweight_method: str = "pow", weight_pow: float = 2.0, ): """Reweight the data based on the token_level_scores. Args: data: DataProto object, containing batch, non_tensor_batch and meta_info reweight_method: str, choices: "pow", "max_min", "max_random" weight_pow: float, the power of the weight Returns: """ @torch.no_grad() def compute_weights(scores: torch.Tensor, reweight_method: str, weight_pow: float) -> torch.Tensor: """Compute importance weights for resampling based on scores. Args: scores (torch.Tensor): Tensor of scores to compute weights from. reweight_method (str): Method for computing weights ('pow', 'max_min', 'max_random'). weight_pow (float): Power exponent for 'pow' method. Returns: torch.Tensor: Computed importance weights. Raises: ValueError: If reweight_method is not supported. """ if reweight_method == "pow": weights = torch.pow(torch.abs(scores), weight_pow) elif reweight_method == "max_min": max_score = torch.max(scores) min_score = torch.min(scores) weights = torch.where((scores == max_score) | (scores == min_score), 1.0, 0.0) elif reweight_method == "max_random": max_score = torch.max(scores) weights = torch.where(scores == max_score, 0.4, 0.1) else: raise ValueError(f"Unsupported reweight_method: {reweight_method}") return weights scores = data.batch["token_level_scores"].sum(dim=-1) weights = compute_weights(scores, reweight_method, weight_pow) weights = torch.clamp(weights + 1e-8, min=1e-8) batch_size = scores.shape[0] sample_indices = torch.multinomial(weights, batch_size, replacement=True) resampled_batch = {key: tensor[sample_indices] for key, tensor in data.batch.items()} sample_indices_np = sample_indices.numpy() resampled_non_tensor_batch = {} for key, array in data.non_tensor_batch.items(): if isinstance(array, np.ndarray): resampled_non_tensor_batch[key] = array[sample_indices_np] else: resampled_non_tensor_batch[key] = [array[i] for i in sample_indices_np] resampled_meta_info = {} for key, value in data.meta_info.items(): if isinstance(value, list) and len(value) == batch_size: resampled_meta_info[key] = [value[i] for i in sample_indices_np] else: resampled_meta_info[key] = value from copy import deepcopy resampled_data = deepcopy(data) resampled_data.batch = type(data.batch)(resampled_batch) resampled_data.batch.batch_size = data.batch.batch_size resampled_data.non_tensor_batch = resampled_non_tensor_batch resampled_data.meta_info = resampled_meta_info return resampled_data