import torch _dwt = None def _get_wavelet_loss(device, dtype): global _dwt if _dwt is not None: return _dwt # init wavelets from pytorch_wavelets import DWTForward # wave='db1' wave='haar' dwt = DWTForward(J=1, mode="zero", wave="haar").to(device=device, dtype=dtype) _dwt = dwt return dwt def wavelet_loss(model_pred, latents, noise): model_pred = model_pred.float() latents = latents.float() noise = noise.float() dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype) with torch.no_grad(): model_input_xll, model_input_xh = dwt(latents) model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind( model_input_xh[0], dim=2 ) model_input = torch.cat( [model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1 ) # reverse the noise to get the model prediction of the pure latents model_pred = noise - model_pred model_pred_xll, model_pred_xh = dwt(model_pred) model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind( model_pred_xh[0], dim=2 ) model_pred = torch.cat( [model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1 ) return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none") def stepped_loss(model_pred, latents, noise, noisy_latents, timesteps, scheduler): # this steps the on a 20 step timescale from the current step (50 idx steps ahead) # and then reconstructs the original image at that timestep. This should lessen the error # possible in high noise timesteps and make the flow smoother. bs = model_pred.shape[0] noise_pred_chunks = torch.chunk(model_pred, bs) timestep_chunks = torch.chunk(timesteps, bs) noisy_latent_chunks = torch.chunk(noisy_latents, bs) noise_chunks = torch.chunk(noise, bs) x0_pred_chunks = [] for idx in range(bs): model_output = noise_pred_chunks[idx] # predicted noise (same shape as latent) timestep = timestep_chunks[idx] # scalar tensor per sample (e.g., [t]) sample = noisy_latent_chunks[idx].to(torch.float32) noise_i = noise_chunks[idx].to(sample.dtype).to(sample.device) # Initialize scheduler step index for this sample scheduler._step_index = None scheduler._init_step_index(timestep) # ---- Step +50 indices (or to the end) in sigma-space ---- sigma = scheduler.sigmas[scheduler.step_index] target_idx = min(scheduler.step_index + 50, len(scheduler.sigmas) - 1) sigma_next = scheduler.sigmas[target_idx] # One-step update along the model-predicted direction stepped = sample + (sigma_next - sigma) * model_output # ---- Inverse-Gaussian recovery at the target timestep ---- t_01 = ( (scheduler.sigmas[target_idx]).to(stepped.device).to(stepped.dtype) ) original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01) x0_pred_chunks.append(original_samples) predicted_images = torch.cat(x0_pred_chunks, dim=0) return torch.nn.functional.mse_loss( predicted_images.float(), latents.float().to(device=predicted_images.device), reduction="none", )