Data interface
DataProto is the interface for data exchange.
The verl.DataProto class contains two key members:
batch: a
tensordict.TensorDictobject for the actual datameta_info: a
Dictwith additional meta information
TensorDict
DataProto.batch is built on top of tensordict, a project in the PyTorch ecosystem.
A TensorDict is a dict-like container for tensors. To instantiate a TensorDict, you must specify key-value pairs as well as the batch size.
>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict({"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 5)}, batch_size=[2,])
>>> tensordict["twos"] = 2 * torch.ones(2, 5, 6)
>>> zeros = tensordict["zeros"]
>>> tensordict
TensorDict(
fields={
ones: Tensor(shape=torch.Size([2, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
twos: Tensor(shape=torch.Size([2, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False),
zeros: Tensor(shape=torch.Size([2, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
One can also index a tensordict along its batch_size. The contents of the TensorDict can be manipulated collectively as well.
>>> tensordict[..., :1]
TensorDict(
fields={
ones: Tensor(shape=torch.Size([1, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
twos: Tensor(shape=torch.Size([1, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False),
zeros: Tensor(shape=torch.Size([1, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([1]),
device=None,
is_shared=False)
>>> tensordict = tensordict.to("cuda:0")
>>> tensordict = tensordict.reshape(6)
For more about tensordict.TensorDict usage, see the official tensordict documentation.
Core APIs
- class verl.DataProto(batch: ~tensordict._td.TensorDict = None, non_tensor_batch: ~typing.Dict = <factory>, meta_info: ~typing.Dict = <factory>)
A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/. TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the same batch size should be put inside batch.
- static concat(data: List[DataProto]) DataProto
Concat a list of DataProto. The batch is concatenated among dim=0. The meta_info is assumed to be identical and will use the first one.
- Args:
data (List[DataProto]): list of DataProto
- Returns:
DataProto: concatenated DataProto
- make_iterator(mini_batch_size, epochs, seed=None, dataloader_kwargs=None)
Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details.
- Args:
mini_batch_size (int): mini-batch size when iterating the dataset. We require that
batch.batch_size[0] % mini_batch_size == 0. epochs (int): number of epochs when iterating the dataset. dataloader_kwargs (Any): internally, it returns a DataLoader over the batch. The dataloader_kwargs is the kwargs passed to the DataLoader.- Returns:
Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is
self.batch.batch_size * epochs // mini_batch_size
- select(batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) DataProto
Select a subset of the DataProto via batch_keys and meta_info_keys
- Args:
batch_keys (list, optional): a list of strings indicating the keys in batch to select meta_info_keys (list, optional): a list of keys indicating the meta info to select
- Returns:
DataProto: the DataProto with the selected batch_keys and meta_info_keys
- to(device) DataProto
move the batch to device
- Args:
device (torch.device, str): torch device
- Returns:
DataProto: the current DataProto
- union(other: DataProto) DataProto
Union with another DataProto. Union batch and meta_info separately. Throw an error if
there are conflict keys in batch and they are not equal
the batch size of two data batch is not the same
there are conflict keys in meta_info and they are not the same.
- Args:
other (DataProto): another DataProto to union
- Returns:
DataProto: the DataProto after union