diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index e5965dfb..769b702a 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -34,6 +34,7 @@ from diffusers import EMAModel import math from toolkit.train_tools import precondition_model_outputs_flow_match from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe +from toolkit.util.wavelet_loss import wavelet_loss def flush(): @@ -481,6 +482,8 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.loss_type == "mae": loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none") + elif self.train_config.loss_type == "wavelet": + loss = wavelet_loss(pred, batch.latents, noise) else: loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") diff --git a/requirements.txt b/requirements.txt index 12c9ce89..33e6bd2a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,4 +34,5 @@ huggingface_hub peft gradio python-slugify -opencv-python \ No newline at end of file +opencv-python +pytorch-wavelets==1.3.0 \ No newline at end of file diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 3fe61027..ae591bb5 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -405,7 +405,7 @@ class TrainConfig: self.correct_pred_norm = kwargs.get('correct_pred_norm', False) self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0) - self.loss_type = kwargs.get('loss_type', 'mse') + self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet # scale the prediction by this. Increase for more detail, decrease for less self.pred_scaler = kwargs.get('pred_scaler', 1.0) diff --git a/toolkit/util/wavelet_loss.py b/toolkit/util/wavelet_loss.py new file mode 100644 index 00000000..8ae9e5a2 --- /dev/null +++ b/toolkit/util/wavelet_loss.py @@ -0,0 +1,39 @@ + +import torch + + +_dwt = None + + +def _get_wavelet_loss(device, dtype): + global _dwt + if _dwt is not None: + return _dwt + + # init wavelets + from pytorch_wavelets import DWTForward + # wave='db1' wave='haar' + dwt = DWTForward(J=1, mode='zero', wave='haar').to( + device=device, dtype=dtype) + _dwt = dwt + return dwt + + +def wavelet_loss(model_pred, latents, noise): + model_pred = model_pred.float() + latents = latents.float() + noise = noise.float() + dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype) + with torch.no_grad(): + model_input_xll, model_input_xh = dwt(latents) + model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(model_input_xh[0], dim=2) + model_input = torch.cat([model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1) + + # reverse the noise to get the model prediction of the pure latents + model_pred = noise - model_pred + + model_pred_xll, model_pred_xh = dwt(model_pred) + model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(model_pred_xh[0], dim=2) + model_pred = torch.cat([model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1) + + return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none") \ No newline at end of file