llm.datasets.sharded
Utilities for training with sharded datasets.
DistributedShardedDataset
¶
DistributedShardedDataset(
dataset_type: type[Dataset[SampleType]],
shard_params: dict[str, DatasetParams],
*,
rank: int,
world_size: int,
shuffle: bool = False,
seed: int = 0
)
Bases: Dataset[SampleType]
Dataset wrapper for sharded datasets in distributed environments.
This class manages a set of datasets (shards) and restricts ranks to viewing a subset of the global indices across the shards. This is achieved by sorting the shards and counting the samples in each shard to compute the total number of samples then chunking those samples by rank.
For example, if there are four ranks and eight shards of equal size, rank
zero will see shards zero and one, rank two will see shards two and three,
and so on. The length of an instance of this class as seen by a rank
will be (1 / world_size) * sum_of_samples_across_shards
.
This class also ensures only one shard is loaded at a time on a rank so the full dataset is never loaded into memory at once.
Warning
When building a DataLoader
from a
DistributedShardedDataset
,
do NOT use PyTorch's
DistributedSampler
.
If you want to be able to save the state of the data loader, use the
SequentialSampler
because this
class already provides the support for partitioning samples across
ranks. This module provides a
ResumableSequentialSampler
to enable resuming sampling from the last sampled index.
Note
Samples at the end of the last shard will be dropped to ensure each rank sees an equal number of samples.
Todo
- Next shard prefetching
- Sample index shuffling within a shard
- Support shuffle shard order by epoch
Parameters:
-
dataset_type
(type[Dataset[SampleType]]
) –Dataset type that represents a single shard. This subtype of Dataset must be a map-style dataset. Iterable-style datasets are not supported.
-
shard_params
(dict[str, DatasetParams]
) –Dictionary mapping shard keys to the parameters used to initialize a
dataset_type
for the shard. The parameter type is a tuple of args and kwargs. -
rank
(int
) –Rank of this process.
-
world_size
(int
) –Number of ranks sharing the dataset.
-
shuffle
(bool
, default:False
) –Shuffle the shard order by the shard keys. The default (
False
) sorts the shards by shard key. -
seed
(int
, default:0
) –Seed used for shuffling the shard order.
Source code in llm/datasets/sharded.py
rank_index_to_global_index
¶
Convert an index local to a rank to a global index.
rank_index_to_shard_index
¶
Convert an index local to a rank to a shard and shard index.
Parameters:
-
rank_index
(int
) –Dataset index local to the rank.
Returns:
-
tuple[str, int]
–Tuple of the shard key and the index within the shard that
rank_index
corresponds to.