Added adafactor implementation that handles stochastic rounding of update and accumulation

This commit is contained in:
Jaret Burkett
2024-10-30 05:25:57 -06:00
parent e72b59a8e9
commit 58f9d01c2b
6 changed files with 458 additions and 269 deletions

View File

@@ -1750,7 +1750,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
with torch.no_grad():
# torch.cuda.empty_cache()
if self.train_config.optimizer.lower().startswith('dadaptation') or \
# if optimizer has get_lrs method, then use it
if hasattr(optimizer, 'get_learning_rates'):
learning_rate = optimizer.get_learning_rates()[0]
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
self.train_config.optimizer.lower().startswith('prodigy'):
learning_rate = (
optimizer.param_groups[0]["d"] *