mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Merge pull request #271 from ostris/wavelet_loss
Added experimental wavelet loss
This commit is contained in:
@@ -34,6 +34,7 @@ from diffusers import EMAModel
|
|||||||
import math
|
import math
|
||||||
from toolkit.train_tools import precondition_model_outputs_flow_match
|
from toolkit.train_tools import precondition_model_outputs_flow_match
|
||||||
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
||||||
|
from toolkit.util.wavelet_loss import wavelet_loss
|
||||||
|
|
||||||
|
|
||||||
def flush():
|
def flush():
|
||||||
@@ -481,6 +482,8 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
if self.train_config.loss_type == "mae":
|
if self.train_config.loss_type == "mae":
|
||||||
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
|
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
|
||||||
|
elif self.train_config.loss_type == "wavelet":
|
||||||
|
loss = wavelet_loss(pred, batch.latents, noise)
|
||||||
else:
|
else:
|
||||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||||
|
|
||||||
|
|||||||
@@ -34,4 +34,5 @@ huggingface_hub
|
|||||||
peft
|
peft
|
||||||
gradio
|
gradio
|
||||||
python-slugify
|
python-slugify
|
||||||
opencv-python
|
opencv-python
|
||||||
|
pytorch-wavelets==1.3.0
|
||||||
@@ -405,7 +405,7 @@ class TrainConfig:
|
|||||||
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
|
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
|
||||||
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
|
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
|
||||||
|
|
||||||
self.loss_type = kwargs.get('loss_type', 'mse')
|
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet
|
||||||
|
|
||||||
# scale the prediction by this. Increase for more detail, decrease for less
|
# scale the prediction by this. Increase for more detail, decrease for less
|
||||||
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
|
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
|
||||||
|
|||||||
39
toolkit/util/wavelet_loss.py
Normal file
39
toolkit/util/wavelet_loss.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
_dwt = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_wavelet_loss(device, dtype):
|
||||||
|
global _dwt
|
||||||
|
if _dwt is not None:
|
||||||
|
return _dwt
|
||||||
|
|
||||||
|
# init wavelets
|
||||||
|
from pytorch_wavelets import DWTForward
|
||||||
|
# wave='db1' wave='haar'
|
||||||
|
dwt = DWTForward(J=1, mode='zero', wave='haar').to(
|
||||||
|
device=device, dtype=dtype)
|
||||||
|
_dwt = dwt
|
||||||
|
return dwt
|
||||||
|
|
||||||
|
|
||||||
|
def wavelet_loss(model_pred, latents, noise):
|
||||||
|
model_pred = model_pred.float()
|
||||||
|
latents = latents.float()
|
||||||
|
noise = noise.float()
|
||||||
|
dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype)
|
||||||
|
with torch.no_grad():
|
||||||
|
model_input_xll, model_input_xh = dwt(latents)
|
||||||
|
model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(model_input_xh[0], dim=2)
|
||||||
|
model_input = torch.cat([model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1)
|
||||||
|
|
||||||
|
# reverse the noise to get the model prediction of the pure latents
|
||||||
|
model_pred = noise - model_pred
|
||||||
|
|
||||||
|
model_pred_xll, model_pred_xh = dwt(model_pred)
|
||||||
|
model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(model_pred_xh[0], dim=2)
|
||||||
|
model_pred = torch.cat([model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1)
|
||||||
|
|
||||||
|
return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none")
|
||||||
Reference in New Issue
Block a user