Skip to content

lars

References
  • https://arxiv.org/pdf/1708.03888.pdf
  • https://github.com/pytorch/pytorch/blob/1.6/torch/optim/sgd.py

LARS

Bases: Optimizer

Extends SGD in PyTorch with LARS scaling from the paper Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>_. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float): learning rate momentum (float, optional): momentum factor (default: 0) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) dampening (float, optional): dampening for momentum (default: 0) nesterov (bool, optional): enables Nesterov momentum (default: False) trust_coefficient (float, optional): trust coefficient for computing LR (default: 0.001) eps (float, optional): eps for division denominator (default: 1e-8)

Example

model = torch.nn.Linear(10, 1) input = torch.Tensor(10) target = torch.Tensor([1.]) loss_fn = lambda input, target: (input - target) ** 2

optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9) optimizer.zero_grad() loss_fn(model(input), target).backward() optimizer.step()

.. note:: The application of momentum in the SGD part is modified according to the PyTorch standards. LARS scaling fits into the equation in the following fashion.

.. math::
    egin{aligned}
        g_{t+1} & =     ext{lars_lr} * (eta * p_{t} + g_{t+1}), \
        v_{t+1} & = \mu * v_{t} + g_{t+1}, \
        p_{t+1} & = p_{t} -     ext{lr} * v_{t+1},
    \end{aligned}

where :math:`p`, :math:`g`, :math:`v`, :math:`\mu` and :math:`eta` denote the
parameters, gradient, velocity, momentum, and weight decay respectively.
The :math:`lars_lr` is defined by Eq. 6 in the paper.
The Nesterov version is analogously modified.

.. warning:: Parameters with weight decay set to 0 will automatically be excluded from layer-wise LR scaling. This is to ensure consistency with papers like SimCLR and BYOL.

