mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-11 08:20:35 +00:00
Fix issue with precision on pseudo_huber loss
This commit is contained in:
@@ -782,7 +782,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
target = batch.latents.detach()
|
||||
pred = t0
|
||||
if self.train_config.loss_type == "pseudo_huber":
|
||||
diff = pred - target
|
||||
diff = pred.float() - target.float()
|
||||
c=0.0
|
||||
loss =(torch.sqrt(diff.pow(2) + c ** 2) - c)
|
||||
elif self.train_config.loss_type == "mae":
|
||||
|
||||
Reference in New Issue
Block a user