mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed saving and displaying for automagic
This commit is contained in:
@@ -566,7 +566,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
try:
|
||||
filename = f'optimizer.pt'
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
torch.save(self.optimizer.state_dict(), file_path)
|
||||
state_dict = self.optimizer.state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Could not save optimizer")
|
||||
@@ -1786,7 +1787,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
with torch.no_grad():
|
||||
# torch.cuda.empty_cache()
|
||||
# if optimizer has get_lrs method, then use it
|
||||
if hasattr(optimizer, 'get_learning_rates'):
|
||||
if hasattr(optimizer, 'get_avg_learning_rate'):
|
||||
learning_rate = optimizer.get_avg_learning_rate()
|
||||
elif hasattr(optimizer, 'get_learning_rates'):
|
||||
learning_rate = optimizer.get_learning_rates()[0]
|
||||
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
self.train_config.optimizer.lower().startswith('prodigy'):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
from typing import List
|
||||
import torch
|
||||
@@ -95,15 +96,16 @@ class Automagic(torch.optim.Optimizer):
|
||||
|
||||
@staticmethod
|
||||
def _get_lr(param_group, param_state):
|
||||
lr = param_group["avg_lr"]
|
||||
param_scale = 1.0
|
||||
return param_scale * lr
|
||||
if 'avg_lr' in param_state:
|
||||
lr = param_state["avg_lr"]
|
||||
else:
|
||||
lr = param_state["lr"]
|
||||
return lr
|
||||
|
||||
def _get_group_lr(self, group):
|
||||
group_lrs = []
|
||||
for p in group["params"]:
|
||||
if p.grad is not None:
|
||||
group_lrs.append(self._get_lr(group, self.state[p]))
|
||||
group_lrs.append(self._get_lr(group, self.state[p]))
|
||||
# return avg
|
||||
if len(group_lrs) == 0:
|
||||
return self.lr
|
||||
@@ -179,26 +181,7 @@ class Automagic(torch.optim.Optimizer):
|
||||
factored = len(grad_shape) >= 2
|
||||
# State Initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
|
||||
# store the lr mask
|
||||
state['lr_mask'] = Auto8bitTensor(torch.ones(
|
||||
p.shape).to(p.device, dtype=torch.float32) * self.lr
|
||||
)
|
||||
state['avg_lr'] = torch.mean(
|
||||
state['lr_mask'].to(torch.float32))
|
||||
state['last_polarity'] = torch.zeros(
|
||||
p.shape, dtype=torch.bool, device=p.device)
|
||||
|
||||
if factored:
|
||||
state["exp_avg_sq_row"] = torch.zeros(
|
||||
grad_shape[:-1]).to(grad)
|
||||
state["exp_avg_sq_col"] = torch.zeros(
|
||||
grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
||||
else:
|
||||
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||
|
||||
state["RMS"] = 0
|
||||
self.initialize_state(p)
|
||||
else:
|
||||
if factored:
|
||||
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(
|
||||
@@ -294,3 +277,57 @@ class Automagic(torch.optim.Optimizer):
|
||||
copy_stochastic(p, p_data_fp32)
|
||||
|
||||
return loss
|
||||
|
||||
def initialize_state(self, p):
|
||||
state = self.state[p]
|
||||
state["step"] = 0
|
||||
|
||||
# store the lr mask
|
||||
if 'lr_mask' not in state:
|
||||
state['lr_mask'] = Auto8bitTensor(torch.ones(
|
||||
p.shape).to(p.device, dtype=torch.float32) * self.lr
|
||||
)
|
||||
state['avg_lr'] = torch.mean(
|
||||
state['lr_mask'].to(torch.float32))
|
||||
if 'last_polarity' not in state:
|
||||
state['last_polarity'] = torch.zeros(
|
||||
p.shape, dtype=torch.bool, device=p.device)
|
||||
|
||||
factored = len(p.shape) >= 2
|
||||
if factored:
|
||||
state["exp_avg_sq_row"] = torch.zeros(
|
||||
p.shape[:-1]).to(p)
|
||||
state["exp_avg_sq_col"] = torch.zeros(
|
||||
p.shape[:-2] + p.shape[-1:]).to(p)
|
||||
else:
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
|
||||
state["RMS"] = 0
|
||||
|
||||
# override the state_dict to save the lr_mask
|
||||
def state_dict(self, *args, **kwargs):
|
||||
orig_state_dict = super().state_dict(*args, **kwargs)
|
||||
# convert the state to quantized tensor to scale and quantized
|
||||
new_sace_state = {}
|
||||
for p, state in orig_state_dict['state'].items():
|
||||
save_state = {k: v for k, v in state.items() if k != 'lr_mask'}
|
||||
save_state['lr_mask'] = state['lr_mask'].state_dict()
|
||||
new_sace_state[p] = save_state
|
||||
|
||||
orig_state_dict['state'] = new_sace_state
|
||||
|
||||
return orig_state_dict
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
# load the lr_mask from the state_dict
|
||||
idx = 0
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
self.initialize_state(p)
|
||||
state = self.state[p]
|
||||
m = state_dict['state'][idx]['lr_mask']
|
||||
sd_mask = m['quantized'].to(m['orig_dtype']) * m['scale']
|
||||
state['lr_mask'] = Auto8bitTensor(sd_mask)
|
||||
del state_dict['state'][idx]['lr_mask']
|
||||
idx += 1
|
||||
super().load_state_dict(state_dict, strict)
|
||||
|
||||
@@ -241,7 +241,7 @@ class Auto8bitTensor:
|
||||
self.orig_dtype = state_dict['orig_dtype']
|
||||
|
||||
def __str__(self):
|
||||
return f"Auto8bitTensor(scale={self.scale}, orig_dtype={self.orig_dtype})"
|
||||
return f"Auto8bitTensor({self.dequantize()})"
|
||||
|
||||
|
||||
def stochastic_grad_accummulation(param):
|
||||
|
||||
Reference in New Issue
Block a user