diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index df24e435..17bb4863 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1073,7 +1073,6 @@ class SDTrainer(BaseSDTrainProcess): # set the weights network.multiplier = network_weight_list - self.optimizer.zero_grad(set_to_none=True) # activate network if it exits @@ -1554,8 +1553,8 @@ class SDTrainer(BaseSDTrainProcess): else: batch_list = [batch] total_loss = None + self.optimizer.zero_grad() for batch in batch_list: - self.optimizer.zero_grad(set_to_none=True) loss = self.train_single_accumulation(batch) if total_loss is None: total_loss = loss