Added a basic torch profiler that can be used in config during development to find some obvious issues.

This commit is contained in:
Jaret Burkett
2025-06-17 13:03:39 -06:00
parent ff617fdaea
commit 989ebfaa11
4 changed files with 53 additions and 66 deletions

View File

@@ -109,7 +109,9 @@ class Adafactor(torch.optim.Optimizer):
do_paramiter_swapping=False,
paramiter_swapping_factor=0.1,
stochastic_accumulation=True,
stochastic_rounding=True,
):
self.stochastic_rounding = stochastic_rounding
if lr is not None and relative_step:
raise ValueError(
"Cannot combine manual `lr` and `relative_step=True` options")
@@ -354,7 +356,7 @@ class Adafactor(torch.optim.Optimizer):
p_data_fp32.add_(-update)
if p.dtype != torch.float32:
if p.dtype != torch.float32 and self.stochastic_rounding:
# apply stochastic rounding
copy_stochastic(p, p_data_fp32)