diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index fa31424d..9e0d20d9 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -766,6 +766,21 @@ class SDTrainer(BaseSDTrainProcess): loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2) loss = loss_per_element else: + if self.train_config.t0_loss_target: + # do the loss on a stepped timestep 0 prediction + # doto handle doing priors, preservations, masking, etc + with torch.no_grad(): + tv = timesteps.to(noise_pred.device).to(noise_pred.dtype) / 1000.0 + # expand shape to match noise_pred + while len(tv.shape) < len(noise_pred.shape): + tv = tv.unsqueeze(-1) + # min 0.001 + tv = torch.clamp(tv, min=0.001) + + # step latent + t0 = noisy_latents - tv * noise_pred + target = batch.latents.detach() + pred = t0 if self.train_config.loss_type == "mae": loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none") diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a581417e..9a7427ed 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -497,6 +497,9 @@ class TrainConfig: self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0) self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, mean_flow + + # do the loss on a timestep to 0 prediction + self.t0_loss_target = kwargs.get('t0_loss_target', False) # scale the prediction by this. Increase for more detail, decrease for less self.pred_scaler = kwargs.get('pred_scaler', 1.0)