PPO Ray Trainer

We implement the RayPPOTrainer, which is a trainer runs on the driver process on a single CPU/GPU node (default is CPU).

The PPORayTrainer include 3 core functions for data preparation, WorkerGroup initialization and PPO training loop.

Data Preparation

The PPORayTrainer, as a single process, is responsible for loading a complete batch of samples (prompts) from the dataset and then dispatch to different worker_groups running on different GPUs.

To generalize the data loading, we implement the RLHFDataset class to load the preprocessed parquet files, apply chat templates to the prompts, add padding, truncate prompts that exceed max prompt length and then tokenize.

self.train_dataset = RLHFDataset(data_files=self.config.data.train_files,
                                    tokenizer=self.tokenizer,
                                    config=self.config.data)

Then, the dataloader will iterate the dataset under PPO mini batch size.

WorkerGroup Initialization

We first introduce a basic implementation of initializing the WorkerGroup of the actor model on a given set of GPUs.

# max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool
# For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.
# For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models
resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes,
                                use_gpu=True,
                                max_colocate_count=1)
# define actor rollout cls to be init on remote
actor_rollout_cls = RayClassWithInitArgs(cls=ActorRolloutWorker)
# define actor_rollout worker group
actor_rollout_worker_group = MegatronRayWorkerGroup(resource_pool=resource_pool,
                                                    ray_cls_with_init=actor_rollout_cls,
                                                    default_megatron_kwargs=config.actor_rollout.megatron)

Different WorkerGroups, like actor_rollout_worker_group , critic_worker_group and ref_worker_group lies on a separate process in the above implementation.

The driver process can then call the distributed compute function within the actor_rollout_worker_group and other roles to construct the RL training loop.

For models colocated in the same set of GPUs, we further provide a fine-grain optimization, which merge the worker_group of different roles in the same process. This optimization can save the redundant CUDA/distributed context in different processes.

# initialize WorkerGroup
# NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
# you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups.
# See TODO(url) for more information.
all_wg = {}
for resource_pool, class_dict in self.resource_pool_to_cls.items():
    worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
    wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
    spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
    all_wg.update(spawn_wg)

if self.use_critic:
    self.critic_wg = all_wg['critic']
    self.critic_wg.init_model()

if self.use_reference_policy:
    self.ref_policy_wg = all_wg['ref']
    self.ref_policy_wg.init_model()

if self.use_rm:
    self.rm_wg = all_wg['rm']
    self.rm_wg.init_model()

# we should create rollout at the end so that vllm can have a better estimation of kv cache memory
self.actor_rollout_wg = all_wg['actor_rollout']
self.actor_rollout_wg.init_model()

Note

For megatron backend, if we merge the worker_groups into the same processes, all the roles will utilize the same 3D parallel size. To optimize this, we may need to maintain several 3D process groups for each role in the same distributed context. If you want to use different 3D parallel size for different roles, please follow the similar architecture of the first code block to initialize each role’s worker_group

PPO Training Loop

We implement the PPO training loop by calling the functions in worker_group of each role. The input and output data of each function is a DataProto object implemented in protocol.py. In the training loop, trainer will dispatch/collect the data to/from different GPUs following the transfer protocols wrapped in the workers’ functions. The computation of PPO micro batches is processed in update_actor and update_critic functions.

To extend to other RLHF algorithms, such as DPO, GRPO, please refer to Extend to other RL(HF) algorithms.

def fit(self):
    """
    The training loop of PPO.
    The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
    The light-weight advantage computation is done on the driver process.
    """
    from verl.utils.tracking import Tracking
    from omegaconf import OmegaConf

    logger = Tracking(project_name=self.config.trainer.project_name,
                        experiment_name=self.config.trainer.experiment_name,
                        default_backend=self.config.trainer.logger,
                        config=OmegaConf.to_container(self.config, resolve=True))

    global_steps = 0

    # perform validation before training
    # currently, we only support validation using the reward_function.
    if self.val_reward_fn is not None:
        val_metrics = self._validate()
        pprint(f'Initial validation metrics: {val_metrics}')

    for epoch in range(self.config.trainer.total_epochs):
        for batch_dict in self.train_dataloader:
            metrics = {}

            batch: DataProto = DataProto.from_single_dict(batch_dict)
            # batch = batch.to('cuda')

            # pop those keys for generation
            gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])

            # generate a batch
            with Timer(name='gen', logger=None) as timer:
                gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
            metrics['timing/gen'] = timer.last

            batch = batch.union(gen_batch_output)

            if self.use_reference_policy:
                # compute reference log_prob
                with Timer(name='ref', logger=None) as timer:
                    ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
                    batch = batch.union(ref_log_prob)
                metrics['timing/ref'] = timer.last

            # compute values
            with Timer(name='values', logger=None) as timer:
                values = self.critic_wg.compute_values(batch)
                batch = batch.union(values)
            metrics['timing/values'] = timer.last

            with Timer(name='adv', logger=None) as timer:
                # compute scores. Support both model and function-based.
                # We first compute the scores using reward model. Then, we call reward_fn to combine
                # the results from reward model and rule-based results.
                if self.use_rm:
                    # we first compute reward model score
                    reward_tensor = self.rm_wg.compute_rm_score(batch)
                    batch = batch.union(reward_tensor)

                # we combine with rule-based rm
                reward_tensor = self.reward_fn(batch)
                batch.batch['token_level_scores'] = reward_tensor

                # compute rewards. apply_kl_penalty if available
                batch, kl_metrics = apply_kl_penalty(batch,
                                                        kl_ctrl=self.kl_ctrl_in_reward,
                                                        kl_penalty=self.config.algorithm.kl_penalty)
                metrics.update(kl_metrics)

                # compute advantages, executed on the driver process
                batch = compute_advantage(batch,
                                            self.config.algorithm.gamma,
                                            self.config.algorithm.lam,
                                            adv_estimator=self.config.algorithm.adv_estimator)
            metrics['timing/adv'] = timer.last

            # update critic
            if self.use_critic:
                with Timer(name='update_critic', logger=None) as timer:
                    critic_output = self.critic_wg.update_critic(batch)
                metrics['timing/update_critic'] = timer.last
                critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
                metrics.update(critic_output_metrics)

            # implement critic warmup
            if self.config.trainer.critic_warmup <= global_steps:
                # update actor
                with Timer(name='update_actor', logger=None) as timer:
                    actor_output = self.actor_rollout_wg.update_actor(batch)
                metrics['timing/update_actor'] = timer.last
                actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
                metrics.update(actor_output_metrics)

            # validate
            if self.val_reward_fn is not None and (global_steps + 1) % self.config.trainer.test_freq == 0:
                with Timer(name='testing', logger=None) as timer:
                    val_metrics: dict = self._validate()
                    val_metrics = {f'val/{key}': val for key, val in val_metrics.items()}
                metrics['timing/testing'] = timer.last
                metrics.update(val_metrics)

            # collect metrics
            data_metrics = compute_data_metrics(batch=batch)
            metrics.update(data_metrics)

            # TODO: make a canonical logger that supports various backend
            logger.log(data=metrics, step=global_steps)

            if self.config.trainer.save_freq > 0 and (global_steps + 1) % self.config.trainer.save_freq == 0:
                actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor',
                                                f'global_step_{global_steps}')
                actor_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'actor')
                self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path)

                if self.use_critic:
                    critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic',
                                                        f'global_step_{global_steps}')
                    critic_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'critic')
                    self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path)

            global_steps += 1

    # perform validation after training
    if self.val_reward_fn is not None:
        val_metrics = self._validate()
        pprint(f'Final validation metrics: {val_metrics}')