Fix issue with precision on pseudo_huber loss

This commit is contained in:
Jaret Burkett
2026-04-19 10:02:27 -06:00
parent 20a99258b8
commit fc85410c9a

View File

@@ -782,7 +782,7 @@ class SDTrainer(BaseSDTrainProcess):
target = batch.latents.detach()
pred = t0
if self.train_config.loss_type == "pseudo_huber":
diff = pred - target
diff = pred.float() - target.float()
c=0.0
loss =(torch.sqrt(diff.pow(2) + c ** 2) - c)
elif self.train_config.loss_type == "mae":