Skip to content

llm.engine.accumulation

Utilities for easy gradient accumulation training.

GradientAccumulationOptimizer

GradientAccumulationOptimizer(
    optimizer: BaseOptimizer,
    model: torch.nn.Module,
    accumulation_steps: int,
)

Bases: BaseOptimizer

Optimizer wrapper for enabling gradient accumulation.

This wrapper will skip calls to BaseOptimizer.step() until accumulation_steps forward/backward passes have been performed.

Parameters:

  • optimizer (BaseOptimizer) –

    Optimizer to wrap.

  • model (Module) –

    Model being optimized.

  • accumulation_steps (int) –

    Number of iterations between optimization steps.

Source code in llm/engine/accumulation.py
def __init__(
    self,
    optimizer: BaseOptimizer,
    model: torch.nn.Module,
    accumulation_steps: int,
) -> None:
    super().__init__(optimizer)
    self._accumulation_steps = accumulation_steps
    self._accumulation_step = 0
    self._model = model

accumulation_boundary()

accumulation_boundary() -> bool

Return if the current step is an accumulation boundary.

I.e., the last call to step() resulted in an optimization step and no accumulation for the next step has started.

Source code in llm/engine/accumulation.py
def accumulation_boundary(self) -> bool:
    """Return if the current step is an accumulation boundary.

    I.e., the last call to
    [`step()`][llm.engine.accumulation.GradientAccumulationOptimizer.step]
    resulted in an optimization step and no accumulation for the next step
    has started.
    """
    return self._accumulation_step == 0

backward()

backward(loss: torch.Tensor) -> None

Perform a backward pass.

Note

If model is a DistributedDataParallel instance, backward passes will be performed with no_sync() during gradient accumulation steps.

Parameters:

  • loss (Tensor) –

    Loss to compute gradients with respect to.

Source code in llm/engine/accumulation.py
def backward(self, loss: torch.Tensor) -> None:
    """Perform a backward pass.

    Note:
        If `model` is a
        [`DistributedDataParallel`][torch.nn.parallel.DistributedDataParallel]
        instance, backward passes will be performed with
        [`no_sync()`][torch.nn.parallel.DistributedDataParallel.no_sync]
        during gradient accumulation steps.

    Args:
        loss: Loss to compute gradients with respect to.
    """
    self._accumulation_step += 1

    context = (
        self._model.no_sync()
        if (
            self._accumulation_step < self._accumulation_steps
            and isinstance(self._model, DistributedDataParallel)
        )
        else contextlib.nullcontext()
    )

    with context:
        scaled_loss = loss / self._accumulation_steps
        self._optimizer.backward(scaled_loss)

step()

step(*args: Any, **kwargs: Any) -> None

Perform an optimization step.

This method is a no-op unless accumulation_steps have occurred.

Source code in llm/engine/accumulation.py
def step(self, *args: Any, **kwargs: Any) -> None:
    """Perform an optimization step.

    This method is a no-op unless `accumulation_steps` have occurred.
    """
    if self._accumulation_step == self._accumulation_steps:
        self._accumulation_step = 0
        self._optimizer.step(*args, **kwargs)

zero_grad()

zero_grad(*args: Any, **kwargs: Any) -> None

Zero the gradients of the wrapped optimizer.

Source code in llm/engine/accumulation.py
def zero_grad(self, *args: Any, **kwargs: Any) -> None:
    """Zero the gradients of the wrapped optimizer."""
    if self._accumulation_step == 0:
        self._optimizer.zero_grad(*args, **kwargs)

GradientAccumulationLRScheduler

GradientAccumulationLRScheduler(
    scheduler: _LRScheduler, accumulation_steps: int
)

Bases: _LRScheduler

LR scheduler wrapper that accounts for gradient accumulation.

This wrapper allows you to call scheduler.step() after every forward/backward pass and will correctly skip the call if it happens during a gradient accumulation period.

Parameters:

  • scheduler (_LRScheduler) –

    LR scheduler to wrap.

  • accumulation_steps (int) –

    Number of iterations between optimization steps.

Source code in llm/engine/accumulation.py
def __init__(
    self,
    scheduler: _LRScheduler,
    accumulation_steps: int,
) -> None:
    self._accumulation_steps = accumulation_steps
    self._accumulation_step = 0
    self._scheduler = scheduler

step()

step(epoch: int | None = None) -> None

Update the learning rate.

This method is a no-op unless accumulation_steps have occurred.

Source code in llm/engine/accumulation.py
def step(self, epoch: int | None = None) -> None:
    """Update the learning rate.

    This method is a no-op unless `accumulation_steps` have occurred.
    """
    self._accumulation_step += 1
    if self._accumulation_step == self._accumulation_steps:
        self._accumulation_step = 0
        self._scheduler.step(epoch)

initialize()

initialize(
    model: torch.nn.Module,
    optimizer: BaseOptimizer,
    scheduler: _LRScheduler,
    accumulation_steps: int = 1,
) -> tuple[
    GradientAccumulationOptimizer,
    GradientAccumulationLRScheduler,
]

Initialize gradient accumulation training.

Parameters:

  • model (Module) –

    Model being optimized.

  • optimizer (BaseOptimizer) –

    Optimizer to wrap.

  • scheduler (_LRScheduler) –

    LR scheduler to wrap.

  • accumulation_steps (int, default: 1 ) –

    Number of iterations between optimization steps.

Returns:

Source code in llm/engine/accumulation.py
def initialize(
    model: torch.nn.Module,
    optimizer: BaseOptimizer,
    scheduler: _LRScheduler,
    accumulation_steps: int = 1,
) -> tuple[GradientAccumulationOptimizer, GradientAccumulationLRScheduler]:
    """Initialize gradient accumulation training.

    Args:
        model: Model being optimized.
        optimizer: Optimizer to wrap.
        scheduler: LR scheduler to wrap.
        accumulation_steps: Number of iterations between optimization steps.

    Returns:
        The wrapped optimizer and LR scheduler.
    """
    return (
        GradientAccumulationOptimizer(optimizer, model, accumulation_steps),
        GradientAccumulationLRScheduler(scheduler, accumulation_steps),
    )