WIP on mean flow loss. Still a WIP.

This commit is contained in:
Jaret Burkett
2025-06-12 08:00:51 -06:00
parent cf11f128b9
commit fc83eb7691
6 changed files with 465 additions and 62 deletions

View File

@@ -36,6 +36,7 @@ from toolkit.train_tools import precondition_model_outputs_flow_match
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
from toolkit.util.wavelet_loss import wavelet_loss
import torch.nn.functional as F
from toolkit.models.flux import convert_flux_to_mean_flow
def flush():
@@ -134,6 +135,9 @@ class SDTrainer(BaseSDTrainProcess):
def hook_before_train_loop(self):
super().hook_before_train_loop()
if self.train_config.timestep_type == "mean_flow":
# todo handle non flux models
convert_flux_to_mean_flow(self.sd.transformer)
if self.train_config.do_prior_divergence:
self.do_prior_prediction = True
@@ -595,19 +599,6 @@ class SDTrainer(BaseSDTrainProcess):
return loss + additional_loss
def get_diff_output_preservation_loss(
self,
noise_pred: torch.Tensor,
noise: torch.Tensor,
noisy_latents: torch.Tensor,
timesteps: torch.Tensor,
batch: 'DataLoaderBatchDTO',
mask_multiplier: Union[torch.Tensor, float] = 1.0,
prior_pred: Union[torch.Tensor, None] = None,
**kwargs
):
loss_target = self.train_config.loss_target
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
return batch
@@ -641,6 +632,254 @@ class SDTrainer(BaseSDTrainProcess):
)
return loss
# ------------------------------------------------------------------
# Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative
# Modelling”, 2025 see Alg. 1 + Eq. (6) of the paper)
# This version avoids jvp / double-back-prop issues with Flash-Attention
# adapted from the work of lodestonerock
# ------------------------------------------------------------------
def get_mean_flow_loss_wip(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
unconditional_embeds: Optional[PromptEmbeds] = None,
**kwargs
):
batch_latents = batch.latents.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
time_end = timesteps.float() / 1000
# for timestep_r, we need values from timestep_end to 0.0 randomly
time_origin = torch.rand_like(time_end, device=self.device_torch, dtype=time_end.dtype) * time_end
# time_origin = torch.zeros_like(time_end, device=self.device_torch, dtype=time_end.dtype)
# Compute noised data points
# lerp_vector = noisy_latents
# compute instantaneous vector
instantaneous_vector = noise - batch_latents
# finite difference method
epsilon_fd = 1e-3
jitter_std = 1e-4
epsilon_jittered = epsilon_fd + torch.randn(1, device=batch_latents.device) * jitter_std
epsilon_jittered = torch.clamp(epsilon_jittered, min=1e-4)
# f(x + epsilon * v) for the primal (we backprop through here)
# mean_vec_val_pred = self.forward(lerp_vector, class_label)
mean_vec_val_pred = self.predict_noise(
noisy_latents=noisy_latents,
timesteps=torch.cat([time_end, time_origin], dim=0) * 1000,
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
with torch.no_grad():
perturbed_time_end = torch.clamp(time_end + epsilon_jittered, 0.0, 1.0)
# intermediate vector to compute tangent approximation f(x + epsilon * v) ! NO GRAD HERE!
perturbed_lerp_vector = noisy_latents + epsilon_jittered * instantaneous_vector
# f_x_plus_eps_v = self.forward(perturbed_lerp_vector, class_label)
f_x_plus_eps_v = self.predict_noise(
noisy_latents=perturbed_lerp_vector,
timesteps=torch.cat([perturbed_time_end, time_origin], dim=0) * 1000,
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
# JVP approximation: (f(x + epsilon * v) - f(x)) / epsilon
mean_vec_grad_fd = (f_x_plus_eps_v - mean_vec_val_pred) / epsilon_jittered
mean_vec_grad = mean_vec_grad_fd
# calculate the regression target the mean vector
time_difference_broadcast = (time_end - time_origin)[:, None, None, None]
mean_vec_target = instantaneous_vector - time_difference_broadcast * mean_vec_grad
# 5) MSE loss
loss = torch.nn.functional.mse_loss(
mean_vec_val_pred.float(),
mean_vec_target.float(),
reduction='none'
)
with torch.no_grad():
pure_loss = loss.mean().detach()
# add grad to pure_loss so it can be backwards without issues
pure_loss.requires_grad_(True)
# normalize the loss per batch element to 1.0
# this method has large loss swings that can hurt the model. This method will prevent that
with torch.no_grad():
loss_mean = loss.mean([1, 2, 3], keepdim=True)
loss = loss / loss_mean
loss = loss.mean()
# backward the pure loss for logging
self.accelerator.backward(loss)
# return the real loss for logging
return pure_loss
# ------------------------------------------------------------------
# Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative
# Modelling”, 2025 see Alg. 1 + Eq. (6) of the paper)
# This version avoids jvp / double-back-prop issues with Flash-Attention
# adapted from the work of lodestonerock
# ------------------------------------------------------------------
def get_mean_flow_loss(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
unconditional_embeds: Optional[PromptEmbeds] = None,
**kwargs
):
# ------------------------------------------------------------------
# “Slow” Mean-Flow loss finite-difference version
# (avoids JVP / double-backprop issues with Flash-Attention)
# ------------------------------------------------------------------
dtype = get_torch_dtype(self.train_config.dtype)
total_steps = float(self.sd.noise_scheduler.config.num_train_timesteps) # 1000
base_eps = 1e-3 # this is one step when multiplied by 1000
with torch.no_grad():
num_train_timesteps = self.sd.noise_scheduler.config.num_train_timesteps
batch_size = batch.latents.shape[0]
timestep_t_list = []
timestep_r_list = []
for i in range(batch_size):
t1 = random.randint(0, num_train_timesteps - 1)
t2 = random.randint(0, num_train_timesteps - 1)
t_t = self.sd.noise_scheduler.timesteps[min(t1, t2)]
t_r = self.sd.noise_scheduler.timesteps[max(t1, t2)]
if (t_t - t_r).item() < base_eps * 1000:
# we need to ensure the time gap is wider than the epsilon(one step)
scaled_eps = base_eps * 1000
if t_t.item() + scaled_eps > 1000:
t_r = t_r - scaled_eps
else:
t_t = t_t + scaled_eps
timestep_t_list.append(t_t)
timestep_r_list.append(t_r)
eps = min((t_t - t_r).item(), 1e-3) / num_train_timesteps
timesteps_t = torch.stack(timestep_t_list, dim=0).float()
timesteps_r = torch.stack(timestep_r_list, dim=0).float()
# fractions in [0,1]
t_frac = timesteps_t / total_steps
r_frac = timesteps_r / total_steps
# 2) construct data points
latents_clean = batch.latents.to(dtype)
noise_sample = noise.to(dtype)
lerp_vector = noise_sample * t_frac[:, None, None, None] \
+ latents_clean * (1.0 - t_frac[:, None, None, None])
if hasattr(self.sd, 'get_loss_target'):
instantaneous_vector = self.sd.get_loss_target(
noise=noise_sample,
batch=batch,
timesteps=timesteps,
).detach()
else:
instantaneous_vector = noise_sample - latents_clean # v_t (B,C,H,W)
# 3) finite-difference JVP approximation (bump z **and** t)
# eps_base, eps_jitter = 1e-3, 1e-4
# eps = (eps_base + torch.randn(1, device=lerp_vector.device) * eps_jitter).clamp_(min=1e-4)
jitter = 1e-4
eps = value_map(
torch.rand_like(t_frac),
0.0,
1.0,
base_eps,
base_eps + jitter
)
# eps = 1e-3
# primary prediction (needs grad)
mean_vec_pred = self.predict_noise(
noisy_latents=lerp_vector,
timesteps=torch.cat([t_frac, r_frac], dim=0) * total_steps,
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
# secondary prediction: bump both latent and timestep by ε
with torch.no_grad():
# lerp_perturbed = lerp_vector + eps * instantaneous_vector
t_frac_plus_eps = t_frac + eps # bump time fraction
lerp_perturbed = noise_sample * t_frac_plus_eps[:, None, None, None] \
+ latents_clean * (1.0 - t_frac_plus_eps[:, None, None, None])
f_x_plus_eps_v = self.predict_noise(
noisy_latents=lerp_perturbed,
timesteps=torch.cat([t_frac_plus_eps, r_frac], dim=0) * total_steps,
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
# finite-difference JVP: (f(x+εv) f(x)) / ε
mean_vec_grad = (f_x_plus_eps_v - mean_vec_pred) / eps
# time_gap = (t_frac - r_frac)[:, None, None, None]
# mean_vec_scaler = time_gap / eps
# mean_vec_grad = (f_x_plus_eps_v - mean_vec_pred) * mean_vec_scaler
# mean_vec_grad = mean_vec_grad.detach() # stop-grad as in Eq. 11
# 4) regression target for the mean vector
time_gap = (t_frac - r_frac)[:, None, None, None]
mean_vec_target = instantaneous_vector - time_gap * mean_vec_grad
# mean_vec_target = instantaneous_vector - mean_vec_grad
# # 5) MSE loss
# loss = torch.nn.functional.mse_loss(
# mean_vec_pred.float(),
# mean_vec_target.float()
# )
# return loss
# 5) MSE loss
loss = torch.nn.functional.mse_loss(
mean_vec_pred.float(),
mean_vec_target.float(),
reduction='none'
)
with torch.no_grad():
pure_loss = loss.mean().detach()
# add grad to pure_loss so it can be backwards without issues
pure_loss.requires_grad_(True)
# normalize the loss per batch element to 1.0
# this method has large loss swings that can hurt the model. This method will prevent that
# with torch.no_grad():
# loss_mean = loss.mean([1, 2, 3], keepdim=True)
# loss = loss / loss_mean
loss = loss.mean()
# backward the pure loss for logging
self.accelerator.backward(loss)
# return the real loss for logging
return pure_loss
def get_prior_prediction(
@@ -1495,6 +1734,52 @@ class SDTrainer(BaseSDTrainProcess):
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
self.before_unet_predict()
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
with self.timer('condition_noisy_latents'):
# do it for the model
noisy_latents = self.sd.condition_noisy_latents(noisy_latents, batch)
if self.adapter and isinstance(self.adapter, CustomAdapter):
noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch)
if self.train_config.timestep_type == 'next_sample':
with self.timer('next_sample_step'):
with torch.no_grad():
stepped_timestep_indicies = [self.sd.noise_scheduler.index_for_timestep(t) + 1 for t in timesteps]
stepped_timesteps = [self.sd.noise_scheduler.timesteps[x] for x in stepped_timestep_indicies]
stepped_timesteps = torch.stack(stepped_timesteps, dim=0)
# do a sample at the current timestep and step it, then determine new noise
next_sample_pred = self.predict_noise(
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
timesteps=timesteps,
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
stepped_latents = self.sd.step_scheduler(
next_sample_pred,
noisy_latents,
timesteps,
self.sd.noise_scheduler
)
# stepped latents is our new noisy latents. Now we need to determine noise in the current sample
noisy_latents = stepped_latents
original_samples = batch.latents.to(self.device_torch, dtype=dtype)
# todo calc next timestep, for now this may work as it
t_01 = (stepped_timesteps / 1000).to(original_samples.device)
if len(stepped_latents.shape) == 4:
t_01 = t_01.view(-1, 1, 1, 1)
elif len(stepped_latents.shape) == 5:
t_01 = t_01.view(-1, 1, 1, 1, 1)
else:
raise ValueError("Unknown stepped latents shape", stepped_latents.shape)
next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01
noise = next_sample_noise
timesteps = stepped_timesteps
# do a prior pred if we have an unconditional image, we will swap out the giadance later
if batch.unconditional_latents is not None or self.do_guided_loss:
# do guided loss
@@ -1511,54 +1796,21 @@ class SDTrainer(BaseSDTrainProcess):
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
)
elif self.train_config.loss_type == 'mean_flow':
loss = self.get_mean_flow_loss(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
batch=batch,
noise=noise,
unconditional_embeds=unconditional_embeds,
prior_pred=prior_pred,
)
else:
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
with self.timer('condition_noisy_latents'):
# do it for the model
noisy_latents = self.sd.condition_noisy_latents(noisy_latents, batch)
if self.adapter and isinstance(self.adapter, CustomAdapter):
noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch)
if self.train_config.timestep_type == 'next_sample':
with self.timer('next_sample_step'):
with torch.no_grad():
stepped_timestep_indicies = [self.sd.noise_scheduler.index_for_timestep(t) + 1 for t in timesteps]
stepped_timesteps = [self.sd.noise_scheduler.timesteps[x] for x in stepped_timestep_indicies]
stepped_timesteps = torch.stack(stepped_timesteps, dim=0)
# do a sample at the current timestep and step it, then determine new noise
next_sample_pred = self.predict_noise(
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
timesteps=timesteps,
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
stepped_latents = self.sd.step_scheduler(
next_sample_pred,
noisy_latents,
timesteps,
self.sd.noise_scheduler
)
# stepped latents is our new noisy latents. Now we need to determine noise in the current sample
noisy_latents = stepped_latents
original_samples = batch.latents.to(self.device_torch, dtype=dtype)
# todo calc next timestep, for now this may work as it
t_01 = (stepped_timesteps / 1000).to(original_samples.device)
if len(stepped_latents.shape) == 4:
t_01 = t_01.view(-1, 1, 1, 1)
elif len(stepped_latents.shape) == 5:
t_01 = t_01.view(-1, 1, 1, 1, 1)
else:
raise ValueError("Unknown stepped latents shape", stepped_latents.shape)
next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01
noise = next_sample_noise
timesteps = stepped_timesteps
with self.timer('predict_unet'):
noise_pred = self.predict_noise(
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),