Reworked automagic optimizer and did more testing. Starting to really like it. Working well.

This commit is contained in:
Jaret Burkett
2025-04-28 08:01:10 -06:00
parent 88b3fbae37
commit 2b4c525489
5 changed files with 149 additions and 61 deletions

View File

@@ -62,7 +62,7 @@ from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, Netw
DecoratorConfig
from toolkit.logging_aitk import create_logger
from diffusers import FluxTransformer2DModel
from toolkit.accelerator import get_accelerator
from toolkit.accelerator import get_accelerator, unwrap_model
from toolkit.print import print_acc
from accelerate import Accelerator
import transformers
@@ -629,7 +629,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
try:
filename = f'optimizer.pt'
file_path = os.path.join(self.save_root, filename)
state_dict = self.optimizer.state_dict()
state_dict = unwrap_model(self.optimizer).state_dict()
torch.save(state_dict, file_path)
print_acc(f"Saved optimizer to {file_path}")
except Exception as e:
@@ -1457,7 +1457,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.load_training_state_from_metadata(previous_refiner_save)
self.sd = ModelClass(
device=self.device,
# todo handle single gpu and multi gpu here
# device=self.device,
device=self.accelerator.device,
model_config=model_config_to_load,
dtype=self.train_config.dtype,
custom_pipeline=self.custom_pipeline,

View File

@@ -113,15 +113,15 @@ class BaseModel:
):
self.accelerator = get_accelerator()
self.custom_pipeline = custom_pipeline
self.device = str(self.accelerator.device)
self.device = device
self.dtype = dtype
self.torch_dtype = get_torch_dtype(dtype)
self.device_torch = self.accelerator.device
self.device_torch = torch.device(device)
self.vae_device_torch = self.accelerator.device
self.vae_device_torch = torch.device(device)
self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype)
self.te_device_torch = self.accelerator.device
self.te_device_torch = torch.device(device)
self.te_torch_dtype = get_torch_dtype(model_config.te_dtype)
self.model_config = model_config

View File

@@ -182,7 +182,7 @@ class ToolkitModuleMixin:
lx = self.lora_down(x)
except RuntimeError as e:
print(f"Error in {self.__class__.__name__} lora_down")
print(e)
raise e
if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity):
lx = self.dropout(lx)

View File

