Fix bug with zeroing out gradients when accumulating

This commit is contained in:
Jaret Burkett
2024-09-03 08:29:15 -06:00
parent 121a760c19
commit 5c8fcc8a4e

View File

@@ -1073,7 +1073,6 @@ class SDTrainer(BaseSDTrainProcess):
# set the weights
network.multiplier = network_weight_list
self.optimizer.zero_grad(set_to_none=True)
# activate network if it exits
@@ -1554,8 +1553,8 @@ class SDTrainer(BaseSDTrainProcess):
else:
batch_list = [batch]
total_loss = None
self.optimizer.zero_grad()
for batch in batch_list:
self.optimizer.zero_grad(set_to_none=True)
loss = self.train_single_accumulation(batch)
if total_loss is None:
total_loss = loss