mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-11 08:20:35 +00:00
Add t0 loss target
This commit is contained in:
@@ -766,6 +766,21 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
|
||||
loss = loss_per_element
|
||||
else:
|
||||
if self.train_config.t0_loss_target:
|
||||
# do the loss on a stepped timestep 0 prediction
|
||||
# doto handle doing priors, preservations, masking, etc
|
||||
with torch.no_grad():
|
||||
tv = timesteps.to(noise_pred.device).to(noise_pred.dtype) / 1000.0
|
||||
# expand shape to match noise_pred
|
||||
while len(tv.shape) < len(noise_pred.shape):
|
||||
tv = tv.unsqueeze(-1)
|
||||
# min 0.001
|
||||
tv = torch.clamp(tv, min=0.001)
|
||||
|
||||
# step latent
|
||||
t0 = noisy_latents - tv * noise_pred
|
||||
target = batch.latents.detach()
|
||||
pred = t0
|
||||
|
||||
if self.train_config.loss_type == "mae":
|
||||
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
|
||||
|
||||
@@ -497,6 +497,9 @@ class TrainConfig:
|
||||
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
|
||||
|
||||
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, mean_flow
|
||||
|
||||
# do the loss on a timestep to 0 prediction
|
||||
self.t0_loss_target = kwargs.get('t0_loss_target', False)
|
||||
|
||||
# scale the prediction by this. Increase for more detail, decrease for less
|
||||
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
|
||||
|
||||
Reference in New Issue
Block a user