Add t0 loss target

This commit is contained in:
Jaret Burkett
2026-03-28 13:35:21 -06:00
parent 8302b21f8f
commit 6a1fc54779
2 changed files with 18 additions and 0 deletions

View File

@@ -766,6 +766,21 @@ class SDTrainer(BaseSDTrainProcess):
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
loss = loss_per_element
else:
if self.train_config.t0_loss_target:
# do the loss on a stepped timestep 0 prediction
# doto handle doing priors, preservations, masking, etc
with torch.no_grad():
tv = timesteps.to(noise_pred.device).to(noise_pred.dtype) / 1000.0
# expand shape to match noise_pred
while len(tv.shape) < len(noise_pred.shape):
tv = tv.unsqueeze(-1)
# min 0.001
tv = torch.clamp(tv, min=0.001)
# step latent
t0 = noisy_latents - tv * noise_pred
target = batch.latents.detach()
pred = t0
if self.train_config.loss_type == "mae":
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")

View File

@@ -497,6 +497,9 @@ class TrainConfig:
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, mean_flow
# do the loss on a timestep to 0 prediction
self.t0_loss_target = kwargs.get('t0_loss_target', False)
# scale the prediction by this. Increase for more detail, decrease for less
self.pred_scaler = kwargs.get('pred_scaler', 1.0)