mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added experimental wavelet loss
This commit is contained in:
@@ -34,6 +34,7 @@ from diffusers import EMAModel
|
||||
import math
|
||||
from toolkit.train_tools import precondition_model_outputs_flow_match
|
||||
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
||||
from toolkit.util.wavelet_loss import wavelet_loss
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -481,6 +482,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if self.train_config.loss_type == "mae":
|
||||
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
|
||||
elif self.train_config.loss_type == "wavelet":
|
||||
loss = wavelet_loss(pred, batch.latents, noise)
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user