Extend to other RL(HF) algorithms

We already implemented the complete training pipeline of the PPO algorithms. To extend to other algorithms, we analyze the high-level principle to use verl and provide a tutorial to implement the DPO algorithm. Users can follow the similar paradigm to extend to other RL algorithms.

Note

Key ideas: Single process drives multi-process computation and data communication.

Overall Approach

Step 1: Consider what multi-machine multi-GPU computations are needed for each model, such as generate_sequence , compute_log_prob and update_policy in the actor_rollout model. Implement distributed single-process-multiple-data (SPMD) computation and encapsulate them into APIs

Step 2: Based on different distributed scenarios, including FSDP and 3D parallelism in Megatron-LM, implement single-process control of data interaction among multi-process computations.

Step 3: Utilize the encapsulated APIs to implement the control flow

Example: Online DPO

We use verl to implement a simple online DPO algorithm. The algorithm flow of Online DPO is as follows:

  1. There is a prompt (rollout) generator which has the same weight as the actor model. After a batch of prompts are fed into the generator, it generates N responses for each prompt.

  2. Send all the prompts + responses to a verifier for scoring, which can be reward model or a rule-based function. Then sort them in pairs to form a training batch.

  3. Use this training batch to train the actor model using DPO. During the process, a reference policy is needed.

Step 1: What are the multi-machine multi-GPU computations

Sample Generator

Implementation details:

from verl.single_controller.base import Worker
from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool
import ray

@ray.remote
class SampleGenerator(Worker):
    def __init__(self, config):
        super().__init__()
        self.config = config

    def generate_sequences(self, data):
        pass

Here, SampleGenerator can be viewed as a multi-process pulled up by torchrun, with each process running the same code (SPMD). SampleGenerator needs to implement a generate_sequences API for the control flow to call. The implementation details inside can use any inference engine including vllm, sglang and huggingface. Users can largely reuse the code in verl/verl/workers/rollout/vllm_rollout/vllm_rollout.py and we won’t go into details here.

ReferencePolicy inference

API: compute reference log probability

from verl.single_controller.base import Worker
import ray

@ray.remote
class ReferencePolicy(Worker):
    def __init__(self):
        super().__init__()
        self.model = Model()

    def infer(self, data):
        return self.model(data)

Actor update

API: Update actor model parameters

from verl.single_controller.base import Worker
import ray

@ray.remote
class DPOActor(Worker):
    def __init__(self):
        super().__init__()
        self.model = Model()
        self.model = FSDP(self.model)  # or other distributed strategy
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
        self.loss_fn = xxx

    def update(self, data):
        self.optimizer.zero_grad()
        logits = self.model(data)
        loss = self.loss_fn(logits)
        loss.backward()
        self.optimizer.step()

Notes: How to distinguish between control processes and distributed computation processes

  • Control processes are generally functions directly decorated with @ray.remote

  • Computation processes are all wrapped into a RayWorkerGroup.

Users can reuse most of the distribtued computation logics implemented in PPO algorithm, including FSDP and Megatron-LM backend in verl/verl/trainer/ppo.

Step 2: Based on different distributed scenarios, implement single-process control of multi-process data interaction

The core problem to solve here is how a single process sends data to multiple processes, drives multi-process computation, and how the control process obtains the results of multi-process computation. First, we initialize the multi-process WorkerGroup in the control process.

@ray.remote(num_cpus=1)
def main_task(config):
    # construct SampleGenerator
    resource_pool = RayResourcePool(process_on_nodes=[8] * 2)  # 16 GPUs
    ray_cls = RayClassWithInitArgs(SampleGenerator, config=config)
    # put SampleGenerator onto resource pool
    worker_group = RayWorkerGroup(resource_pool, ray_cls)

    # construct reference policy

As we can see, in the control process, multiple processes are wrapped into a RayWorkerGroup. Inside this WorkerGroup, there is a self._workers member, where each worker is a RayActor (https://docs.ray.io/en/latest/ray-core/actors.html) of SampleGenerator. ray_trainer.md also provide an implementation of MegatronRayWorkerGroup.

Assuming the model is distributed using FSDP, and there is a batch of data on the control process, for data parallelism, the underlying calling process is:

data = xxx
data_list = data.chunk(dp_size)

output = []
for d in data_list:
    # worker_group._workers[i] is a SampleGenerator
    output.append(worker_group._workers[i].generate_sequences.remote(d))

output = ray.get(output)
output = torch.cat(output)

Single process calling multiple processes involves the following 3 steps:

  1. Split the data into DP parts on the control process.

  2. Send the data to remote, call the remote computation through RPC, and utilize multi-process computation.

  3. Obtain the computation results of each worker on the control process and merge them.

Frequently calling these 3 steps on the controller process greatly hurts code readability. In verl, we have abstracted and encapsulated these 3 steps, so that the worker’s method + dispatch + collect can be registered into the worker_group

from verl.single_controller.base.decorator import register

def dispatch_data(worker_group, data):
    return data.chunk(worker_group.world_size)

def collect_data(worker_group, data):
    return torch.cat(data)

dispatch_mode = {
    'dispatch_fn': dispatch_data,
    'collect_fn': collect_data
}

@register(dispatch_mode=dispatch_mode)
def generate_sequences(self, data):
    pass

In this way, we can directly call the method inside the worker through the worker_group on the control (driver) process (which is a single process):

output = worker_group.generate_sequences(data)

This single line includes data splitting, data distribution and computation, and data collection.

Furthermore, the model parallelism size of each model is usually fixed, including dp, tp, pp. So for these common distributed scenarios, we have pre-implemented specific dispatch and collect methods,in decorator.py, which can be directly used to wrap the computations.

from verl.single_controller.base.decorator import register, Dispatch

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, data: DataProto) -> DataProto:
    pass

Here it requires the data interface to be DataProto. Definition of DataProto is in protocol.py.

Step 3: Main training loop

With the above training flows, we can implement the algorithm’s control flow. It is recommended that main_task is also a ray remote process.

@ray.remote(num_cpus=1)
def main_task(config):
    # construct SampleGenerator
    resource_pool = RayResourcePool(process_on_nodes=[8] * 2)  # 16 GPUs
    ray_cls = RayClassWithInitArgs(SampleGenerator, config=config)
    # put SampleGenerator onto resource pool
    sample_gen = RayWorkerGroup(resource_pool, ray_cls)

    # construct reference policy
    ray_cls = RayClassWithInitArgs(ReferencePolicy)
    ref_policy = RayWorkerGroup(resource_pool, ray_cls)

    # construct actor
    ray_cls = RayClassWithInitArgs(DPOActor)
    dpo_policy = RayWorkerGroup(resource_pool, ray_cls)

    dataloader = DataLoader()

    for data in dataloader:
        # generate data
        data = sample_gen.generate_sequences(data)
        # generate scores for each data
        data = generate_scores(data)
        # generate pairwise data using scores
        data = generate_pairwise_data(data)
        # generate ref_log_prob
        data.batch['ref_log_prob'] = ref_policy.infer(data)
        # update using dpo
        dpo_policy.update(data)
        # logging

Here, different WorkerGroups can be placed in the same resource pool or in different resource pools using create_colocated_worker_cls similar as in ray_trainer.py.