diff --git a/toolkit/optimizer.py b/toolkit/optimizer.py index e900d004..d1a258ff 100644 --- a/toolkit/optimizer.py +++ b/toolkit/optimizer.py @@ -95,6 +95,9 @@ def get_optimizer( if 'warmup_init' not in optimizer_params: optimizer_params['warmup_init'] = False optimizer = Adafactor(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) + elif lower_type == 'automagic': + from toolkit.optimizers.automagic import Automagic + optimizer = Automagic(params, lr=float(learning_rate), **optimizer_params) else: raise ValueError(f'Unknown optimizer type {optimizer_type}') return optimizer diff --git a/toolkit/optimizers/automagic.py b/toolkit/optimizers/automagic.py new file mode 100644 index 00000000..c7eac6f5 --- /dev/null +++ b/toolkit/optimizers/automagic.py @@ -0,0 +1,296 @@ +import math +from typing import List +import torch +from toolkit.optimizers.optimizer_utils import Auto8bitTensor, copy_stochastic, stochastic_grad_accummulation +from optimum.quanto import QBytesTensor +import random + + +class Automagic(torch.optim.Optimizer): + def __init__( + self, + params, + lr=None, + min_lr=1e-7, + max_lr=1e-2, + lr_momentum=0.9, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + weight_decay=0.0, + do_paramiter_swapping=False, + paramiter_swapping_factor=0.1, + ): + self.lr = lr + self.min_lr = min_lr + self.max_lr = max_lr + self.lr_momentum = lr_momentum + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "decay_rate": decay_rate, + "weight_decay": weight_decay, + } + super().__init__(params, defaults) + + self.base_lrs: List[float] = [ + lr for group in self.param_groups + ] + + self.is_stochastic_rounding_accumulation = False + + # setup stochastic grad accum hooks + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and param.dtype != torch.float32: + self.is_stochastic_rounding_accumulation = True + param.register_post_accumulate_grad_hook( + stochastic_grad_accummulation + ) + + self.do_paramiter_swapping = do_paramiter_swapping + self.paramiter_swapping_factor = paramiter_swapping_factor + self._total_paramiter_size = 0 + # count total paramiters + for group in self.param_groups: + for param in group['params']: + self._total_paramiter_size += torch.numel(param) + # pretty print total paramiters with comma seperation + print(f"Total training paramiters: {self._total_paramiter_size:,}") + + # needs to be enabled to count paramiters + if self.do_paramiter_swapping: + self.enable_paramiter_swapping(self.paramiter_swapping_factor) + + def enable_paramiter_swapping(self, paramiter_swapping_factor=0.1): + self.do_paramiter_swapping = True + self.paramiter_swapping_factor = paramiter_swapping_factor + # call it an initial time + self.swap_paramiters() + + def swap_paramiters(self): + all_params = [] + # deactivate all paramiters + for group in self.param_groups: + for param in group['params']: + param.requires_grad_(False) + # remove any grad + param.grad = None + all_params.append(param) + # shuffle all paramiters + random.shuffle(all_params) + + # keep activating paramiters until we are going to go over the target paramiters + target_paramiters = int( + self._total_paramiter_size * self.paramiter_swapping_factor) + total_paramiters = 0 + for param in all_params: + total_paramiters += torch.numel(param) + if total_paramiters >= target_paramiters: + break + else: + param.requires_grad_(True) + + @staticmethod + def _get_lr(param_group, param_state): + lr = param_group["avg_lr"] + param_scale = 1.0 + return param_scale * lr + + def _get_group_lr(self, group): + group_lrs = [] + for p in group["params"]: + if p.grad is not None: + group_lrs.append(self._get_lr(group, self.state[p])) + # return avg + if len(group_lrs) == 0: + return self.lr + return sum(group_lrs) / len(group_lrs) + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + # copy from fairseq's adafactor implementation: + # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=- + 1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def step_hook(self): + if not self.is_stochastic_rounding_accumulation: + return + # copy over stochastically rounded grads + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and hasattr(param, "_accum_grad"): + param.grad = param._accum_grad + del param._accum_grad + + # adafactor manages its own lr + def get_learning_rates(self): + + lrs = [ + self._get_group_lr(group) + for group in self.param_groups + ] + if len(lrs) == 0: + lrs = self.base_lrs # if called before stepping + return lrs + + def get_avg_learning_rate(self): + lrs = self.get_learning_rates() + return sum(lrs) / len(lrs) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self.step_hook() + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None or not p.requires_grad: + continue + + grad = p.grad + if grad.dtype != torch.float32: + grad = grad.to(torch.float32) + if grad.is_sparse: + raise RuntimeError( + "Automagic does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored = len(grad_shape) >= 2 + # State Initialization + if len(state) == 0: + state["step"] = 0 + + # store the lr mask + state['lr_mask'] = Auto8bitTensor(torch.ones( + p.shape).to(p.device, dtype=torch.float32) * self.lr + ) + state['avg_lr'] = torch.mean( + state['lr_mask'].to(torch.float32)) + state['last_polarity'] = torch.zeros( + p.shape, dtype=torch.bool, device=p.device) + + if factored: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to( + grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to( + grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + + if isinstance(p_data_fp32, QBytesTensor): + p_data_fp32 = p_data_fp32.dequantize() + if p.dtype != torch.float32: + p_data_fp32 = p_data_fp32.clone().float() + + state["step"] += 1 + state["RMS"] = self._rms(p_data_fp32) + # lr = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + eps = group["eps"] + if isinstance(eps, tuple) or isinstance(eps, list): + eps = eps[0] + update = (grad**2) + eps + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_( + update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_( + update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad( + exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_( + (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + + # calculate new lr mask. if the updated param is going in same direction, increase lr, else decrease + # update the lr mask. self.lr_momentum is < 1.0. If a paramiter is positive and increasing (or negative and decreasing), increase lr, + # for that single paramiter. If a paramiter is negative and increasing or positive and decreasing, decrease lr for that single paramiter. + # to decrease lr, multiple by self.lr_momentum, to increase lr, divide by self.lr_momentum. + + # not doing it this way anymore + # update.mul_(lr) + + # Get signs of current last update and updates + last_polarity = state['last_polarity'] + current_polarity = (update > 0).to(torch.bool) + sign_agreement = torch.where( + last_polarity == current_polarity, 1, -1) + state['last_polarity'] = current_polarity + + lr_mask = state['lr_mask'].to(torch.float32) + + # Update learning rate mask based on sign agreement + new_lr = torch.where( + sign_agreement > 0, + lr_mask / self.lr_momentum, # Increase lr + lr_mask * self.lr_momentum # Decrease lr + ) + + # Clip learning rates to bounds + new_lr = torch.clamp( + new_lr, + min=self.min_lr, + max=self.max_lr + ) + + # Apply the learning rate mask to the update + update.mul_(new_lr) + + state['lr_mask'] = Auto8bitTensor(new_lr) + state['avg_lr'] = torch.mean(new_lr) + + if group["weight_decay"] != 0: + p_data_fp32.add_( + p_data_fp32, alpha=(-group["weight_decay"] * new_lr)) + + p_data_fp32.add_(-update) + + if p.dtype != torch.float32: + # apply stochastic rounding + copy_stochastic(p, p_data_fp32) + + return loss