mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
362 lines
14 KiB
Python
362 lines
14 KiB
Python
import math
|
|
from typing import List
|
|
import torch
|
|
from toolkit.optimizers.optimizer_utils import copy_stochastic, stochastic_grad_accummulation
|
|
from optimum.quanto import QBytesTensor
|
|
import random
|
|
|
|
|
|
class Adafactor(torch.optim.Optimizer):
|
|
"""
|
|
Adafactor implementation with stochastic rounding accumulation and stochastic rounding on apply.
|
|
Modified from transformers Adafactor implementation to support stochastic rounding accumulation and apply.
|
|
|
|
AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
|
|
https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
|
|
|
|
Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that
|
|
this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
|
|
`warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
|
|
`relative_step=False`.
|
|
|
|
Arguments:
|
|
params (`Iterable[nn.parameter.Parameter]`):
|
|
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
|
lr (`float`, *optional*):
|
|
The external learning rate.
|
|
eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`):
|
|
Regularization constants for square gradient and parameter scale respectively
|
|
clip_threshold (`float`, *optional*, defaults to 1.0):
|
|
Threshold of root mean square of final gradient update
|
|
decay_rate (`float`, *optional*, defaults to -0.8):
|
|
Coefficient used to compute running averages of square
|
|
beta1 (`float`, *optional*):
|
|
Coefficient used for computing running averages of gradient
|
|
weight_decay (`float`, *optional*, defaults to 0.0):
|
|
Weight decay (L2 penalty)
|
|
scale_parameter (`bool`, *optional*, defaults to `True`):
|
|
If True, learning rate is scaled by root mean square
|
|
relative_step (`bool`, *optional*, defaults to `True`):
|
|
If True, time-dependent learning rate is computed instead of external learning rate
|
|
warmup_init (`bool`, *optional*, defaults to `False`):
|
|
Time-dependent learning rate computation depends on whether warm-up initialization is being used
|
|
|
|
This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
|
|
|
|
Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
|
|
|
|
- Training without LR warmup or clip_threshold is not recommended.
|
|
|
|
- use scheduled LR warm-up to fixed LR
|
|
- use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235)
|
|
- Disable relative updates
|
|
- Use scale_parameter=False
|
|
- Additional optimizer operations like gradient clipping should not be used alongside Adafactor
|
|
|
|
Example:
|
|
|
|
```python
|
|
Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
|
|
```
|
|
|
|
Others reported the following combination to work well:
|
|
|
|
```python
|
|
Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
|
```
|
|
|
|
When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
|
|
scheduler as following:
|
|
|
|
```python
|
|
from transformers.optimization import Adafactor, AdafactorSchedule
|
|
|
|
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
|
lr_scheduler = AdafactorSchedule(optimizer)
|
|
trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
|
|
```
|
|
|
|
Usage:
|
|
|
|
```python
|
|
# replace AdamW with Adafactor
|
|
optimizer = Adafactor(
|
|
model.parameters(),
|
|
lr=1e-3,
|
|
eps=(1e-30, 1e-3),
|
|
clip_threshold=1.0,
|
|
decay_rate=-0.8,
|
|
beta1=None,
|
|
weight_decay=0.0,
|
|
relative_step=False,
|
|
scale_parameter=False,
|
|
warmup_init=False,
|
|
)
|
|
```"""
|
|
|
|
def __init__(
|
|
self,
|
|
params,
|
|
lr=None,
|
|
eps=(1e-30, 1e-3),
|
|
clip_threshold=1.0,
|
|
decay_rate=-0.8,
|
|
beta1=None,
|
|
weight_decay=0.0,
|
|
scale_parameter=True,
|
|
relative_step=True,
|
|
warmup_init=False,
|
|
do_paramiter_swapping=False,
|
|
paramiter_swapping_factor=0.1,
|
|
stochastic_accumulation=True,
|
|
):
|
|
if lr is not None and relative_step:
|
|
raise ValueError(
|
|
"Cannot combine manual `lr` and `relative_step=True` options")
|
|
if warmup_init and not relative_step:
|
|
raise ValueError(
|
|
"`warmup_init=True` requires `relative_step=True`")
|
|
|
|
defaults = {
|
|
"lr": lr,
|
|
"eps": eps,
|
|
"clip_threshold": clip_threshold,
|
|
"decay_rate": decay_rate,
|
|
"beta1": beta1,
|
|
"weight_decay": weight_decay,
|
|
"scale_parameter": scale_parameter,
|
|
"relative_step": relative_step,
|
|
"warmup_init": warmup_init,
|
|
}
|
|
super().__init__(params, defaults)
|
|
|
|
self.base_lrs: List[float] = [
|
|
lr for group in self.param_groups
|
|
]
|
|
|
|
self.is_stochastic_rounding_accumulation = False
|
|
|
|
# setup stochastic grad accum hooks
|
|
if stochastic_accumulation:
|
|
for group in self.param_groups:
|
|
for param in group['params']:
|
|
if param.requires_grad and param.dtype != torch.float32:
|
|
self.is_stochastic_rounding_accumulation = True
|
|
param.register_post_accumulate_grad_hook(
|
|
stochastic_grad_accummulation
|
|
)
|
|
|
|
self.do_paramiter_swapping = do_paramiter_swapping
|
|
self.paramiter_swapping_factor = paramiter_swapping_factor
|
|
self._total_paramiter_size = 0
|
|
# count total paramiters
|
|
for group in self.param_groups:
|
|
for param in group['params']:
|
|
self._total_paramiter_size += torch.numel(param)
|
|
# pretty print total paramiters with comma seperation
|
|
print(f"Total training paramiters: {self._total_paramiter_size:,}")
|
|
|
|
# needs to be enabled to count paramiters
|
|
if self.do_paramiter_swapping:
|
|
self.enable_paramiter_swapping(self.paramiter_swapping_factor)
|
|
|
|
|
|
def enable_paramiter_swapping(self, paramiter_swapping_factor=0.1):
|
|
self.do_paramiter_swapping = True
|
|
self.paramiter_swapping_factor = paramiter_swapping_factor
|
|
# call it an initial time
|
|
self.swap_paramiters()
|
|
|
|
def swap_paramiters(self):
|
|
all_params = []
|
|
# deactivate all paramiters
|
|
for group in self.param_groups:
|
|
for param in group['params']:
|
|
param.requires_grad_(False)
|
|
# remove any grad
|
|
param.grad = None
|
|
all_params.append(param)
|
|
# shuffle all paramiters
|
|
random.shuffle(all_params)
|
|
|
|
# keep activating paramiters until we are going to go over the target paramiters
|
|
target_paramiters = int(self._total_paramiter_size * self.paramiter_swapping_factor)
|
|
total_paramiters = 0
|
|
for param in all_params:
|
|
total_paramiters += torch.numel(param)
|
|
if total_paramiters >= target_paramiters:
|
|
break
|
|
else:
|
|
param.requires_grad_(True)
|
|
|
|
@staticmethod
|
|
def _get_lr(param_group, param_state):
|
|
rel_step_sz = param_group["lr"]
|
|
if param_group["relative_step"]:
|
|
min_step = 1e-6 * \
|
|
param_state["step"] if param_group["warmup_init"] else 1e-2
|
|
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
|
|
param_scale = 1.0
|
|
if param_group["scale_parameter"]:
|
|
param_scale = max(param_group["eps"][1], param_state["RMS"])
|
|
return param_scale * rel_step_sz
|
|
|
|
@staticmethod
|
|
def _get_options(param_group, param_shape):
|
|
factored = len(param_shape) >= 2
|
|
use_first_moment = param_group["beta1"] is not None
|
|
return factored, use_first_moment
|
|
|
|
@staticmethod
|
|
def _rms(tensor):
|
|
return tensor.norm(2) / (tensor.numel() ** 0.5)
|
|
|
|
@staticmethod
|
|
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
|
|
# copy from fairseq's adafactor implementation:
|
|
# https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
|
|
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-
|
|
1, keepdim=True)).rsqrt_().unsqueeze(-1)
|
|
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
|
return torch.mul(r_factor, c_factor)
|
|
|
|
def step_hook(self):
|
|
if not self.is_stochastic_rounding_accumulation:
|
|
return
|
|
# copy over stochastically rounded grads
|
|
for group in self.param_groups:
|
|
for param in group['params']:
|
|
if param.requires_grad and hasattr(param, "_accum_grad"):
|
|
param.grad = param._accum_grad
|
|
del param._accum_grad
|
|
|
|
# adafactor manages its own lr
|
|
def get_learning_rates(self):
|
|
lrs = [
|
|
self._get_lr(group, self.state[group["params"][0]])
|
|
for group in self.param_groups
|
|
if group["params"][0].grad is not None
|
|
]
|
|
if len(lrs) == 0:
|
|
lrs = self.base_lrs # if called before stepping
|
|
return lrs
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure=None):
|
|
"""
|
|
Performs a single optimization step
|
|
|
|
Arguments:
|
|
closure (callable, optional): A closure that reevaluates the model
|
|
and returns the loss.
|
|
"""
|
|
self.step_hook()
|
|
loss = None
|
|
if closure is not None:
|
|
loss = closure()
|
|
|
|
for group in self.param_groups:
|
|
for p in group["params"]:
|
|
if p.grad is None or not p.requires_grad:
|
|
continue
|
|
|
|
grad = p.grad
|
|
if grad.dtype != torch.float32:
|
|
grad = grad.to(torch.float32)
|
|
if grad.is_sparse:
|
|
raise RuntimeError(
|
|
"Adafactor does not support sparse gradients.")
|
|
|
|
# if p has atts _scale then it is quantized. We need to divide the grad by the scale
|
|
# if hasattr(p, "_scale"):
|
|
# grad = grad / p._scale
|
|
|
|
state = self.state[p]
|
|
grad_shape = grad.shape
|
|
|
|
factored, use_first_moment = self._get_options(
|
|
group, grad_shape)
|
|
# State Initialization
|
|
if len(state) == 0:
|
|
state["step"] = 0
|
|
|
|
if use_first_moment:
|
|
# Exponential moving average of gradient values
|
|
state["exp_avg"] = torch.zeros_like(grad)
|
|
if factored:
|
|
state["exp_avg_sq_row"] = torch.zeros(
|
|
grad_shape[:-1]).to(grad)
|
|
state["exp_avg_sq_col"] = torch.zeros(
|
|
grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
|
else:
|
|
state["exp_avg_sq"] = torch.zeros_like(grad)
|
|
|
|
state["RMS"] = 0
|
|
else:
|
|
if use_first_moment:
|
|
state["exp_avg"] = state["exp_avg"].to(grad)
|
|
if factored:
|
|
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(
|
|
grad)
|
|
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(
|
|
grad)
|
|
else:
|
|
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
|
|
|
p_data_fp32 = p
|
|
|
|
if isinstance(p_data_fp32, QBytesTensor):
|
|
p_data_fp32 = p_data_fp32.dequantize()
|
|
if p.dtype != torch.float32:
|
|
p_data_fp32 = p_data_fp32.clone().float()
|
|
|
|
state["step"] += 1
|
|
state["RMS"] = self._rms(p_data_fp32)
|
|
lr = self._get_lr(group, state)
|
|
|
|
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
|
eps = group["eps"]
|
|
if isinstance(eps, tuple) or isinstance(eps, list):
|
|
eps = eps[0]
|
|
update = (grad**2) + eps
|
|
if factored:
|
|
exp_avg_sq_row = state["exp_avg_sq_row"]
|
|
exp_avg_sq_col = state["exp_avg_sq_col"]
|
|
|
|
exp_avg_sq_row.mul_(beta2t).add_(
|
|
update.mean(dim=-1), alpha=(1.0 - beta2t))
|
|
exp_avg_sq_col.mul_(beta2t).add_(
|
|
update.mean(dim=-2), alpha=(1.0 - beta2t))
|
|
|
|
# Approximation of exponential moving average of square of gradient
|
|
update = self._approx_sq_grad(
|
|
exp_avg_sq_row, exp_avg_sq_col)
|
|
update.mul_(grad)
|
|
else:
|
|
exp_avg_sq = state["exp_avg_sq"]
|
|
|
|
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
|
|
update = exp_avg_sq.rsqrt().mul_(grad)
|
|
|
|
update.div_(
|
|
(self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
|
update.mul_(lr)
|
|
|
|
if use_first_moment:
|
|
exp_avg = state["exp_avg"]
|
|
exp_avg.mul_(group["beta1"]).add_(
|
|
update, alpha=(1 - group["beta1"]))
|
|
update = exp_avg
|
|
|
|
if group["weight_decay"] != 0:
|
|
p_data_fp32.add_(
|
|
p_data_fp32, alpha=(-group["weight_decay"] * lr))
|
|
|
|
p_data_fp32.add_(-update)
|
|
|
|
if p.dtype != torch.float32:
|
|
# apply stochastic rounding
|
|
copy_stochastic(p, p_data_fp32)
|
|
|
|
return loss
|