diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index d9a0fb8c..fafeff12 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1750,7 +1750,10 @@ class BaseSDTrainProcess(BaseTrainProcess): with torch.no_grad(): # torch.cuda.empty_cache() - if self.train_config.optimizer.lower().startswith('dadaptation') or \ + # if optimizer has get_lrs method, then use it + if hasattr(optimizer, 'get_learning_rates'): + learning_rate = optimizer.get_learning_rates()[0] + elif self.train_config.optimizer.lower().startswith('dadaptation') or \ self.train_config.optimizer.lower().startswith('prodigy'): learning_rate = ( optimizer.param_groups[0]["d"] * diff --git a/toolkit/optimizer.py b/toolkit/optimizer.py index 12a630ac..473e333d 100644 --- a/toolkit/optimizer.py +++ b/toolkit/optimizer.py @@ -77,9 +77,9 @@ def get_optimizer( except ImportError: raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch") elif lower_type == 'adagrad': - optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) + optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params) elif lower_type == 'adafactor': - # hack in stochastic rounding + from toolkit.optimizers.adafactor import Adafactor if 'relative_step' not in optimizer_params: optimizer_params['relative_step'] = False if 'scale_parameter' not in optimizer_params: @@ -87,8 +87,6 @@ def get_optimizer( if 'warmup_init' not in optimizer_params: optimizer_params['warmup_init'] = False optimizer = Adafactor(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) - from toolkit.util.adafactor_stochastic_rounding import step_adafactor - optimizer.step = step_adafactor.__get__(optimizer, Adafactor) else: raise ValueError(f'Unknown optimizer type {optimizer_type}') return optimizer diff --git a/toolkit/optimizers/adafactor.py b/toolkit/optimizers/adafactor.py new file mode 100644 index 00000000..3f4a738c --- /dev/null +++ b/toolkit/optimizers/adafactor.py @@ -0,0 +1,305 @@ +import math +from typing import List +import torch +from toolkit.optimizers.optimizer_utils import copy_stochastic, stochastic_grad_accummulation + + +class Adafactor(torch.optim.Optimizer): + """ + Adafactor implementation with stochastic rounding accumulation and stochastic rounding on apply. + Modified from transformers Adafactor implementation to support stochastic rounding accumulation and apply. + + AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code: + https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py + + Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that + this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and + `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and + `relative_step=False`. + + Arguments: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*): + The external learning rate. + eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`): + Regularization constants for square gradient and parameter scale respectively + clip_threshold (`float`, *optional*, defaults to 1.0): + Threshold of root mean square of final gradient update + decay_rate (`float`, *optional*, defaults to -0.8): + Coefficient used to compute running averages of square + beta1 (`float`, *optional*): + Coefficient used for computing running averages of gradient + weight_decay (`float`, *optional*, defaults to 0.0): + Weight decay (L2 penalty) + scale_parameter (`bool`, *optional*, defaults to `True`): + If True, learning rate is scaled by root mean square + relative_step (`bool`, *optional*, defaults to `True`): + If True, time-dependent learning rate is computed instead of external learning rate + warmup_init (`bool`, *optional*, defaults to `False`): + Time-dependent learning rate computation depends on whether warm-up initialization is being used + + This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. + + Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3): + + - Training without LR warmup or clip_threshold is not recommended. + + - use scheduled LR warm-up to fixed LR + - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235) + - Disable relative updates + - Use scale_parameter=False + - Additional optimizer operations like gradient clipping should not be used alongside Adafactor + + Example: + + ```python + Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3) + ``` + + Others reported the following combination to work well: + + ```python + Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + ``` + + When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`] + scheduler as following: + + ```python + from transformers.optimization import Adafactor, AdafactorSchedule + + optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + lr_scheduler = AdafactorSchedule(optimizer) + trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) + ``` + + Usage: + + ```python + # replace AdamW with Adafactor + optimizer = Adafactor( + model.parameters(), + lr=1e-3, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + relative_step=False, + scale_parameter=False, + warmup_init=False, + ) + ```""" + + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + if lr is not None and relative_step: + raise ValueError( + "Cannot combine manual `lr` and `relative_step=True` options") + if warmup_init and not relative_step: + raise ValueError( + "`warmup_init=True` requires `relative_step=True`") + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "decay_rate": decay_rate, + "beta1": beta1, + "weight_decay": weight_decay, + "scale_parameter": scale_parameter, + "relative_step": relative_step, + "warmup_init": warmup_init, + } + super().__init__(params, defaults) + + self.base_lrs: List[float] = [ + lr for group in self.param_groups + ] + + self.is_stochastic_rounding_accumulation = False + + # setup stochastic grad accum hooks + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and param.dtype != torch.float32: + self.is_stochastic_rounding_accumulation = True + param.register_post_accumulate_grad_hook( + stochastic_grad_accummulation + ) + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * \ + param_state["step"] if param_group["warmup_init"] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + # copy from fairseq's adafactor implementation: + # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=- + 1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def step_hook(self): + if not self.is_stochastic_rounding_accumulation: + return + # copy over stochastically rounded grads + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and hasattr(param, "_accum_grad"): + param.grad = param._accum_grad + del param._accum_grad + + # adafactor manages its own lr + def get_learning_rates(self): + lrs = [ + self._get_lr(group, self.state[group["params"][0]]) + for group in self.param_groups + if group["params"][0].grad is not None + ] + if len(lrs) == 0: + lrs = self.base_lrs # if called before stepping + return lrs + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self.step_hook() + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + grad = p.grad + if grad.dtype != torch.float32: + grad = grad.to(torch.float32) + if grad.is_sparse: + raise RuntimeError( + "Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options( + group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros( + grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to( + grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to( + grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + if p.dtype != torch.float32: + p_data_fp32 = p_data_fp32.clone().float() + + state["step"] += 1 + state["RMS"] = self._rms(p_data_fp32) + lr = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + eps = group["eps"] + if isinstance(eps, tuple) or isinstance(eps, list): + eps = eps[0] + update = (grad**2) + eps + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_( + update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_( + update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad( + exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_( + (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_( + update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_( + p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update) + + if p.dtype != torch.float32: + # apply stochastic rounding + copy_stochastic(p, p_data_fp32) + + return loss diff --git a/toolkit/optimizers/optimizer_utils.py b/toolkit/optimizers/optimizer_utils.py new file mode 100644 index 00000000..f895244c --- /dev/null +++ b/toolkit/optimizers/optimizer_utils.py @@ -0,0 +1,145 @@ +import torch +from torch import Tensor +from typing import Optional + + +def get_format_params(dtype: torch.dtype) -> tuple[int, int]: + """ + Returns (mantissa_bits, total_bits) for each format. + mantissa_bits excludes the implicit leading 1. + """ + if dtype == torch.float32: + return 23, 32 + elif dtype == torch.bfloat16: + return 7, 16 + elif dtype == torch.float16: + return 10, 16 + elif dtype == torch.float8_e4m3fn: + return 3, 8 + elif dtype == torch.float8_e5m2: + return 2, 8 + elif dtype == torch.int8: + return 0, 8 # Int8 doesn't have mantissa bits + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + +def copy_stochastic( + target: torch.Tensor, + source: torch.Tensor, + eps: Optional[float] = None +) -> None: + """ + Performs stochastic rounding from source tensor to target tensor. + + Args: + target: Destination tensor (determines the target format) + source: Source tensor (typically float32) + eps: Optional minimum value for stochastic rounding (for numerical stability) + """ + with torch.no_grad(): + # If target is float32, just copy directly + if target.dtype == torch.float32: + target.copy_(source) + return + + # Special handling for int8 + if target.dtype == torch.int8: + # Scale the source values to utilize the full int8 range + scaled = source * 127.0 # Scale to [-127, 127] + + # Add random noise for stochastic rounding + noise = torch.rand_like(scaled) - 0.5 + rounded = torch.round(scaled + noise) + + # Clamp to int8 range + clamped = torch.clamp(rounded, -127, 127) + target.copy_(clamped.to(torch.int8)) + return + + mantissa_bits, _ = get_format_params(target.dtype) + + # Convert source to int32 view + source_int = source.view(dtype=torch.int32) + + # Calculate number of bits to round + bits_to_round = 23 - mantissa_bits # 23 is float32 mantissa bits + + # Create random integers for stochastic rounding + rand = torch.randint_like( + source, + dtype=torch.int32, + low=0, + high=(1 << bits_to_round), + ) + + # Add random values to the bits that will be rounded off + result = source_int.clone() + result.add_(rand) + + # Mask to keep only the bits we want + # Create mask with 1s in positions we want to keep + mask = (-1) << bits_to_round + result.bitwise_and_(mask) + + # Handle minimum value threshold if specified + if eps is not None: + eps_int = torch.tensor( + eps, dtype=torch.float32).view(dtype=torch.int32) + zero_mask = (result.abs() < eps_int) + result[zero_mask] = torch.sign(source_int[zero_mask]) * eps_int + + # Convert back to float32 view + result_float = result.view(dtype=torch.float32) + + # Special handling for float8 formats + if target.dtype == torch.float8_e4m3fn: + result_float.clamp_(-448.0, 448.0) + elif target.dtype == torch.float8_e5m2: + result_float.clamp_(-57344.0, 57344.0) + + target.copy_(result_float) + del result, rand, source_int + + +class Auto8bitTensor: + def __init__(self, data: Tensor, *args, **kwargs): + + abs_max = data.abs().max().item() + scale = abs_max / 127.0 if abs_max > 0 else 1.0 + + self.quantized = (data / scale).round().clamp(-127, 127).to(torch.int8) + self.scale = scale + self.orig_dtype = data.dtype + + def dequantize(self) -> Tensor: + return self.quantized.to(dtype=torch.float32) * self.scale + + def to(self, *args, **kwargs): + # Handle the dtype argument whether it's positional or keyword + dtype = None + if args and isinstance(args[0], torch.dtype): + dtype = args[0] + args = args[1:] + elif 'dtype' in kwargs: + dtype = kwargs['dtype'] + del kwargs['dtype'] + + if dtype is not None: + # First dequantize then convert to requested dtype + return self.dequantize().to(dtype=dtype, *args, **kwargs) + + # If no dtype specified, just pass through to parent + return self.dequantize().to(*args, **kwargs) + + +def stochastic_grad_accummulation(param): + if hasattr(param, "_accum_grad"): + grad_fp32 = param._accum_grad.clone().to(torch.float32) + grad_fp32.add_(param.grad.to(torch.float32)) + copy_stochastic(param._accum_grad, grad_fp32) + del grad_fp32 + del param.grad + else: + param._accum_grad = param.grad.clone() + del param.grad diff --git a/toolkit/optimizers/prodigy_8bit.py b/toolkit/optimizers/prodigy_8bit.py index 55f39423..ee7f0914 100644 --- a/toolkit/optimizers/prodigy_8bit.py +++ b/toolkit/optimizers/prodigy_8bit.py @@ -1,150 +1,8 @@ import math import torch -from torch import Tensor import torch.distributed as dist from torch.optim import Optimizer -from typing import Optional - - -def get_format_params(dtype: torch.dtype) -> tuple[int, int]: - """ - Returns (mantissa_bits, total_bits) for each format. - mantissa_bits excludes the implicit leading 1. - """ - if dtype == torch.float32: - return 23, 32 - elif dtype == torch.bfloat16: - return 7, 16 - elif dtype == torch.float16: - return 10, 16 - elif dtype == torch.float8_e4m3fn: - return 3, 8 - elif dtype == torch.float8_e5m2: - return 2, 8 - elif dtype == torch.int8: - return 0, 8 # Int8 doesn't have mantissa bits - else: - raise ValueError(f"Unsupported dtype: {dtype}") - - -def copy_stochastic( - target: torch.Tensor, - source: torch.Tensor, - eps: Optional[float] = None -) -> None: - """ - Performs stochastic rounding from source tensor to target tensor. - - Args: - target: Destination tensor (determines the target format) - source: Source tensor (typically float32) - eps: Optional minimum value for stochastic rounding (for numerical stability) - """ - with torch.no_grad(): - # If target is float32, just copy directly - if target.dtype == torch.float32: - target.copy_(source) - return - - # Special handling for int8 - if target.dtype == torch.int8: - # Scale the source values to utilize the full int8 range - scaled = source * 127.0 # Scale to [-127, 127] - - # Add random noise for stochastic rounding - noise = torch.rand_like(scaled) - 0.5 - rounded = torch.round(scaled + noise) - - # Clamp to int8 range - clamped = torch.clamp(rounded, -127, 127) - target.copy_(clamped.to(torch.int8)) - return - - mantissa_bits, _ = get_format_params(target.dtype) - - # Convert source to int32 view - source_int = source.view(dtype=torch.int32) - - # Calculate number of bits to round - bits_to_round = 23 - mantissa_bits # 23 is float32 mantissa bits - - # Create random integers for stochastic rounding - rand = torch.randint_like( - source, - dtype=torch.int32, - low=0, - high=(1 << bits_to_round), - ) - - # Add random values to the bits that will be rounded off - result = source_int.clone() - result.add_(rand) - - # Mask to keep only the bits we want - # Create mask with 1s in positions we want to keep - mask = (-1) << bits_to_round - result.bitwise_and_(mask) - - # Handle minimum value threshold if specified - if eps is not None: - eps_int = torch.tensor( - eps, dtype=torch.float32).view(dtype=torch.int32) - zero_mask = (result.abs() < eps_int) - result[zero_mask] = torch.sign(source_int[zero_mask]) * eps_int - - # Convert back to float32 view - result_float = result.view(dtype=torch.float32) - - # Special handling for float8 formats - if target.dtype == torch.float8_e4m3fn: - result_float.clamp_(-448.0, 448.0) - elif target.dtype == torch.float8_e5m2: - result_float.clamp_(-57344.0, 57344.0) - - target.copy_(result_float) - - -class Auto8bitTensor: - def __init__(self, data: Tensor, *args, **kwargs): - - abs_max = data.abs().max().item() - scale = abs_max / 127.0 if abs_max > 0 else 1.0 - - self.quantized = (data / scale).round().clamp(-127, 127).to(torch.int8) - self.scale = scale - self.orig_dtype = data.dtype - - def dequantize(self) -> Tensor: - return self.quantized.to(dtype=torch.float32) * self.scale - - def to(self, *args, **kwargs): - # Handle the dtype argument whether it's positional or keyword - dtype = None - if args and isinstance(args[0], torch.dtype): - dtype = args[0] - args = args[1:] - elif 'dtype' in kwargs: - dtype = kwargs['dtype'] - del kwargs['dtype'] - - if dtype is not None: - # First dequantize then convert to requested dtype - return self.dequantize().to(dtype=dtype, *args, **kwargs) - - # If no dtype specified, just pass through to parent - return self.dequantize().to(*args, **kwargs) - - -def stochastic_grad_accummulation(param): - if hasattr(param, "_accum_grad"): - grad_fp32 = param._accum_grad.clone().to(torch.float32) - grad_fp32.add_(param.grad.to(torch.float32)) - copy_stochastic(param._accum_grad, grad_fp32) - del grad_fp32 - del param.grad - else: - param._accum_grad = param.grad.clone() - del param.grad +from toolkit.optimizers.optimizer_utils import copy_stochastic, Auto8bitTensor, stochastic_grad_accummulation class Prodigy8bit(Optimizer): @@ -222,7 +80,7 @@ class Prodigy8bit(Optimizer): fsdp_in_use=fsdp_in_use) self.d0 = d0 super(Prodigy8bit, self).__init__(params, defaults) - + self.is_stochastic_rounding_accumulation = False # setup stochastic grad accum hooks diff --git a/toolkit/util/adafactor_stochastic_rounding.py b/toolkit/util/adafactor_stochastic_rounding.py deleted file mode 100644 index c9930322..00000000 --- a/toolkit/util/adafactor_stochastic_rounding.py +++ /dev/null @@ -1,120 +0,0 @@ -# ref https://github.com/Nerogar/OneTrainer/compare/master...stochastic_rounding -import math -import torch -from torch import Tensor - - -def copy_stochastic_(target: Tensor, source: Tensor): - # create a random 16 bit integer - result = torch.randint_like( - source, - dtype=torch.int32, - low=0, - high=(1 << 16), - ) - - # add the random number to the lower 16 bit of the mantissa - result.add_(source.view(dtype=torch.int32)) - - # mask off the lower 16 bit of the mantissa - result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 - - # copy the higher 16 bit into the target tensor - target.copy_(result.view(dtype=torch.float32)) - - -@torch.no_grad() -def step_adafactor(self, closure=None): - """ - Performs a single optimization step - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - grad = p.grad - if grad.dtype in {torch.float16, torch.bfloat16}: - grad = grad.float() - if grad.is_sparse: - raise RuntimeError("Adafactor does not support sparse gradients.") - - state = self.state[p] - grad_shape = grad.shape - - factored, use_first_moment = self._get_options(group, grad_shape) - # State Initialization - if len(state) == 0: - state["step"] = 0 - - if use_first_moment: - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like(grad) - if factored: - state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) - state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) - else: - state["exp_avg_sq"] = torch.zeros_like(grad) - - state["RMS"] = 0 - else: - if use_first_moment: - state["exp_avg"] = state["exp_avg"].to(grad) - if factored: - state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) - state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) - else: - state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) - - p_data_fp32 = p - if p.dtype in {torch.float16, torch.bfloat16}: - p_data_fp32 = p_data_fp32.float() - - state["step"] += 1 - state["RMS"] = self._rms(p_data_fp32) - lr = self._get_lr(group, state) - - beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) - eps = group["eps"][0] if isinstance(group["eps"], list) else group["eps"] - update = (grad ** 2) + eps - if factored: - exp_avg_sq_row = state["exp_avg_sq_row"] - exp_avg_sq_col = state["exp_avg_sq_col"] - - exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) - exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) - - # Approximation of exponential moving average of square of gradient - update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) - update.mul_(grad) - else: - exp_avg_sq = state["exp_avg_sq"] - - exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) - update = exp_avg_sq.rsqrt().mul_(grad) - - update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) - update.mul_(lr) - - if use_first_moment: - exp_avg = state["exp_avg"] - exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) - update = exp_avg - - if group["weight_decay"] != 0: - p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) - - p_data_fp32.add_(-update) - - if p.dtype == torch.bfloat16: - copy_stochastic_(p, p_data_fp32) - elif p.dtype == torch.float16: - p.copy_(p_data_fp32) - - return loss