mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add stepped loss type
This commit is contained in:
@@ -34,7 +34,7 @@ from diffusers import EMAModel
|
||||
import math
|
||||
from toolkit.train_tools import precondition_model_outputs_flow_match
|
||||
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
||||
from toolkit.util.wavelet_loss import wavelet_loss
|
||||
from toolkit.util.losses import wavelet_loss, stepped_loss
|
||||
import torch.nn.functional as F
|
||||
from toolkit.unloader import unload_text_encoder
|
||||
from PIL import Image
|
||||
@@ -679,6 +679,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
|
||||
elif self.train_config.loss_type == "wavelet":
|
||||
loss = wavelet_loss(pred, batch.latents, noise)
|
||||
elif self.train_config.loss_type == "stepped":
|
||||
loss = stepped_loss(pred, batch.latents, noise, noisy_latents, timesteps, self.sd.noise_scheduler)
|
||||
# the way this loss works, it is low, increase it to match predictable LR effects
|
||||
loss = loss * 10.0
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||
|
||||
|
||||
@@ -630,7 +630,7 @@ class DiffusionFeatureExtractor5(DiffusionFeatureExtractor4):
|
||||
stepped_chunks.append(stepped)
|
||||
|
||||
# ---- Inverse-Gaussian recovery at the target timestep ----
|
||||
t_01 = (scheduler.sigmas[target_idx] / 1000).to(stepped.device).to(stepped.dtype)
|
||||
t_01 = (scheduler.sigmas[target_idx]).to(stepped.device).to(stepped.dtype)
|
||||
original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01)
|
||||
x0_pred_chunks.append(original_samples)
|
||||
|
||||
|
||||
93
toolkit/util/losses.py
Normal file
93
toolkit/util/losses.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import torch
|
||||
|
||||
|
||||
_dwt = None
|
||||
|
||||
|
||||
def _get_wavelet_loss(device, dtype):
|
||||
global _dwt
|
||||
if _dwt is not None:
|
||||
return _dwt
|
||||
|
||||
# init wavelets
|
||||
from pytorch_wavelets import DWTForward
|
||||
|
||||
# wave='db1' wave='haar'
|
||||
dwt = DWTForward(J=1, mode="zero", wave="haar").to(device=device, dtype=dtype)
|
||||
_dwt = dwt
|
||||
return dwt
|
||||
|
||||
|
||||
def wavelet_loss(model_pred, latents, noise):
|
||||
model_pred = model_pred.float()
|
||||
latents = latents.float()
|
||||
noise = noise.float()
|
||||
dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype)
|
||||
with torch.no_grad():
|
||||
model_input_xll, model_input_xh = dwt(latents)
|
||||
model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(
|
||||
model_input_xh[0], dim=2
|
||||
)
|
||||
model_input = torch.cat(
|
||||
[model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1
|
||||
)
|
||||
|
||||
# reverse the noise to get the model prediction of the pure latents
|
||||
model_pred = noise - model_pred
|
||||
|
||||
model_pred_xll, model_pred_xh = dwt(model_pred)
|
||||
model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(
|
||||
model_pred_xh[0], dim=2
|
||||
)
|
||||
model_pred = torch.cat(
|
||||
[model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1
|
||||
)
|
||||
|
||||
return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none")
|
||||
|
||||
|
||||
def stepped_loss(model_pred, latents, noise, noisy_latents, timesteps, scheduler):
|
||||
# this steps the on a 20 step timescale from the current step (50 idx steps ahead)
|
||||
# and then reconstructs the original image at that timestep. This should lessen the error
|
||||
# possible in high noise timesteps and make the flow smoother.
|
||||
bs = model_pred.shape[0]
|
||||
|
||||
noise_pred_chunks = torch.chunk(model_pred, bs)
|
||||
timestep_chunks = torch.chunk(timesteps, bs)
|
||||
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
|
||||
noise_chunks = torch.chunk(noise, bs)
|
||||
|
||||
x0_pred_chunks = []
|
||||
|
||||
for idx in range(bs):
|
||||
model_output = noise_pred_chunks[idx] # predicted noise (same shape as latent)
|
||||
timestep = timestep_chunks[idx] # scalar tensor per sample (e.g., [t])
|
||||
sample = noisy_latent_chunks[idx].to(torch.float32)
|
||||
noise_i = noise_chunks[idx].to(sample.dtype).to(sample.device)
|
||||
|
||||
# Initialize scheduler step index for this sample
|
||||
scheduler._step_index = None
|
||||
scheduler._init_step_index(timestep)
|
||||
|
||||
# ---- Step +50 indices (or to the end) in sigma-space ----
|
||||
sigma = scheduler.sigmas[scheduler.step_index]
|
||||
target_idx = min(scheduler.step_index + 50, len(scheduler.sigmas) - 1)
|
||||
sigma_next = scheduler.sigmas[target_idx]
|
||||
|
||||
# One-step update along the model-predicted direction
|
||||
stepped = sample + (sigma_next - sigma) * model_output
|
||||
|
||||
# ---- Inverse-Gaussian recovery at the target timestep ----
|
||||
t_01 = (
|
||||
(scheduler.sigmas[target_idx]).to(stepped.device).to(stepped.dtype)
|
||||
)
|
||||
original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01)
|
||||
x0_pred_chunks.append(original_samples)
|
||||
|
||||
predicted_images = torch.cat(x0_pred_chunks, dim=0)
|
||||
|
||||
return torch.nn.functional.mse_loss(
|
||||
predicted_images.float(),
|
||||
latents.float().to(device=predicted_images.device),
|
||||
reduction="none",
|
||||
)
|
||||
@@ -1,39 +0,0 @@
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
_dwt = None
|
||||
|
||||
|
||||
def _get_wavelet_loss(device, dtype):
|
||||
global _dwt
|
||||
if _dwt is not None:
|
||||
return _dwt
|
||||
|
||||
# init wavelets
|
||||
from pytorch_wavelets import DWTForward
|
||||
# wave='db1' wave='haar'
|
||||
dwt = DWTForward(J=1, mode='zero', wave='haar').to(
|
||||
device=device, dtype=dtype)
|
||||
_dwt = dwt
|
||||
return dwt
|
||||
|
||||
|
||||
def wavelet_loss(model_pred, latents, noise):
|
||||
model_pred = model_pred.float()
|
||||
latents = latents.float()
|
||||
noise = noise.float()
|
||||
dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype)
|
||||
with torch.no_grad():
|
||||
model_input_xll, model_input_xh = dwt(latents)
|
||||
model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(model_input_xh[0], dim=2)
|
||||
model_input = torch.cat([model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1)
|
||||
|
||||
# reverse the noise to get the model prediction of the pure latents
|
||||
model_pred = noise - model_pred
|
||||
|
||||
model_pred_xll, model_pred_xh = dwt(model_pred)
|
||||
model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(model_pred_xh[0], dim=2)
|
||||
model_pred = torch.cat([model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1)
|
||||
|
||||
return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none")
|
||||
@@ -524,13 +524,15 @@ export default function SimpleJob({
|
||||
]}
|
||||
/>
|
||||
<SelectInput
|
||||
label="Noise Scheduler"
|
||||
label="Loss Type"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.noise_scheduler}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.noise_scheduler')}
|
||||
value={jobConfig.config.process[0].train.loss_type}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.loss_type')}
|
||||
options={[
|
||||
{ value: 'flowmatch', label: 'FlowMatch' },
|
||||
{ value: 'ddpm', label: 'DDPM' },
|
||||
{ value: 'mse', label: 'Mean Squared Error' },
|
||||
{ value: 'mae', label: 'Mean Absolute Error' },
|
||||
{ value: 'wavelet', label: 'Wavelet' },
|
||||
{ value: 'stepped', label: 'Stepped Recovery' },
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -91,6 +91,7 @@ export const defaultJobConfig: JobConfig = {
|
||||
diff_output_preservation_multiplier: 1.0,
|
||||
diff_output_preservation_class: 'person',
|
||||
switch_boundary_every: 1,
|
||||
loss_type: 'mse',
|
||||
},
|
||||
model: {
|
||||
name_or_path: 'ostris/Flex.1-alpha',
|
||||
|
||||
@@ -123,6 +123,7 @@ export interface TrainConfig {
|
||||
diff_output_preservation_multiplier: number;
|
||||
diff_output_preservation_class: string;
|
||||
switch_boundary_every: number;
|
||||
loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped';
|
||||
}
|
||||
|
||||
export interface QuantizeKwargsConfig {
|
||||
|
||||
Reference in New Issue
Block a user