mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-11 16:30:40 +00:00
Fix issue is is else on pseudo_huber loss
This commit is contained in:
@@ -785,7 +785,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
diff = pred - target
|
||||
c=0.0
|
||||
loss =(torch.sqrt(diff.pow(2) + c ** 2) - c)
|
||||
if self.train_config.loss_type == "mae":
|
||||
elif self.train_config.loss_type == "mae":
|
||||
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
|
||||
elif self.train_config.loss_type == "wavelet":
|
||||
loss = wavelet_loss(pred, batch.latents, noise)
|
||||
|
||||
Reference in New Issue
Block a user