Source code in fmcib/optimizers/lars.py
class LARS(Optimizer):
    """Extends SGD in PyTorch with LARS scaling from the paper
    `Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>`_.
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): learning rate
        momentum (float, optional): momentum factor (default: 0)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        dampening (float, optional): dampening for momentum (default: 0)
        nesterov (bool, optional): enables Nesterov momentum (default: False)
        trust_coefficient (float, optional): trust coefficient for computing LR (default: 0.001)
        eps (float, optional): eps for division denominator (default: 1e-8)

    Example:
        >>> model = torch.nn.Linear(10, 1)
        >>> input = torch.Tensor(10)
        >>> target = torch.Tensor([1.])
        >>> loss_fn = lambda input, target: (input - target) ** 2
        >>> #
        >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
        >>> optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> optimizer.step()

    .. note::
        The application of momentum in the SGD part is modified according to
        the PyTorch standards. LARS scaling fits into the equation in the
        following fashion.

        .. math::
            \begin{aligned}
                g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\
                v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\
                p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
            \\end{aligned}

        where :math:`p`, :math:`g`, :math:`v`, :math:`\\mu` and :math:`\beta` denote the
        parameters, gradient, velocity, momentum, and weight decay respectively.
        The :math:`lars_lr` is defined by Eq. 6 in the paper.
        The Nesterov version is analogously modified.

    .. warning::
        Parameters with weight decay set to 0 will automatically be excluded from
        layer-wise LR scaling. This is to ensure consistency with papers like SimCLR
        and BYOL.
    """

    def __init__(
        self,
        params,
        lr=required,
        momentum=0,
        dampening=0,
        weight_decay=0,
        nesterov=False,
        trust_coefficient=0.001,
        eps=1e-8,
    ):
        """
        Initialize an optimizer with the given parameters.

        Args:
            params (iterable): Iterable of parameters to optimize.
            lr (float, optional): Learning rate. Default is required.
            momentum (float, optional): Momentum factor. Default is 0.
            dampening (float, optional): Dampening for momentum. Default is 0.
            weight_decay (float, optional): Weight decay factor. Default is 0.
            nesterov (bool, optional): Use Nesterov momentum. Default is False.
            trust_coefficient (float, optional): Trust coefficient. Default is 0.001.
            eps (float, optional): Small value for numerical stability. Default is 1e-08.

        Raises:
            ValueError: If an invalid value is provided for lr, momentum, or weight_decay.
            ValueError: If nesterov momentum is enabled without providing a momentum and zero dampening.
        """
        if lr is not required and lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if momentum < 0.0:
            raise ValueError(f"Invalid momentum value: {momentum}")
        if weight_decay < 0.0:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")

        defaults = dict(
            lr=lr,
            momentum=momentum,
            dampening=dampening,
            weight_decay=weight_decay,
            nesterov=nesterov,
            trust_coefficient=trust_coefficient,
            eps=eps,
        )
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")

        super().__init__(params, defaults)

    def __setstate__(self, state):
        """
        Set the state of the optimizer.

        Args:
            state (dict): A dictionary containing the state of the optimizer.

        Returns:
            None

        Note:
            This method is an override of the `__setstate__` method of the superclass. It sets the state of the optimizer using the provided dictionary. Additionally, it sets the `nesterov` parameter in each group of the optimizer to `False` if it is not already present.
        """
        super().__setstate__(state)

        for group in self.param_groups:
            group.setdefault("nesterov", False)

    @torch.no_grad()
    def step(self, closure=None):
        """
        Performs a single optimization step.

        Parameters:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        # exclude scaling for params with 0 weight decay
        for group in self.param_groups:
            weight_decay = group["weight_decay"]
            momentum = group["momentum"]
            dampening = group["dampening"]
            nesterov = group["nesterov"]

            for p in group["params"]:
                if p.grad is None:
                    continue

                d_p = p.grad
                p_norm = torch.norm(p.data)
                g_norm = torch.norm(p.grad.data)

                # lars scaling + weight decay part
                if weight_decay != 0:
                    if p_norm != 0 and g_norm != 0:
                        lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"])
                        lars_lr *= group["trust_coefficient"]

                        d_p = d_p.add(p, alpha=weight_decay)
                        d_p *= lars_lr

                # sgd part
                if momentum != 0:
                    param_state = self.state[p]
                    if "momentum_buffer" not in param_state:
                        buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
                    else:
                        buf = param_state["momentum_buffer"]
                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
                    if nesterov:
                        d_p = d_p.add(buf, alpha=momentum)
                    else:
                        d_p = buf

                p.add_(d_p, alpha=-group["lr"])

        return loss

__init__(params, lr=required, momentum=0, dampening=0, weight_decay=0, nesterov=False, trust_coefficient=0.001, eps=1e-08)

Initialize an optimizer with the given parameters.

Parameters:

Name Type Description Default
params iterable

Iterable of parameters to optimize.

required
lr float

Learning rate. Default is required.

required
momentum float

Momentum factor. Default is 0.

0
dampening float

Dampening for momentum. Default is 0.

0
weight_decay float

Weight decay factor. Default is 0.

0
nesterov bool

Use Nesterov momentum. Default is False.

False
trust_coefficient float

Trust coefficient. Default is 0.001.

0.001
eps float

Small value for numerical stability. Default is 1e-08.

1e-08

Raises:

Type Description
ValueError

If an invalid value is provided for lr, momentum, or weight_decay.

ValueError

If nesterov momentum is enabled without providing a momentum and zero dampening.

Source code in fmcib/optimizers/lars.py
def __init__(
    self,
    params,
    lr=required,
    momentum=0,
    dampening=0,
    weight_decay=0,
    nesterov=False,
    trust_coefficient=0.001,
    eps=1e-8,
):
    """
    Initialize an optimizer with the given parameters.

    Args:
        params (iterable): Iterable of parameters to optimize.
        lr (float, optional): Learning rate. Default is required.
        momentum (float, optional): Momentum factor. Default is 0.
        dampening (float, optional): Dampening for momentum. Default is 0.
        weight_decay (float, optional): Weight decay factor. Default is 0.
        nesterov (bool, optional): Use Nesterov momentum. Default is False.
        trust_coefficient (float, optional): Trust coefficient. Default is 0.001.
        eps (float, optional): Small value for numerical stability. Default is 1e-08.

    Raises:
        ValueError: If an invalid value is provided for lr, momentum, or weight_decay.
        ValueError: If nesterov momentum is enabled without providing a momentum and zero dampening.
    """
    if lr is not required and lr < 0.0:
        raise ValueError(f"Invalid learning rate: {lr}")
    if momentum < 0.0:
        raise ValueError(f"Invalid momentum value: {momentum}")
    if weight_decay < 0.0:
        raise ValueError(f"Invalid weight_decay value: {weight_decay}")

    defaults = dict(
        lr=lr,
        momentum=momentum,
        dampening=dampening,
        weight_decay=weight_decay,
        nesterov=nesterov,
        trust_coefficient=trust_coefficient,
        eps=eps,
    )
    if nesterov and (momentum <= 0 or dampening != 0):
        raise ValueError("Nesterov momentum requires a momentum and zero dampening")

    super().__init__(params, defaults)

__setstate__(state)

Set the state of the optimizer.

Parameters:

Name Type Description Default
state dict

A dictionary containing the state of the optimizer.

required

Returns:

Type Description

None

Note

This method is an override of the __setstate__ method of the superclass. It sets the state of the optimizer using the provided dictionary. Additionally, it sets the nesterov parameter in each group of the optimizer to False if it is not already present.

Source code in fmcib/optimizers/lars.py
def __setstate__(self, state):
    """
    Set the state of the optimizer.

    Args:
        state (dict): A dictionary containing the state of the optimizer.

    Returns:
        None

    Note:
        This method is an override of the `__setstate__` method of the superclass. It sets the state of the optimizer using the provided dictionary. Additionally, it sets the `nesterov` parameter in each group of the optimizer to `False` if it is not already present.
    """
    super().__setstate__(state)

    for group in self.param_groups:
        group.setdefault("nesterov", False)

step(closure=None)

Performs a single optimization step.

Parameters:

Name Type Description Default
closure callable

A closure that reevaluates the model and returns the loss.

None
Source code in fmcib/optimizers/lars.py
@torch.no_grad()
def step(self, closure=None):
    """
    Performs a single optimization step.

    Parameters:
        closure (callable, optional): A closure that reevaluates the model
            and returns the loss.
    """
    loss = None
    if closure is not None:
        with torch.enable_grad():
            loss = closure()

    # exclude scaling for params with 0 weight decay
    for group in self.param_groups:
        weight_decay = group["weight_decay"]
        momentum = group["momentum"]
        dampening = group["dampening"]
        nesterov = group["nesterov"]

        for p in group["params"]:
            if p.grad is None:
                continue

            d_p = p.grad
            p_norm = torch.norm(p.data)
            g_norm = torch.norm(p.grad.data)

            # lars scaling + weight decay part
            if weight_decay != 0:
                if p_norm != 0 and g_norm != 0:
                    lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"])
                    lars_lr *= group["trust_coefficient"]

                    d_p = d_p.add(p, alpha=weight_decay)
                    d_p *= lars_lr

            # sgd part
            if momentum != 0:
                param_state = self.state[p]
                if "momentum_buffer" not in param_state:
                    buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
                else:
                    buf = param_state["momentum_buffer"]
                    buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
                if nesterov:
                    d_p = d_p.add(buf, alpha=momentum)
                else:
                    d_p = buf

            p.add_(d_p, alpha=-group["lr"])

    return loss