mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Merge pull request #271 from ostris/wavelet_loss
Added experimental wavelet loss
This commit is contained in:
@@ -34,6 +34,7 @@ from diffusers import EMAModel
|
||||
import math
|
||||
from toolkit.train_tools import precondition_model_outputs_flow_match
|
||||
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
||||
from toolkit.util.wavelet_loss import wavelet_loss
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -481,6 +482,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if self.train_config.loss_type == "mae":
|
||||
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
|
||||
elif self.train_config.loss_type == "wavelet":
|
||||
loss = wavelet_loss(pred, batch.latents, noise)
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||
|
||||
|
||||
@@ -34,4 +34,5 @@ huggingface_hub
|
||||
peft
|
||||
gradio
|
||||
python-slugify
|
||||
opencv-python
|
||||
opencv-python
|
||||
pytorch-wavelets==1.3.0
|
||||
@@ -405,7 +405,7 @@ class TrainConfig:
|
||||
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
|
||||
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
|
||||
|
||||
self.loss_type = kwargs.get('loss_type', 'mse')
|
||||
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet
|
||||
|
||||
# scale the prediction by this. Increase for more detail, decrease for less
|
||||
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
|
||||
|
||||
39
toolkit/util/wavelet_loss.py
Normal file
39
toolkit/util/wavelet_loss.py
Normal file
@@ -0,0 +1,39 @@
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
_dwt = None
|
||||
|
||||
|
||||
def _get_wavelet_loss(device, dtype):
|
||||
global _dwt
|
||||
if _dwt is not None:
|
||||
return _dwt
|
||||
|
||||
# init wavelets
|
||||
from pytorch_wavelets import DWTForward
|
||||
# wave='db1' wave='haar'
|
||||
dwt = DWTForward(J=1, mode='zero', wave='haar').to(
|
||||
device=device, dtype=dtype)
|
||||
_dwt = dwt
|
||||
return dwt
|
||||
|
||||
|
||||
def wavelet_loss(model_pred, latents, noise):
|
||||
model_pred = model_pred.float()
|
||||
latents = latents.float()
|
||||
noise = noise.float()
|
||||
dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype)
|
||||
with torch.no_grad():
|
||||
model_input_xll, model_input_xh = dwt(latents)
|
||||
model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(model_input_xh[0], dim=2)
|
||||
model_input = torch.cat([model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1)
|
||||
|
||||
# reverse the noise to get the model prediction of the pure latents
|
||||
model_pred = noise - model_pred
|
||||
|
||||
model_pred_xll, model_pred_xh = dwt(model_pred)
|
||||
model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(model_pred_xh[0], dim=2)
|
||||
model_pred = torch.cat([model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1)
|
||||
|
||||
return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none")
|
||||
Reference in New Issue
Block a user