diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index c8861158..ebf518f2 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -785,7 +785,7 @@ class SDTrainer(BaseSDTrainProcess): diff = pred - target c=0.0 loss =(torch.sqrt(diff.pow(2) + c ** 2) - c) - if self.train_config.loss_type == "mae": + elif self.train_config.loss_type == "mae": loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none") elif self.train_config.loss_type == "wavelet": loss = wavelet_loss(pred, batch.latents, noise)