mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fix bug with zeroing out gradients when accumulating
This commit is contained in:
@@ -1073,7 +1073,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# set the weights
|
||||
network.multiplier = network_weight_list
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# activate network if it exits
|
||||
|
||||
@@ -1554,8 +1553,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
else:
|
||||
batch_list = [batch]
|
||||
total_loss = None
|
||||
self.optimizer.zero_grad()
|
||||
for batch in batch_list:
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
loss = self.train_single_accumulation(batch)
|
||||
if total_loss is None:
|
||||
total_loss = loss
|
||||
|
||||
Reference in New Issue
Block a user