Added adafactor implementation that handles stochastic rounding of update and accumulation

This commit is contained in:
Jaret Burkett
2024-10-30 05:25:57 -06:00
parent e72b59a8e9
commit 58f9d01c2b
6 changed files with 458 additions and 269 deletions

View File

@@ -77,9 +77,9 @@ def get_optimizer(
except ImportError:
raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch")
elif lower_type == 'adagrad':
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params)
elif lower_type == 'adafactor':
# hack in stochastic rounding
from toolkit.optimizers.adafactor import Adafactor
if 'relative_step' not in optimizer_params:
optimizer_params['relative_step'] = False
if 'scale_parameter' not in optimizer_params:
@@ -87,8 +87,6 @@ def get_optimizer(
if 'warmup_init' not in optimizer_params:
optimizer_params['warmup_init'] = False
optimizer = Adafactor(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
from toolkit.util.adafactor_stochastic_rounding import step_adafactor
optimizer.step = step_adafactor.__get__(optimizer, Adafactor)
else:
raise ValueError(f'Unknown optimizer type {optimizer_type}')
return optimizer