mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Add stepped loss type
This commit is contained in:
@@ -34,7 +34,7 @@ from diffusers import EMAModel
|
|||||||
import math
|
import math
|
||||||
from toolkit.train_tools import precondition_model_outputs_flow_match
|
from toolkit.train_tools import precondition_model_outputs_flow_match
|
||||||
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
||||||
from toolkit.util.wavelet_loss import wavelet_loss
|
from toolkit.util.losses import wavelet_loss, stepped_loss
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from toolkit.unloader import unload_text_encoder
|
from toolkit.unloader import unload_text_encoder
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -679,6 +679,10 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
|
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
|
||||||
elif self.train_config.loss_type == "wavelet":
|
elif self.train_config.loss_type == "wavelet":
|
||||||
loss = wavelet_loss(pred, batch.latents, noise)
|
loss = wavelet_loss(pred, batch.latents, noise)
|
||||||
|
elif self.train_config.loss_type == "stepped":
|
||||||
|
loss = stepped_loss(pred, batch.latents, noise, noisy_latents, timesteps, self.sd.noise_scheduler)
|
||||||
|
# the way this loss works, it is low, increase it to match predictable LR effects
|
||||||
|
loss = loss * 10.0
|
||||||
else:
|
else:
|
||||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||||
|
|
||||||
|
|||||||
@@ -630,7 +630,7 @@ class DiffusionFeatureExtractor5(DiffusionFeatureExtractor4):
|
|||||||
stepped_chunks.append(stepped)
|
stepped_chunks.append(stepped)
|
||||||
|
|
||||||
# ---- Inverse-Gaussian recovery at the target timestep ----
|
# ---- Inverse-Gaussian recovery at the target timestep ----
|
||||||
t_01 = (scheduler.sigmas[target_idx] / 1000).to(stepped.device).to(stepped.dtype)
|
t_01 = (scheduler.sigmas[target_idx]).to(stepped.device).to(stepped.dtype)
|
||||||
original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01)
|
original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01)
|
||||||
x0_pred_chunks.append(original_samples)
|
x0_pred_chunks.append(original_samples)
|
||||||
|
|
||||||
|
|||||||
93
toolkit/util/losses.py
Normal file
93
toolkit/util/losses.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
_dwt = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_wavelet_loss(device, dtype):
|
||||||
|
global _dwt
|
||||||
|
if _dwt is not None:
|
||||||
|
return _dwt
|
||||||
|
|
||||||
|
# init wavelets
|
||||||
|
from pytorch_wavelets import DWTForward
|
||||||
|
|
||||||
|
# wave='db1' wave='haar'
|
||||||
|
dwt = DWTForward(J=1, mode="zero", wave="haar").to(device=device, dtype=dtype)
|
||||||
|
_dwt = dwt
|
||||||
|
return dwt
|
||||||
|
|
||||||
|
|
||||||
|
def wavelet_loss(model_pred, latents, noise):
|
||||||
|
model_pred = model_pred.float()
|
||||||
|
latents = latents.float()
|
||||||
|
noise = noise.float()
|
||||||
|
dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype)
|
||||||
|
with torch.no_grad():
|
||||||
|
model_input_xll, model_input_xh = dwt(latents)
|
||||||
|
model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(
|
||||||
|
model_input_xh[0], dim=2
|
||||||
|
)
|
||||||
|
model_input = torch.cat(
|
||||||
|
[model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# reverse the noise to get the model prediction of the pure latents
|
||||||
|
model_pred = noise - model_pred
|
||||||
|
|
||||||
|
model_pred_xll, model_pred_xh = dwt(model_pred)
|
||||||
|
model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(
|
||||||
|
model_pred_xh[0], dim=2
|
||||||
|
)
|
||||||
|
model_pred = torch.cat(
|
||||||
|
[model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none")
|
||||||
|
|
||||||
|
|
||||||
|
def stepped_loss(model_pred, latents, noise, noisy_latents, timesteps, scheduler):
|
||||||
|
# this steps the on a 20 step timescale from the current step (50 idx steps ahead)
|
||||||
|
# and then reconstructs the original image at that timestep. This should lessen the error
|
||||||
|
# possible in high noise timesteps and make the flow smoother.
|
||||||
|
bs = model_pred.shape[0]
|
||||||
|
|
||||||
|
noise_pred_chunks = torch.chunk(model_pred, bs)
|
||||||
|
timestep_chunks = torch.chunk(timesteps, bs)
|
||||||
|
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
|
||||||
|
noise_chunks = torch.chunk(noise, bs)
|
||||||
|
|
||||||
|
x0_pred_chunks = []
|
||||||
|
|
||||||
|
for idx in range(bs):
|
||||||
|
model_output = noise_pred_chunks[idx] # predicted noise (same shape as latent)
|
||||||
|
timestep = timestep_chunks[idx] # scalar tensor per sample (e.g., [t])
|
||||||
|
sample = noisy_latent_chunks[idx].to(torch.float32)
|
||||||
|
noise_i = noise_chunks[idx].to(sample.dtype).to(sample.device)
|
||||||
|
|
||||||
|
# Initialize scheduler step index for this sample
|
||||||
|
scheduler._step_index = None
|
||||||
|
scheduler._init_step_index(timestep)
|
||||||
|
|
||||||
|
# ---- Step +50 indices (or to the end) in sigma-space ----
|
||||||
|
sigma = scheduler.sigmas[scheduler.step_index]
|
||||||
|
target_idx = min(scheduler.step_index + 50, len(scheduler.sigmas) - 1)
|
||||||
|
sigma_next = scheduler.sigmas[target_idx]
|
||||||
|
|
||||||
|
# One-step update along the model-predicted direction
|
||||||
|
stepped = sample + (sigma_next - sigma) * model_output
|
||||||
|
|
||||||
|
# ---- Inverse-Gaussian recovery at the target timestep ----
|
||||||
|
t_01 = (
|
||||||
|
(scheduler.sigmas[target_idx]).to(stepped.device).to(stepped.dtype)
|
||||||
|
)
|
||||||
|
original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01)
|
||||||
|
x0_pred_chunks.append(original_samples)
|
||||||
|
|
||||||
|
predicted_images = torch.cat(x0_pred_chunks, dim=0)
|
||||||
|
|
||||||
|
return torch.nn.functional.mse_loss(
|
||||||
|
predicted_images.float(),
|
||||||
|
latents.float().to(device=predicted_images.device),
|
||||||
|
reduction="none",
|
||||||
|
)
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
_dwt = None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_wavelet_loss(device, dtype):
|
|
||||||
global _dwt
|
|
||||||
if _dwt is not None:
|
|
||||||
return _dwt
|
|
||||||
|
|
||||||
# init wavelets
|
|
||||||
from pytorch_wavelets import DWTForward
|
|
||||||
# wave='db1' wave='haar'
|
|
||||||
dwt = DWTForward(J=1, mode='zero', wave='haar').to(
|
|
||||||
device=device, dtype=dtype)
|
|
||||||
_dwt = dwt
|
|
||||||
return dwt
|
|
||||||
|
|
||||||
|
|
||||||
def wavelet_loss(model_pred, latents, noise):
|
|
||||||
model_pred = model_pred.float()
|
|
||||||
latents = latents.float()
|
|
||||||
noise = noise.float()
|
|
||||||
dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype)
|
|
||||||
with torch.no_grad():
|
|
||||||
model_input_xll, model_input_xh = dwt(latents)
|
|
||||||
model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(model_input_xh[0], dim=2)
|
|
||||||
model_input = torch.cat([model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1)
|
|
||||||
|
|
||||||
# reverse the noise to get the model prediction of the pure latents
|
|
||||||
model_pred = noise - model_pred
|
|
||||||
|
|
||||||
model_pred_xll, model_pred_xh = dwt(model_pred)
|
|
||||||
model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(model_pred_xh[0], dim=2)
|
|
||||||
model_pred = torch.cat([model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1)
|
|
||||||
|
|
||||||
return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none")
|
|
||||||
@@ -524,13 +524,15 @@ export default function SimpleJob({
|
|||||||
]}
|
]}
|
||||||
/>
|
/>
|
||||||
<SelectInput
|
<SelectInput
|
||||||
label="Noise Scheduler"
|
label="Loss Type"
|
||||||
className="pt-2"
|
className="pt-2"
|
||||||
value={jobConfig.config.process[0].train.noise_scheduler}
|
value={jobConfig.config.process[0].train.loss_type}
|
||||||
onChange={value => setJobConfig(value, 'config.process[0].train.noise_scheduler')}
|
onChange={value => setJobConfig(value, 'config.process[0].train.loss_type')}
|
||||||
options={[
|
options={[
|
||||||
{ value: 'flowmatch', label: 'FlowMatch' },
|
{ value: 'mse', label: 'Mean Squared Error' },
|
||||||
{ value: 'ddpm', label: 'DDPM' },
|
{ value: 'mae', label: 'Mean Absolute Error' },
|
||||||
|
{ value: 'wavelet', label: 'Wavelet' },
|
||||||
|
{ value: 'stepped', label: 'Stepped Recovery' },
|
||||||
]}
|
]}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ export const defaultJobConfig: JobConfig = {
|
|||||||
diff_output_preservation_multiplier: 1.0,
|
diff_output_preservation_multiplier: 1.0,
|
||||||
diff_output_preservation_class: 'person',
|
diff_output_preservation_class: 'person',
|
||||||
switch_boundary_every: 1,
|
switch_boundary_every: 1,
|
||||||
|
loss_type: 'mse',
|
||||||
},
|
},
|
||||||
model: {
|
model: {
|
||||||
name_or_path: 'ostris/Flex.1-alpha',
|
name_or_path: 'ostris/Flex.1-alpha',
|
||||||
|
|||||||
@@ -123,6 +123,7 @@ export interface TrainConfig {
|
|||||||
diff_output_preservation_multiplier: number;
|
diff_output_preservation_multiplier: number;
|
||||||
diff_output_preservation_class: string;
|
diff_output_preservation_class: string;
|
||||||
switch_boundary_every: number;
|
switch_boundary_every: number;
|
||||||
|
loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped';
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface QuantizeKwargsConfig {
|
export interface QuantizeKwargsConfig {
|
||||||
|
|||||||
Reference in New Issue
Block a user