llm.trainers.bert.utils
BERT pretraining utilities.
TrainingConfig
dataclass
¶
TrainingConfig(
PHASE: int,
BERT_CONFIG: dict[str, Any],
OPTIMIZER: Literal["adam", "lamb"],
CHECKPOINT_DIR: str,
TENSORBOARD_DIR: str,
DATASET_CONFIG: Union[
NvidiaBertDatasetConfig, RobertaDatasetConfig
],
GLOBAL_BATCH_SIZE: int,
BATCH_SIZE: int,
STEPS: int,
CHECKPOINT_STEPS: int,
LR: float,
WARMUP_STEPS: int,
ACCUMULATION_STEPS: int,
CLIP_GRAD_NORM: Optional[float] = None,
DTYPE: Optional[torch.dtype] = None,
GRADIENT_CHECKPOINTING: bool = False,
LOG_FILE: Optional[str] = None,
SEED: int = 42,
)
Training configuration.
parse_config() ¶
Parses a config ensuring all required options are present.
Source code in llm/trainers/bert/utils.py
checkpoint() ¶
checkpoint(
config: TrainingConfig,
global_step: int,
epoch: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler,
sampler_index: int = 0,
) -> None
Write a training checkpoint.
Source code in llm/trainers/bert/utils.py
load_state() ¶
load_state(
config: TrainingConfig,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer | None = None,
scheduler: (
torch.optim.lr_scheduler._LRScheduler | None
) = None,
) -> tuple[int, int, int]
Load the latest checkpoint if one exists.
Returns:
Source code in llm/trainers/bert/utils.py
get_optimizer_grouped_parameters() ¶
Get the parameters of the BERT model to optimizer.