mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added adafactor implementation that handles stochastic rounding of update and accumulation
This commit is contained in:
@@ -1750,7 +1750,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
with torch.no_grad():
|
||||
# torch.cuda.empty_cache()
|
||||
if self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
# if optimizer has get_lrs method, then use it
|
||||
if hasattr(optimizer, 'get_learning_rates'):
|
||||
learning_rate = optimizer.get_learning_rates()[0]
|
||||
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
self.train_config.optimizer.lower().startswith('prodigy'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
|
||||
@@ -77,9 +77,9 @@ def get_optimizer(
|
||||
except ImportError:
|
||||
raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch")
|
||||
elif lower_type == 'adagrad':
|
||||
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
|
||||
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params)
|
||||
elif lower_type == 'adafactor':
|
||||
# hack in stochastic rounding
|
||||
from toolkit.optimizers.adafactor import Adafactor
|
||||
if 'relative_step' not in optimizer_params:
|
||||
optimizer_params['relative_step'] = False
|
||||
if 'scale_parameter' not in optimizer_params:
|
||||
@@ -87,8 +87,6 @@ def get_optimizer(
|
||||
if 'warmup_init' not in optimizer_params:
|
||||
optimizer_params['warmup_init'] = False
|
||||
optimizer = Adafactor(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
|
||||
from toolkit.util.adafactor_stochastic_rounding import step_adafactor
|
||||
optimizer.step = step_adafactor.__get__(optimizer, Adafactor)
|
||||
else:
|
||||
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
||||
return optimizer
|
||||
|
||||
305
toolkit/optimizers/adafactor.py
Normal file
305
toolkit/optimizers/adafactor.py
Normal file
@@ -0,0 +1,305 @@
|
||||
import math
|
||||
from typing import List
|
||||
import torch
|
||||
from toolkit.optimizers.optimizer_utils import copy_stochastic, stochastic_grad_accummulation
|
||||
|
||||
|
||||
class Adafactor(torch.optim.Optimizer):
|
||||
"""
|
||||
Adafactor implementation with stochastic rounding accumulation and stochastic rounding on apply.
|
||||
Modified from transformers Adafactor implementation to support stochastic rounding accumulation and apply.
|
||||
|
||||
AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
|
||||
https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
|
||||
|
||||
Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that
|
||||
this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
|
||||
`warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
|
||||
`relative_step=False`.
|
||||
|
||||
Arguments:
|
||||
params (`Iterable[nn.parameter.Parameter]`):
|
||||
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
||||
lr (`float`, *optional*):
|
||||
The external learning rate.
|
||||
eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`):
|
||||
Regularization constants for square gradient and parameter scale respectively
|
||||
clip_threshold (`float`, *optional*, defaults to 1.0):
|
||||
Threshold of root mean square of final gradient update
|
||||
decay_rate (`float`, *optional*, defaults to -0.8):
|
||||
Coefficient used to compute running averages of square
|
||||
beta1 (`float`, *optional*):
|
||||
Coefficient used for computing running averages of gradient
|
||||
weight_decay (`float`, *optional*, defaults to 0.0):
|
||||
Weight decay (L2 penalty)
|
||||
scale_parameter (`bool`, *optional*, defaults to `True`):
|
||||
If True, learning rate is scaled by root mean square
|
||||
relative_step (`bool`, *optional*, defaults to `True`):
|
||||
If True, time-dependent learning rate is computed instead of external learning rate
|
||||
warmup_init (`bool`, *optional*, defaults to `False`):
|
||||
Time-dependent learning rate computation depends on whether warm-up initialization is being used
|
||||
|
||||
This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
|
||||
|
||||
Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
|
||||
|
||||
- Training without LR warmup or clip_threshold is not recommended.
|
||||
|
||||
- use scheduled LR warm-up to fixed LR
|
||||
- use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235)
|
||||
- Disable relative updates
|
||||
- Use scale_parameter=False
|
||||
- Additional optimizer operations like gradient clipping should not be used alongside Adafactor
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
|
||||
```
|
||||
|
||||
Others reported the following combination to work well:
|
||||
|
||||
```python
|
||||
Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
||||
```
|
||||
|
||||
When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
|
||||
scheduler as following:
|
||||
|
||||
```python
|
||||
from transformers.optimization import Adafactor, AdafactorSchedule
|
||||
|
||||
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
||||
lr_scheduler = AdafactorSchedule(optimizer)
|
||||
trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
|
||||
```
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
# replace AdamW with Adafactor
|
||||
optimizer = Adafactor(
|
||||
model.parameters(),
|
||||
lr=1e-3,
|
||||
eps=(1e-30, 1e-3),
|
||||
clip_threshold=1.0,
|
||||
decay_rate=-0.8,
|
||||
beta1=None,
|
||||
weight_decay=0.0,
|
||||
relative_step=False,
|
||||
scale_parameter=False,
|
||||
warmup_init=False,
|
||||
)
|
||||
```"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=None,
|
||||
eps=(1e-30, 1e-3),
|
||||
clip_threshold=1.0,
|
||||
decay_rate=-0.8,
|
||||
beta1=None,
|
||||
weight_decay=0.0,
|
||||
scale_parameter=True,
|
||||
relative_step=True,
|
||||
warmup_init=False,
|
||||
):
|
||||
if lr is not None and relative_step:
|
||||
raise ValueError(
|
||||
"Cannot combine manual `lr` and `relative_step=True` options")
|
||||
if warmup_init and not relative_step:
|
||||
raise ValueError(
|
||||
"`warmup_init=True` requires `relative_step=True`")
|
||||
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"eps": eps,
|
||||
"clip_threshold": clip_threshold,
|
||||
"decay_rate": decay_rate,
|
||||
"beta1": beta1,
|
||||
"weight_decay": weight_decay,
|
||||
"scale_parameter": scale_parameter,
|
||||
"relative_step": relative_step,
|
||||
"warmup_init": warmup_init,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
self.base_lrs: List[float] = [
|
||||
lr for group in self.param_groups
|
||||
]
|
||||
|
||||
self.is_stochastic_rounding_accumulation = False
|
||||
|
||||
# setup stochastic grad accum hooks
|
||||
for group in self.param_groups:
|
||||
for param in group['params']:
|
||||
if param.requires_grad and param.dtype != torch.float32:
|
||||
self.is_stochastic_rounding_accumulation = True
|
||||
param.register_post_accumulate_grad_hook(
|
||||
stochastic_grad_accummulation
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_lr(param_group, param_state):
|
||||
rel_step_sz = param_group["lr"]
|
||||
if param_group["relative_step"]:
|
||||
min_step = 1e-6 * \
|
||||
param_state["step"] if param_group["warmup_init"] else 1e-2
|
||||
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
|
||||
param_scale = 1.0
|
||||
if param_group["scale_parameter"]:
|
||||
param_scale = max(param_group["eps"][1], param_state["RMS"])
|
||||
return param_scale * rel_step_sz
|
||||
|
||||
@staticmethod
|
||||
def _get_options(param_group, param_shape):
|
||||
factored = len(param_shape) >= 2
|
||||
use_first_moment = param_group["beta1"] is not None
|
||||
return factored, use_first_moment
|
||||
|
||||
@staticmethod
|
||||
def _rms(tensor):
|
||||
return tensor.norm(2) / (tensor.numel() ** 0.5)
|
||||
|
||||
@staticmethod
|
||||
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
|
||||
# copy from fairseq's adafactor implementation:
|
||||
# https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
|
||||
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-
|
||||
1, keepdim=True)).rsqrt_().unsqueeze(-1)
|
||||
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
||||
return torch.mul(r_factor, c_factor)
|
||||
|
||||
def step_hook(self):
|
||||
if not self.is_stochastic_rounding_accumulation:
|
||||
return
|
||||
# copy over stochastically rounded grads
|
||||
for group in self.param_groups:
|
||||
for param in group['params']:
|
||||
if param.requires_grad and hasattr(param, "_accum_grad"):
|
||||
param.grad = param._accum_grad
|
||||
del param._accum_grad
|
||||
|
||||
# adafactor manages its own lr
|
||||
def get_learning_rates(self):
|
||||
lrs = [
|
||||
self._get_lr(group, self.state[group["params"][0]])
|
||||
for group in self.param_groups
|
||||
if group["params"][0].grad is not None
|
||||
]
|
||||
if len(lrs) == 0:
|
||||
lrs = self.base_lrs # if called before stepping
|
||||
return lrs
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""
|
||||
Performs a single optimization step
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
self.step_hook()
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
grad = p.grad
|
||||
if grad.dtype != torch.float32:
|
||||
grad = grad.to(torch.float32)
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"Adafactor does not support sparse gradients.")
|
||||
|
||||
state = self.state[p]
|
||||
grad_shape = grad.shape
|
||||
|
||||
factored, use_first_moment = self._get_options(
|
||||
group, grad_shape)
|
||||
# State Initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
|
||||
if use_first_moment:
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(grad)
|
||||
if factored:
|
||||
state["exp_avg_sq_row"] = torch.zeros(
|
||||
grad_shape[:-1]).to(grad)
|
||||
state["exp_avg_sq_col"] = torch.zeros(
|
||||
grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
||||
else:
|
||||
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||
|
||||
state["RMS"] = 0
|
||||
else:
|
||||
if use_first_moment:
|
||||
state["exp_avg"] = state["exp_avg"].to(grad)
|
||||
if factored:
|
||||
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(
|
||||
grad)
|
||||
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(
|
||||
grad)
|
||||
else:
|
||||
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
||||
|
||||
p_data_fp32 = p
|
||||
if p.dtype != torch.float32:
|
||||
p_data_fp32 = p_data_fp32.clone().float()
|
||||
|
||||
state["step"] += 1
|
||||
state["RMS"] = self._rms(p_data_fp32)
|
||||
lr = self._get_lr(group, state)
|
||||
|
||||
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
||||
eps = group["eps"]
|
||||
if isinstance(eps, tuple) or isinstance(eps, list):
|
||||
eps = eps[0]
|
||||
update = (grad**2) + eps
|
||||
if factored:
|
||||
exp_avg_sq_row = state["exp_avg_sq_row"]
|
||||
exp_avg_sq_col = state["exp_avg_sq_col"]
|
||||
|
||||
exp_avg_sq_row.mul_(beta2t).add_(
|
||||
update.mean(dim=-1), alpha=(1.0 - beta2t))
|
||||
exp_avg_sq_col.mul_(beta2t).add_(
|
||||
update.mean(dim=-2), alpha=(1.0 - beta2t))
|
||||
|
||||
# Approximation of exponential moving average of square of gradient
|
||||
update = self._approx_sq_grad(
|
||||
exp_avg_sq_row, exp_avg_sq_col)
|
||||
update.mul_(grad)
|
||||
else:
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
|
||||
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
|
||||
update = exp_avg_sq.rsqrt().mul_(grad)
|
||||
|
||||
update.div_(
|
||||
(self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
||||
update.mul_(lr)
|
||||
|
||||
if use_first_moment:
|
||||
exp_avg = state["exp_avg"]
|
||||
exp_avg.mul_(group["beta1"]).add_(
|
||||
update, alpha=(1 - group["beta1"]))
|
||||
update = exp_avg
|
||||
|
||||
if group["weight_decay"] != 0:
|
||||
p_data_fp32.add_(
|
||||
p_data_fp32, alpha=(-group["weight_decay"] * lr))
|
||||
|
||||
p_data_fp32.add_(-update)
|
||||
|
||||
if p.dtype != torch.float32:
|
||||
# apply stochastic rounding
|
||||
copy_stochastic(p, p_data_fp32)
|
||||
|
||||
return loss
|
||||
145
toolkit/optimizers/optimizer_utils.py
Normal file
145
toolkit/optimizers/optimizer_utils.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_format_params(dtype: torch.dtype) -> tuple[int, int]:
|
||||
"""
|
||||
Returns (mantissa_bits, total_bits) for each format.
|
||||
mantissa_bits excludes the implicit leading 1.
|
||||
"""
|
||||
if dtype == torch.float32:
|
||||
return 23, 32
|
||||
elif dtype == torch.bfloat16:
|
||||
return 7, 16
|
||||
elif dtype == torch.float16:
|
||||
return 10, 16
|
||||
elif dtype == torch.float8_e4m3fn:
|
||||
return 3, 8
|
||||
elif dtype == torch.float8_e5m2:
|
||||
return 2, 8
|
||||
elif dtype == torch.int8:
|
||||
return 0, 8 # Int8 doesn't have mantissa bits
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
def copy_stochastic(
|
||||
target: torch.Tensor,
|
||||
source: torch.Tensor,
|
||||
eps: Optional[float] = None
|
||||
) -> None:
|
||||
"""
|
||||
Performs stochastic rounding from source tensor to target tensor.
|
||||
|
||||
Args:
|
||||
target: Destination tensor (determines the target format)
|
||||
source: Source tensor (typically float32)
|
||||
eps: Optional minimum value for stochastic rounding (for numerical stability)
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# If target is float32, just copy directly
|
||||
if target.dtype == torch.float32:
|
||||
target.copy_(source)
|
||||
return
|
||||
|
||||
# Special handling for int8
|
||||
if target.dtype == torch.int8:
|
||||
# Scale the source values to utilize the full int8 range
|
||||
scaled = source * 127.0 # Scale to [-127, 127]
|
||||
|
||||
# Add random noise for stochastic rounding
|
||||
noise = torch.rand_like(scaled) - 0.5
|
||||
rounded = torch.round(scaled + noise)
|
||||
|
||||
# Clamp to int8 range
|
||||
clamped = torch.clamp(rounded, -127, 127)
|
||||
target.copy_(clamped.to(torch.int8))
|
||||
return
|
||||
|
||||
mantissa_bits, _ = get_format_params(target.dtype)
|
||||
|
||||
# Convert source to int32 view
|
||||
source_int = source.view(dtype=torch.int32)
|
||||
|
||||
# Calculate number of bits to round
|
||||
bits_to_round = 23 - mantissa_bits # 23 is float32 mantissa bits
|
||||
|
||||
# Create random integers for stochastic rounding
|
||||
rand = torch.randint_like(
|
||||
source,
|
||||
dtype=torch.int32,
|
||||
low=0,
|
||||
high=(1 << bits_to_round),
|
||||
)
|
||||
|
||||
# Add random values to the bits that will be rounded off
|
||||
result = source_int.clone()
|
||||
result.add_(rand)
|
||||
|
||||
# Mask to keep only the bits we want
|
||||
# Create mask with 1s in positions we want to keep
|
||||
mask = (-1) << bits_to_round
|
||||
result.bitwise_and_(mask)
|
||||
|
||||
# Handle minimum value threshold if specified
|
||||
if eps is not None:
|
||||
eps_int = torch.tensor(
|
||||
eps, dtype=torch.float32).view(dtype=torch.int32)
|
||||
zero_mask = (result.abs() < eps_int)
|
||||
result[zero_mask] = torch.sign(source_int[zero_mask]) * eps_int
|
||||
|
||||
# Convert back to float32 view
|
||||
result_float = result.view(dtype=torch.float32)
|
||||
|
||||
# Special handling for float8 formats
|
||||
if target.dtype == torch.float8_e4m3fn:
|
||||
result_float.clamp_(-448.0, 448.0)
|
||||
elif target.dtype == torch.float8_e5m2:
|
||||
result_float.clamp_(-57344.0, 57344.0)
|
||||
|
||||
target.copy_(result_float)
|
||||
del result, rand, source_int
|
||||
|
||||
|
||||
class Auto8bitTensor:
|
||||
def __init__(self, data: Tensor, *args, **kwargs):
|
||||
|
||||
abs_max = data.abs().max().item()
|
||||
scale = abs_max / 127.0 if abs_max > 0 else 1.0
|
||||
|
||||
self.quantized = (data / scale).round().clamp(-127, 127).to(torch.int8)
|
||||
self.scale = scale
|
||||
self.orig_dtype = data.dtype
|
||||
|
||||
def dequantize(self) -> Tensor:
|
||||
return self.quantized.to(dtype=torch.float32) * self.scale
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
# Handle the dtype argument whether it's positional or keyword
|
||||
dtype = None
|
||||
if args and isinstance(args[0], torch.dtype):
|
||||
dtype = args[0]
|
||||
args = args[1:]
|
||||
elif 'dtype' in kwargs:
|
||||
dtype = kwargs['dtype']
|
||||
del kwargs['dtype']
|
||||
|
||||
if dtype is not None:
|
||||
# First dequantize then convert to requested dtype
|
||||
return self.dequantize().to(dtype=dtype, *args, **kwargs)
|
||||
|
||||
# If no dtype specified, just pass through to parent
|
||||
return self.dequantize().to(*args, **kwargs)
|
||||
|
||||
|
||||
def stochastic_grad_accummulation(param):
|
||||
if hasattr(param, "_accum_grad"):
|
||||
grad_fp32 = param._accum_grad.clone().to(torch.float32)
|
||||
grad_fp32.add_(param.grad.to(torch.float32))
|
||||
copy_stochastic(param._accum_grad, grad_fp32)
|
||||
del grad_fp32
|
||||
del param.grad
|
||||
else:
|
||||
param._accum_grad = param.grad.clone()
|
||||
del param.grad
|
||||
@@ -1,150 +1,8 @@
|
||||
import math
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.distributed as dist
|
||||
from torch.optim import Optimizer
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_format_params(dtype: torch.dtype) -> tuple[int, int]:
|
||||
"""
|
||||
Returns (mantissa_bits, total_bits) for each format.
|
||||
mantissa_bits excludes the implicit leading 1.
|
||||
"""
|
||||
if dtype == torch.float32:
|
||||
return 23, 32
|
||||
elif dtype == torch.bfloat16:
|
||||
return 7, 16
|
||||
elif dtype == torch.float16:
|
||||
return 10, 16
|
||||
elif dtype == torch.float8_e4m3fn:
|
||||
return 3, 8
|
||||
elif dtype == torch.float8_e5m2:
|
||||
return 2, 8
|
||||
elif dtype == torch.int8:
|
||||
return 0, 8 # Int8 doesn't have mantissa bits
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
def copy_stochastic(
|
||||
target: torch.Tensor,
|
||||
source: torch.Tensor,
|
||||
eps: Optional[float] = None
|
||||
) -> None:
|
||||
"""
|
||||
Performs stochastic rounding from source tensor to target tensor.
|
||||
|
||||
Args:
|
||||
target: Destination tensor (determines the target format)
|
||||
source: Source tensor (typically float32)
|
||||
eps: Optional minimum value for stochastic rounding (for numerical stability)
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# If target is float32, just copy directly
|
||||
if target.dtype == torch.float32:
|
||||
target.copy_(source)
|
||||
return
|
||||
|
||||
# Special handling for int8
|
||||
if target.dtype == torch.int8:
|
||||
# Scale the source values to utilize the full int8 range
|
||||
scaled = source * 127.0 # Scale to [-127, 127]
|
||||
|
||||
# Add random noise for stochastic rounding
|
||||
noise = torch.rand_like(scaled) - 0.5
|
||||
rounded = torch.round(scaled + noise)
|
||||
|
||||
# Clamp to int8 range
|
||||
clamped = torch.clamp(rounded, -127, 127)
|
||||
target.copy_(clamped.to(torch.int8))
|
||||
return
|
||||
|
||||
mantissa_bits, _ = get_format_params(target.dtype)
|
||||
|
||||
# Convert source to int32 view
|
||||
source_int = source.view(dtype=torch.int32)
|
||||
|
||||
# Calculate number of bits to round
|
||||
bits_to_round = 23 - mantissa_bits # 23 is float32 mantissa bits
|
||||
|
||||
# Create random integers for stochastic rounding
|
||||
rand = torch.randint_like(
|
||||
source,
|
||||
dtype=torch.int32,
|
||||
low=0,
|
||||
high=(1 << bits_to_round),
|
||||
)
|
||||
|
||||
# Add random values to the bits that will be rounded off
|
||||
result = source_int.clone()
|
||||
result.add_(rand)
|
||||
|
||||
# Mask to keep only the bits we want
|
||||
# Create mask with 1s in positions we want to keep
|
||||
mask = (-1) << bits_to_round
|
||||
result.bitwise_and_(mask)
|
||||
|
||||
# Handle minimum value threshold if specified
|
||||
if eps is not None:
|
||||
eps_int = torch.tensor(
|
||||
eps, dtype=torch.float32).view(dtype=torch.int32)
|
||||
zero_mask = (result.abs() < eps_int)
|
||||
result[zero_mask] = torch.sign(source_int[zero_mask]) * eps_int
|
||||
|
||||
# Convert back to float32 view
|
||||
result_float = result.view(dtype=torch.float32)
|
||||
|
||||
# Special handling for float8 formats
|
||||
if target.dtype == torch.float8_e4m3fn:
|
||||
result_float.clamp_(-448.0, 448.0)
|
||||
elif target.dtype == torch.float8_e5m2:
|
||||
result_float.clamp_(-57344.0, 57344.0)
|
||||
|
||||
target.copy_(result_float)
|
||||
|
||||
|
||||
class Auto8bitTensor:
|
||||
def __init__(self, data: Tensor, *args, **kwargs):
|
||||
|
||||
abs_max = data.abs().max().item()
|
||||
scale = abs_max / 127.0 if abs_max > 0 else 1.0
|
||||
|
||||
self.quantized = (data / scale).round().clamp(-127, 127).to(torch.int8)
|
||||
self.scale = scale
|
||||
self.orig_dtype = data.dtype
|
||||
|
||||
def dequantize(self) -> Tensor:
|
||||
return self.quantized.to(dtype=torch.float32) * self.scale
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
# Handle the dtype argument whether it's positional or keyword
|
||||
dtype = None
|
||||
if args and isinstance(args[0], torch.dtype):
|
||||
dtype = args[0]
|
||||
args = args[1:]
|
||||
elif 'dtype' in kwargs:
|
||||
dtype = kwargs['dtype']
|
||||
del kwargs['dtype']
|
||||
|
||||
if dtype is not None:
|
||||
# First dequantize then convert to requested dtype
|
||||
return self.dequantize().to(dtype=dtype, *args, **kwargs)
|
||||
|
||||
# If no dtype specified, just pass through to parent
|
||||
return self.dequantize().to(*args, **kwargs)
|
||||
|
||||
|
||||
def stochastic_grad_accummulation(param):
|
||||
if hasattr(param, "_accum_grad"):
|
||||
grad_fp32 = param._accum_grad.clone().to(torch.float32)
|
||||
grad_fp32.add_(param.grad.to(torch.float32))
|
||||
copy_stochastic(param._accum_grad, grad_fp32)
|
||||
del grad_fp32
|
||||
del param.grad
|
||||
else:
|
||||
param._accum_grad = param.grad.clone()
|
||||
del param.grad
|
||||
from toolkit.optimizers.optimizer_utils import copy_stochastic, Auto8bitTensor, stochastic_grad_accummulation
|
||||
|
||||
|
||||
class Prodigy8bit(Optimizer):
|
||||
@@ -222,7 +80,7 @@ class Prodigy8bit(Optimizer):
|
||||
fsdp_in_use=fsdp_in_use)
|
||||
self.d0 = d0
|
||||
super(Prodigy8bit, self).__init__(params, defaults)
|
||||
|
||||
|
||||
self.is_stochastic_rounding_accumulation = False
|
||||
|
||||
# setup stochastic grad accum hooks
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
# ref https://github.com/Nerogar/OneTrainer/compare/master...stochastic_rounding
|
||||
import math
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def copy_stochastic_(target: Tensor, source: Tensor):
|
||||
# create a random 16 bit integer
|
||||
result = torch.randint_like(
|
||||
source,
|
||||
dtype=torch.int32,
|
||||
low=0,
|
||||
high=(1 << 16),
|
||||
)
|
||||
|
||||
# add the random number to the lower 16 bit of the mantissa
|
||||
result.add_(source.view(dtype=torch.int32))
|
||||
|
||||
# mask off the lower 16 bit of the mantissa
|
||||
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
||||
|
||||
# copy the higher 16 bit into the target tensor
|
||||
target.copy_(result.view(dtype=torch.float32))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def step_adafactor(self, closure=None):
|
||||
"""
|
||||
Performs a single optimization step
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
if grad.dtype in {torch.float16, torch.bfloat16}:
|
||||
grad = grad.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("Adafactor does not support sparse gradients.")
|
||||
|
||||
state = self.state[p]
|
||||
grad_shape = grad.shape
|
||||
|
||||
factored, use_first_moment = self._get_options(group, grad_shape)
|
||||
# State Initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
|
||||
if use_first_moment:
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(grad)
|
||||
if factored:
|
||||
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
|
||||
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
||||
else:
|
||||
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||
|
||||
state["RMS"] = 0
|
||||
else:
|
||||
if use_first_moment:
|
||||
state["exp_avg"] = state["exp_avg"].to(grad)
|
||||
if factored:
|
||||
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
|
||||
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
|
||||
else:
|
||||
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
||||
|
||||
p_data_fp32 = p
|
||||
if p.dtype in {torch.float16, torch.bfloat16}:
|
||||
p_data_fp32 = p_data_fp32.float()
|
||||
|
||||
state["step"] += 1
|
||||
state["RMS"] = self._rms(p_data_fp32)
|
||||
lr = self._get_lr(group, state)
|
||||
|
||||
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
||||
eps = group["eps"][0] if isinstance(group["eps"], list) else group["eps"]
|
||||
update = (grad ** 2) + eps
|
||||
if factored:
|
||||
exp_avg_sq_row = state["exp_avg_sq_row"]
|
||||
exp_avg_sq_col = state["exp_avg_sq_col"]
|
||||
|
||||
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
|
||||
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
|
||||
|
||||
# Approximation of exponential moving average of square of gradient
|
||||
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
||||
update.mul_(grad)
|
||||
else:
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
|
||||
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
|
||||
update = exp_avg_sq.rsqrt().mul_(grad)
|
||||
|
||||
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
||||
update.mul_(lr)
|
||||
|
||||
if use_first_moment:
|
||||
exp_avg = state["exp_avg"]
|
||||
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
|
||||
update = exp_avg
|
||||
|
||||
if group["weight_decay"] != 0:
|
||||
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
|
||||
|
||||
p_data_fp32.add_(-update)
|
||||
|
||||
if p.dtype == torch.bfloat16:
|
||||
copy_stochastic_(p, p_data_fp32)
|
||||
elif p.dtype == torch.float16:
|
||||
p.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
||||
Reference in New Issue
Block a user