Fix issue is is else on pseudo_huber loss

This commit is contained in:
Jaret Burkett
2026-04-19 09:59:18 -06:00
parent f4445cd78c
commit 20a99258b8

View File

@@ -785,7 +785,7 @@ class SDTrainer(BaseSDTrainProcess):
diff = pred - target
c=0.0
loss =(torch.sqrt(diff.pow(2) + c ** 2) - c)
if self.train_config.loss_type == "mae":
elif self.train_config.loss_type == "mae":
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
elif self.train_config.loss_type == "wavelet":
loss = wavelet_loss(pred, batch.latents, noise)