From 2b4c525489c6fd42a8e99724d9c528b3dcbecb92 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 28 Apr 2025 08:01:10 -0600 Subject: [PATCH] Reworked automagic optimizer and did more testing. Starting to really like it. Working well. --- jobs/process/BaseSDTrainProcess.py | 8 +- toolkit/models/base_model.py | 8 +- toolkit/network_mixins.py | 2 +- toolkit/optimizers/automagic.py | 184 +++++++++++++++++++++-------- toolkit/stable_diffusion_model.py | 8 +- 5 files changed, 149 insertions(+), 61 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 83ed8242..1ccb0c3d 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -62,7 +62,7 @@ from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, Netw DecoratorConfig from toolkit.logging_aitk import create_logger from diffusers import FluxTransformer2DModel -from toolkit.accelerator import get_accelerator +from toolkit.accelerator import get_accelerator, unwrap_model from toolkit.print import print_acc from accelerate import Accelerator import transformers @@ -629,7 +629,7 @@ class BaseSDTrainProcess(BaseTrainProcess): try: filename = f'optimizer.pt' file_path = os.path.join(self.save_root, filename) - state_dict = self.optimizer.state_dict() + state_dict = unwrap_model(self.optimizer).state_dict() torch.save(state_dict, file_path) print_acc(f"Saved optimizer to {file_path}") except Exception as e: @@ -1457,7 +1457,9 @@ class BaseSDTrainProcess(BaseTrainProcess): self.load_training_state_from_metadata(previous_refiner_save) self.sd = ModelClass( - device=self.device, + # todo handle single gpu and multi gpu here + # device=self.device, + device=self.accelerator.device, model_config=model_config_to_load, dtype=self.train_config.dtype, custom_pipeline=self.custom_pipeline, diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 6982e506..27a94503 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -113,15 +113,15 @@ class BaseModel: ): self.accelerator = get_accelerator() self.custom_pipeline = custom_pipeline - self.device = str(self.accelerator.device) + self.device = device self.dtype = dtype self.torch_dtype = get_torch_dtype(dtype) - self.device_torch = self.accelerator.device + self.device_torch = torch.device(device) - self.vae_device_torch = self.accelerator.device + self.vae_device_torch = torch.device(device) self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype) - self.te_device_torch = self.accelerator.device + self.te_device_torch = torch.device(device) self.te_torch_dtype = get_torch_dtype(model_config.te_dtype) self.model_config = model_config diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 826d7f09..796e4042 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -182,7 +182,7 @@ class ToolkitModuleMixin: lx = self.lora_down(x) except RuntimeError as e: print(f"Error in {self.__class__.__name__} lora_down") - print(e) + raise e if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity): lx = self.dropout(lx) diff --git a/toolkit/optimizers/automagic.py b/toolkit/optimizers/automagic.py index 0aa5d51f..f5a88eff 100644 --- a/toolkit/optimizers/automagic.py +++ b/toolkit/optimizers/automagic.py @@ -1,5 +1,3 @@ -from collections import OrderedDict -import math from typing import List import torch from toolkit.optimizers.optimizer_utils import Auto8bitTensor, copy_stochastic, stochastic_grad_accummulation @@ -11,29 +9,30 @@ class Automagic(torch.optim.Optimizer): def __init__( self, params, - lr=None, + lr=1e-6, # lr is start lr min_lr=1e-7, max_lr=1e-3, - lr_pump_scale=1.1, - lr_dump_scale=0.85, + lr_bump=1e-6, # amount to bump the lr when adjusting eps=(1e-30, 1e-3), clip_threshold=1.0, - decay_rate=-0.8, + beta2=0.999, weight_decay=0.0, do_paramiter_swapping=False, paramiter_swapping_factor=0.1, ): self.lr = lr + if self.lr > 1e-3: + print(f"Warning! Start lr is very high: {self.lr}. Forcing to 1e-6. this does not work like prodigy") + self.lr = 1e-6 self.min_lr = min_lr self.max_lr = max_lr - self.lr_pump_scale = lr_pump_scale - self.lr_dump_scale = lr_dump_scale + self.lr_bump = lr_bump defaults = { "lr": lr, "eps": eps, "clip_threshold": clip_threshold, - "decay_rate": decay_rate, + "beta2": beta2, "weight_decay": weight_decay, } super().__init__(params, defaults) @@ -119,8 +118,6 @@ class Automagic(torch.optim.Optimizer): @staticmethod def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): - # copy from fairseq's adafactor implementation: - # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=- 1, keepdim=True)).rsqrt_().unsqueeze(-1) c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() @@ -136,7 +133,7 @@ class Automagic(torch.optim.Optimizer): param.grad = param._accum_grad del param._accum_grad - # adafactor manages its own lr + # automagic manages its own lr def get_learning_rates(self): lrs = [ @@ -185,13 +182,20 @@ class Automagic(torch.optim.Optimizer): if len(state) == 0: self.initialize_state(p) else: + # Check if exp_avg_sq_row and exp_avg_sq_col exist for factored case if factored: - state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to( - grad) - state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to( - grad) + if "exp_avg_sq_row" not in state or "exp_avg_sq_col" not in state: + state["exp_avg_sq_row"] = torch.zeros(p.shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros(p.shape[:-2] + p.shape[-1:]).to(grad) + else: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + # Check if exp_avg_sq exists for non-factored case else: - state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + if "exp_avg_sq" not in state: + state["exp_avg_sq"] = torch.zeros_like(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) p_data_fp32 = p @@ -200,11 +204,14 @@ class Automagic(torch.optim.Optimizer): if p.dtype != torch.float32: p_data_fp32 = p_data_fp32.clone().float() + # Initialize step if it doesn't exist + if "step" not in state: + state["step"] = 0 state["step"] += 1 state["RMS"] = self._rms(p_data_fp32) - # lr = self._get_lr(group, state) - beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + # Use fixed beta2 from group instead of decay_rate calculation + beta2 = group["beta2"] eps = group["eps"] if isinstance(eps, tuple) or isinstance(eps, list): eps = eps[0] @@ -213,10 +220,10 @@ class Automagic(torch.optim.Optimizer): exp_avg_sq_row = state["exp_avg_sq_row"] exp_avg_sq_col = state["exp_avg_sq_col"] - exp_avg_sq_row.mul_(beta2t).add_( - update.mean(dim=-1), alpha=(1.0 - beta2t)) - exp_avg_sq_col.mul_(beta2t).add_( - update.mean(dim=-2), alpha=(1.0 - beta2t)) + exp_avg_sq_row.mul_(beta2).add_( + update.mean(dim=-1), alpha=(1.0 - beta2)) + exp_avg_sq_col.mul_(beta2).add_( + update.mean(dim=-2), alpha=(1.0 - beta2)) # Approximation of exponential moving average of square of gradient update = self._approx_sq_grad( @@ -225,20 +232,16 @@ class Automagic(torch.optim.Optimizer): else: exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + exp_avg_sq.mul_(beta2).add_(update, alpha=(1.0 - beta2)) update = exp_avg_sq.rsqrt().mul_(grad) update.div_( (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) - # calculate new lr mask. if the updated param is going in same direction, increase lr, else decrease - # update the lr mask. self.lr_momentum is < 1.0. If a paramiter is positive and increasing (or negative and decreasing), increase lr, - # for that single paramiter. If a paramiter is negative and increasing or positive and decreasing, decrease lr for that single paramiter. - # to decrease lr, multiple by self.lr_momentum, to increase lr, divide by self.lr_momentum. - - # not doing it this way anymore - # update.mul_(lr) - + # Ensure state is properly initialized + if 'last_polarity' not in state or 'lr_mask' not in state: + self.initialize_state(p) + # Get signs of current last update and updates last_polarity = state['last_polarity'] current_polarity = (update > 0).to(torch.bool) @@ -251,8 +254,8 @@ class Automagic(torch.optim.Optimizer): # Update learning rate mask based on sign agreement new_lr = torch.where( sign_agreement > 0, - lr_mask * self.lr_pump_scale, # Increase lr - lr_mask * self.lr_dump_scale # Decrease lr + lr_mask + self.lr_bump, # Increase lr + lr_mask - self.lr_bump # Decrease lr ) # Clip learning rates to bounds @@ -269,8 +272,11 @@ class Automagic(torch.optim.Optimizer): state['avg_lr'] = torch.mean(new_lr) if group["weight_decay"] != 0: - p_data_fp32.add_( - p_data_fp32, alpha=(-group["weight_decay"] * new_lr)) + # Apply weight decay with per-parameter learning rates + # Instead of using add_ with a tensor alpha (which isn't supported), + # we'll use element-wise multiplication to apply the weight decay + weight_decay_update = p_data_fp32 * (-group["weight_decay"]) * new_lr + p_data_fp32.add_(weight_decay_update) p_data_fp32.add_(-update) @@ -313,7 +319,11 @@ class Automagic(torch.optim.Optimizer): new_sace_state = {} for p, state in orig_state_dict['state'].items(): save_state = {k: v for k, v in state.items() if k != 'lr_mask'} - save_state['lr_mask'] = state['lr_mask'].state_dict() + + # Check if lr_mask exists in the state before trying to access it + if 'lr_mask' in state: + save_state['lr_mask'] = state['lr_mask'].state_dict() + new_sace_state[p] = save_state orig_state_dict['state'] = new_sace_state @@ -321,17 +331,93 @@ class Automagic(torch.optim.Optimizer): return orig_state_dict def load_state_dict(self, state_dict, strict=True): - # load the lr_mask from the state_dict - # dont load state dict for now. Has a bug. Need to fix it. - return - idx = 0 + # Validate that the state_dict is from an Automagic optimizer + is_valid_automagic_state = False + + # Check if state_dict has the expected structure + if 'state' in state_dict and isinstance(state_dict['state'], dict): + # Check if at least one state entry has an lr_mask, which is specific to Automagic + for param_id, param_state in state_dict['state'].items(): + if isinstance(param_state, dict) and 'lr_mask' in param_state: + is_valid_automagic_state = True + break + + if not is_valid_automagic_state: + return + + # First, call the parent class's load_state_dict to load the basic optimizer state + # We'll handle the lr_mask separately + state_dict_copy = { + 'state': {}, + 'param_groups': state_dict['param_groups'] + } + + # Copy all state entries except lr_mask + for param_id, param_state in state_dict['state'].items(): + state_dict_copy['state'][param_id] = { + k: v for k, v in param_state.items() if k != 'lr_mask' + } + + # Call parent class load_state_dict with the modified state dict + super().load_state_dict(state_dict_copy) + + # Now handle the lr_mask separately + # We need to map the saved parameters to the current parameters + # This is tricky because the parameter IDs might be different + + # Get all current parameters that require gradients + current_params = [] for group in self.param_groups: for p in group['params']: - self.initialize_state(p) - state = self.state[p] - m = state_dict['state'][idx]['lr_mask'] - sd_mask = m['quantized'].to(m['orig_dtype']) * m['scale'] - state['lr_mask'] = Auto8bitTensor(sd_mask) - del state_dict['state'][idx]['lr_mask'] - idx += 1 - super().load_state_dict(state_dict) + if p.requires_grad: + current_params.append(p) + + # If the number of parameters doesn't match, we can't reliably map them + if len(current_params) != len(state_dict['param_groups'][0]['params']): + print(f"WARNING: Number of parameters doesn't match between saved state ({len(state_dict['param_groups'][0]['params'])}) " + f"and current model ({len(current_params)}). Learning rate masks may not be correctly loaded.") + + # Map parameters by their position in the param_groups + # This assumes the order of parameters is preserved between saving and loading + saved_param_ids = list(state_dict['state'].keys()) + + for i, current_param in enumerate(current_params): + if i >= len(saved_param_ids): + break + + saved_param_id = saved_param_ids[i] + saved_state = state_dict['state'][saved_param_id] + + # Skip if this saved state doesn't have an lr_mask + if 'lr_mask' not in saved_state: + continue + + # Initialize the state for this parameter if it doesn't exist + if current_param not in self.state: + self.initialize_state(current_param) + + # Get the current state for this parameter + current_state = self.state[current_param] + + # Load the lr_mask from the saved state + saved_lr_mask = saved_state['lr_mask'] + + # Reconstruct the Auto8bitTensor from its state dict + try: + # Make sure the shapes match + if 'quantized' in saved_lr_mask and saved_lr_mask['quantized'].shape == current_param.shape: + current_state['lr_mask'] = Auto8bitTensor(saved_lr_mask) + else: + print(f"WARNING: Shape mismatch for parameter {i}. " + f"Expected {current_param.shape}, got {saved_lr_mask['quantized'].shape if 'quantized' in saved_lr_mask else 'unknown'}. " + f"Initializing new lr_mask.") + # Initialize a new lr_mask + current_state['lr_mask'] = Auto8bitTensor(torch.ones( + current_param.shape).to(current_param.device, dtype=torch.float32) * self.lr + ) + except Exception as e: + print(f"ERROR: Failed to load lr_mask for parameter {i}: {e}") + # Initialize a new lr_mask + current_state['lr_mask'] = Auto8bitTensor(torch.ones( + current_param.shape).to(current_param.device, dtype=torch.float32) * self.lr + ) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 66c8c19c..3780792f 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -142,15 +142,15 @@ class StableDiffusion: ): self.accelerator = get_accelerator() self.custom_pipeline = custom_pipeline - self.device = str(self.accelerator.device) + self.device = device + self.device_torch = torch.device(device) self.dtype = dtype self.torch_dtype = get_torch_dtype(dtype) - self.device_torch = self.accelerator.device - self.vae_device_torch = self.accelerator.device + self.vae_device_torch = torch.device(device) self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype) - self.te_device_torch = self.accelerator.device + self.te_device_torch = torch.device(device) self.te_torch_dtype = get_torch_dtype(model_config.te_dtype) self.model_config = model_config