Fixed issue with grad scaling

This commit is contained in:
Jaret Burkett
2024-07-20 08:21:57 -06:00
parent a2301cf28c
commit 22d2f6e28f
4 changed files with 4 additions and 3161 deletions

View File

@@ -1520,9 +1520,9 @@ class SDTrainer(BaseSDTrainProcess):
# I spent weeks on fighting this. DON'T DO IT
# with fsdp_overlap_step_with_backward():
# if self.is_bfloat:
loss.backward()
# loss.backward()
# else:
# self.scaler.scale(loss).backward()
self.scaler.scale(loss).backward()
# flush()
if not self.is_grad_accumulation_step: