Experimental features and bug fixes

This commit is contained in:
Jaret Burkett
2025-02-04 13:36:34 -07:00
parent e6180d1e1d
commit 216ab164ce
6 changed files with 26 additions and 16 deletions

View File

@@ -108,6 +108,7 @@ class Adafactor(torch.optim.Optimizer):
warmup_init=False,
do_paramiter_swapping=False,
paramiter_swapping_factor=0.1,
stochastic_accumulation=True,
):
if lr is not None and relative_step:
raise ValueError(
@@ -136,13 +137,14 @@ class Adafactor(torch.optim.Optimizer):
self.is_stochastic_rounding_accumulation = False
# setup stochastic grad accum hooks
for group in self.param_groups:
for param in group['params']:
if param.requires_grad and param.dtype != torch.float32:
self.is_stochastic_rounding_accumulation = True
param.register_post_accumulate_grad_hook(
stochastic_grad_accummulation
)
if stochastic_accumulation:
for group in self.param_groups:
for param in group['params']:
if param.requires_grad and param.dtype != torch.float32:
self.is_stochastic_rounding_accumulation = True
param.register_post_accumulate_grad_hook(
stochastic_grad_accummulation
)
self.do_paramiter_swapping = do_paramiter_swapping
self.paramiter_swapping_factor = paramiter_swapping_factor