Bug fixes and little improvements here and there.

This commit is contained in:
Jaret Burkett
2024-06-08 06:24:20 -06:00
parent 833c833f28
commit 3f3636b788
12 changed files with 358 additions and 117 deletions

View File

@@ -81,7 +81,8 @@ def step_adafactor(self, closure=None):
lr = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad ** 2) + group["eps"][0]
eps = group["eps"][0] if isinstance(group["eps"], list) else group["eps"]
update = (grad ** 2) + eps
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]