WIP on mean flow loss. Still a WIP.

This commit is contained in:
Jaret Burkett
2025-06-12 08:00:51 -06:00
parent cf11f128b9
commit fc83eb7691
6 changed files with 465 additions and 62 deletions

View File

@@ -4,7 +4,7 @@ VERSION=dev
GIT_COMMIT=dev GIT_COMMIT=dev
echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo." echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
echo "Building version: $VERSION and latest" echo "Building version: $VERSION"
# wait 2 seconds # wait 2 seconds
sleep 2 sleep 2

View File

@@ -36,6 +36,7 @@ from toolkit.train_tools import precondition_model_outputs_flow_match
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
from toolkit.util.wavelet_loss import wavelet_loss from toolkit.util.wavelet_loss import wavelet_loss
import torch.nn.functional as F import torch.nn.functional as F
from toolkit.models.flux import convert_flux_to_mean_flow
def flush(): def flush():
@@ -134,6 +135,9 @@ class SDTrainer(BaseSDTrainProcess):
def hook_before_train_loop(self): def hook_before_train_loop(self):
super().hook_before_train_loop() super().hook_before_train_loop()
if self.train_config.timestep_type == "mean_flow":
# todo handle non flux models
convert_flux_to_mean_flow(self.sd.transformer)
if self.train_config.do_prior_divergence: if self.train_config.do_prior_divergence:
self.do_prior_prediction = True self.do_prior_prediction = True
@@ -596,19 +600,6 @@ class SDTrainer(BaseSDTrainProcess):
return loss + additional_loss return loss + additional_loss
def get_diff_output_preservation_loss(
self,
noise_pred: torch.Tensor,
noise: torch.Tensor,
noisy_latents: torch.Tensor,
timesteps: torch.Tensor,
batch: 'DataLoaderBatchDTO',
mask_multiplier: Union[torch.Tensor, float] = 1.0,
prior_pred: Union[torch.Tensor, None] = None,
**kwargs
):
loss_target = self.train_config.loss_target
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'): def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
return batch return batch
@@ -643,6 +634,254 @@ class SDTrainer(BaseSDTrainProcess):
return loss return loss
# ------------------------------------------------------------------
# Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative
# Modelling”, 2025 see Alg. 1 + Eq. (6) of the paper)
# This version avoids jvp / double-back-prop issues with Flash-Attention
# adapted from the work of lodestonerock
# ------------------------------------------------------------------
def get_mean_flow_loss_wip(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
unconditional_embeds: Optional[PromptEmbeds] = None,
**kwargs
):
batch_latents = batch.latents.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
time_end = timesteps.float() / 1000
# for timestep_r, we need values from timestep_end to 0.0 randomly
time_origin = torch.rand_like(time_end, device=self.device_torch, dtype=time_end.dtype) * time_end
# time_origin = torch.zeros_like(time_end, device=self.device_torch, dtype=time_end.dtype)
# Compute noised data points
# lerp_vector = noisy_latents
# compute instantaneous vector
instantaneous_vector = noise - batch_latents
# finite difference method
epsilon_fd = 1e-3
jitter_std = 1e-4
epsilon_jittered = epsilon_fd + torch.randn(1, device=batch_latents.device) * jitter_std
epsilon_jittered = torch.clamp(epsilon_jittered, min=1e-4)
# f(x + epsilon * v) for the primal (we backprop through here)
# mean_vec_val_pred = self.forward(lerp_vector, class_label)
mean_vec_val_pred = self.predict_noise(
noisy_latents=noisy_latents,
timesteps=torch.cat([time_end, time_origin], dim=0) * 1000,
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
with torch.no_grad():
perturbed_time_end = torch.clamp(time_end + epsilon_jittered, 0.0, 1.0)
# intermediate vector to compute tangent approximation f(x + epsilon * v) ! NO GRAD HERE!
perturbed_lerp_vector = noisy_latents + epsilon_jittered * instantaneous_vector
# f_x_plus_eps_v = self.forward(perturbed_lerp_vector, class_label)
f_x_plus_eps_v = self.predict_noise(
noisy_latents=perturbed_lerp_vector,
timesteps=torch.cat([perturbed_time_end, time_origin], dim=0) * 1000,
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
# JVP approximation: (f(x + epsilon * v) - f(x)) / epsilon
mean_vec_grad_fd = (f_x_plus_eps_v - mean_vec_val_pred) / epsilon_jittered
mean_vec_grad = mean_vec_grad_fd
# calculate the regression target the mean vector
time_difference_broadcast = (time_end - time_origin)[:, None, None, None]
mean_vec_target = instantaneous_vector - time_difference_broadcast * mean_vec_grad
# 5) MSE loss
loss = torch.nn.functional.mse_loss(
mean_vec_val_pred.float(),
mean_vec_target.float(),
reduction='none'
)
with torch.no_grad():
pure_loss = loss.mean().detach()
# add grad to pure_loss so it can be backwards without issues
pure_loss.requires_grad_(True)
# normalize the loss per batch element to 1.0
# this method has large loss swings that can hurt the model. This method will prevent that
with torch.no_grad():
loss_mean = loss.mean([1, 2, 3], keepdim=True)
loss = loss / loss_mean
loss = loss.mean()
# backward the pure loss for logging
self.accelerator.backward(loss)
# return the real loss for logging
return pure_loss
# ------------------------------------------------------------------
# Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative
# Modelling”, 2025 see Alg. 1 + Eq. (6) of the paper)
# This version avoids jvp / double-back-prop issues with Flash-Attention
# adapted from the work of lodestonerock
# ------------------------------------------------------------------
def get_mean_flow_loss(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
unconditional_embeds: Optional[PromptEmbeds] = None,
**kwargs
):
# ------------------------------------------------------------------
# “Slow” Mean-Flow loss finite-difference version
# (avoids JVP / double-backprop issues with Flash-Attention)
# ------------------------------------------------------------------
dtype = get_torch_dtype(self.train_config.dtype)
total_steps = float(self.sd.noise_scheduler.config.num_train_timesteps) # 1000
base_eps = 1e-3 # this is one step when multiplied by 1000
with torch.no_grad():
num_train_timesteps = self.sd.noise_scheduler.config.num_train_timesteps
batch_size = batch.latents.shape[0]
timestep_t_list = []
timestep_r_list = []
for i in range(batch_size):
t1 = random.randint(0, num_train_timesteps - 1)
t2 = random.randint(0, num_train_timesteps - 1)
t_t = self.sd.noise_scheduler.timesteps[min(t1, t2)]
t_r = self.sd.noise_scheduler.timesteps[max(t1, t2)]
if (t_t - t_r).item() < base_eps * 1000:
# we need to ensure the time gap is wider than the epsilon(one step)
scaled_eps = base_eps * 1000
if t_t.item() + scaled_eps > 1000:
t_r = t_r - scaled_eps
else:
t_t = t_t + scaled_eps
timestep_t_list.append(t_t)
timestep_r_list.append(t_r)
eps = min((t_t - t_r).item(), 1e-3) / num_train_timesteps
timesteps_t = torch.stack(timestep_t_list, dim=0).float()
timesteps_r = torch.stack(timestep_r_list, dim=0).float()
# fractions in [0,1]
t_frac = timesteps_t / total_steps
r_frac = timesteps_r / total_steps
# 2) construct data points
latents_clean = batch.latents.to(dtype)
noise_sample = noise.to(dtype)
lerp_vector = noise_sample * t_frac[:, None, None, None] \
+ latents_clean * (1.0 - t_frac[:, None, None, None])
if hasattr(self.sd, 'get_loss_target'):
instantaneous_vector = self.sd.get_loss_target(
noise=noise_sample,
batch=batch,
timesteps=timesteps,
).detach()
else:
instantaneous_vector = noise_sample - latents_clean # v_t (B,C,H,W)
# 3) finite-difference JVP approximation (bump z **and** t)
# eps_base, eps_jitter = 1e-3, 1e-4
# eps = (eps_base + torch.randn(1, device=lerp_vector.device) * eps_jitter).clamp_(min=1e-4)
jitter = 1e-4
eps = value_map(
torch.rand_like(t_frac),
0.0,
1.0,
base_eps,
base_eps + jitter
)
# eps = 1e-3
# primary prediction (needs grad)
mean_vec_pred = self.predict_noise(
noisy_latents=lerp_vector,
timesteps=torch.cat([t_frac, r_frac], dim=0) * total_steps,
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
# secondary prediction: bump both latent and timestep by ε
with torch.no_grad():
# lerp_perturbed = lerp_vector + eps * instantaneous_vector
t_frac_plus_eps = t_frac + eps # bump time fraction
lerp_perturbed = noise_sample * t_frac_plus_eps[:, None, None, None] \
+ latents_clean * (1.0 - t_frac_plus_eps[:, None, None, None])
f_x_plus_eps_v = self.predict_noise(
noisy_latents=lerp_perturbed,
timesteps=torch.cat([t_frac_plus_eps, r_frac], dim=0) * total_steps,
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
# finite-difference JVP: (f(x+εv) f(x)) / ε
mean_vec_grad = (f_x_plus_eps_v - mean_vec_pred) / eps
# time_gap = (t_frac - r_frac)[:, None, None, None]
# mean_vec_scaler = time_gap / eps
# mean_vec_grad = (f_x_plus_eps_v - mean_vec_pred) * mean_vec_scaler
# mean_vec_grad = mean_vec_grad.detach() # stop-grad as in Eq. 11
# 4) regression target for the mean vector
time_gap = (t_frac - r_frac)[:, None, None, None]
mean_vec_target = instantaneous_vector - time_gap * mean_vec_grad
# mean_vec_target = instantaneous_vector - mean_vec_grad
# # 5) MSE loss
# loss = torch.nn.functional.mse_loss(
# mean_vec_pred.float(),
# mean_vec_target.float()
# )
# return loss
# 5) MSE loss
loss = torch.nn.functional.mse_loss(
mean_vec_pred.float(),
mean_vec_target.float(),
reduction='none'
)
with torch.no_grad():
pure_loss = loss.mean().detach()
# add grad to pure_loss so it can be backwards without issues
pure_loss.requires_grad_(True)
# normalize the loss per batch element to 1.0
# this method has large loss swings that can hurt the model. This method will prevent that
# with torch.no_grad():
# loss_mean = loss.mean([1, 2, 3], keepdim=True)
# loss = loss / loss_mean
loss = loss.mean()
# backward the pure loss for logging
self.accelerator.backward(loss)
# return the real loss for logging
return pure_loss
def get_prior_prediction( def get_prior_prediction(
self, self,
noisy_latents: torch.Tensor, noisy_latents: torch.Tensor,
@@ -1495,24 +1734,7 @@ class SDTrainer(BaseSDTrainProcess):
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
self.before_unet_predict() self.before_unet_predict()
# do a prior pred if we have an unconditional image, we will swap out the giadance later
if batch.unconditional_latents is not None or self.do_guided_loss:
# do guided loss
loss = self.get_guided_loss(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
batch=batch,
noise=noise,
unconditional_embeds=unconditional_embeds,
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
)
else:
if unconditional_embeds is not None: if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
with self.timer('condition_noisy_latents'): with self.timer('condition_noisy_latents'):
@@ -1558,7 +1780,37 @@ class SDTrainer(BaseSDTrainProcess):
next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01 next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01
noise = next_sample_noise noise = next_sample_noise
timesteps = stepped_timesteps timesteps = stepped_timesteps
# do a prior pred if we have an unconditional image, we will swap out the giadance later
if batch.unconditional_latents is not None or self.do_guided_loss:
# do guided loss
loss = self.get_guided_loss(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
batch=batch,
noise=noise,
unconditional_embeds=unconditional_embeds,
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
)
elif self.train_config.loss_type == 'mean_flow':
loss = self.get_mean_flow_loss(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
batch=batch,
noise=noise,
unconditional_embeds=unconditional_embeds,
prior_pred=prior_pred,
)
else:
with self.timer('predict_unet'): with self.timer('predict_unet'):
noise_pred = self.predict_noise( noise_pred = self.predict_noise(
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),

View File

@@ -413,7 +413,7 @@ class TrainConfig:
self.correct_pred_norm = kwargs.get('correct_pred_norm', False) self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0) self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm, mean_flow
# scale the prediction by this. Increase for more detail, decrease for less # scale the prediction by this. Increase for more detail, decrease for less
self.pred_scaler = kwargs.get('pred_scaler', 1.0) self.pred_scaler = kwargs.get('pred_scaler', 1.0)

View File

@@ -4,6 +4,7 @@ from functools import partial
from typing import Optional from typing import Optional
import torch import torch
from diffusers import FluxTransformer2DModel from diffusers import FluxTransformer2DModel
from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, CombinedTimestepGuidanceTextProjEmbeddings
def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection): def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection):
@@ -174,3 +175,57 @@ def add_model_gpu_splitter_to_flux(
transformer._pre_gpu_split_to = transformer.to transformer._pre_gpu_split_to = transformer.to
transformer.to = partial(new_device_to, transformer) transformer.to = partial(new_device_to, transformer)
def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection):
# make zero timestep ending if none is passed
if timestep.shape[0] == pooled_projection.shape[0] // 2:
timestep = torch.cat([timestep, timestep], dim=0) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
return conditioning
def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection):
# make zero timestep ending if none is passed
if timestep.shape[0] == pooled_projection.shape[0] // 2:
timestep = torch.cat([timestep, timestep], dim=0)
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
time_guidance_emb = timesteps_emb_start + timesteps_emb_end + guidance_emb
pooled_projections = self.text_embedder(pooled_projection)
conditioning = time_guidance_emb + pooled_projections
return conditioning
def convert_flux_to_mean_flow(
transformer: FluxTransformer2DModel,
):
if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings):
transformer.time_text_embed.forward = partial(
mean_flow_time_text_embed_forward, transformer.time_text_embed
)
elif isinstance(transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings):
transformer.time_text_embed.forward = partial(
mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed
)
else:
raise ValueError(
"Unsupported time_text_embed type: {}".format(
type(transformer.time_text_embed)
)
)

View File

@@ -16,6 +16,7 @@ from diffusers import (
LCMScheduler, LCMScheduler,
FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler,
) )
from toolkit.samplers.mean_flow_scheduler import MeanFlowScheduler
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
@@ -159,6 +160,8 @@ def get_sampler(
scheduler_cls = LCMScheduler scheduler_cls = LCMScheduler
elif sampler == "custom_lcm": elif sampler == "custom_lcm":
scheduler_cls = CustomLCMScheduler scheduler_cls = CustomLCMScheduler
elif sampler == "mean_flow":
scheduler_cls = MeanFlowScheduler
elif sampler == "flowmatch": elif sampler == "flowmatch":
scheduler_cls = CustomFlowMatchEulerDiscreteScheduler scheduler_cls = CustomFlowMatchEulerDiscreteScheduler
config_to_use = copy.deepcopy(flux_config) config_to_use = copy.deepcopy(flux_config)

View File

@@ -0,0 +1,93 @@
from typing import Union
from diffusers import FlowMatchEulerDiscreteScheduler
import torch
from toolkit.timestep_weighing.default_weighing_scheme import default_weighing_scheme
from dataclasses import dataclass
from typing import Optional, Tuple
from diffusers.utils import BaseOutput
@dataclass
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class MeanFlowScheduler(FlowMatchEulerDiscreteScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_noise_sigma = 1.0
self.timestep_type = "linear"
with torch.no_grad():
# create weights for timesteps
num_timesteps = 1000
# Create linear timesteps from 1000 to 1
timesteps = torch.linspace(1000, 1, num_timesteps, device="cpu")
self.linear_timesteps = timesteps
pass
def get_weights_for_timesteps(
self, timesteps: torch.Tensor, v2=False, timestep_type="linear"
) -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
weights = 1.0
# Get the weights for the timesteps
if timestep_type == "weighted":
weights = torch.tensor(
[default_weighing_scheme[i] for i in step_indices],
device=timesteps.device,
dtype=timesteps.dtype,
)
return weights
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
t_01 = (timesteps / 1000).to(original_samples.device)
noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise
return noisy_model_input
def scale_model_input(
self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]
) -> torch.Tensor:
return sample
def set_train_timesteps(self, num_timesteps, device, **kwargs):
timesteps = torch.linspace(1000, 1, num_timesteps, device=device)
self.timesteps = timesteps
return timesteps
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
return_dict: bool = True,
**kwargs: Optional[dict],
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
# single euler step (Eq. 5 ⇒ x₀ = x₁ uθ)
output = sample - model_output
if not return_dict:
return (output,)
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=output)