Data interface

DataProto is the interface for data exchange.

The verl.DataProto class contains two key members:

  • batch: a tensordict.TensorDict object for the actual data

  • meta_info: a Dict with additional meta information

TensorDict

DataProto.batch is built on top of tensordict, a project in the PyTorch ecosystem. A TensorDict is a dict-like container for tensors. To instantiate a TensorDict, you must specify key-value pairs as well as the batch size.

>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict({"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 5)}, batch_size=[2,])
>>> tensordict["twos"] = 2 * torch.ones(2, 5, 6)
>>> zeros = tensordict["zeros"]
>>> tensordict
TensorDict(
fields={
    ones: Tensor(shape=torch.Size([2, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
    twos: Tensor(shape=torch.Size([2, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False),
    zeros: Tensor(shape=torch.Size([2, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)

One can also index a tensordict along its batch_size. The contents of the TensorDict can be manipulated collectively as well.

>>> tensordict[..., :1]
TensorDict(
fields={
    ones: Tensor(shape=torch.Size([1, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
    twos: Tensor(shape=torch.Size([1, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False),
    zeros: Tensor(shape=torch.Size([1, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([1]),
device=None,
is_shared=False)
>>> tensordict = tensordict.to("cuda:0")
>>> tensordict = tensordict.reshape(6)

For more about tensordict.TensorDict usage, see the official tensordict documentation.

Core APIs

class verl.DataProto(batch: ~tensordict._td.TensorDict = None, non_tensor_batch: ~typing.Dict = <factory>, meta_info: ~typing.Dict = <factory>)

A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/. TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the same batch size should be put inside batch.

static concat(data: List[DataProto]) DataProto

Concat a list of DataProto. The batch is concatenated among dim=0. The meta_info is assumed to be identical and will use the first one.

Args:

data (List[DataProto]): list of DataProto

Returns:

DataProto: concatenated DataProto

make_iterator(mini_batch_size, epochs, seed=None, dataloader_kwargs=None)

Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details.

Args:

mini_batch_size (int): mini-batch size when iterating the dataset. We require that batch.batch_size[0] % mini_batch_size == 0. epochs (int): number of epochs when iterating the dataset. dataloader_kwargs (Any): internally, it returns a DataLoader over the batch. The dataloader_kwargs is the kwargs passed to the DataLoader.

Returns:

Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is self.batch.batch_size * epochs // mini_batch_size

select(batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) DataProto

Select a subset of the DataProto via batch_keys and meta_info_keys

Args:

batch_keys (list, optional): a list of strings indicating the keys in batch to select meta_info_keys (list, optional): a list of keys indicating the meta info to select

Returns:

DataProto: the DataProto with the selected batch_keys and meta_info_keys

to(device) DataProto

move the batch to device

Args:

device (torch.device, str): torch device

Returns:

DataProto: the current DataProto

union(other: DataProto) DataProto

Union with another DataProto. Union batch and meta_info separately. Throw an error if

  • there are conflict keys in batch and they are not equal

  • the batch size of two data batch is not the same

  • there are conflict keys in meta_info and they are not the same.

Args:

other (DataProto): another DataProto to union

Returns:

DataProto: the DataProto after union