mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Fixed saving and displaying for automagic
This commit is contained in:
@@ -566,7 +566,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
try:
|
||||
filename = f'optimizer.pt'
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
torch.save(self.optimizer.state_dict(), file_path)
|
||||
state_dict = self.optimizer.state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Could not save optimizer")
|
||||
@@ -1786,7 +1787,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
with torch.no_grad():
|
||||
# torch.cuda.empty_cache()
|
||||
# if optimizer has get_lrs method, then use it
|
||||
if hasattr(optimizer, 'get_learning_rates'):
|
||||
if hasattr(optimizer, 'get_avg_learning_rate'):
|
||||
learning_rate = optimizer.get_avg_learning_rate()
|
||||
elif hasattr(optimizer, 'get_learning_rates'):
|
||||
learning_rate = optimizer.get_learning_rates()[0]
|
||||
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
self.train_config.optimizer.lower().startswith('prodigy'):
|
||||
|
||||
Reference in New Issue
Block a user