mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
More work on mean flow loss. Moved it to an adapter. Still not functioning properly though.
This commit is contained in:
@@ -36,7 +36,6 @@ from toolkit.train_tools import precondition_model_outputs_flow_match
|
||||
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
||||
from toolkit.util.wavelet_loss import wavelet_loss
|
||||
import torch.nn.functional as F
|
||||
from toolkit.models.flux import convert_flux_to_mean_flow
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -135,9 +134,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
super().hook_before_train_loop()
|
||||
if self.train_config.loss_type == "mean_flow":
|
||||
# todo handle non flux models
|
||||
convert_flux_to_mean_flow(self.sd.unet)
|
||||
|
||||
if self.train_config.do_prior_divergence:
|
||||
self.do_prior_prediction = True
|
||||
@@ -634,102 +630,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
return loss
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative
|
||||
# Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper)
|
||||
# This version avoids jvp / double-back-prop issues with Flash-Attention
|
||||
# adapted from the work of lodestonerock
|
||||
# ------------------------------------------------------------------
|
||||
def get_mean_flow_loss_wip(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||
**kwargs
|
||||
):
|
||||
batch_latents = batch.latents.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
|
||||
|
||||
|
||||
time_end = timesteps.float() / 1000
|
||||
# for timestep_r, we need values from timestep_end to 0.0 randomly
|
||||
time_origin = torch.rand_like(time_end, device=self.device_torch, dtype=time_end.dtype) * time_end
|
||||
|
||||
# time_origin = torch.zeros_like(time_end, device=self.device_torch, dtype=time_end.dtype)
|
||||
# Compute noised data points
|
||||
# lerp_vector = noisy_latents
|
||||
# compute instantaneous vector
|
||||
instantaneous_vector = noise - batch_latents
|
||||
|
||||
# finite difference method
|
||||
epsilon_fd = 1e-3
|
||||
jitter_std = 1e-4
|
||||
epsilon_jittered = epsilon_fd + torch.randn(1, device=batch_latents.device) * jitter_std
|
||||
epsilon_jittered = torch.clamp(epsilon_jittered, min=1e-4)
|
||||
|
||||
# f(x + epsilon * v) for the primal (we backprop through here)
|
||||
# mean_vec_val_pred = self.forward(lerp_vector, class_label)
|
||||
mean_vec_val_pred = self.predict_noise(
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=torch.cat([time_end, time_origin], dim=0) * 1000,
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
perturbed_time_end = torch.clamp(time_end + epsilon_jittered, 0.0, 1.0)
|
||||
# intermediate vector to compute tangent approximation f(x + epsilon * v) ! NO GRAD HERE!
|
||||
perturbed_lerp_vector = noisy_latents + epsilon_jittered * instantaneous_vector
|
||||
# f_x_plus_eps_v = self.forward(perturbed_lerp_vector, class_label)
|
||||
f_x_plus_eps_v = self.predict_noise(
|
||||
noisy_latents=perturbed_lerp_vector,
|
||||
timesteps=torch.cat([perturbed_time_end, time_origin], dim=0) * 1000,
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
# JVP approximation: (f(x + epsilon * v) - f(x)) / epsilon
|
||||
mean_vec_grad_fd = (f_x_plus_eps_v - mean_vec_val_pred) / epsilon_jittered
|
||||
mean_vec_grad = mean_vec_grad_fd
|
||||
|
||||
|
||||
# calculate the regression target the mean vector
|
||||
time_difference_broadcast = (time_end - time_origin)[:, None, None, None]
|
||||
mean_vec_target = instantaneous_vector - time_difference_broadcast * mean_vec_grad
|
||||
|
||||
# 5) MSE loss
|
||||
loss = torch.nn.functional.mse_loss(
|
||||
mean_vec_val_pred.float(),
|
||||
mean_vec_target.float(),
|
||||
reduction='none'
|
||||
)
|
||||
with torch.no_grad():
|
||||
pure_loss = loss.mean().detach()
|
||||
# add grad to pure_loss so it can be backwards without issues
|
||||
pure_loss.requires_grad_(True)
|
||||
# normalize the loss per batch element to 1.0
|
||||
# this method has large loss swings that can hurt the model. This method will prevent that
|
||||
with torch.no_grad():
|
||||
loss_mean = loss.mean([1, 2, 3], keepdim=True)
|
||||
loss = loss / loss_mean
|
||||
loss = loss.mean()
|
||||
|
||||
# backward the pure loss for logging
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
# return the real loss for logging
|
||||
return pure_loss
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative
|
||||
# Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper)
|
||||
@@ -811,7 +711,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
base_eps,
|
||||
base_eps + jitter
|
||||
)
|
||||
# eps = (t_frac - r_frac) / 2
|
||||
|
||||
# eps = 1e-3
|
||||
# primary prediction (needs grad)
|
||||
|
||||
Reference in New Issue
Block a user