@@ -1,5 +1,3 @@
from collections import OrderedDict
import math
from typing import List
import torch
from toolkit.optimizers.optimizer_utils import Auto8bitTensor, copy_stochastic, stochastic_grad_accummulation
@@ -11,29 +9,30 @@ class Automagic(torch.optim.Optimizer):
def __init__(
self,
params,
lr=None,
lr=1e-6, # lr is start lr
min_lr=1e-7,
max_lr=1e-3,
lr_pump_scale=1.1,
lr_dump_scale=0.85,
lr_bump=1e-6, # amount to bump the lr when adjusting
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta2=0.999,
weight_decay=0.0,
do_paramiter_swapping=False,
paramiter_swapping_factor=0.1,
):
self.lr = lr
if self.lr > 1e-3:
print(f"Warning! Start lr is very high: {self.lr}. Forcing to 1e-6. this does not work like prodigy")
self.lr = 1e-6
self.min_lr = min_lr
self.max_lr = max_lr
self.lr_pump_scale = lr_pump_scale
self.lr_dump_scale = lr_dump_scale
self.lr_bump = lr_bump
defaults = {
"lr": lr,
"eps": eps,
"clip_threshold": clip_threshold,
"decay_rate": decay_rate,
"beta2": beta2,
"weight_decay": weight_decay,
}
super().__init__(params, defaults)
@@ -119,8 +118,6 @@ class Automagic(torch.optim.Optimizer):
@staticmethod
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
# copy from fairseq's adafactor implementation:
# https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-
1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
@@ -136,7 +133,7 @@ class Automagic(torch.optim.Optimizer):
param.grad = param._accum_grad
del param._accum_grad
# adafactor manages its own lr
# automagic manages its own lr
def get_learning_rates(self):
lrs = [
@@ -185,13 +182,20 @@ class Automagic(torch.optim.Optimizer):
if len(state) == 0:
self.initialize_state(p)
else:
# Check if exp_avg_sq_row and exp_avg_sq_col exist for factored case
if factored:
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(
grad)
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(
grad)
if "exp_avg_sq_row" not in state or "exp_avg_sq_col" not in state:
state["exp_avg_sq_row"] = torch.zeros(p.shape[:-1]).to(grad)
state["exp_avg_sq_col"] = torch.zeros(p.shape[:-2] + p.shape[-1:]).to(grad)
else:
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
# Check if exp_avg_sq exists for non-factored case
else:
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
if "exp_avg_sq" not in state:
state["exp_avg_sq"] = torch.zeros_like(grad)
else:
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
p_data_fp32 = p
@@ -200,11 +204,14 @@ class Automagic(torch.optim.Optimizer):
if p.dtype != torch.float32:
p_data_fp32 = p_data_fp32.clone().float()
# Initialize step if it doesn't exist
if "step" not in state:
state["step"] = 0
state["step"] += 1
state["RMS"] = self._rms(p_data_fp32)
# lr = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
# Use fixed beta2 from group instead of decay_rate calculation
beta2 = group["beta2"]
eps = group["eps"]
if isinstance(eps, tuple) or isinstance(eps, list):
eps = eps[0]
@@ -213,10 +220,10 @@ class Automagic(torch.optim.Optimizer):
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
exp_avg_sq_row.mul_(beta2t).add_(
update.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(
update.mean(dim=-2), alpha=(1.0 - beta2t))
exp_avg_sq_row.mul_(beta2).add_(
update.mean(dim=-1), alpha=(1.0 - beta2))
exp_avg_sq_col.mul_(beta2).add_(
update.mean(dim=-2), alpha=(1.0 - beta2))
# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(
@@ -225,20 +232,16 @@ class Automagic(torch.optim.Optimizer):
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
exp_avg_sq.mul_(beta2).add_(update, alpha=(1.0 - beta2))
update = exp_avg_sq.rsqrt().mul_(grad)
update.div_(
(self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
# calculate new lr mask. if the updated param is going in same direction, increase lr, else decrease
# update the lr mask. self.lr_momentum is < 1.0. If a paramiter is positive and increasing (or negative and decreasing), increase lr,
# for that single paramiter. If a paramiter is negative and increasing or positive and decreasing, decrease lr for that single paramiter.
# to decrease lr, multiple by self.lr_momentum, to increase lr, divide by self.lr_momentum.
# not doing it this way anymore
# update.mul_(lr)
# Ensure state is properly initialized
if 'last_polarity' not in state or 'lr_mask' not in state:
self.initialize_state(p)
# Get signs of current last update and updates
last_polarity = state['last_polarity']
current_polarity = (update > 0).to(torch.bool)
@@ -251,8 +254,8 @@ class Automagic(torch.optim.Optimizer):
# Update learning rate mask based on sign agreement
new_lr = torch.where(
sign_agreement > 0,
lr_mask * self.lr_pump_scale, # Increase lr
lr_mask * self.lr_dump_scale # Decrease lr
lr_mask + self.lr_bump, # Increase lr
lr_mask - self.lr_bump # Decrease lr
)
# Clip learning rates to bounds
@@ -269,8 +272,11 @@ class Automagic(torch.optim.Optimizer):
state['avg_lr'] = torch.mean(new_lr)
if group["weight_decay"] != 0:
p_data_fp32.add_(
p_data_fp32, alpha=(-group["weight_decay"] * new_lr))
# Apply weight decay with per-parameter learning rates
# Instead of using add_ with a tensor alpha (which isn't supported),
# we'll use element-wise multiplication to apply the weight decay
weight_decay_update = p_data_fp32 * (-group["weight_decay"]) * new_lr
p_data_fp32.add_(weight_decay_update)
p_data_fp32.add_(-update)
@@ -313,7 +319,11 @@ class Automagic(torch.optim.Optimizer):
new_sace_state = {}
for p, state in orig_state_dict['state'].items():
save_state = {k: v for k, v in state.items() if k != 'lr_mask'}
save_state['lr_mask'] = state['lr_mask'].state_dict()
# Check if lr_mask exists in the state before trying to access it
if 'lr_mask' in state:
save_state['lr_mask'] = state['lr_mask'].state_dict()
new_sace_state[p] = save_state
orig_state_dict['state'] = new_sace_state
@@ -321,17 +331,93 @@ class Automagic(torch.optim.Optimizer):
return orig_state_dict
def load_state_dict(self, state_dict, strict=True):
# load the lr_mask from the state_dict
# dont load state dict for now. Has a bug. Need to fix it.
return
idx = 0
# Validate that the state_dict is from an Automagic optimizer
is_valid_automagic_state = False
# Check if state_dict has the expected structure
if 'state' in state_dict and isinstance(state_dict['state'], dict):
# Check if at least one state entry has an lr_mask, which is specific to Automagic
for param_id, param_state in state_dict['state'].items():
if isinstance(param_state, dict) and 'lr_mask' in param_state:
is_valid_automagic_state = True
break
if not is_valid_automagic_state:
return
# First, call the parent class's load_state_dict to load the basic optimizer state
# We'll handle the lr_mask separately
state_dict_copy = {
'state': {},
'param_groups': state_dict['param_groups']
}
# Copy all state entries except lr_mask
for param_id, param_state in state_dict['state'].items():
state_dict_copy['state'][param_id] = {
k: v for k, v in param_state.items() if k != 'lr_mask'
}
# Call parent class load_state_dict with the modified state dict
super().load_state_dict(state_dict_copy)
# Now handle the lr_mask separately
# We need to map the saved parameters to the current parameters
# This is tricky because the parameter IDs might be different
# Get all current parameters that require gradients
current_params = []
for group in self.param_groups:
for p in group['params']:
self.initialize_state(p)
state = self.state[p]
m = state_dict['state'][idx]['lr_mask']
sd_mask = m['quantized'].to(m['orig_dtype']) * m['scale']
state['lr_mask'] = Auto8bitTensor(sd_mask)
del state_dict['state'][idx]['lr_mask']
idx += 1
super().load_state_dict(state_dict)
if p.requires_grad:
current_params.append(p)
# If the number of parameters doesn't match, we can't reliably map them
if len(current_params) != len(state_dict['param_groups'][0]['params']):
print(f"WARNING: Number of parameters doesn't match between saved state ({len(state_dict['param_groups'][0]['params'])}) "
f"and current model ({len(current_params)}). Learning rate masks may not be correctly loaded.")
# Map parameters by their position in the param_groups
# This assumes the order of parameters is preserved between saving and loading
saved_param_ids = list(state_dict['state'].keys())
for i, current_param in enumerate(current_params):
if i >= len(saved_param_ids):
break
saved_param_id = saved_param_ids[i]
saved_state = state_dict['state'][saved_param_id]
# Skip if this saved state doesn't have an lr_mask
if 'lr_mask' not in saved_state:
continue
# Initialize the state for this parameter if it doesn't exist
if current_param not in self.state:
self.initialize_state(current_param)
# Get the current state for this parameter
current_state = self.state[current_param]
# Load the lr_mask from the saved state
saved_lr_mask = saved_state['lr_mask']
# Reconstruct the Auto8bitTensor from its state dict
try:
# Make sure the shapes match
if 'quantized' in saved_lr_mask and saved_lr_mask['quantized'].shape == current_param.shape:
current_state['lr_mask'] = Auto8bitTensor(saved_lr_mask)
else:
print(f"WARNING: Shape mismatch for parameter {i}. "
f"Expected {current_param.shape}, got {saved_lr_mask['quantized'].shape if 'quantized' in saved_lr_mask else 'unknown'}. "
f"Initializing new lr_mask.")
# Initialize a new lr_mask
current_state['lr_mask'] = Auto8bitTensor(torch.ones(
current_param.shape).to(current_param.device, dtype=torch.float32) * self.lr
)
except Exception as e:
print(f"ERROR: Failed to load lr_mask for parameter {i}: {e}")
# Initialize a new lr_mask
current_state['lr_mask'] = Auto8bitTensor(torch.ones(
current_param.shape).to(current_param.device, dtype=torch.float32) * self.lr
)

View File

@@ -142,15 +142,15 @@ class StableDiffusion:
):
self.accelerator = get_accelerator()
self.custom_pipeline = custom_pipeline
self.device = str(self.accelerator.device)
self.device = device
self.device_torch = torch.device(device)
self.dtype = dtype
self.torch_dtype = get_torch_dtype(dtype)
self.device_torch = self.accelerator.device
self.vae_device_torch = self.accelerator.device
self.vae_device_torch = torch.device(device)
self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype)
self.te_device_torch = self.accelerator.device
self.te_device_torch = torch.device(device)
self.te_torch_dtype = get_torch_dtype(model_config.te_dtype)
self.model_config = model_config