mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
WIP on mean flow loss. Still a WIP.
This commit is contained in:
@@ -4,7 +4,7 @@ VERSION=dev
|
|||||||
GIT_COMMIT=dev
|
GIT_COMMIT=dev
|
||||||
|
|
||||||
echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
|
echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
|
||||||
echo "Building version: $VERSION and latest"
|
echo "Building version: $VERSION"
|
||||||
# wait 2 seconds
|
# wait 2 seconds
|
||||||
sleep 2
|
sleep 2
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from toolkit.train_tools import precondition_model_outputs_flow_match
|
|||||||
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
||||||
from toolkit.util.wavelet_loss import wavelet_loss
|
from toolkit.util.wavelet_loss import wavelet_loss
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from toolkit.models.flux import convert_flux_to_mean_flow
|
||||||
|
|
||||||
|
|
||||||
def flush():
|
def flush():
|
||||||
@@ -134,6 +135,9 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
def hook_before_train_loop(self):
|
def hook_before_train_loop(self):
|
||||||
super().hook_before_train_loop()
|
super().hook_before_train_loop()
|
||||||
|
if self.train_config.timestep_type == "mean_flow":
|
||||||
|
# todo handle non flux models
|
||||||
|
convert_flux_to_mean_flow(self.sd.transformer)
|
||||||
|
|
||||||
if self.train_config.do_prior_divergence:
|
if self.train_config.do_prior_divergence:
|
||||||
self.do_prior_prediction = True
|
self.do_prior_prediction = True
|
||||||
@@ -596,19 +600,6 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
return loss + additional_loss
|
return loss + additional_loss
|
||||||
|
|
||||||
def get_diff_output_preservation_loss(
|
|
||||||
self,
|
|
||||||
noise_pred: torch.Tensor,
|
|
||||||
noise: torch.Tensor,
|
|
||||||
noisy_latents: torch.Tensor,
|
|
||||||
timesteps: torch.Tensor,
|
|
||||||
batch: 'DataLoaderBatchDTO',
|
|
||||||
mask_multiplier: Union[torch.Tensor, float] = 1.0,
|
|
||||||
prior_pred: Union[torch.Tensor, None] = None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
loss_target = self.train_config.loss_target
|
|
||||||
|
|
||||||
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
|
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@@ -643,6 +634,254 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative
|
||||||
|
# Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper)
|
||||||
|
# This version avoids jvp / double-back-prop issues with Flash-Attention
|
||||||
|
# adapted from the work of lodestonerock
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def get_mean_flow_loss_wip(
|
||||||
|
self,
|
||||||
|
noisy_latents: torch.Tensor,
|
||||||
|
conditional_embeds: PromptEmbeds,
|
||||||
|
match_adapter_assist: bool,
|
||||||
|
network_weight_list: list,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
pred_kwargs: dict,
|
||||||
|
batch: 'DataLoaderBatchDTO',
|
||||||
|
noise: torch.Tensor,
|
||||||
|
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
batch_latents = batch.latents.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
time_end = timesteps.float() / 1000
|
||||||
|
# for timestep_r, we need values from timestep_end to 0.0 randomly
|
||||||
|
time_origin = torch.rand_like(time_end, device=self.device_torch, dtype=time_end.dtype) * time_end
|
||||||
|
|
||||||
|
# time_origin = torch.zeros_like(time_end, device=self.device_torch, dtype=time_end.dtype)
|
||||||
|
# Compute noised data points
|
||||||
|
# lerp_vector = noisy_latents
|
||||||
|
# compute instantaneous vector
|
||||||
|
instantaneous_vector = noise - batch_latents
|
||||||
|
|
||||||
|
# finite difference method
|
||||||
|
epsilon_fd = 1e-3
|
||||||
|
jitter_std = 1e-4
|
||||||
|
epsilon_jittered = epsilon_fd + torch.randn(1, device=batch_latents.device) * jitter_std
|
||||||
|
epsilon_jittered = torch.clamp(epsilon_jittered, min=1e-4)
|
||||||
|
|
||||||
|
# f(x + epsilon * v) for the primal (we backprop through here)
|
||||||
|
# mean_vec_val_pred = self.forward(lerp_vector, class_label)
|
||||||
|
mean_vec_val_pred = self.predict_noise(
|
||||||
|
noisy_latents=noisy_latents,
|
||||||
|
timesteps=torch.cat([time_end, time_origin], dim=0) * 1000,
|
||||||
|
conditional_embeds=conditional_embeds,
|
||||||
|
unconditional_embeds=unconditional_embeds,
|
||||||
|
batch=batch,
|
||||||
|
**pred_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
perturbed_time_end = torch.clamp(time_end + epsilon_jittered, 0.0, 1.0)
|
||||||
|
# intermediate vector to compute tangent approximation f(x + epsilon * v) ! NO GRAD HERE!
|
||||||
|
perturbed_lerp_vector = noisy_latents + epsilon_jittered * instantaneous_vector
|
||||||
|
# f_x_plus_eps_v = self.forward(perturbed_lerp_vector, class_label)
|
||||||
|
f_x_plus_eps_v = self.predict_noise(
|
||||||
|
noisy_latents=perturbed_lerp_vector,
|
||||||
|
timesteps=torch.cat([perturbed_time_end, time_origin], dim=0) * 1000,
|
||||||
|
conditional_embeds=conditional_embeds,
|
||||||
|
unconditional_embeds=unconditional_embeds,
|
||||||
|
batch=batch,
|
||||||
|
**pred_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# JVP approximation: (f(x + epsilon * v) - f(x)) / epsilon
|
||||||
|
mean_vec_grad_fd = (f_x_plus_eps_v - mean_vec_val_pred) / epsilon_jittered
|
||||||
|
mean_vec_grad = mean_vec_grad_fd
|
||||||
|
|
||||||
|
|
||||||
|
# calculate the regression target the mean vector
|
||||||
|
time_difference_broadcast = (time_end - time_origin)[:, None, None, None]
|
||||||
|
mean_vec_target = instantaneous_vector - time_difference_broadcast * mean_vec_grad
|
||||||
|
|
||||||
|
# 5) MSE loss
|
||||||
|
loss = torch.nn.functional.mse_loss(
|
||||||
|
mean_vec_val_pred.float(),
|
||||||
|
mean_vec_target.float(),
|
||||||
|
reduction='none'
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
pure_loss = loss.mean().detach()
|
||||||
|
# add grad to pure_loss so it can be backwards without issues
|
||||||
|
pure_loss.requires_grad_(True)
|
||||||
|
# normalize the loss per batch element to 1.0
|
||||||
|
# this method has large loss swings that can hurt the model. This method will prevent that
|
||||||
|
with torch.no_grad():
|
||||||
|
loss_mean = loss.mean([1, 2, 3], keepdim=True)
|
||||||
|
loss = loss / loss_mean
|
||||||
|
loss = loss.mean()
|
||||||
|
|
||||||
|
# backward the pure loss for logging
|
||||||
|
self.accelerator.backward(loss)
|
||||||
|
|
||||||
|
# return the real loss for logging
|
||||||
|
return pure_loss
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative
|
||||||
|
# Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper)
|
||||||
|
# This version avoids jvp / double-back-prop issues with Flash-Attention
|
||||||
|
# adapted from the work of lodestonerock
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def get_mean_flow_loss(
|
||||||
|
self,
|
||||||
|
noisy_latents: torch.Tensor,
|
||||||
|
conditional_embeds: PromptEmbeds,
|
||||||
|
match_adapter_assist: bool,
|
||||||
|
network_weight_list: list,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
pred_kwargs: dict,
|
||||||
|
batch: 'DataLoaderBatchDTO',
|
||||||
|
noise: torch.Tensor,
|
||||||
|
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# “Slow” Mean-Flow loss – finite-difference version
|
||||||
|
# (avoids JVP / double-backprop issues with Flash-Attention)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
|
total_steps = float(self.sd.noise_scheduler.config.num_train_timesteps) # 1000
|
||||||
|
base_eps = 1e-3 # this is one step when multiplied by 1000
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
num_train_timesteps = self.sd.noise_scheduler.config.num_train_timesteps
|
||||||
|
batch_size = batch.latents.shape[0]
|
||||||
|
timestep_t_list = []
|
||||||
|
timestep_r_list = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
t1 = random.randint(0, num_train_timesteps - 1)
|
||||||
|
t2 = random.randint(0, num_train_timesteps - 1)
|
||||||
|
t_t = self.sd.noise_scheduler.timesteps[min(t1, t2)]
|
||||||
|
t_r = self.sd.noise_scheduler.timesteps[max(t1, t2)]
|
||||||
|
if (t_t - t_r).item() < base_eps * 1000:
|
||||||
|
# we need to ensure the time gap is wider than the epsilon(one step)
|
||||||
|
scaled_eps = base_eps * 1000
|
||||||
|
if t_t.item() + scaled_eps > 1000:
|
||||||
|
t_r = t_r - scaled_eps
|
||||||
|
else:
|
||||||
|
t_t = t_t + scaled_eps
|
||||||
|
timestep_t_list.append(t_t)
|
||||||
|
timestep_r_list.append(t_r)
|
||||||
|
eps = min((t_t - t_r).item(), 1e-3) / num_train_timesteps
|
||||||
|
timesteps_t = torch.stack(timestep_t_list, dim=0).float()
|
||||||
|
timesteps_r = torch.stack(timestep_r_list, dim=0).float()
|
||||||
|
|
||||||
|
# fractions in [0,1]
|
||||||
|
t_frac = timesteps_t / total_steps
|
||||||
|
r_frac = timesteps_r / total_steps
|
||||||
|
|
||||||
|
# 2) construct data points
|
||||||
|
latents_clean = batch.latents.to(dtype)
|
||||||
|
noise_sample = noise.to(dtype)
|
||||||
|
|
||||||
|
lerp_vector = noise_sample * t_frac[:, None, None, None] \
|
||||||
|
+ latents_clean * (1.0 - t_frac[:, None, None, None])
|
||||||
|
|
||||||
|
if hasattr(self.sd, 'get_loss_target'):
|
||||||
|
instantaneous_vector = self.sd.get_loss_target(
|
||||||
|
noise=noise_sample,
|
||||||
|
batch=batch,
|
||||||
|
timesteps=timesteps,
|
||||||
|
).detach()
|
||||||
|
else:
|
||||||
|
instantaneous_vector = noise_sample - latents_clean # v_t (B,C,H,W)
|
||||||
|
|
||||||
|
# 3) finite-difference JVP approximation (bump z **and** t)
|
||||||
|
# eps_base, eps_jitter = 1e-3, 1e-4
|
||||||
|
# eps = (eps_base + torch.randn(1, device=lerp_vector.device) * eps_jitter).clamp_(min=1e-4)
|
||||||
|
jitter = 1e-4
|
||||||
|
eps = value_map(
|
||||||
|
torch.rand_like(t_frac),
|
||||||
|
0.0,
|
||||||
|
1.0,
|
||||||
|
base_eps,
|
||||||
|
base_eps + jitter
|
||||||
|
)
|
||||||
|
|
||||||
|
# eps = 1e-3
|
||||||
|
# primary prediction (needs grad)
|
||||||
|
mean_vec_pred = self.predict_noise(
|
||||||
|
noisy_latents=lerp_vector,
|
||||||
|
timesteps=torch.cat([t_frac, r_frac], dim=0) * total_steps,
|
||||||
|
conditional_embeds=conditional_embeds,
|
||||||
|
unconditional_embeds=unconditional_embeds,
|
||||||
|
batch=batch,
|
||||||
|
**pred_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# secondary prediction: bump both latent and timestep by ε
|
||||||
|
with torch.no_grad():
|
||||||
|
# lerp_perturbed = lerp_vector + eps * instantaneous_vector
|
||||||
|
t_frac_plus_eps = t_frac + eps # bump time fraction
|
||||||
|
lerp_perturbed = noise_sample * t_frac_plus_eps[:, None, None, None] \
|
||||||
|
+ latents_clean * (1.0 - t_frac_plus_eps[:, None, None, None])
|
||||||
|
|
||||||
|
f_x_plus_eps_v = self.predict_noise(
|
||||||
|
noisy_latents=lerp_perturbed,
|
||||||
|
timesteps=torch.cat([t_frac_plus_eps, r_frac], dim=0) * total_steps,
|
||||||
|
conditional_embeds=conditional_embeds,
|
||||||
|
unconditional_embeds=unconditional_embeds,
|
||||||
|
batch=batch,
|
||||||
|
**pred_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# finite-difference JVP: (f(x+εv) – f(x)) / ε
|
||||||
|
mean_vec_grad = (f_x_plus_eps_v - mean_vec_pred) / eps
|
||||||
|
# time_gap = (t_frac - r_frac)[:, None, None, None]
|
||||||
|
# mean_vec_scaler = time_gap / eps
|
||||||
|
# mean_vec_grad = (f_x_plus_eps_v - mean_vec_pred) * mean_vec_scaler
|
||||||
|
# mean_vec_grad = mean_vec_grad.detach() # stop-grad as in Eq. 11
|
||||||
|
|
||||||
|
# 4) regression target for the mean vector
|
||||||
|
time_gap = (t_frac - r_frac)[:, None, None, None]
|
||||||
|
mean_vec_target = instantaneous_vector - time_gap * mean_vec_grad
|
||||||
|
# mean_vec_target = instantaneous_vector - mean_vec_grad
|
||||||
|
|
||||||
|
# # 5) MSE loss
|
||||||
|
# loss = torch.nn.functional.mse_loss(
|
||||||
|
# mean_vec_pred.float(),
|
||||||
|
# mean_vec_target.float()
|
||||||
|
# )
|
||||||
|
# return loss
|
||||||
|
# 5) MSE loss
|
||||||
|
loss = torch.nn.functional.mse_loss(
|
||||||
|
mean_vec_pred.float(),
|
||||||
|
mean_vec_target.float(),
|
||||||
|
reduction='none'
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
pure_loss = loss.mean().detach()
|
||||||
|
# add grad to pure_loss so it can be backwards without issues
|
||||||
|
pure_loss.requires_grad_(True)
|
||||||
|
# normalize the loss per batch element to 1.0
|
||||||
|
# this method has large loss swings that can hurt the model. This method will prevent that
|
||||||
|
# with torch.no_grad():
|
||||||
|
# loss_mean = loss.mean([1, 2, 3], keepdim=True)
|
||||||
|
# loss = loss / loss_mean
|
||||||
|
loss = loss.mean()
|
||||||
|
|
||||||
|
# backward the pure loss for logging
|
||||||
|
self.accelerator.backward(loss)
|
||||||
|
|
||||||
|
# return the real loss for logging
|
||||||
|
return pure_loss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_prior_prediction(
|
def get_prior_prediction(
|
||||||
self,
|
self,
|
||||||
noisy_latents: torch.Tensor,
|
noisy_latents: torch.Tensor,
|
||||||
@@ -1495,24 +1734,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
|
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
|
||||||
|
|
||||||
self.before_unet_predict()
|
self.before_unet_predict()
|
||||||
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
|
||||||
if batch.unconditional_latents is not None or self.do_guided_loss:
|
|
||||||
# do guided loss
|
|
||||||
loss = self.get_guided_loss(
|
|
||||||
noisy_latents=noisy_latents,
|
|
||||||
conditional_embeds=conditional_embeds,
|
|
||||||
match_adapter_assist=match_adapter_assist,
|
|
||||||
network_weight_list=network_weight_list,
|
|
||||||
timesteps=timesteps,
|
|
||||||
pred_kwargs=pred_kwargs,
|
|
||||||
batch=batch,
|
|
||||||
noise=noise,
|
|
||||||
unconditional_embeds=unconditional_embeds,
|
|
||||||
mask_multiplier=mask_multiplier,
|
|
||||||
prior_pred=prior_pred,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
if unconditional_embeds is not None:
|
if unconditional_embeds is not None:
|
||||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
||||||
with self.timer('condition_noisy_latents'):
|
with self.timer('condition_noisy_latents'):
|
||||||
@@ -1558,7 +1780,37 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01
|
next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01
|
||||||
noise = next_sample_noise
|
noise = next_sample_noise
|
||||||
timesteps = stepped_timesteps
|
timesteps = stepped_timesteps
|
||||||
|
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
||||||
|
if batch.unconditional_latents is not None or self.do_guided_loss:
|
||||||
|
# do guided loss
|
||||||
|
loss = self.get_guided_loss(
|
||||||
|
noisy_latents=noisy_latents,
|
||||||
|
conditional_embeds=conditional_embeds,
|
||||||
|
match_adapter_assist=match_adapter_assist,
|
||||||
|
network_weight_list=network_weight_list,
|
||||||
|
timesteps=timesteps,
|
||||||
|
pred_kwargs=pred_kwargs,
|
||||||
|
batch=batch,
|
||||||
|
noise=noise,
|
||||||
|
unconditional_embeds=unconditional_embeds,
|
||||||
|
mask_multiplier=mask_multiplier,
|
||||||
|
prior_pred=prior_pred,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.train_config.loss_type == 'mean_flow':
|
||||||
|
loss = self.get_mean_flow_loss(
|
||||||
|
noisy_latents=noisy_latents,
|
||||||
|
conditional_embeds=conditional_embeds,
|
||||||
|
match_adapter_assist=match_adapter_assist,
|
||||||
|
network_weight_list=network_weight_list,
|
||||||
|
timesteps=timesteps,
|
||||||
|
pred_kwargs=pred_kwargs,
|
||||||
|
batch=batch,
|
||||||
|
noise=noise,
|
||||||
|
unconditional_embeds=unconditional_embeds,
|
||||||
|
prior_pred=prior_pred,
|
||||||
|
)
|
||||||
|
else:
|
||||||
with self.timer('predict_unet'):
|
with self.timer('predict_unet'):
|
||||||
noise_pred = self.predict_noise(
|
noise_pred = self.predict_noise(
|
||||||
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||||
|
|||||||
@@ -413,7 +413,7 @@ class TrainConfig:
|
|||||||
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
|
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
|
||||||
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
|
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
|
||||||
|
|
||||||
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm
|
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm, mean_flow
|
||||||
|
|
||||||
# scale the prediction by this. Increase for more detail, decrease for less
|
# scale the prediction by this. Increase for more detail, decrease for less
|
||||||
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
|
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from functools import partial
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
from diffusers import FluxTransformer2DModel
|
from diffusers import FluxTransformer2DModel
|
||||||
|
from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, CombinedTimestepGuidanceTextProjEmbeddings
|
||||||
|
|
||||||
|
|
||||||
def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection):
|
def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection):
|
||||||
@@ -174,3 +175,57 @@ def add_model_gpu_splitter_to_flux(
|
|||||||
|
|
||||||
transformer._pre_gpu_split_to = transformer.to
|
transformer._pre_gpu_split_to = transformer.to
|
||||||
transformer.to = partial(new_device_to, transformer)
|
transformer.to = partial(new_device_to, transformer)
|
||||||
|
|
||||||
|
|
||||||
|
def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection):
|
||||||
|
# make zero timestep ending if none is passed
|
||||||
|
if timestep.shape[0] == pooled_projection.shape[0] // 2:
|
||||||
|
timestep = torch.cat([timestep, timestep], dim=0) # timestep - 0 (final timestep) == same as start timestep
|
||||||
|
|
||||||
|
timesteps_proj = self.time_proj(timestep)
|
||||||
|
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||||
|
|
||||||
|
pooled_projections = self.text_embedder(pooled_projection)
|
||||||
|
|
||||||
|
conditioning = timesteps_emb + pooled_projections
|
||||||
|
|
||||||
|
return conditioning
|
||||||
|
|
||||||
|
def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection):
|
||||||
|
# make zero timestep ending if none is passed
|
||||||
|
if timestep.shape[0] == pooled_projection.shape[0] // 2:
|
||||||
|
timestep = torch.cat([timestep, timestep], dim=0)
|
||||||
|
timesteps_proj = self.time_proj(timestep)
|
||||||
|
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||||
|
|
||||||
|
guidance_proj = self.time_proj(guidance)
|
||||||
|
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||||
|
|
||||||
|
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
|
||||||
|
|
||||||
|
time_guidance_emb = timesteps_emb_start + timesteps_emb_end + guidance_emb
|
||||||
|
|
||||||
|
pooled_projections = self.text_embedder(pooled_projection)
|
||||||
|
conditioning = time_guidance_emb + pooled_projections
|
||||||
|
|
||||||
|
return conditioning
|
||||||
|
|
||||||
|
|
||||||
|
def convert_flux_to_mean_flow(
|
||||||
|
transformer: FluxTransformer2DModel,
|
||||||
|
):
|
||||||
|
if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings):
|
||||||
|
transformer.time_text_embed.forward = partial(
|
||||||
|
mean_flow_time_text_embed_forward, transformer.time_text_embed
|
||||||
|
)
|
||||||
|
elif isinstance(transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings):
|
||||||
|
transformer.time_text_embed.forward = partial(
|
||||||
|
mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported time_text_embed type: {}".format(
|
||||||
|
type(transformer.time_text_embed)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@@ -16,6 +16,7 @@ from diffusers import (
|
|||||||
LCMScheduler,
|
LCMScheduler,
|
||||||
FlowMatchEulerDiscreteScheduler,
|
FlowMatchEulerDiscreteScheduler,
|
||||||
)
|
)
|
||||||
|
from toolkit.samplers.mean_flow_scheduler import MeanFlowScheduler
|
||||||
|
|
||||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||||
|
|
||||||
@@ -159,6 +160,8 @@ def get_sampler(
|
|||||||
scheduler_cls = LCMScheduler
|
scheduler_cls = LCMScheduler
|
||||||
elif sampler == "custom_lcm":
|
elif sampler == "custom_lcm":
|
||||||
scheduler_cls = CustomLCMScheduler
|
scheduler_cls = CustomLCMScheduler
|
||||||
|
elif sampler == "mean_flow":
|
||||||
|
scheduler_cls = MeanFlowScheduler
|
||||||
elif sampler == "flowmatch":
|
elif sampler == "flowmatch":
|
||||||
scheduler_cls = CustomFlowMatchEulerDiscreteScheduler
|
scheduler_cls = CustomFlowMatchEulerDiscreteScheduler
|
||||||
config_to_use = copy.deepcopy(flux_config)
|
config_to_use = copy.deepcopy(flux_config)
|
||||||
|
|||||||
93
toolkit/samplers/mean_flow_scheduler.py
Normal file
93
toolkit/samplers/mean_flow_scheduler.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
from typing import Union
|
||||||
|
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||||
|
import torch
|
||||||
|
from toolkit.timestep_weighing.default_weighing_scheme import default_weighing_scheme
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from diffusers.utils import BaseOutput
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
||||||
|
"""
|
||||||
|
Output class for the scheduler's `step` function output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||||
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||||
|
denoising loop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prev_sample: torch.FloatTensor
|
||||||
|
|
||||||
|
|
||||||
|
class MeanFlowScheduler(FlowMatchEulerDiscreteScheduler):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.init_noise_sigma = 1.0
|
||||||
|
self.timestep_type = "linear"
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# create weights for timesteps
|
||||||
|
num_timesteps = 1000
|
||||||
|
|
||||||
|
# Create linear timesteps from 1000 to 1
|
||||||
|
timesteps = torch.linspace(1000, 1, num_timesteps, device="cpu")
|
||||||
|
|
||||||
|
self.linear_timesteps = timesteps
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_weights_for_timesteps(
|
||||||
|
self, timesteps: torch.Tensor, v2=False, timestep_type="linear"
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Get the indices of the timesteps
|
||||||
|
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
|
||||||
|
|
||||||
|
weights = 1.0
|
||||||
|
|
||||||
|
# Get the weights for the timesteps
|
||||||
|
if timestep_type == "weighted":
|
||||||
|
weights = torch.tensor(
|
||||||
|
[default_weighing_scheme[i] for i in step_indices],
|
||||||
|
device=timesteps.device,
|
||||||
|
dtype=timesteps.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return weights
|
||||||
|
|
||||||
|
def add_noise(
|
||||||
|
self,
|
||||||
|
original_samples: torch.Tensor,
|
||||||
|
noise: torch.Tensor,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
t_01 = (timesteps / 1000).to(original_samples.device)
|
||||||
|
noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise
|
||||||
|
return noisy_model_input
|
||||||
|
|
||||||
|
def scale_model_input(
|
||||||
|
self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def set_train_timesteps(self, num_timesteps, device, **kwargs):
|
||||||
|
timesteps = torch.linspace(1000, 1, num_timesteps, device=device)
|
||||||
|
self.timesteps = timesteps
|
||||||
|
return timesteps
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
model_output: torch.FloatTensor,
|
||||||
|
timestep: Union[float, torch.FloatTensor],
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
return_dict: bool = True,
|
||||||
|
**kwargs: Optional[dict],
|
||||||
|
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
||||||
|
|
||||||
|
# single euler step (Eq. 5 ⇒ x₀ = x₁ − uθ)
|
||||||
|
output = sample - model_output
|
||||||
|
if not return_dict:
|
||||||
|
return (output,)
|
||||||
|
|
||||||
|
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=output)
|
||||||
Reference in New Issue
Block a user