mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Added a basic torch profiler that can be used in config during development to find some obvious issues.
This commit is contained in:
@@ -109,7 +109,9 @@ class Adafactor(torch.optim.Optimizer):
|
||||
do_paramiter_swapping=False,
|
||||
paramiter_swapping_factor=0.1,
|
||||
stochastic_accumulation=True,
|
||||
stochastic_rounding=True,
|
||||
):
|
||||
self.stochastic_rounding = stochastic_rounding
|
||||
if lr is not None and relative_step:
|
||||
raise ValueError(
|
||||
"Cannot combine manual `lr` and `relative_step=True` options")
|
||||
@@ -354,7 +356,7 @@ class Adafactor(torch.optim.Optimizer):
|
||||
|
||||
p_data_fp32.add_(-update)
|
||||
|
||||
if p.dtype != torch.float32:
|
||||
if p.dtype != torch.float32 and self.stochastic_rounding:
|
||||
# apply stochastic rounding
|
||||
copy_stochastic(p, p_data_fp32)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user