mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 17:51:41 +00:00
Added adafactor implementation that handles stochastic rounding of update and accumulation
This commit is contained in:
@@ -1750,7 +1750,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
with torch.no_grad():
|
||||
# torch.cuda.empty_cache()
|
||||
if self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
# if optimizer has get_lrs method, then use it
|
||||
if hasattr(optimizer, 'get_learning_rates'):
|
||||
learning_rate = optimizer.get_learning_rates()[0]
|
||||
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
self.train_config.optimizer.lower().startswith('prodigy'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
|
||||
Reference in New Issue
Block a user