Add stepped loss type

This commit is contained in:
Jaret Burkett
2025-09-22 15:50:12 -06:00
parent 28728a1e92
commit f74475161e
7 changed files with 108 additions and 46 deletions

View File

@@ -34,7 +34,7 @@ from diffusers import EMAModel
import math
from toolkit.train_tools import precondition_model_outputs_flow_match
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
from toolkit.util.wavelet_loss import wavelet_loss
from toolkit.util.losses import wavelet_loss, stepped_loss
import torch.nn.functional as F
from toolkit.unloader import unload_text_encoder
from PIL import Image
@@ -679,6 +679,10 @@ class SDTrainer(BaseSDTrainProcess):
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
elif self.train_config.loss_type == "wavelet":
loss = wavelet_loss(pred, batch.latents, noise)
elif self.train_config.loss_type == "stepped":
loss = stepped_loss(pred, batch.latents, noise, noisy_latents, timesteps, self.sd.noise_scheduler)
# the way this loss works, it is low, increase it to match predictable LR effects
loss = loss * 10.0
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")