mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Reworked automagic optimizer and did more testing. Starting to really like it. Working well.
This commit is contained in:
@@ -62,7 +62,7 @@ from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, Netw
|
||||
DecoratorConfig
|
||||
from toolkit.logging_aitk import create_logger
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from toolkit.accelerator import get_accelerator
|
||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from toolkit.print import print_acc
|
||||
from accelerate import Accelerator
|
||||
import transformers
|
||||
@@ -629,7 +629,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
try:
|
||||
filename = f'optimizer.pt'
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
state_dict = self.optimizer.state_dict()
|
||||
state_dict = unwrap_model(self.optimizer).state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
print_acc(f"Saved optimizer to {file_path}")
|
||||
except Exception as e:
|
||||
@@ -1457,7 +1457,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.load_training_state_from_metadata(previous_refiner_save)
|
||||
|
||||
self.sd = ModelClass(
|
||||
device=self.device,
|
||||
# todo handle single gpu and multi gpu here
|
||||
# device=self.device,
|
||||
device=self.accelerator.device,
|
||||
model_config=model_config_to_load,
|
||||
dtype=self.train_config.dtype,
|
||||
custom_pipeline=self.custom_pipeline,
|
||||
|
||||
@@ -113,15 +113,15 @@ class BaseModel:
|
||||
):
|
||||
self.accelerator = get_accelerator()
|
||||
self.custom_pipeline = custom_pipeline
|
||||
self.device = str(self.accelerator.device)
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.torch_dtype = get_torch_dtype(dtype)
|
||||
self.device_torch = self.accelerator.device
|
||||
self.device_torch = torch.device(device)
|
||||
|
||||
self.vae_device_torch = self.accelerator.device
|
||||
self.vae_device_torch = torch.device(device)
|
||||
self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype)
|
||||
|
||||
self.te_device_torch = self.accelerator.device
|
||||
self.te_device_torch = torch.device(device)
|
||||
self.te_torch_dtype = get_torch_dtype(model_config.te_dtype)
|
||||
|
||||
self.model_config = model_config
|
||||
|
||||
@@ -182,7 +182,7 @@ class ToolkitModuleMixin:
|
||||
lx = self.lora_down(x)
|
||||
except RuntimeError as e:
|
||||
print(f"Error in {self.__class__.__name__} lora_down")
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity):
|
||||
lx = self.dropout(lx)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
from typing import List
|
||||
import torch
|
||||
from toolkit.optimizers.optimizer_utils import Auto8bitTensor, copy_stochastic, stochastic_grad_accummulation
|
||||
@@ -11,29 +9,30 @@ class Automagic(torch.optim.Optimizer):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=None,
|
||||
lr=1e-6, # lr is start lr
|
||||
min_lr=1e-7,
|
||||
max_lr=1e-3,
|
||||
lr_pump_scale=1.1,
|
||||
lr_dump_scale=0.85,
|
||||
lr_bump=1e-6, # amount to bump the lr when adjusting
|
||||
eps=(1e-30, 1e-3),
|
||||
clip_threshold=1.0,
|
||||
decay_rate=-0.8,
|
||||
beta2=0.999,
|
||||
weight_decay=0.0,
|
||||
do_paramiter_swapping=False,
|
||||
paramiter_swapping_factor=0.1,
|
||||
):
|
||||
self.lr = lr
|
||||
if self.lr > 1e-3:
|
||||
print(f"Warning! Start lr is very high: {self.lr}. Forcing to 1e-6. this does not work like prodigy")
|
||||
self.lr = 1e-6
|
||||
self.min_lr = min_lr
|
||||
self.max_lr = max_lr
|
||||
self.lr_pump_scale = lr_pump_scale
|
||||
self.lr_dump_scale = lr_dump_scale
|
||||
self.lr_bump = lr_bump
|
||||
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"eps": eps,
|
||||
"clip_threshold": clip_threshold,
|
||||
"decay_rate": decay_rate,
|
||||
"beta2": beta2,
|
||||
"weight_decay": weight_decay,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
@@ -119,8 +118,6 @@ class Automagic(torch.optim.Optimizer):
|
||||
|
||||
@staticmethod
|
||||
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
|
||||
# copy from fairseq's adafactor implementation:
|
||||
# https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
|
||||
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-
|
||||
1, keepdim=True)).rsqrt_().unsqueeze(-1)
|
||||
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
||||
@@ -136,7 +133,7 @@ class Automagic(torch.optim.Optimizer):
|
||||
param.grad = param._accum_grad
|
||||
del param._accum_grad
|
||||
|
||||
# adafactor manages its own lr
|
||||
# automagic manages its own lr
|
||||
def get_learning_rates(self):
|
||||
|
||||
lrs = [
|
||||
@@ -185,13 +182,20 @@ class Automagic(torch.optim.Optimizer):
|
||||
if len(state) == 0:
|
||||
self.initialize_state(p)
|
||||
else:
|
||||
# Check if exp_avg_sq_row and exp_avg_sq_col exist for factored case
|
||||
if factored:
|
||||
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(
|
||||
grad)
|
||||
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(
|
||||
grad)
|
||||
if "exp_avg_sq_row" not in state or "exp_avg_sq_col" not in state:
|
||||
state["exp_avg_sq_row"] = torch.zeros(p.shape[:-1]).to(grad)
|
||||
state["exp_avg_sq_col"] = torch.zeros(p.shape[:-2] + p.shape[-1:]).to(grad)
|
||||
else:
|
||||
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
|
||||
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
|
||||
# Check if exp_avg_sq exists for non-factored case
|
||||
else:
|
||||
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
||||
if "exp_avg_sq" not in state:
|
||||
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||
else:
|
||||
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
||||
|
||||
p_data_fp32 = p
|
||||
|
||||
@@ -200,11 +204,14 @@ class Automagic(torch.optim.Optimizer):
|
||||
if p.dtype != torch.float32:
|
||||
p_data_fp32 = p_data_fp32.clone().float()
|
||||
|
||||
# Initialize step if it doesn't exist
|
||||
if "step" not in state:
|
||||
state["step"] = 0
|
||||
state["step"] += 1
|
||||
state["RMS"] = self._rms(p_data_fp32)
|
||||
# lr = self._get_lr(group, state)
|
||||
|
||||
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
||||
# Use fixed beta2 from group instead of decay_rate calculation
|
||||
beta2 = group["beta2"]
|
||||
eps = group["eps"]
|
||||
if isinstance(eps, tuple) or isinstance(eps, list):
|
||||
eps = eps[0]
|
||||
@@ -213,10 +220,10 @@ class Automagic(torch.optim.Optimizer):
|
||||
exp_avg_sq_row = state["exp_avg_sq_row"]
|
||||
exp_avg_sq_col = state["exp_avg_sq_col"]
|
||||
|
||||
exp_avg_sq_row.mul_(beta2t).add_(
|
||||
update.mean(dim=-1), alpha=(1.0 - beta2t))
|
||||
exp_avg_sq_col.mul_(beta2t).add_(
|
||||
update.mean(dim=-2), alpha=(1.0 - beta2t))
|
||||
exp_avg_sq_row.mul_(beta2).add_(
|
||||
update.mean(dim=-1), alpha=(1.0 - beta2))
|
||||
exp_avg_sq_col.mul_(beta2).add_(
|
||||
update.mean(dim=-2), alpha=(1.0 - beta2))
|
||||
|
||||
# Approximation of exponential moving average of square of gradient
|
||||
update = self._approx_sq_grad(
|
||||
@@ -225,20 +232,16 @@ class Automagic(torch.optim.Optimizer):
|
||||
else:
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
|
||||
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
|
||||
exp_avg_sq.mul_(beta2).add_(update, alpha=(1.0 - beta2))
|
||||
update = exp_avg_sq.rsqrt().mul_(grad)
|
||||
|
||||
update.div_(
|
||||
(self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
||||
|
||||
# calculate new lr mask. if the updated param is going in same direction, increase lr, else decrease
|
||||
# update the lr mask. self.lr_momentum is < 1.0. If a paramiter is positive and increasing (or negative and decreasing), increase lr,
|
||||
# for that single paramiter. If a paramiter is negative and increasing or positive and decreasing, decrease lr for that single paramiter.
|
||||
# to decrease lr, multiple by self.lr_momentum, to increase lr, divide by self.lr_momentum.
|
||||
|
||||
# not doing it this way anymore
|
||||
# update.mul_(lr)
|
||||
|
||||
# Ensure state is properly initialized
|
||||
if 'last_polarity' not in state or 'lr_mask' not in state:
|
||||
self.initialize_state(p)
|
||||
|
||||
# Get signs of current last update and updates
|
||||
last_polarity = state['last_polarity']
|
||||
current_polarity = (update > 0).to(torch.bool)
|
||||
@@ -251,8 +254,8 @@ class Automagic(torch.optim.Optimizer):
|
||||
# Update learning rate mask based on sign agreement
|
||||
new_lr = torch.where(
|
||||
sign_agreement > 0,
|
||||
lr_mask * self.lr_pump_scale, # Increase lr
|
||||
lr_mask * self.lr_dump_scale # Decrease lr
|
||||
lr_mask + self.lr_bump, # Increase lr
|
||||
lr_mask - self.lr_bump # Decrease lr
|
||||
)
|
||||
|
||||
# Clip learning rates to bounds
|
||||
@@ -269,8 +272,11 @@ class Automagic(torch.optim.Optimizer):
|
||||
state['avg_lr'] = torch.mean(new_lr)
|
||||
|
||||
if group["weight_decay"] != 0:
|
||||
p_data_fp32.add_(
|
||||
p_data_fp32, alpha=(-group["weight_decay"] * new_lr))
|
||||
# Apply weight decay with per-parameter learning rates
|
||||
# Instead of using add_ with a tensor alpha (which isn't supported),
|
||||
# we'll use element-wise multiplication to apply the weight decay
|
||||
weight_decay_update = p_data_fp32 * (-group["weight_decay"]) * new_lr
|
||||
p_data_fp32.add_(weight_decay_update)
|
||||
|
||||
p_data_fp32.add_(-update)
|
||||
|
||||
@@ -313,7 +319,11 @@ class Automagic(torch.optim.Optimizer):
|
||||
new_sace_state = {}
|
||||
for p, state in orig_state_dict['state'].items():
|
||||
save_state = {k: v for k, v in state.items() if k != 'lr_mask'}
|
||||
save_state['lr_mask'] = state['lr_mask'].state_dict()
|
||||
|
||||
# Check if lr_mask exists in the state before trying to access it
|
||||
if 'lr_mask' in state:
|
||||
save_state['lr_mask'] = state['lr_mask'].state_dict()
|
||||
|
||||
new_sace_state[p] = save_state
|
||||
|
||||
orig_state_dict['state'] = new_sace_state
|
||||
@@ -321,17 +331,93 @@ class Automagic(torch.optim.Optimizer):
|
||||
return orig_state_dict
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
# load the lr_mask from the state_dict
|
||||
# dont load state dict for now. Has a bug. Need to fix it.
|
||||
return
|
||||
idx = 0
|
||||
# Validate that the state_dict is from an Automagic optimizer
|
||||
is_valid_automagic_state = False
|
||||
|
||||
# Check if state_dict has the expected structure
|
||||
if 'state' in state_dict and isinstance(state_dict['state'], dict):
|
||||
# Check if at least one state entry has an lr_mask, which is specific to Automagic
|
||||
for param_id, param_state in state_dict['state'].items():
|
||||
if isinstance(param_state, dict) and 'lr_mask' in param_state:
|
||||
is_valid_automagic_state = True
|
||||
break
|
||||
|
||||
if not is_valid_automagic_state:
|
||||
return
|
||||
|
||||
# First, call the parent class's load_state_dict to load the basic optimizer state
|
||||
# We'll handle the lr_mask separately
|
||||
state_dict_copy = {
|
||||
'state': {},
|
||||
'param_groups': state_dict['param_groups']
|
||||
}
|
||||
|
||||
# Copy all state entries except lr_mask
|
||||
for param_id, param_state in state_dict['state'].items():
|
||||
state_dict_copy['state'][param_id] = {
|
||||
k: v for k, v in param_state.items() if k != 'lr_mask'
|
||||
}
|
||||
|
||||
# Call parent class load_state_dict with the modified state dict
|
||||
super().load_state_dict(state_dict_copy)
|
||||
|
||||
# Now handle the lr_mask separately
|
||||
# We need to map the saved parameters to the current parameters
|
||||
# This is tricky because the parameter IDs might be different
|
||||
|
||||
# Get all current parameters that require gradients
|
||||
current_params = []
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
self.initialize_state(p)
|
||||
state = self.state[p]
|
||||
m = state_dict['state'][idx]['lr_mask']
|
||||
sd_mask = m['quantized'].to(m['orig_dtype']) * m['scale']
|
||||
state['lr_mask'] = Auto8bitTensor(sd_mask)
|
||||
del state_dict['state'][idx]['lr_mask']
|
||||
idx += 1
|
||||
super().load_state_dict(state_dict)
|
||||
if p.requires_grad:
|
||||
current_params.append(p)
|
||||
|
||||
# If the number of parameters doesn't match, we can't reliably map them
|
||||
if len(current_params) != len(state_dict['param_groups'][0]['params']):
|
||||
print(f"WARNING: Number of parameters doesn't match between saved state ({len(state_dict['param_groups'][0]['params'])}) "
|
||||
f"and current model ({len(current_params)}). Learning rate masks may not be correctly loaded.")
|
||||
|
||||
# Map parameters by their position in the param_groups
|
||||
# This assumes the order of parameters is preserved between saving and loading
|
||||
saved_param_ids = list(state_dict['state'].keys())
|
||||
|
||||
for i, current_param in enumerate(current_params):
|
||||
if i >= len(saved_param_ids):
|
||||
break
|
||||
|
||||
saved_param_id = saved_param_ids[i]
|
||||
saved_state = state_dict['state'][saved_param_id]
|
||||
|
||||
# Skip if this saved state doesn't have an lr_mask
|
||||
if 'lr_mask' not in saved_state:
|
||||
continue
|
||||
|
||||
# Initialize the state for this parameter if it doesn't exist
|
||||
if current_param not in self.state:
|
||||
self.initialize_state(current_param)
|
||||
|
||||
# Get the current state for this parameter
|
||||
current_state = self.state[current_param]
|
||||
|
||||
# Load the lr_mask from the saved state
|
||||
saved_lr_mask = saved_state['lr_mask']
|
||||
|
||||
# Reconstruct the Auto8bitTensor from its state dict
|
||||
try:
|
||||
# Make sure the shapes match
|
||||
if 'quantized' in saved_lr_mask and saved_lr_mask['quantized'].shape == current_param.shape:
|
||||
current_state['lr_mask'] = Auto8bitTensor(saved_lr_mask)
|
||||
else:
|
||||
print(f"WARNING: Shape mismatch for parameter {i}. "
|
||||
f"Expected {current_param.shape}, got {saved_lr_mask['quantized'].shape if 'quantized' in saved_lr_mask else 'unknown'}. "
|
||||
f"Initializing new lr_mask.")
|
||||
# Initialize a new lr_mask
|
||||
current_state['lr_mask'] = Auto8bitTensor(torch.ones(
|
||||
current_param.shape).to(current_param.device, dtype=torch.float32) * self.lr
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to load lr_mask for parameter {i}: {e}")
|
||||
# Initialize a new lr_mask
|
||||
current_state['lr_mask'] = Auto8bitTensor(torch.ones(
|
||||
current_param.shape).to(current_param.device, dtype=torch.float32) * self.lr
|
||||
)
|
||||
|
||||
@@ -142,15 +142,15 @@ class StableDiffusion:
|
||||
):
|
||||
self.accelerator = get_accelerator()
|
||||
self.custom_pipeline = custom_pipeline
|
||||
self.device = str(self.accelerator.device)
|
||||
self.device = device
|
||||
self.device_torch = torch.device(device)
|
||||
self.dtype = dtype
|
||||
self.torch_dtype = get_torch_dtype(dtype)
|
||||
self.device_torch = self.accelerator.device
|
||||
|
||||
self.vae_device_torch = self.accelerator.device
|
||||
self.vae_device_torch = torch.device(device)
|
||||
self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype)
|
||||
|
||||
self.te_device_torch = self.accelerator.device
|
||||
self.te_device_torch = torch.device(device)
|
||||
self.te_torch_dtype = get_torch_dtype(model_config.te_dtype)
|
||||
|
||||
self.model_config = model_config
|
||||
|
||||
Reference in New Issue
Block a user