diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 532a1c8d..16c392d7 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -34,7 +34,7 @@ from diffusers import EMAModel import math from toolkit.train_tools import precondition_model_outputs_flow_match from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe -from toolkit.util.wavelet_loss import wavelet_loss +from toolkit.util.losses import wavelet_loss, stepped_loss import torch.nn.functional as F from toolkit.unloader import unload_text_encoder from PIL import Image @@ -679,6 +679,10 @@ class SDTrainer(BaseSDTrainProcess): loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none") elif self.train_config.loss_type == "wavelet": loss = wavelet_loss(pred, batch.latents, noise) + elif self.train_config.loss_type == "stepped": + loss = stepped_loss(pred, batch.latents, noise, noisy_latents, timesteps, self.sd.noise_scheduler) + # the way this loss works, it is low, increase it to match predictable LR effects + loss = loss * 10.0 else: loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 7774af29..65c2de3c 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -630,7 +630,7 @@ class DiffusionFeatureExtractor5(DiffusionFeatureExtractor4): stepped_chunks.append(stepped) # ---- Inverse-Gaussian recovery at the target timestep ---- - t_01 = (scheduler.sigmas[target_idx] / 1000).to(stepped.device).to(stepped.dtype) + t_01 = (scheduler.sigmas[target_idx]).to(stepped.device).to(stepped.dtype) original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01) x0_pred_chunks.append(original_samples) diff --git a/toolkit/util/losses.py b/toolkit/util/losses.py new file mode 100644 index 00000000..5328674a --- /dev/null +++ b/toolkit/util/losses.py @@ -0,0 +1,93 @@ +import torch + + +_dwt = None + + +def _get_wavelet_loss(device, dtype): + global _dwt + if _dwt is not None: + return _dwt + + # init wavelets + from pytorch_wavelets import DWTForward + + # wave='db1' wave='haar' + dwt = DWTForward(J=1, mode="zero", wave="haar").to(device=device, dtype=dtype) + _dwt = dwt + return dwt + + +def wavelet_loss(model_pred, latents, noise): + model_pred = model_pred.float() + latents = latents.float() + noise = noise.float() + dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype) + with torch.no_grad(): + model_input_xll, model_input_xh = dwt(latents) + model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind( + model_input_xh[0], dim=2 + ) + model_input = torch.cat( + [model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1 + ) + + # reverse the noise to get the model prediction of the pure latents + model_pred = noise - model_pred + + model_pred_xll, model_pred_xh = dwt(model_pred) + model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind( + model_pred_xh[0], dim=2 + ) + model_pred = torch.cat( + [model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1 + ) + + return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none") + + +def stepped_loss(model_pred, latents, noise, noisy_latents, timesteps, scheduler): + # this steps the on a 20 step timescale from the current step (50 idx steps ahead) + # and then reconstructs the original image at that timestep. This should lessen the error + # possible in high noise timesteps and make the flow smoother. + bs = model_pred.shape[0] + + noise_pred_chunks = torch.chunk(model_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + noise_chunks = torch.chunk(noise, bs) + + x0_pred_chunks = [] + + for idx in range(bs): + model_output = noise_pred_chunks[idx] # predicted noise (same shape as latent) + timestep = timestep_chunks[idx] # scalar tensor per sample (e.g., [t]) + sample = noisy_latent_chunks[idx].to(torch.float32) + noise_i = noise_chunks[idx].to(sample.dtype).to(sample.device) + + # Initialize scheduler step index for this sample + scheduler._step_index = None + scheduler._init_step_index(timestep) + + # ---- Step +50 indices (or to the end) in sigma-space ---- + sigma = scheduler.sigmas[scheduler.step_index] + target_idx = min(scheduler.step_index + 50, len(scheduler.sigmas) - 1) + sigma_next = scheduler.sigmas[target_idx] + + # One-step update along the model-predicted direction + stepped = sample + (sigma_next - sigma) * model_output + + # ---- Inverse-Gaussian recovery at the target timestep ---- + t_01 = ( + (scheduler.sigmas[target_idx]).to(stepped.device).to(stepped.dtype) + ) + original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01) + x0_pred_chunks.append(original_samples) + + predicted_images = torch.cat(x0_pred_chunks, dim=0) + + return torch.nn.functional.mse_loss( + predicted_images.float(), + latents.float().to(device=predicted_images.device), + reduction="none", + ) diff --git a/toolkit/util/wavelet_loss.py b/toolkit/util/wavelet_loss.py deleted file mode 100644 index 8ae9e5a2..00000000 --- a/toolkit/util/wavelet_loss.py +++ /dev/null @@ -1,39 +0,0 @@ - -import torch - - -_dwt = None - - -def _get_wavelet_loss(device, dtype): - global _dwt - if _dwt is not None: - return _dwt - - # init wavelets - from pytorch_wavelets import DWTForward - # wave='db1' wave='haar' - dwt = DWTForward(J=1, mode='zero', wave='haar').to( - device=device, dtype=dtype) - _dwt = dwt - return dwt - - -def wavelet_loss(model_pred, latents, noise): - model_pred = model_pred.float() - latents = latents.float() - noise = noise.float() - dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype) - with torch.no_grad(): - model_input_xll, model_input_xh = dwt(latents) - model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(model_input_xh[0], dim=2) - model_input = torch.cat([model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1) - - # reverse the noise to get the model prediction of the pure latents - model_pred = noise - model_pred - - model_pred_xll, model_pred_xh = dwt(model_pred) - model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(model_pred_xh[0], dim=2) - model_pred = torch.cat([model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1) - - return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none") \ No newline at end of file diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index 30b4c155..3953c182 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -524,13 +524,15 @@ export default function SimpleJob({ ]} /> setJobConfig(value, 'config.process[0].train.noise_scheduler')} + value={jobConfig.config.process[0].train.loss_type} + onChange={value => setJobConfig(value, 'config.process[0].train.loss_type')} options={[ - { value: 'flowmatch', label: 'FlowMatch' }, - { value: 'ddpm', label: 'DDPM' }, + { value: 'mse', label: 'Mean Squared Error' }, + { value: 'mae', label: 'Mean Absolute Error' }, + { value: 'wavelet', label: 'Wavelet' }, + { value: 'stepped', label: 'Stepped Recovery' }, ]} /> diff --git a/ui/src/app/jobs/new/jobConfig.ts b/ui/src/app/jobs/new/jobConfig.ts index a683ac28..c18f7327 100644 --- a/ui/src/app/jobs/new/jobConfig.ts +++ b/ui/src/app/jobs/new/jobConfig.ts @@ -91,6 +91,7 @@ export const defaultJobConfig: JobConfig = { diff_output_preservation_multiplier: 1.0, diff_output_preservation_class: 'person', switch_boundary_every: 1, + loss_type: 'mse', }, model: { name_or_path: 'ostris/Flex.1-alpha', diff --git a/ui/src/types.ts b/ui/src/types.ts index c780e7e8..e5ffb25b 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -123,6 +123,7 @@ export interface TrainConfig { diff_output_preservation_multiplier: number; diff_output_preservation_class: string; switch_boundary_every: number; + loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped'; } export interface QuantizeKwargsConfig {