lars
References
- https://arxiv.org/pdf/1708.03888.pdf
- https://github.com/pytorch/pytorch/blob/1.6/torch/optim/sgd.py
LARS
Bases: Optimizer
Extends SGD in PyTorch with LARS scaling from the paper
Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>
_.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
trust_coefficient (float, optional): trust coefficient for computing LR (default: 0.001)
eps (float, optional): eps for division denominator (default: 1e-8)
Example
model = torch.nn.Linear(10, 1) input = torch.Tensor(10) target = torch.Tensor([1.]) loss_fn = lambda input, target: (input - target) ** 2
optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9) optimizer.zero_grad() loss_fn(model(input), target).backward() optimizer.step()
.. note:: The application of momentum in the SGD part is modified according to the PyTorch standards. LARS scaling fits into the equation in the following fashion.
.. math::
egin{aligned}
g_{t+1} & = ext{lars_lr} * (eta * p_{t} + g_{t+1}), \
v_{t+1} & = \mu * v_{t} + g_{t+1}, \
p_{t+1} & = p_{t} - ext{lr} * v_{t+1},
\end{aligned}
where :math:`p`, :math:`g`, :math:`v`, :math:`\mu` and :math:`eta` denote the
parameters, gradient, velocity, momentum, and weight decay respectively.
The :math:`lars_lr` is defined by Eq. 6 in the paper.
The Nesterov version is analogously modified.
.. warning:: Parameters with weight decay set to 0 will automatically be excluded from layer-wise LR scaling. This is to ensure consistency with papers like SimCLR and BYOL.
Source code in fmcib/optimizers/lars.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
|
__init__(params, lr=required, momentum=0, dampening=0, weight_decay=0, nesterov=False, trust_coefficient=0.001, eps=1e-08)
Initialize an optimizer with the given parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params
|
iterable
|
Iterable of parameters to optimize. |
required |
lr
|
float
|
Learning rate. Default is required. |
required
|
momentum
|
float
|
Momentum factor. Default is 0. |
0
|
dampening
|
float
|
Dampening for momentum. Default is 0. |
0
|
weight_decay
|
float
|
Weight decay factor. Default is 0. |
0
|
nesterov
|
bool
|
Use Nesterov momentum. Default is False. |
False
|
trust_coefficient
|
float
|
Trust coefficient. Default is 0.001. |
0.001
|
eps
|
float
|
Small value for numerical stability. Default is 1e-08. |
1e-08
|
Raises:
Type | Description |
---|---|
ValueError
|
If an invalid value is provided for lr, momentum, or weight_decay. |
ValueError
|
If nesterov momentum is enabled without providing a momentum and zero dampening. |
Source code in fmcib/optimizers/lars.py
__setstate__(state)
Set the state of the optimizer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
dict
|
A dictionary containing the state of the optimizer. |
required |
Returns:
Type | Description |
---|---|
None |
Note
This method is an override of the __setstate__
method of the superclass. It sets the state of the optimizer using the provided dictionary. Additionally, it sets the nesterov
parameter in each group of the optimizer to False
if it is not already present.
Source code in fmcib/optimizers/lars.py
step(closure=None)
Performs a single optimization step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
closure
|
callable
|
A closure that reevaluates the model and returns the loss. |
None
|