Skip to content

llm.engine.amp

Utilities for easy automatic mixed precision training.

AMPCriterion

AMPCriterion(
    criterion: torch.nn.Module, autocast: torch.autocast
)

Bases: Module

Wrap a loss function for AMP training.

Parameters:

  • criterion (Module) –

    Loss function to wrap.

  • autocast (autocast) –

    Autocast context manager to compute loss inside.

Source code in llm/engine/amp.py
def __init__(
    self,
    criterion: torch.nn.Module,
    autocast: torch.autocast,
) -> None:
    super().__init__()
    self._criterion = criterion
    self._autocast = autocast

forward()

forward(*args: Any, **kwargs: Any) -> Any

Compute the loss inside the autocast.

Source code in llm/engine/amp.py
def forward(self, *args: Any, **kwargs: Any) -> Any:
    """Compute the loss inside the autocast."""
    with self._autocast:
        return self._criterion(*args, **kwargs)

AMPModel

AMPModel(model: torch.nn.Module, autocast: torch.autocast)

Bases: Module

Wrap a model for AMP training.

Parameters:

  • model (Module) –

    Model to wrap.

  • autocast (autocast) –

    Autocast context manager to compute loss inside.

Source code in llm/engine/amp.py
def __init__(
    self,
    model: torch.nn.Module,
    autocast: torch.autocast,
) -> None:
    super().__init__()
    self._model = model
    self._autocast = autocast

forward()

forward(*args: Any, **kwargs: Any) -> Any

Perform a forward pass inside the autocast.

Source code in llm/engine/amp.py
def forward(self, *args: Any, **kwargs: Any) -> Any:
    """Perform a forward pass inside the autocast."""
    with self._autocast:
        return self._model(*args, **kwargs)

AMPOptimizer

AMPOptimizer(
    model: torch.nn.Module,
    optimizer: Optimizer,
    scaler: GradScaler,
    max_norm: float | None = None,
)

Bases: BaseOptimizer

Wrap an optimizer for AMP training.

Parameters:

  • model (Module) –

    Model being optimized.

  • optimizer (Optimizer) –

    Optimizer to wrap.

  • scaler (GradScaler) –

    Gradient scaler.

  • max_norm (float | None, default: None ) –

    Optionally clip gradient norm.

Source code in llm/engine/amp.py
def __init__(
    self,
    model: torch.nn.Module,
    optimizer: Optimizer,
    scaler: GradScaler,
    max_norm: float | None = None,
) -> None:
    super().__init__(optimizer)
    self._model = model
    self._scaler = scaler
    self._max_norm = max_norm

zero_grad()

zero_grad(*args: Any, **kwargs: Any) -> None

Zero the gradients of optimized parameters.

Source code in llm/engine/base.py
def zero_grad(self, *args: Any, **kwargs: Any) -> None:
    """Zero the gradients of optimized parameters."""
    self._optimizer.zero_grad(*args, **kwargs)

state_dict()

state_dict() -> dict[str, Any]

Dictionary containing references to the whole state of the module.

Includes the state of the grad_scaler.

Source code in llm/engine/amp.py
def state_dict(self) -> dict[str, Any]:
    """Dictionary containing references to the whole state of the module.

    Includes the state of the `grad_scaler`.
    """
    state_dict = self._optimizer.state_dict()
    assert 'grad_scaler' not in state_dict
    state_dict['grad_scaler'] = self._scaler.state_dict()
    return state_dict

load_state_dict()

load_state_dict(state_dict: dict[str, Any]) -> None

Copy the state into this module.

Source code in llm/engine/amp.py
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
    """Copy the state into this module."""
    scaler_state_dict = state_dict.pop('grad_scaler', None)
    if scaler_state_dict is not None:  # pragma: no branch
        self._scaler.load_state_dict(scaler_state_dict)
    self._optimizer.load_state_dict(state_dict)

backward()

backward(loss: torch.Tensor) -> None

Perform a backward pass and correctly scale the loss.

Source code in llm/engine/amp.py
def backward(self, loss: torch.Tensor) -> None:
    """Perform a backward pass and correctly scale the loss."""
    self._scaler.scale(loss).backward()

step()

step(closure: Callable[[], float] | None = None) -> None

Perform an optimization using the gradient scaler.

Source code in llm/engine/amp.py
def step(self, closure: Callable[[], float] | None = None) -> None:
    """Perform an optimization using the gradient scaler."""
    if self._max_norm is not None:
        self._scaler.unscale_(self._optimizer)
        torch.nn.utils.clip_grad_norm_(
            self._model.parameters(),
            self._max_norm,
        )
    self._scaler.step(self._optimizer)
    self._scaler.update()

initialize()

initialize(
    model: torch.nn.Module,
    optimizer: Optimizer,
    criterion: torch.nn.Module,
    dtype: torch.dtype = torch.float16,
    max_norm: float | None = None,
    **kwargs: Any
) -> tuple[AMPModel, AMPOptimizer, AMPCriterion]

Initialize AMP training.

Parameters:

  • model (Module) –

    Model being optimized.

  • optimizer (Optimizer) –

    Optimizer to wrap.

  • criterion (Module) –

    Loss function to wrap.

  • dtype (dtype, default: float16 ) –

    Data type to perform mixed precision in. Typically torch.float16 or torch.bfloat16.

  • max_norm (float | None, default: None ) –

    Optionally clip gradient norm.

  • kwargs (Any, default: {} ) –

    Additional keyword arguments to pass to the GradScaler.

Returns:

Source code in llm/engine/amp.py
def initialize(
    model: torch.nn.Module,
    optimizer: Optimizer,
    criterion: torch.nn.Module,
    dtype: torch.dtype = torch.float16,
    max_norm: float | None = None,
    **kwargs: Any,
) -> tuple[AMPModel, AMPOptimizer, AMPCriterion]:
    """Initialize AMP training.

    Args:
        model: Model being optimized.
        optimizer: Optimizer to wrap.
        criterion: Loss function to wrap.
        dtype: Data type to perform mixed precision in. Typically
            `torch.float16` or `torch.bfloat16`.
        max_norm: Optionally clip gradient norm.
        kwargs: Additional keyword arguments to pass to the
            [`GradScaler`][torch.cuda.amp.GradScaler].

    Returns:
        A tuple of the wrapped model, optimizer, and loss.
    """
    device = 'cuda' if next(model.parameters()).is_cuda else 'cpu'
    autocast = torch.autocast(device, dtype=dtype)

    # GradScaler only works on CUDA tensors so we disable on CPU
    scaler = GradScaler(**kwargs, enabled=device == 'cuda')

    return (
        AMPModel(model, autocast),
        AMPOptimizer(model, optimizer, scaler, max_norm),
        AMPCriterion(criterion, autocast),
    )