mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-31 10:59:47 +00:00
94 lines
3.2 KiB
Python
94 lines
3.2 KiB
Python
import torch
|
|
|
|
|
|
_dwt = None
|
|
|
|
|
|
def _get_wavelet_loss(device, dtype):
|
|
global _dwt
|
|
if _dwt is not None:
|
|
return _dwt
|
|
|
|
# init wavelets
|
|
from pytorch_wavelets import DWTForward
|
|
|
|
# wave='db1' wave='haar'
|
|
dwt = DWTForward(J=1, mode="zero", wave="haar").to(device=device, dtype=dtype)
|
|
_dwt = dwt
|
|
return dwt
|
|
|
|
|
|
def wavelet_loss(model_pred, latents, noise):
|
|
model_pred = model_pred.float()
|
|
latents = latents.float()
|
|
noise = noise.float()
|
|
dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype)
|
|
with torch.no_grad():
|
|
model_input_xll, model_input_xh = dwt(latents)
|
|
model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(
|
|
model_input_xh[0], dim=2
|
|
)
|
|
model_input = torch.cat(
|
|
[model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1
|
|
)
|
|
|
|
# reverse the noise to get the model prediction of the pure latents
|
|
model_pred = noise - model_pred
|
|
|
|
model_pred_xll, model_pred_xh = dwt(model_pred)
|
|
model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(
|
|
model_pred_xh[0], dim=2
|
|
)
|
|
model_pred = torch.cat(
|
|
[model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1
|
|
)
|
|
|
|
return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none")
|
|
|
|
|
|
def stepped_loss(model_pred, latents, noise, noisy_latents, timesteps, scheduler):
|
|
# this steps the on a 20 step timescale from the current step (50 idx steps ahead)
|
|
# and then reconstructs the original image at that timestep. This should lessen the error
|
|
# possible in high noise timesteps and make the flow smoother.
|
|
bs = model_pred.shape[0]
|
|
|
|
noise_pred_chunks = torch.chunk(model_pred, bs)
|
|
timestep_chunks = torch.chunk(timesteps, bs)
|
|
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
|
|
noise_chunks = torch.chunk(noise, bs)
|
|
|
|
x0_pred_chunks = []
|
|
|
|
for idx in range(bs):
|
|
model_output = noise_pred_chunks[idx] # predicted noise (same shape as latent)
|
|
timestep = timestep_chunks[idx] # scalar tensor per sample (e.g., [t])
|
|
sample = noisy_latent_chunks[idx].to(torch.float32)
|
|
noise_i = noise_chunks[idx].to(sample.dtype).to(sample.device)
|
|
|
|
# Initialize scheduler step index for this sample
|
|
scheduler._step_index = None
|
|
scheduler._init_step_index(timestep)
|
|
|
|
# ---- Step +50 indices (or to the end) in sigma-space ----
|
|
sigma = scheduler.sigmas[scheduler.step_index]
|
|
target_idx = min(scheduler.step_index + 50, len(scheduler.sigmas) - 1)
|
|
sigma_next = scheduler.sigmas[target_idx]
|
|
|
|
# One-step update along the model-predicted direction
|
|
stepped = sample + (sigma_next - sigma) * model_output
|
|
|
|
# ---- Inverse-Gaussian recovery at the target timestep ----
|
|
t_01 = (
|
|
(scheduler.sigmas[target_idx]).to(stepped.device).to(stepped.dtype)
|
|
)
|
|
original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01)
|
|
x0_pred_chunks.append(original_samples)
|
|
|
|
predicted_images = torch.cat(x0_pred_chunks, dim=0)
|
|
|
|
return torch.nn.functional.mse_loss(
|
|
predicted_images.float(),
|
|
latents.float().to(device=predicted_images.device),
|
|
reduction="none",
|
|
)
|