diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index f4c10cf3..1e1c36de 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -566,7 +566,8 @@ class BaseSDTrainProcess(BaseTrainProcess): try: filename = f'optimizer.pt' file_path = os.path.join(self.save_root, filename) - torch.save(self.optimizer.state_dict(), file_path) + state_dict = self.optimizer.state_dict() + torch.save(state_dict, file_path) except Exception as e: print(e) print("Could not save optimizer") @@ -1786,7 +1787,9 @@ class BaseSDTrainProcess(BaseTrainProcess): with torch.no_grad(): # torch.cuda.empty_cache() # if optimizer has get_lrs method, then use it - if hasattr(optimizer, 'get_learning_rates'): + if hasattr(optimizer, 'get_avg_learning_rate'): + learning_rate = optimizer.get_avg_learning_rate() + elif hasattr(optimizer, 'get_learning_rates'): learning_rate = optimizer.get_learning_rates()[0] elif self.train_config.optimizer.lower().startswith('dadaptation') or \ self.train_config.optimizer.lower().startswith('prodigy'): diff --git a/toolkit/optimizers/automagic.py b/toolkit/optimizers/automagic.py index c7eac6f5..dfbebc90 100644 --- a/toolkit/optimizers/automagic.py +++ b/toolkit/optimizers/automagic.py @@ -1,3 +1,4 @@ +from collections import OrderedDict import math from typing import List import torch @@ -95,15 +96,16 @@ class Automagic(torch.optim.Optimizer): @staticmethod def _get_lr(param_group, param_state): - lr = param_group["avg_lr"] - param_scale = 1.0 - return param_scale * lr + if 'avg_lr' in param_state: + lr = param_state["avg_lr"] + else: + lr = param_state["lr"] + return lr def _get_group_lr(self, group): group_lrs = [] for p in group["params"]: - if p.grad is not None: - group_lrs.append(self._get_lr(group, self.state[p])) + group_lrs.append(self._get_lr(group, self.state[p])) # return avg if len(group_lrs) == 0: return self.lr @@ -179,26 +181,7 @@ class Automagic(torch.optim.Optimizer): factored = len(grad_shape) >= 2 # State Initialization if len(state) == 0: - state["step"] = 0 - - # store the lr mask - state['lr_mask'] = Auto8bitTensor(torch.ones( - p.shape).to(p.device, dtype=torch.float32) * self.lr - ) - state['avg_lr'] = torch.mean( - state['lr_mask'].to(torch.float32)) - state['last_polarity'] = torch.zeros( - p.shape, dtype=torch.bool, device=p.device) - - if factored: - state["exp_avg_sq_row"] = torch.zeros( - grad_shape[:-1]).to(grad) - state["exp_avg_sq_col"] = torch.zeros( - grad_shape[:-2] + grad_shape[-1:]).to(grad) - else: - state["exp_avg_sq"] = torch.zeros_like(grad) - - state["RMS"] = 0 + self.initialize_state(p) else: if factored: state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to( @@ -294,3 +277,57 @@ class Automagic(torch.optim.Optimizer): copy_stochastic(p, p_data_fp32) return loss + + def initialize_state(self, p): + state = self.state[p] + state["step"] = 0 + + # store the lr mask + if 'lr_mask' not in state: + state['lr_mask'] = Auto8bitTensor(torch.ones( + p.shape).to(p.device, dtype=torch.float32) * self.lr + ) + state['avg_lr'] = torch.mean( + state['lr_mask'].to(torch.float32)) + if 'last_polarity' not in state: + state['last_polarity'] = torch.zeros( + p.shape, dtype=torch.bool, device=p.device) + + factored = len(p.shape) >= 2 + if factored: + state["exp_avg_sq_row"] = torch.zeros( + p.shape[:-1]).to(p) + state["exp_avg_sq_col"] = torch.zeros( + p.shape[:-2] + p.shape[-1:]).to(p) + else: + state["exp_avg_sq"] = torch.zeros_like(p) + + state["RMS"] = 0 + + # override the state_dict to save the lr_mask + def state_dict(self, *args, **kwargs): + orig_state_dict = super().state_dict(*args, **kwargs) + # convert the state to quantized tensor to scale and quantized + new_sace_state = {} + for p, state in orig_state_dict['state'].items(): + save_state = {k: v for k, v in state.items() if k != 'lr_mask'} + save_state['lr_mask'] = state['lr_mask'].state_dict() + new_sace_state[p] = save_state + + orig_state_dict['state'] = new_sace_state + + return orig_state_dict + + def load_state_dict(self, state_dict, strict=True): + # load the lr_mask from the state_dict + idx = 0 + for group in self.param_groups: + for p in group['params']: + self.initialize_state(p) + state = self.state[p] + m = state_dict['state'][idx]['lr_mask'] + sd_mask = m['quantized'].to(m['orig_dtype']) * m['scale'] + state['lr_mask'] = Auto8bitTensor(sd_mask) + del state_dict['state'][idx]['lr_mask'] + idx += 1 + super().load_state_dict(state_dict, strict) diff --git a/toolkit/optimizers/optimizer_utils.py b/toolkit/optimizers/optimizer_utils.py index a559d0d7..67991f21 100644 --- a/toolkit/optimizers/optimizer_utils.py +++ b/toolkit/optimizers/optimizer_utils.py @@ -241,7 +241,7 @@ class Auto8bitTensor: self.orig_dtype = state_dict['orig_dtype'] def __str__(self): - return f"Auto8bitTensor(scale={self.scale}, orig_dtype={self.orig_dtype})" + return f"Auto8bitTensor({self.dequantize()})" def stochastic_grad_accummulation(param):