Fixed saving and displaying for automagic

This commit is contained in:
Jaret Burkett
2024-11-29 08:00:22 -07:00
parent cbe31eaf0a
commit f213996aa5
3 changed files with 68 additions and 28 deletions

View File

@@ -566,7 +566,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
try:
filename = f'optimizer.pt'
file_path = os.path.join(self.save_root, filename)
torch.save(self.optimizer.state_dict(), file_path)
state_dict = self.optimizer.state_dict()
torch.save(state_dict, file_path)
except Exception as e:
print(e)
print("Could not save optimizer")
@@ -1786,7 +1787,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
with torch.no_grad():
# torch.cuda.empty_cache()
# if optimizer has get_lrs method, then use it
if hasattr(optimizer, 'get_learning_rates'):
if hasattr(optimizer, 'get_avg_learning_rate'):
learning_rate = optimizer.get_avg_learning_rate()
elif hasattr(optimizer, 'get_learning_rates'):
learning_rate = optimizer.get_learning_rates()[0]
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
self.train_config.optimizer.lower().startswith('prodigy'):

View File

@@ -1,3 +1,4 @@
from collections import OrderedDict
import math
from typing import List
import torch
@@ -95,15 +96,16 @@ class Automagic(torch.optim.Optimizer):
@staticmethod
def _get_lr(param_group, param_state):
lr = param_group["avg_lr"]
param_scale = 1.0
return param_scale * lr
if 'avg_lr' in param_state:
lr = param_state["avg_lr"]
else:
lr = param_state["lr"]
return lr
def _get_group_lr(self, group):
group_lrs = []
for p in group["params"]:
if p.grad is not None:
group_lrs.append(self._get_lr(group, self.state[p]))
group_lrs.append(self._get_lr(group, self.state[p]))
# return avg
if len(group_lrs) == 0:
return self.lr
@@ -179,26 +181,7 @@ class Automagic(torch.optim.Optimizer):
factored = len(grad_shape) >= 2
# State Initialization
if len(state) == 0:
state["step"] = 0
# store the lr mask
state['lr_mask'] = Auto8bitTensor(torch.ones(
p.shape).to(p.device, dtype=torch.float32) * self.lr
)
state['avg_lr'] = torch.mean(
state['lr_mask'].to(torch.float32))
state['last_polarity'] = torch.zeros(
p.shape, dtype=torch.bool, device=p.device)
if factored:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[:-1]).to(grad)
state["exp_avg_sq_col"] = torch.zeros(
grad_shape[:-2] + grad_shape[-1:]).to(grad)
else:
state["exp_avg_sq"] = torch.zeros_like(grad)
state["RMS"] = 0
self.initialize_state(p)
else:
if factored:
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(
@@ -294,3 +277,57 @@ class Automagic(torch.optim.Optimizer):
copy_stochastic(p, p_data_fp32)
return loss
def initialize_state(self, p):
state = self.state[p]
state["step"] = 0
# store the lr mask
if 'lr_mask' not in state:
state['lr_mask'] = Auto8bitTensor(torch.ones(
p.shape).to(p.device, dtype=torch.float32) * self.lr
)
state['avg_lr'] = torch.mean(
state['lr_mask'].to(torch.float32))
if 'last_polarity' not in state:
state['last_polarity'] = torch.zeros(
p.shape, dtype=torch.bool, device=p.device)
factored = len(p.shape) >= 2
if factored:
state["exp_avg_sq_row"] = torch.zeros(
p.shape[:-1]).to(p)
state["exp_avg_sq_col"] = torch.zeros(
p.shape[:-2] + p.shape[-1:]).to(p)
else:
state["exp_avg_sq"] = torch.zeros_like(p)
state["RMS"] = 0
# override the state_dict to save the lr_mask
def state_dict(self, *args, **kwargs):
orig_state_dict = super().state_dict(*args, **kwargs)
# convert the state to quantized tensor to scale and quantized
new_sace_state = {}
for p, state in orig_state_dict['state'].items():
save_state = {k: v for k, v in state.items() if k != 'lr_mask'}
save_state['lr_mask'] = state['lr_mask'].state_dict()
new_sace_state[p] = save_state
orig_state_dict['state'] = new_sace_state
return orig_state_dict
def load_state_dict(self, state_dict, strict=True):
# load the lr_mask from the state_dict
idx = 0
for group in self.param_groups:
for p in group['params']:
self.initialize_state(p)
state = self.state[p]
m = state_dict['state'][idx]['lr_mask']
sd_mask = m['quantized'].to(m['orig_dtype']) * m['scale']
state['lr_mask'] = Auto8bitTensor(sd_mask)
del state_dict['state'][idx]['lr_mask']
idx += 1
super().load_state_dict(state_dict, strict)

View File

@@ -241,7 +241,7 @@ class Auto8bitTensor:
self.orig_dtype = state_dict['orig_dtype']
def __str__(self):
return f"Auto8bitTensor(scale={self.scale}, orig_dtype={self.orig_dtype})"
return f"Auto8bitTensor({self.dequantize()})"
def stochastic_grad_accummulation(param):