Add stepped loss type

This commit is contained in:
Jaret Burkett
2025-09-22 15:50:12 -06:00
parent 28728a1e92
commit f74475161e
7 changed files with 108 additions and 46 deletions

View File

@@ -34,7 +34,7 @@ from diffusers import EMAModel
import math
from toolkit.train_tools import precondition_model_outputs_flow_match
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
from toolkit.util.wavelet_loss import wavelet_loss
from toolkit.util.losses import wavelet_loss, stepped_loss
import torch.nn.functional as F
from toolkit.unloader import unload_text_encoder
from PIL import Image
@@ -679,6 +679,10 @@ class SDTrainer(BaseSDTrainProcess):
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
elif self.train_config.loss_type == "wavelet":
loss = wavelet_loss(pred, batch.latents, noise)
elif self.train_config.loss_type == "stepped":
loss = stepped_loss(pred, batch.latents, noise, noisy_latents, timesteps, self.sd.noise_scheduler)
# the way this loss works, it is low, increase it to match predictable LR effects
loss = loss * 10.0
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")

View File

@@ -630,7 +630,7 @@ class DiffusionFeatureExtractor5(DiffusionFeatureExtractor4):
stepped_chunks.append(stepped)
# ---- Inverse-Gaussian recovery at the target timestep ----
t_01 = (scheduler.sigmas[target_idx] / 1000).to(stepped.device).to(stepped.dtype)
t_01 = (scheduler.sigmas[target_idx]).to(stepped.device).to(stepped.dtype)
original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01)
x0_pred_chunks.append(original_samples)

93
toolkit/util/losses.py Normal file
View File

@@ -0,0 +1,93 @@
import torch
_dwt = None
def _get_wavelet_loss(device, dtype):
global _dwt
if _dwt is not None:
return _dwt
# init wavelets
from pytorch_wavelets import DWTForward
# wave='db1' wave='haar'
dwt = DWTForward(J=1, mode="zero", wave="haar").to(device=device, dtype=dtype)
_dwt = dwt
return dwt
def wavelet_loss(model_pred, latents, noise):
model_pred = model_pred.float()
latents = latents.float()
noise = noise.float()
dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype)
with torch.no_grad():
model_input_xll, model_input_xh = dwt(latents)
model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(
model_input_xh[0], dim=2
)
model_input = torch.cat(
[model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1
)
# reverse the noise to get the model prediction of the pure latents
model_pred = noise - model_pred
model_pred_xll, model_pred_xh = dwt(model_pred)
model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(
model_pred_xh[0], dim=2
)
model_pred = torch.cat(
[model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1
)
return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none")
def stepped_loss(model_pred, latents, noise, noisy_latents, timesteps, scheduler):
# this steps the on a 20 step timescale from the current step (50 idx steps ahead)
# and then reconstructs the original image at that timestep. This should lessen the error
# possible in high noise timesteps and make the flow smoother.
bs = model_pred.shape[0]
noise_pred_chunks = torch.chunk(model_pred, bs)
timestep_chunks = torch.chunk(timesteps, bs)
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
noise_chunks = torch.chunk(noise, bs)
x0_pred_chunks = []
for idx in range(bs):
model_output = noise_pred_chunks[idx] # predicted noise (same shape as latent)
timestep = timestep_chunks[idx] # scalar tensor per sample (e.g., [t])
sample = noisy_latent_chunks[idx].to(torch.float32)
noise_i = noise_chunks[idx].to(sample.dtype).to(sample.device)
# Initialize scheduler step index for this sample
scheduler._step_index = None
scheduler._init_step_index(timestep)
# ---- Step +50 indices (or to the end) in sigma-space ----
sigma = scheduler.sigmas[scheduler.step_index]
target_idx = min(scheduler.step_index + 50, len(scheduler.sigmas) - 1)
sigma_next = scheduler.sigmas[target_idx]
# One-step update along the model-predicted direction
stepped = sample + (sigma_next - sigma) * model_output
# ---- Inverse-Gaussian recovery at the target timestep ----
t_01 = (
(scheduler.sigmas[target_idx]).to(stepped.device).to(stepped.dtype)
)
original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01)
x0_pred_chunks.append(original_samples)
predicted_images = torch.cat(x0_pred_chunks, dim=0)
return torch.nn.functional.mse_loss(
predicted_images.float(),
latents.float().to(device=predicted_images.device),
reduction="none",
)

View File

@@ -1,39 +0,0 @@
import torch
_dwt = None
def _get_wavelet_loss(device, dtype):
global _dwt
if _dwt is not None:
return _dwt
# init wavelets
from pytorch_wavelets import DWTForward
# wave='db1' wave='haar'
dwt = DWTForward(J=1, mode='zero', wave='haar').to(
device=device, dtype=dtype)
_dwt = dwt
return dwt
def wavelet_loss(model_pred, latents, noise):
model_pred = model_pred.float()
latents = latents.float()
noise = noise.float()
dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype)
with torch.no_grad():
model_input_xll, model_input_xh = dwt(latents)
model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(model_input_xh[0], dim=2)
model_input = torch.cat([model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1)
# reverse the noise to get the model prediction of the pure latents
model_pred = noise - model_pred
model_pred_xll, model_pred_xh = dwt(model_pred)
model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(model_pred_xh[0], dim=2)
model_pred = torch.cat([model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1)
return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none")

View File

@@ -524,13 +524,15 @@ export default function SimpleJob({
]}
/>
<SelectInput
label="Noise Scheduler"
label="Loss Type"
className="pt-2"
value={jobConfig.config.process[0].train.noise_scheduler}
onChange={value => setJobConfig(value, 'config.process[0].train.noise_scheduler')}
value={jobConfig.config.process[0].train.loss_type}
onChange={value => setJobConfig(value, 'config.process[0].train.loss_type')}
options={[
{ value: 'flowmatch', label: 'FlowMatch' },
{ value: 'ddpm', label: 'DDPM' },
{ value: 'mse', label: 'Mean Squared Error' },
{ value: 'mae', label: 'Mean Absolute Error' },
{ value: 'wavelet', label: 'Wavelet' },
{ value: 'stepped', label: 'Stepped Recovery' },
]}
/>
</div>

View File

@@ -91,6 +91,7 @@ export const defaultJobConfig: JobConfig = {
diff_output_preservation_multiplier: 1.0,
diff_output_preservation_class: 'person',
switch_boundary_every: 1,
loss_type: 'mse',
},
model: {
name_or_path: 'ostris/Flex.1-alpha',

View File

@@ -123,6 +123,7 @@ export interface TrainConfig {
diff_output_preservation_multiplier: number;
diff_output_preservation_class: string;
switch_boundary_every: number;
loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped';
}
export interface QuantizeKwargsConfig {