Added experimental wavelet loss

This commit is contained in:
Jaret Burkett
2025-03-26 18:11:23 -06:00
parent c101f07834
commit ce4c5291a0
4 changed files with 45 additions and 2 deletions

View File

@@ -34,6 +34,7 @@ from diffusers import EMAModel
import math
from toolkit.train_tools import precondition_model_outputs_flow_match
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
from toolkit.util.wavelet_loss import wavelet_loss
def flush():
@@ -481,6 +482,8 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.loss_type == "mae":
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
elif self.train_config.loss_type == "wavelet":
loss = wavelet_loss(pred, batch.latents, noise)
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")