Merge branch 'main' into dev

This commit is contained in:
Jaret Burkett
2025-06-12 08:11:19 -06:00
8 changed files with 468 additions and 65 deletions

View File

@@ -4,7 +4,7 @@ VERSION=dev
GIT_COMMIT=dev
echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
echo "Building version: $VERSION and latest"
echo "Building version: $VERSION"
# wait 2 seconds
sleep 2

View File

@@ -36,6 +36,7 @@ from toolkit.train_tools import precondition_model_outputs_flow_match
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
from toolkit.util.wavelet_loss import wavelet_loss
import torch.nn.functional as F
from toolkit.models.flux import convert_flux_to_mean_flow
def flush():
@@ -134,6 +135,9 @@ class SDTrainer(BaseSDTrainProcess):
def hook_before_train_loop(self):
super().hook_before_train_loop()
if self.train_config.timestep_type == "mean_flow":
# todo handle non flux models
convert_flux_to_mean_flow(self.sd.transformer)
if self.train_config.do_prior_divergence:
self.do_prior_prediction = True
@@ -595,19 +599,6 @@ class SDTrainer(BaseSDTrainProcess):
return loss + additional_loss
def get_diff_output_preservation_loss(
self,
noise_pred: torch.Tensor,
noise: torch.Tensor,
noisy_latents: torch.Tensor,
timesteps: torch.Tensor,
batch: 'DataLoaderBatchDTO',
mask_multiplier: Union[torch.Tensor, float] = 1.0,
prior_pred: Union[torch.Tensor, None] = None,
**kwargs
):
loss_target = self.train_config.loss_target
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
return batch
@@ -641,6 +632,254 @@ class SDTrainer(BaseSDTrainProcess):
)
return loss
# ------------------------------------------------------------------
# Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative
# Modelling”, 2025 see Alg. 1 + Eq. (6) of the paper)
# This version avoids jvp / double-back-prop issues with Flash-Attention
# adapted from the work of lodestonerock
# ------------------------------------------------------------------
def get_mean_flow_loss_wip(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
unconditional_embeds: Optional[PromptEmbeds] = None,
**kwargs
):
batch_latents = batch.latents.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
time_end = timesteps.float() / 1000
# for timestep_r, we need values from timestep_end to 0.0 randomly
time_origin = torch.rand_like(time_end, device=self.device_torch, dtype=time_end.dtype) * time_end
# time_origin = torch.zeros_like(time_end, device=self.device_torch, dtype=time_end.dtype)
# Compute noised data points
# lerp_vector = noisy_latents
# compute instantaneous vector
instantaneous_vector = noise - batch_latents
# finite difference method
epsilon_fd = 1e-3
jitter_std = 1e-4
epsilon_jittered = epsilon_fd + torch.randn(1, device=batch_latents.device) * jitter_std
epsilon_jittered = torch.clamp(epsilon_jittered, min=1e-4)
# f(x + epsilon * v) for the primal (we backprop through here)
# mean_vec_val_pred = self.forward(lerp_vector, class_label)
mean_vec_val_pred = self.predict_noise(
noisy_latents=noisy_latents,
timesteps=torch.cat([time_end, time_origin], dim=0) * 1000,
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
with torch.no_grad():
perturbed_time_end = torch.clamp(time_end + epsilon_jittered, 0.0, 1.0)
# intermediate vector to compute tangent approximation f(x + epsilon * v) ! NO GRAD HERE!
perturbed_lerp_vector = noisy_latents + epsilon_jittered * instantaneous_vector
# f_x_plus_eps_v = self.forward(perturbed_lerp_vector, class_label)
f_x_plus_eps_v = self.predict_noise(
noisy_latents=perturbed_lerp_vector,
timesteps=torch.cat([perturbed_time_end, time_origin], dim=0) * 1000,
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
# JVP approximation: (f(x + epsilon * v) - f(x)) / epsilon
mean_vec_grad_fd = (f_x_plus_eps_v - mean_vec_val_pred) / epsilon_jittered
mean_vec_grad = mean_vec_grad_fd
# calculate the regression target the mean vector
time_difference_broadcast = (time_end - time_origin)[:, None, None, None]
mean_vec_target = instantaneous_vector - time_difference_broadcast * mean_vec_grad
# 5) MSE loss
loss = torch.nn.functional.mse_loss(
mean_vec_val_pred.float(),
mean_vec_target.float(),
reduction='none'
)
with torch.no_grad():
pure_loss = loss.mean().detach()
# add grad to pure_loss so it can be backwards without issues
pure_loss.requires_grad_(True)
# normalize the loss per batch element to 1.0
# this method has large loss swings that can hurt the model. This method will prevent that
with torch.no_grad():
loss_mean = loss.mean([1, 2, 3], keepdim=True)
loss = loss / loss_mean
loss = loss.mean()
# backward the pure loss for logging
self.accelerator.backward(loss)
# return the real loss for logging
return pure_loss
# ------------------------------------------------------------------
# Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative
# Modelling”, 2025 see Alg. 1 + Eq. (6) of the paper)
# This version avoids jvp / double-back-prop issues with Flash-Attention
# adapted from the work of lodestonerock
# ------------------------------------------------------------------
def get_mean_flow_loss(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
unconditional_embeds: Optional[PromptEmbeds] = None,
**kwargs
):
# ------------------------------------------------------------------
# “Slow” Mean-Flow loss finite-difference version
# (avoids JVP / double-backprop issues with Flash-Attention)
# ------------------------------------------------------------------
dtype = get_torch_dtype(self.train_config.dtype)
total_steps = float(self.sd.noise_scheduler.config.num_train_timesteps) # 1000
base_eps = 1e-3 # this is one step when multiplied by 1000
with torch.no_grad():
num_train_timesteps = self.sd.noise_scheduler.config.num_train_timesteps
batch_size = batch.latents.shape[0]
timestep_t_list = []
timestep_r_list = []
for i in range(batch_size):
t1 = random.randint(0, num_train_timesteps - 1)
t2 = random.randint(0, num_train_timesteps - 1)
t_t = self.sd.noise_scheduler.timesteps[min(t1, t2)]
t_r = self.sd.noise_scheduler.timesteps[max(t1, t2)]
if (t_t - t_r).item() < base_eps * 1000:
# we need to ensure the time gap is wider than the epsilon(one step)
scaled_eps = base_eps * 1000
if t_t.item() + scaled_eps > 1000:
t_r = t_r - scaled_eps
else:
t_t = t_t + scaled_eps
timestep_t_list.append(t_t)
timestep_r_list.append(t_r)
eps = min((t_t - t_r).item(), 1e-3) / num_train_timesteps
timesteps_t = torch.stack(timestep_t_list, dim=0).float()
timesteps_r = torch.stack(timestep_r_list, dim=0).float()
# fractions in [0,1]
t_frac = timesteps_t / total_steps
r_frac = timesteps_r / total_steps
# 2) construct data points
latents_clean = batch.latents.to(dtype)
noise_sample = noise.to(dtype)
lerp_vector = noise_sample * t_frac[:, None, None, None] \
+ latents_clean * (1.0 - t_frac[:, None, None, None])
if hasattr(self.sd, 'get_loss_target'):
instantaneous_vector = self.sd.get_loss_target(
noise=noise_sample,
batch=batch,
timesteps=timesteps,
).detach()
else:
instantaneous_vector = noise_sample - latents_clean # v_t (B,C,H,W)
# 3) finite-difference JVP approximation (bump z **and** t)
# eps_base, eps_jitter = 1e-3, 1e-4
# eps = (eps_base + torch.randn(1, device=lerp_vector.device) * eps_jitter).clamp_(min=1e-4)
jitter = 1e-4
eps = value_map(
torch.rand_like(t_frac),
0.0,
1.0,
base_eps,
base_eps + jitter
)
# eps = 1e-3
# primary prediction (needs grad)
mean_vec_pred = self.predict_noise(
noisy_latents=lerp_vector,
timesteps=torch.cat([t_frac, r_frac], dim=0) * total_steps,
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
# secondary prediction: bump both latent and timestep by ε
with torch.no_grad():
# lerp_perturbed = lerp_vector + eps * instantaneous_vector
t_frac_plus_eps = t_frac + eps # bump time fraction
lerp_perturbed = noise_sample * t_frac_plus_eps[:, None, None, None] \
+ latents_clean * (1.0 - t_frac_plus_eps[:, None, None, None])
f_x_plus_eps_v = self.predict_noise(
noisy_latents=lerp_perturbed,
timesteps=torch.cat([t_frac_plus_eps, r_frac], dim=0) * total_steps,
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
# finite-difference JVP: (f(x+εv) f(x)) / ε
mean_vec_grad = (f_x_plus_eps_v - mean_vec_pred) / eps
# time_gap = (t_frac - r_frac)[:, None, None, None]
# mean_vec_scaler = time_gap / eps
# mean_vec_grad = (f_x_plus_eps_v - mean_vec_pred) * mean_vec_scaler
# mean_vec_grad = mean_vec_grad.detach() # stop-grad as in Eq. 11
# 4) regression target for the mean vector
time_gap = (t_frac - r_frac)[:, None, None, None]
mean_vec_target = instantaneous_vector - time_gap * mean_vec_grad
# mean_vec_target = instantaneous_vector - mean_vec_grad
# # 5) MSE loss
# loss = torch.nn.functional.mse_loss(
# mean_vec_pred.float(),
# mean_vec_target.float()
# )
# return loss
# 5) MSE loss
loss = torch.nn.functional.mse_loss(
mean_vec_pred.float(),
mean_vec_target.float(),
reduction='none'
)
with torch.no_grad():
pure_loss = loss.mean().detach()
# add grad to pure_loss so it can be backwards without issues
pure_loss.requires_grad_(True)
# normalize the loss per batch element to 1.0
# this method has large loss swings that can hurt the model. This method will prevent that
# with torch.no_grad():
# loss_mean = loss.mean([1, 2, 3], keepdim=True)
# loss = loss / loss_mean
loss = loss.mean()
# backward the pure loss for logging
self.accelerator.backward(loss)
# return the real loss for logging
return pure_loss
def get_prior_prediction(
@@ -1495,6 +1734,52 @@ class SDTrainer(BaseSDTrainProcess):
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
self.before_unet_predict()
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
with self.timer('condition_noisy_latents'):
# do it for the model
noisy_latents = self.sd.condition_noisy_latents(noisy_latents, batch)
if self.adapter and isinstance(self.adapter, CustomAdapter):
noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch)
if self.train_config.timestep_type == 'next_sample':
with self.timer('next_sample_step'):
with torch.no_grad():
stepped_timestep_indicies = [self.sd.noise_scheduler.index_for_timestep(t) + 1 for t in timesteps]
stepped_timesteps = [self.sd.noise_scheduler.timesteps[x] for x in stepped_timestep_indicies]
stepped_timesteps = torch.stack(stepped_timesteps, dim=0)
# do a sample at the current timestep and step it, then determine new noise
next_sample_pred = self.predict_noise(
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
timesteps=timesteps,
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
stepped_latents = self.sd.step_scheduler(
next_sample_pred,
noisy_latents,
timesteps,
self.sd.noise_scheduler
)
# stepped latents is our new noisy latents. Now we need to determine noise in the current sample
noisy_latents = stepped_latents
original_samples = batch.latents.to(self.device_torch, dtype=dtype)
# todo calc next timestep, for now this may work as it
t_01 = (stepped_timesteps / 1000).to(original_samples.device)
if len(stepped_latents.shape) == 4:
t_01 = t_01.view(-1, 1, 1, 1)
elif len(stepped_latents.shape) == 5:
t_01 = t_01.view(-1, 1, 1, 1, 1)
else:
raise ValueError("Unknown stepped latents shape", stepped_latents.shape)
next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01
noise = next_sample_noise
timesteps = stepped_timesteps
# do a prior pred if we have an unconditional image, we will swap out the giadance later
if batch.unconditional_latents is not None or self.do_guided_loss:
# do guided loss
@@ -1511,54 +1796,21 @@ class SDTrainer(BaseSDTrainProcess):
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
)
elif self.train_config.loss_type == 'mean_flow':
loss = self.get_mean_flow_loss(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
batch=batch,
noise=noise,
unconditional_embeds=unconditional_embeds,
prior_pred=prior_pred,
)
else:
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
with self.timer('condition_noisy_latents'):
# do it for the model
noisy_latents = self.sd.condition_noisy_latents(noisy_latents, batch)
if self.adapter and isinstance(self.adapter, CustomAdapter):
noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch)
if self.train_config.timestep_type == 'next_sample':
with self.timer('next_sample_step'):
with torch.no_grad():
stepped_timestep_indicies = [self.sd.noise_scheduler.index_for_timestep(t) + 1 for t in timesteps]
stepped_timesteps = [self.sd.noise_scheduler.timesteps[x] for x in stepped_timestep_indicies]
stepped_timesteps = torch.stack(stepped_timesteps, dim=0)
# do a sample at the current timestep and step it, then determine new noise
next_sample_pred = self.predict_noise(
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
timesteps=timesteps,
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
stepped_latents = self.sd.step_scheduler(
next_sample_pred,
noisy_latents,
timesteps,
self.sd.noise_scheduler
)
# stepped latents is our new noisy latents. Now we need to determine noise in the current sample
noisy_latents = stepped_latents
original_samples = batch.latents.to(self.device_torch, dtype=dtype)
# todo calc next timestep, for now this may work as it
t_01 = (stepped_timesteps / 1000).to(original_samples.device)
if len(stepped_latents.shape) == 4:
t_01 = t_01.view(-1, 1, 1, 1)
elif len(stepped_latents.shape) == 5:
t_01 = t_01.view(-1, 1, 1, 1, 1)
else:
raise ValueError("Unknown stepped latents shape", stepped_latents.shape)
next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01
noise = next_sample_noise
timesteps = stepped_timesteps
with self.timer('predict_unet'):
noise_pred = self.predict_noise(
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),

View File

@@ -413,7 +413,7 @@ class TrainConfig:
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm, mean_flow
# scale the prediction by this. Increase for more detail, decrease for less
self.pred_scaler = kwargs.get('pred_scaler', 1.0)

View File

@@ -111,12 +111,12 @@ class CaptionMixin:
if not hasattr(self, 'file_list'):
raise Exception('file_list not found on class instance')
img_path_or_tuple = self.file_list[index]
ext = self.dataset_config.caption_ext
if isinstance(img_path_or_tuple, tuple):
img_path = img_path_or_tuple[0] if isinstance(img_path_or_tuple[0], str) else img_path_or_tuple[0].path
# check if either has a prompt file
path_no_ext = os.path.splitext(img_path)[0]
prompt_path = None
ext = self.dataset_config.caption_ext
prompt_path = path_no_ext + ext
else:
img_path = img_path_or_tuple if isinstance(img_path_or_tuple, str) else img_path_or_tuple.path
@@ -315,7 +315,7 @@ class CaptionProcessingDTOMixin:
# see if prompt file exists
path_no_ext = os.path.splitext(self.path)[0]
prompt_ext = self.dataset_config.caption_ext
prompt_path = f"{path_no_ext}.{prompt_ext}"
prompt_path = path_no_ext + prompt_ext
short_caption = None
if os.path.exists(prompt_path):

View File

@@ -4,6 +4,7 @@ from functools import partial
from typing import Optional
import torch
from diffusers import FluxTransformer2DModel
from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, CombinedTimestepGuidanceTextProjEmbeddings
def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection):
@@ -174,3 +175,57 @@ def add_model_gpu_splitter_to_flux(
transformer._pre_gpu_split_to = transformer.to
transformer.to = partial(new_device_to, transformer)
def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection):
# make zero timestep ending if none is passed
if timestep.shape[0] == pooled_projection.shape[0] // 2:
timestep = torch.cat([timestep, timestep], dim=0) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
return conditioning
def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection):
# make zero timestep ending if none is passed
if timestep.shape[0] == pooled_projection.shape[0] // 2:
timestep = torch.cat([timestep, timestep], dim=0)
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
time_guidance_emb = timesteps_emb_start + timesteps_emb_end + guidance_emb
pooled_projections = self.text_embedder(pooled_projection)
conditioning = time_guidance_emb + pooled_projections
return conditioning
def convert_flux_to_mean_flow(
transformer: FluxTransformer2DModel,
):
if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings):
transformer.time_text_embed.forward = partial(
mean_flow_time_text_embed_forward, transformer.time_text_embed
)
elif isinstance(transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings):
transformer.time_text_embed.forward = partial(
mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed
)
else:
raise ValueError(
"Unsupported time_text_embed type: {}".format(
type(transformer.time_text_embed)
)
)

View File

@@ -16,6 +16,7 @@ from diffusers import (
LCMScheduler,
FlowMatchEulerDiscreteScheduler,
)
from toolkit.samplers.mean_flow_scheduler import MeanFlowScheduler
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
@@ -159,6 +160,8 @@ def get_sampler(
scheduler_cls = LCMScheduler
elif sampler == "custom_lcm":
scheduler_cls = CustomLCMScheduler
elif sampler == "mean_flow":
scheduler_cls = MeanFlowScheduler
elif sampler == "flowmatch":
scheduler_cls = CustomFlowMatchEulerDiscreteScheduler
config_to_use = copy.deepcopy(flux_config)

View File

@@ -0,0 +1,93 @@
from typing import Union
from diffusers import FlowMatchEulerDiscreteScheduler
import torch
from toolkit.timestep_weighing.default_weighing_scheme import default_weighing_scheme
from dataclasses import dataclass
from typing import Optional, Tuple
from diffusers.utils import BaseOutput
@dataclass
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class MeanFlowScheduler(FlowMatchEulerDiscreteScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_noise_sigma = 1.0
self.timestep_type = "linear"
with torch.no_grad():
# create weights for timesteps
num_timesteps = 1000
# Create linear timesteps from 1000 to 1
timesteps = torch.linspace(1000, 1, num_timesteps, device="cpu")
self.linear_timesteps = timesteps
pass
def get_weights_for_timesteps(
self, timesteps: torch.Tensor, v2=False, timestep_type="linear"
) -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
weights = 1.0
# Get the weights for the timesteps
if timestep_type == "weighted":
weights = torch.tensor(
[default_weighing_scheme[i] for i in step_indices],
device=timesteps.device,
dtype=timesteps.dtype,
)
return weights
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
t_01 = (timesteps / 1000).to(original_samples.device)
noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise
return noisy_model_input
def scale_model_input(
self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]
) -> torch.Tensor:
return sample
def set_train_timesteps(self, num_timesteps, device, **kwargs):
timesteps = torch.linspace(1000, 1, num_timesteps, device=device)
self.timesteps = timesteps
return timesteps
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
return_dict: bool = True,
**kwargs: Optional[dict],
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
# single euler step (Eq. 5 ⇒ x₀ = x₁ uθ)
output = sample - model_output
if not return_dict:
return (output,)
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=output)

View File

@@ -1 +1 @@
VERSION = "0.2.10"
VERSION = "0.3.0"