Extend to other RL(HF) algorithms
We already implemented the complete training pipeline of the PPO algorithms. To extend to other algorithms, we analyze the high-level principle to use verl and provide a tutorial to implement the DPO algorithm. Users can follow the similar paradigm to extend to other RL algorithms.
Note
Key ideas: Single process drives multi-process computation and data communication.
Overall Approach
Step 1: Consider what multi-machine multi-GPU computations are needed
for each model, such as generate_sequence
, compute_log_prob
and
update_policy
in the actor_rollout model. Implement distributed
single-process-multiple-data (SPMD) computation and encapsulate them
into APIs
Step 2: Based on different distributed scenarios, including FSDP and 3D parallelism in Megatron-LM, implement single-process control of data interaction among multi-process computations.
Step 3: Utilize the encapsulated APIs to implement the control flow
Example: Online DPO
We use verl to implement a simple online DPO algorithm. The algorithm flow of Online DPO is as follows:
There is a prompt (rollout) generator which has the same weight as the actor model. After a batch of prompts are fed into the generator, it generates N responses for each prompt.
Send all the prompts + responses to a verifier for scoring, which can be reward model or a rule-based function. Then sort them in pairs to form a training batch.
Use this training batch to train the actor model using DPO. During the process, a reference policy is needed.
Step 1: What are the multi-machine multi-GPU computations
Sample Generator
Implementation details:
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool
import ray
@ray.remote
class SampleGenerator(Worker):
def __init__(self, config):
super().__init__()
self.config = config
def generate_sequences(self, data):
pass
Here, SampleGenerator
can be viewed as a multi-process pulled up by
torchrun
, with each process running the same code (SPMD).
SampleGenerator
needs to implement a generate_sequences
API for
the control flow to call. The implementation details inside can use any
inference engine including vllm, sglang and huggingface. Users can
largely reuse the code in
verl/verl/workers/rollout/vllm_rollout/vllm_rollout.py and we won’t
go into details here.
ReferencePolicy inference
API: compute reference log probability
from verl.single_controller.base import Worker
import ray
@ray.remote
class ReferencePolicy(Worker):
def __init__(self):
super().__init__()
self.model = Model()
def infer(self, data):
return self.model(data)
Actor update
API: Update actor model parameters
from verl.single_controller.base import Worker
import ray
@ray.remote
class DPOActor(Worker):
def __init__(self):
super().__init__()
self.model = Model()
self.model = FSDP(self.model) # or other distributed strategy
self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
self.loss_fn = xxx
def update(self, data):
self.optimizer.zero_grad()
logits = self.model(data)
loss = self.loss_fn(logits)
loss.backward()
self.optimizer.step()
Notes: How to distinguish between control processes and distributed computation processes
Control processes are generally functions directly decorated with
@ray.remote
Computation processes are all wrapped into a
RayWorkerGroup
.
Users can reuse most of the distribtued computation logics implemented in PPO algorithm, including FSDP and Megatron-LM backend in verl/verl/trainer/ppo.
Step 2: Based on different distributed scenarios, implement single-process control of multi-process data interaction
The core problem to solve here is how a single process sends data to
multiple processes, drives multi-process computation, and how the
control process obtains the results of multi-process computation.
First, we initialize the multi-process WorkerGroup
in the control
process.
@ray.remote(num_cpus=1)
def main_task(config):
# construct SampleGenerator
resource_pool = RayResourcePool(process_on_nodes=[8] * 2) # 16 GPUs
ray_cls = RayClassWithInitArgs(SampleGenerator, config=config)
# put SampleGenerator onto resource pool
worker_group = RayWorkerGroup(resource_pool, ray_cls)
# construct reference policy
As we can see, in the control process, multiple processes are wrapped
into a RayWorkerGroup
. Inside this WorkerGroup
, there is a
self._workers
member, where each worker is a RayActor
(https://docs.ray.io/en/latest/ray-core/actors.html) of SampleGenerator.
ray_trainer.md also provide an implementation of
MegatronRayWorkerGroup
.
Assuming the model is distributed using FSDP, and there is a batch of data on the control process, for data parallelism, the underlying calling process is:
data = xxx
data_list = data.chunk(dp_size)
output = []
for d in data_list:
# worker_group._workers[i] is a SampleGenerator
output.append(worker_group._workers[i].generate_sequences.remote(d))
output = ray.get(output)
output = torch.cat(output)
Single process calling multiple processes involves the following 3 steps:
Split the data into DP parts on the control process.
Send the data to remote, call the remote computation through RPC, and utilize multi-process computation.
Obtain the computation results of each worker on the control process and merge them.
Frequently calling these 3 steps on the controller process greatly hurts code readability. In verl, we have abstracted and encapsulated these 3 steps, so that the worker’s method + dispatch + collect can be registered into the worker_group
from verl.single_controller.base.decorator import register
def dispatch_data(worker_group, data):
return data.chunk(worker_group.world_size)
def collect_data(worker_group, data):
return torch.cat(data)
dispatch_mode = {
'dispatch_fn': dispatch_data,
'collect_fn': collect_data
}
@register(dispatch_mode=dispatch_mode)
def generate_sequences(self, data):
pass
In this way, we can directly call the method inside the worker through
the worker_group
on the control (driver) process (which is a single
process):
output = worker_group.generate_sequences(data)
This single line includes data splitting, data distribution and computation, and data collection.
Furthermore, the model parallelism size of each model is usually fixed, including dp, tp, pp. So for these common distributed scenarios, we have pre-implemented specific dispatch and collect methods,in decorator.py, which can be directly used to wrap the computations.
from verl.single_controller.base.decorator import register, Dispatch
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, data: DataProto) -> DataProto:
pass
Here it requires the data interface to be DataProto
. Definition of
DataProto
is in protocol.py.
Step 3: Main training loop
With the above training flows, we can implement the algorithm’s control
flow. It is recommended that main_task
is also a ray remote process.
@ray.remote(num_cpus=1)
def main_task(config):
# construct SampleGenerator
resource_pool = RayResourcePool(process_on_nodes=[8] * 2) # 16 GPUs
ray_cls = RayClassWithInitArgs(SampleGenerator, config=config)
# put SampleGenerator onto resource pool
sample_gen = RayWorkerGroup(resource_pool, ray_cls)
# construct reference policy
ray_cls = RayClassWithInitArgs(ReferencePolicy)
ref_policy = RayWorkerGroup(resource_pool, ray_cls)
# construct actor
ray_cls = RayClassWithInitArgs(DPOActor)
dpo_policy = RayWorkerGroup(resource_pool, ray_cls)
dataloader = DataLoader()
for data in dataloader:
# generate data
data = sample_gen.generate_sequences(data)
# generate scores for each data
data = generate_scores(data)
# generate pairwise data using scores
data = generate_pairwise_data(data)
# generate ref_log_prob
data.batch['ref_log_prob'] = ref_policy.infer(data)
# update using dpo
dpo_policy.update(data)
# logging
Here, different WorkerGroups
can be placed in the same resource pool or
in different resource pools using create_colocated_worker_cls
similar as in ray_trainer.py.