Files
ai-toolkit/toolkit/util/losses.py
2025-09-22 15:50:12 -06:00

94 lines
3.2 KiB
Python

import torch
_dwt = None
def _get_wavelet_loss(device, dtype):
global _dwt
if _dwt is not None:
return _dwt
# init wavelets
from pytorch_wavelets import DWTForward
# wave='db1' wave='haar'
dwt = DWTForward(J=1, mode="zero", wave="haar").to(device=device, dtype=dtype)
_dwt = dwt
return dwt
def wavelet_loss(model_pred, latents, noise):
model_pred = model_pred.float()
latents = latents.float()
noise = noise.float()
dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype)
with torch.no_grad():
model_input_xll, model_input_xh = dwt(latents)
model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(
model_input_xh[0], dim=2
)
model_input = torch.cat(
[model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1
)
# reverse the noise to get the model prediction of the pure latents
model_pred = noise - model_pred
model_pred_xll, model_pred_xh = dwt(model_pred)
model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(
model_pred_xh[0], dim=2
)
model_pred = torch.cat(
[model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1
)
return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none")
def stepped_loss(model_pred, latents, noise, noisy_latents, timesteps, scheduler):
# this steps the on a 20 step timescale from the current step (50 idx steps ahead)
# and then reconstructs the original image at that timestep. This should lessen the error
# possible in high noise timesteps and make the flow smoother.
bs = model_pred.shape[0]
noise_pred_chunks = torch.chunk(model_pred, bs)
timestep_chunks = torch.chunk(timesteps, bs)
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
noise_chunks = torch.chunk(noise, bs)
x0_pred_chunks = []
for idx in range(bs):
model_output = noise_pred_chunks[idx] # predicted noise (same shape as latent)
timestep = timestep_chunks[idx] # scalar tensor per sample (e.g., [t])
sample = noisy_latent_chunks[idx].to(torch.float32)
noise_i = noise_chunks[idx].to(sample.dtype).to(sample.device)
# Initialize scheduler step index for this sample
scheduler._step_index = None
scheduler._init_step_index(timestep)
# ---- Step +50 indices (or to the end) in sigma-space ----
sigma = scheduler.sigmas[scheduler.step_index]
target_idx = min(scheduler.step_index + 50, len(scheduler.sigmas) - 1)
sigma_next = scheduler.sigmas[target_idx]
# One-step update along the model-predicted direction
stepped = sample + (sigma_next - sigma) * model_output
# ---- Inverse-Gaussian recovery at the target timestep ----
t_01 = (
(scheduler.sigmas[target_idx]).to(stepped.device).to(stepped.dtype)
)
original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01)
x0_pred_chunks.append(original_samples)
predicted_images = torch.cat(x0_pred_chunks, dim=0)
return torch.nn.functional.mse_loss(
predicted_images.float(),
latents.float().to(device=predicted_images.device),
reduction="none",
)