mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-12 00:42:07 +00:00
223 lines
8.4 KiB
Python
223 lines
8.4 KiB
Python
from typing import List
|
|
import torch
|
|
|
|
|
|
class Automagic2(torch.optim.Optimizer):
|
|
"""
|
|
Automagic v2.
|
|
|
|
A single scalar learning rate is kept per parameter (e.g. one lr for the
|
|
full weight matrix of a Linear layer rather than one per element). The lr
|
|
is nudged up when the per-element update direction stays consistent with
|
|
the previous step and nudged down when it flips, clamped to [min_lr, max_lr].
|
|
|
|
The optimizer step is fused into the backward pass via
|
|
``register_post_accumulate_grad_hook``: each parameter is updated and its
|
|
grad freed as soon as autograd finishes accumulating into it. ``.step()``
|
|
therefore does no real work and peak VRAM stays low.
|
|
|
|
Second-moment EMA state is stored in ``p.dtype`` (math runs in fp32 when
|
|
the state is lower precision). Stochastic rounding is applied only when
|
|
writing back to a bf16 parameter.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
params,
|
|
lr: float = 1e-6,
|
|
min_lr: float = 1e-7,
|
|
max_lr: float = 1e-3,
|
|
lr_bump: float = 1e-6,
|
|
beta2: float = 0.999,
|
|
eps: float = 1e-30,
|
|
clip_threshold: float = 1.0,
|
|
weight_decay: float = 0.0,
|
|
agreement_threshold: float = 0.6,
|
|
):
|
|
if lr > 1e-3:
|
|
print(f"Warning! Start lr {lr} is very high; forcing to 1e-6.")
|
|
lr = 1e-6
|
|
defaults = dict(
|
|
lr=lr,
|
|
min_lr=min_lr,
|
|
max_lr=max_lr,
|
|
lr_bump=lr_bump,
|
|
beta2=beta2,
|
|
eps=eps,
|
|
clip_threshold=clip_threshold,
|
|
weight_decay=weight_decay,
|
|
agreement_threshold=agreement_threshold,
|
|
)
|
|
super().__init__(params, defaults)
|
|
|
|
self._hook_handles = []
|
|
for group in self.param_groups:
|
|
for p in group["params"]:
|
|
if p.requires_grad:
|
|
handle = p.register_post_accumulate_grad_hook(
|
|
self._make_backward_hook(group)
|
|
)
|
|
self._hook_handles.append(handle)
|
|
|
|
total = sum(p.numel() for g in self.param_groups for p in g["params"])
|
|
print(f"Total training paramiters: {total:,}")
|
|
|
|
# ------------------------------------------------------------------ utils
|
|
|
|
@staticmethod
|
|
def _rms(t: torch.Tensor) -> torch.Tensor:
|
|
return t.norm(2) / (t.numel() ** 0.5)
|
|
|
|
@staticmethod
|
|
def _approx_sq_grad(row: torch.Tensor, col: torch.Tensor) -> torch.Tensor:
|
|
r = (row / row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
|
|
c = col.unsqueeze(-2).rsqrt()
|
|
return torch.mul(r, c)
|
|
|
|
def _init_state(self, p: torch.Tensor, group: dict) -> None:
|
|
state = self.state[p]
|
|
state["step"] = 0
|
|
state["lr"] = torch.full(
|
|
(), float(group["lr"]), dtype=torch.float32, device=p.device
|
|
)
|
|
state["last_polarity"] = torch.zeros(p.shape, dtype=torch.bool, device=p.device)
|
|
if p.dim() >= 2:
|
|
state["exp_avg_sq_row"] = torch.zeros(
|
|
p.shape[:-1], dtype=p.dtype, device=p.device
|
|
)
|
|
state["exp_avg_sq_col"] = torch.zeros(
|
|
p.shape[:-2] + p.shape[-1:], dtype=p.dtype, device=p.device
|
|
)
|
|
else:
|
|
state["exp_avg_sq"] = torch.zeros(p.shape, dtype=p.dtype, device=p.device)
|
|
|
|
def _make_backward_hook(self, group):
|
|
def _hook(p: torch.Tensor):
|
|
self._update_param(p, group)
|
|
|
|
return _hook
|
|
|
|
# -------------------------------------------------------------- per-param
|
|
|
|
@torch.no_grad()
|
|
def _update_param(self, p: torch.Tensor, group: dict) -> None:
|
|
if p.grad is None:
|
|
return
|
|
state = self.state[p]
|
|
if len(state) == 0:
|
|
self._init_state(p, group)
|
|
|
|
grad = p.grad
|
|
if grad.is_sparse:
|
|
raise RuntimeError("Automagic2 does not support sparse gradients.")
|
|
if grad.dtype != torch.float32:
|
|
grad = grad.to(torch.float32)
|
|
|
|
beta2 = group["beta2"]
|
|
eps = group["eps"]
|
|
sq = (grad * grad).add_(eps)
|
|
|
|
if p.dim() >= 2:
|
|
row_state = state["exp_avg_sq_row"]
|
|
col_state = state["exp_avg_sq_col"]
|
|
if row_state.dtype == torch.float32:
|
|
row, col = row_state, col_state
|
|
row.mul_(beta2).add_(sq.mean(dim=-1), alpha=1.0 - beta2)
|
|
col.mul_(beta2).add_(sq.mean(dim=-2), alpha=1.0 - beta2)
|
|
else:
|
|
row = row_state.to(torch.float32)
|
|
col = col_state.to(torch.float32)
|
|
row.mul_(beta2).add_(sq.mean(dim=-1), alpha=1.0 - beta2)
|
|
col.mul_(beta2).add_(sq.mean(dim=-2), alpha=1.0 - beta2)
|
|
row_state.copy_(row.to(row_state.dtype))
|
|
col_state.copy_(col.to(col_state.dtype))
|
|
update = self._approx_sq_grad(row, col).mul_(grad)
|
|
else:
|
|
v_state = state["exp_avg_sq"]
|
|
if v_state.dtype == torch.float32:
|
|
v = v_state
|
|
v.mul_(beta2).add_(sq, alpha=1.0 - beta2)
|
|
else:
|
|
v = v_state.to(torch.float32)
|
|
v.mul_(beta2).add_(sq, alpha=1.0 - beta2)
|
|
v_state.copy_(v.to(v_state.dtype))
|
|
update = v.rsqrt().mul_(grad)
|
|
|
|
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
|
|
|
# Per-element sign agreement collapsed to a single bump decision.
|
|
# Kept on-device as a 0-D tensor to avoid a CPU<->GPU sync in the hot path.
|
|
cur_polarity = update > 0
|
|
last_polarity = state["last_polarity"]
|
|
agreement = (cur_polarity == last_polarity).to(torch.float32).mean()
|
|
state["last_polarity"] = cur_polarity
|
|
|
|
lr_t = state["lr"]
|
|
if state["step"] > 0:
|
|
direction = (agreement >= group["agreement_threshold"]).to(lr_t.dtype) * 2.0 - 1.0
|
|
lr_t.add_(direction, alpha=group["lr_bump"]).clamp_(
|
|
min=group["min_lr"], max=group["max_lr"]
|
|
)
|
|
state["step"] += 1
|
|
|
|
update.mul_(lr_t)
|
|
wd = group["weight_decay"]
|
|
|
|
if p.dtype == torch.bfloat16:
|
|
# Single bf16 -> fp32 conversion shared by weight decay and SR.
|
|
new_p_fp32 = p.to(torch.float32)
|
|
if wd != 0.0:
|
|
update.addcmul_(new_p_fp32, lr_t, value=wd)
|
|
new_p_fp32.sub_(update)
|
|
# Stochastic rounding fp32 -> bf16: add random noise into the lower
|
|
# 16 mantissa bits, then truncate. Done in place on new_p_fp32 so
|
|
# we don't allocate a separate int32 work buffer.
|
|
as_int = new_p_fp32.view(torch.int32)
|
|
as_int.add_(torch.randint_like(as_int, 1 << 16)).bitwise_and_(-65536)
|
|
p.copy_(new_p_fp32)
|
|
else:
|
|
if wd != 0.0:
|
|
p_fp32 = p if p.dtype == torch.float32 else p.to(torch.float32)
|
|
update.addcmul_(p_fp32, lr_t, value=wd)
|
|
p.add_(update.to(p.dtype), alpha=-1.0)
|
|
|
|
p.grad = None
|
|
|
|
# ----------------------------------------------------------- optimizer API
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure=None):
|
|
loss = None
|
|
if closure is not None:
|
|
with torch.enable_grad():
|
|
loss = closure()
|
|
return loss
|
|
|
|
def get_learning_rates(self) -> List[float]:
|
|
out = []
|
|
for group in self.param_groups:
|
|
lrs = [
|
|
float(self.state[p]["lr"])
|
|
for p in group["params"]
|
|
if p in self.state and "lr" in self.state[p]
|
|
]
|
|
out.append(sum(lrs) / len(lrs) if lrs else float(group["lr"]))
|
|
return out
|
|
|
|
def get_avg_learning_rate(self) -> float:
|
|
lrs = self.get_learning_rates()
|
|
return sum(lrs) / len(lrs) if lrs else float(self.defaults["lr"])
|
|
|
|
def load_state_dict(self, state_dict):
|
|
# Parent casts every fp state tensor to param.dtype; force lr back to fp32
|
|
# so subsequent lr_bump (default 1e-6) isn't rounded away on bf16 weights.
|
|
super().load_state_dict(state_dict)
|
|
# Constructor args always win over whatever was saved in the checkpoint.
|
|
for group in self.param_groups:
|
|
for k, v in self.defaults.items():
|
|
group[k] = v
|
|
for p in group["params"]:
|
|
st = self.state.get(p)
|
|
if st is not None and isinstance(st.get("lr"), torch.Tensor):
|
|
st["lr"] = st["lr"].to(torch.float32)
|