mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Fixed saving and displaying for automagic
This commit is contained in:
@@ -566,7 +566,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
try:
|
try:
|
||||||
filename = f'optimizer.pt'
|
filename = f'optimizer.pt'
|
||||||
file_path = os.path.join(self.save_root, filename)
|
file_path = os.path.join(self.save_root, filename)
|
||||||
torch.save(self.optimizer.state_dict(), file_path)
|
state_dict = self.optimizer.state_dict()
|
||||||
|
torch.save(state_dict, file_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
print("Could not save optimizer")
|
print("Could not save optimizer")
|
||||||
@@ -1786,7 +1787,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# torch.cuda.empty_cache()
|
# torch.cuda.empty_cache()
|
||||||
# if optimizer has get_lrs method, then use it
|
# if optimizer has get_lrs method, then use it
|
||||||
if hasattr(optimizer, 'get_learning_rates'):
|
if hasattr(optimizer, 'get_avg_learning_rate'):
|
||||||
|
learning_rate = optimizer.get_avg_learning_rate()
|
||||||
|
elif hasattr(optimizer, 'get_learning_rates'):
|
||||||
learning_rate = optimizer.get_learning_rates()[0]
|
learning_rate = optimizer.get_learning_rates()[0]
|
||||||
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
|
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||||
self.train_config.optimizer.lower().startswith('prodigy'):
|
self.train_config.optimizer.lower().startswith('prodigy'):
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
import math
|
import math
|
||||||
from typing import List
|
from typing import List
|
||||||
import torch
|
import torch
|
||||||
@@ -95,15 +96,16 @@ class Automagic(torch.optim.Optimizer):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_lr(param_group, param_state):
|
def _get_lr(param_group, param_state):
|
||||||
lr = param_group["avg_lr"]
|
if 'avg_lr' in param_state:
|
||||||
param_scale = 1.0
|
lr = param_state["avg_lr"]
|
||||||
return param_scale * lr
|
else:
|
||||||
|
lr = param_state["lr"]
|
||||||
|
return lr
|
||||||
|
|
||||||
def _get_group_lr(self, group):
|
def _get_group_lr(self, group):
|
||||||
group_lrs = []
|
group_lrs = []
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
if p.grad is not None:
|
group_lrs.append(self._get_lr(group, self.state[p]))
|
||||||
group_lrs.append(self._get_lr(group, self.state[p]))
|
|
||||||
# return avg
|
# return avg
|
||||||
if len(group_lrs) == 0:
|
if len(group_lrs) == 0:
|
||||||
return self.lr
|
return self.lr
|
||||||
@@ -179,26 +181,7 @@ class Automagic(torch.optim.Optimizer):
|
|||||||
factored = len(grad_shape) >= 2
|
factored = len(grad_shape) >= 2
|
||||||
# State Initialization
|
# State Initialization
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
state["step"] = 0
|
self.initialize_state(p)
|
||||||
|
|
||||||
# store the lr mask
|
|
||||||
state['lr_mask'] = Auto8bitTensor(torch.ones(
|
|
||||||
p.shape).to(p.device, dtype=torch.float32) * self.lr
|
|
||||||
)
|
|
||||||
state['avg_lr'] = torch.mean(
|
|
||||||
state['lr_mask'].to(torch.float32))
|
|
||||||
state['last_polarity'] = torch.zeros(
|
|
||||||
p.shape, dtype=torch.bool, device=p.device)
|
|
||||||
|
|
||||||
if factored:
|
|
||||||
state["exp_avg_sq_row"] = torch.zeros(
|
|
||||||
grad_shape[:-1]).to(grad)
|
|
||||||
state["exp_avg_sq_col"] = torch.zeros(
|
|
||||||
grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
|
||||||
else:
|
|
||||||
state["exp_avg_sq"] = torch.zeros_like(grad)
|
|
||||||
|
|
||||||
state["RMS"] = 0
|
|
||||||
else:
|
else:
|
||||||
if factored:
|
if factored:
|
||||||
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(
|
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(
|
||||||
@@ -294,3 +277,57 @@ class Automagic(torch.optim.Optimizer):
|
|||||||
copy_stochastic(p, p_data_fp32)
|
copy_stochastic(p, p_data_fp32)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def initialize_state(self, p):
|
||||||
|
state = self.state[p]
|
||||||
|
state["step"] = 0
|
||||||
|
|
||||||
|
# store the lr mask
|
||||||
|
if 'lr_mask' not in state:
|
||||||
|
state['lr_mask'] = Auto8bitTensor(torch.ones(
|
||||||
|
p.shape).to(p.device, dtype=torch.float32) * self.lr
|
||||||
|
)
|
||||||
|
state['avg_lr'] = torch.mean(
|
||||||
|
state['lr_mask'].to(torch.float32))
|
||||||
|
if 'last_polarity' not in state:
|
||||||
|
state['last_polarity'] = torch.zeros(
|
||||||
|
p.shape, dtype=torch.bool, device=p.device)
|
||||||
|
|
||||||
|
factored = len(p.shape) >= 2
|
||||||
|
if factored:
|
||||||
|
state["exp_avg_sq_row"] = torch.zeros(
|
||||||
|
p.shape[:-1]).to(p)
|
||||||
|
state["exp_avg_sq_col"] = torch.zeros(
|
||||||
|
p.shape[:-2] + p.shape[-1:]).to(p)
|
||||||
|
else:
|
||||||
|
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||||
|
|
||||||
|
state["RMS"] = 0
|
||||||
|
|
||||||
|
# override the state_dict to save the lr_mask
|
||||||
|
def state_dict(self, *args, **kwargs):
|
||||||
|
orig_state_dict = super().state_dict(*args, **kwargs)
|
||||||
|
# convert the state to quantized tensor to scale and quantized
|
||||||
|
new_sace_state = {}
|
||||||
|
for p, state in orig_state_dict['state'].items():
|
||||||
|
save_state = {k: v for k, v in state.items() if k != 'lr_mask'}
|
||||||
|
save_state['lr_mask'] = state['lr_mask'].state_dict()
|
||||||
|
new_sace_state[p] = save_state
|
||||||
|
|
||||||
|
orig_state_dict['state'] = new_sace_state
|
||||||
|
|
||||||
|
return orig_state_dict
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict, strict=True):
|
||||||
|
# load the lr_mask from the state_dict
|
||||||
|
idx = 0
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group['params']:
|
||||||
|
self.initialize_state(p)
|
||||||
|
state = self.state[p]
|
||||||
|
m = state_dict['state'][idx]['lr_mask']
|
||||||
|
sd_mask = m['quantized'].to(m['orig_dtype']) * m['scale']
|
||||||
|
state['lr_mask'] = Auto8bitTensor(sd_mask)
|
||||||
|
del state_dict['state'][idx]['lr_mask']
|
||||||
|
idx += 1
|
||||||
|
super().load_state_dict(state_dict, strict)
|
||||||
|
|||||||
@@ -241,7 +241,7 @@ class Auto8bitTensor:
|
|||||||
self.orig_dtype = state_dict['orig_dtype']
|
self.orig_dtype = state_dict['orig_dtype']
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"Auto8bitTensor(scale={self.scale}, orig_dtype={self.orig_dtype})"
|
return f"Auto8bitTensor({self.dequantize()})"
|
||||||
|
|
||||||
|
|
||||||
def stochastic_grad_accummulation(param):
|
def stochastic_grad_accummulation(param):
|
||||||
|
|||||||
Reference in New Issue
Block a user