mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-30 02:19:47 +00:00
39 lines
1.3 KiB
Python
39 lines
1.3 KiB
Python
|
|
import torch
|
|
|
|
|
|
_dwt = None
|
|
|
|
|
|
def _get_wavelet_loss(device, dtype):
|
|
global _dwt
|
|
if _dwt is not None:
|
|
return _dwt
|
|
|
|
# init wavelets
|
|
from pytorch_wavelets import DWTForward
|
|
# wave='db1' wave='haar'
|
|
dwt = DWTForward(J=1, mode='zero', wave='haar').to(
|
|
device=device, dtype=dtype)
|
|
_dwt = dwt
|
|
return dwt
|
|
|
|
|
|
def wavelet_loss(model_pred, latents, noise):
|
|
model_pred = model_pred.float()
|
|
latents = latents.float()
|
|
noise = noise.float()
|
|
dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype)
|
|
with torch.no_grad():
|
|
model_input_xll, model_input_xh = dwt(latents)
|
|
model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(model_input_xh[0], dim=2)
|
|
model_input = torch.cat([model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1)
|
|
|
|
# reverse the noise to get the model prediction of the pure latents
|
|
model_pred = noise - model_pred
|
|
|
|
model_pred_xll, model_pred_xh = dwt(model_pred)
|
|
model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(model_pred_xh[0], dim=2)
|
|
model_pred = torch.cat([model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1)
|
|
|
|
return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none") |