Skip to content


Custom RoBERTa dataset provider.

This is designed to work with data produced by the RoBERTa encoder preprocessing script in llm.preprocess.roberta.


    input_file: Path | str,
    mask_token_id: int,
    mask_token_prob: float,
    vocab_size: int,

Bases: Dataset[Sample]

RoBERTa pretraining dataset.

Like the PyTorch Dataset, this dataset is indexable returning a Sample.

Samples are randomly masked as runtime using the provided parameters. Next sentence prediction is not supported.

>>> from llm.datasets.roberta import RoBERTaDataset
>>> dataset = RoBERTaDataset('/path/to/shard')
>>> dataset[5]


  • input_file (Path | str) –

    HDF5 file to load.

  • mask_token_id (int) –

    ID of the mask token in the vocabulary.

  • mask_token_prob (float) –

    Probability of a given token in the sample being masked.

  • vocab_size (int) –

    Size of the vocabulary. Used to replace masked tokens with a random token 10% of the time.

Source code in llm/datasets/
def __init__(
    input_file: pathlib.Path | str,
    mask_token_id: int,
    mask_token_prob: float,
    vocab_size: int,
) -> None:
    self.input_file = input_file
    self.mask_token_id = mask_token_id
    self.mask_token_prob = mask_token_prob
    self.vocab_size = vocab_size

    self.loaded = False
    with h5py.File(self.input_file, 'r') as f:
        self.samples = len(f['input_ids'])


    token_ids: LongTensor,
    special_tokens_mask: BoolTensor,
    mask_token_id: int,
    mask_token_prob: float,
    vocab_size: int,
) -> tuple[LongTensor, LongTensor]

Randomly mask a BERT training sequence.

Source: transformers/data/


  • token_ids (LongTensor) –

    Input sequence token IDs to mask.

  • special_tokens_mask (BoolTensor) –

    Mask of special tokens in the sequence which should never be masked.

  • mask_token_id (int) –

    ID of the mask token in the vocabulary.

  • mask_token_prob (float) –

    Probability of a given token in the sample being masked.

  • vocab_size (int) –

    Size of the vocabulary. Used to replace masked tokens with a random token 10% of the time.


  • tuple[LongTensor, LongTensor]

    Masked token_ids and the masked labels.

Source code in llm/datasets/
def bert_mask_sequence(
    token_ids: torch.LongTensor,
    special_tokens_mask: torch.BoolTensor,
    mask_token_id: int,
    mask_token_prob: float,
    vocab_size: int,
) -> tuple[torch.LongTensor, torch.LongTensor]:
    """Randomly mask a BERT training sequence.

    Source: [`transformers/data/`]({target=_blank}

        token_ids: Input sequence token IDs to mask.
        special_tokens_mask: Mask of special tokens in the sequence which
            should never be masked.
        mask_token_id: ID of the mask token in the vocabulary.
        mask_token_prob: Probability of a given token in the sample being
        vocab_size: Size of the vocabulary. Used to replace masked tokens with
            a random token 10% of the time.

        Masked `token_ids` and the masked labels.
    masked_labels = cast(torch.LongTensor, token_ids.clone())

    probability_matrix = torch.full(token_ids.shape, mask_token_prob)
    special_tokens_mask = special_tokens_mask.bool()
    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()

    # Set non-masked tokens to -100 so loss is only computed on masked tokens
    masked_labels[~masked_indices] = -100

    # 80% of the time replace masked token with [MASK]
    indices_replaced = (
        torch.bernoulli(torch.full(token_ids.shape, 0.8)).bool()
        & masked_indices
    token_ids[indices_replaced] = mask_token_id

    # 10% of the time replace masked tokens with random token
    indices_random = (
        torch.bernoulli(torch.full(token_ids.shape, 0.5)).bool()
        & masked_indices
        & ~indices_replaced
    random_words = torch.randint(vocab_size, token_ids.shape, dtype=torch.long)
    token_ids[indices_random] = random_words[indices_random]

    # The rest of the time (10% of the time) the masked tokens are unchanged
    return token_ids, masked_labels