Merge pull request #271 from ostris/wavelet_loss

Added experimental wavelet loss
This commit is contained in:
Jaret Burkett
2025-03-26 19:12:09 -06:00
committed by GitHub
4 changed files with 45 additions and 2 deletions

View File

@@ -34,6 +34,7 @@ from diffusers import EMAModel
import math
from toolkit.train_tools import precondition_model_outputs_flow_match
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
from toolkit.util.wavelet_loss import wavelet_loss
def flush():
@@ -481,6 +482,8 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.loss_type == "mae":
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
elif self.train_config.loss_type == "wavelet":
loss = wavelet_loss(pred, batch.latents, noise)
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")

View File

@@ -34,4 +34,5 @@ huggingface_hub
peft
gradio
python-slugify
opencv-python
opencv-python
pytorch-wavelets==1.3.0

View File

@@ -405,7 +405,7 @@ class TrainConfig:
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
self.loss_type = kwargs.get('loss_type', 'mse')
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet
# scale the prediction by this. Increase for more detail, decrease for less
self.pred_scaler = kwargs.get('pred_scaler', 1.0)

View File

@@ -0,0 +1,39 @@
import torch
_dwt = None
def _get_wavelet_loss(device, dtype):
global _dwt
if _dwt is not None:
return _dwt
# init wavelets
from pytorch_wavelets import DWTForward
# wave='db1' wave='haar'
dwt = DWTForward(J=1, mode='zero', wave='haar').to(
device=device, dtype=dtype)
_dwt = dwt
return dwt
def wavelet_loss(model_pred, latents, noise):
model_pred = model_pred.float()
latents = latents.float()
noise = noise.float()
dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype)
with torch.no_grad():
model_input_xll, model_input_xh = dwt(latents)
model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(model_input_xh[0], dim=2)
model_input = torch.cat([model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1)
# reverse the noise to get the model prediction of the pure latents
model_pred = noise - model_pred
model_pred_xll, model_pred_xh = dwt(model_pred)
model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(model_pred_xh[0], dim=2)
model_pred = torch.cat([model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1)
return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none")