diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index ebf518f2..f86b3d4d 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -782,7 +782,7 @@ class SDTrainer(BaseSDTrainProcess): target = batch.latents.detach() pred = t0 if self.train_config.loss_type == "pseudo_huber": - diff = pred - target + diff = pred.float() - target.float() c=0.0 loss =(torch.sqrt(diff.pow(2) + c ** 2) - c) elif self.train_config.loss_type == "mae":