mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Merge branch 'main' into dev
This commit is contained in:
@@ -4,7 +4,7 @@ VERSION=dev
|
||||
GIT_COMMIT=dev
|
||||
|
||||
echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
|
||||
echo "Building version: $VERSION and latest"
|
||||
echo "Building version: $VERSION"
|
||||
# wait 2 seconds
|
||||
sleep 2
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ from toolkit.train_tools import precondition_model_outputs_flow_match
|
||||
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
||||
from toolkit.util.wavelet_loss import wavelet_loss
|
||||
import torch.nn.functional as F
|
||||
from toolkit.models.flux import convert_flux_to_mean_flow
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -134,6 +135,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
super().hook_before_train_loop()
|
||||
if self.train_config.timestep_type == "mean_flow":
|
||||
# todo handle non flux models
|
||||
convert_flux_to_mean_flow(self.sd.transformer)
|
||||
|
||||
if self.train_config.do_prior_divergence:
|
||||
self.do_prior_prediction = True
|
||||
@@ -595,19 +599,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
|
||||
return loss + additional_loss
|
||||
|
||||
def get_diff_output_preservation_loss(
|
||||
self,
|
||||
noise_pred: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
noisy_latents: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
mask_multiplier: Union[torch.Tensor, float] = 1.0,
|
||||
prior_pred: Union[torch.Tensor, None] = None,
|
||||
**kwargs
|
||||
):
|
||||
loss_target = self.train_config.loss_target
|
||||
|
||||
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
return batch
|
||||
@@ -641,6 +632,254 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative
|
||||
# Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper)
|
||||
# This version avoids jvp / double-back-prop issues with Flash-Attention
|
||||
# adapted from the work of lodestonerock
|
||||
# ------------------------------------------------------------------
|
||||
def get_mean_flow_loss_wip(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||
**kwargs
|
||||
):
|
||||
batch_latents = batch.latents.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
|
||||
|
||||
|
||||
time_end = timesteps.float() / 1000
|
||||
# for timestep_r, we need values from timestep_end to 0.0 randomly
|
||||
time_origin = torch.rand_like(time_end, device=self.device_torch, dtype=time_end.dtype) * time_end
|
||||
|
||||
# time_origin = torch.zeros_like(time_end, device=self.device_torch, dtype=time_end.dtype)
|
||||
# Compute noised data points
|
||||
# lerp_vector = noisy_latents
|
||||
# compute instantaneous vector
|
||||
instantaneous_vector = noise - batch_latents
|
||||
|
||||
# finite difference method
|
||||
epsilon_fd = 1e-3
|
||||
jitter_std = 1e-4
|
||||
epsilon_jittered = epsilon_fd + torch.randn(1, device=batch_latents.device) * jitter_std
|
||||
epsilon_jittered = torch.clamp(epsilon_jittered, min=1e-4)
|
||||
|
||||
# f(x + epsilon * v) for the primal (we backprop through here)
|
||||
# mean_vec_val_pred = self.forward(lerp_vector, class_label)
|
||||
mean_vec_val_pred = self.predict_noise(
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=torch.cat([time_end, time_origin], dim=0) * 1000,
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
perturbed_time_end = torch.clamp(time_end + epsilon_jittered, 0.0, 1.0)
|
||||
# intermediate vector to compute tangent approximation f(x + epsilon * v) ! NO GRAD HERE!
|
||||
perturbed_lerp_vector = noisy_latents + epsilon_jittered * instantaneous_vector
|
||||
# f_x_plus_eps_v = self.forward(perturbed_lerp_vector, class_label)
|
||||
f_x_plus_eps_v = self.predict_noise(
|
||||
noisy_latents=perturbed_lerp_vector,
|
||||
timesteps=torch.cat([perturbed_time_end, time_origin], dim=0) * 1000,
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
# JVP approximation: (f(x + epsilon * v) - f(x)) / epsilon
|
||||
mean_vec_grad_fd = (f_x_plus_eps_v - mean_vec_val_pred) / epsilon_jittered
|
||||
mean_vec_grad = mean_vec_grad_fd
|
||||
|
||||
|
||||
# calculate the regression target the mean vector
|
||||
time_difference_broadcast = (time_end - time_origin)[:, None, None, None]
|
||||
mean_vec_target = instantaneous_vector - time_difference_broadcast * mean_vec_grad
|
||||
|
||||
# 5) MSE loss
|
||||
loss = torch.nn.functional.mse_loss(
|
||||
mean_vec_val_pred.float(),
|
||||
mean_vec_target.float(),
|
||||
reduction='none'
|
||||
)
|
||||
with torch.no_grad():
|
||||
pure_loss = loss.mean().detach()
|
||||
# add grad to pure_loss so it can be backwards without issues
|
||||
pure_loss.requires_grad_(True)
|
||||
# normalize the loss per batch element to 1.0
|
||||
# this method has large loss swings that can hurt the model. This method will prevent that
|
||||
with torch.no_grad():
|
||||
loss_mean = loss.mean([1, 2, 3], keepdim=True)
|
||||
loss = loss / loss_mean
|
||||
loss = loss.mean()
|
||||
|
||||
# backward the pure loss for logging
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
# return the real loss for logging
|
||||
return pure_loss
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative
|
||||
# Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper)
|
||||
# This version avoids jvp / double-back-prop issues with Flash-Attention
|
||||
# adapted from the work of lodestonerock
|
||||
# ------------------------------------------------------------------
|
||||
def get_mean_flow_loss(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||
**kwargs
|
||||
):
|
||||
# ------------------------------------------------------------------
|
||||
# “Slow” Mean-Flow loss – finite-difference version
|
||||
# (avoids JVP / double-backprop issues with Flash-Attention)
|
||||
# ------------------------------------------------------------------
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
total_steps = float(self.sd.noise_scheduler.config.num_train_timesteps) # 1000
|
||||
base_eps = 1e-3 # this is one step when multiplied by 1000
|
||||
|
||||
with torch.no_grad():
|
||||
num_train_timesteps = self.sd.noise_scheduler.config.num_train_timesteps
|
||||
batch_size = batch.latents.shape[0]
|
||||
timestep_t_list = []
|
||||
timestep_r_list = []
|
||||
for i in range(batch_size):
|
||||
t1 = random.randint(0, num_train_timesteps - 1)
|
||||
t2 = random.randint(0, num_train_timesteps - 1)
|
||||
t_t = self.sd.noise_scheduler.timesteps[min(t1, t2)]
|
||||
t_r = self.sd.noise_scheduler.timesteps[max(t1, t2)]
|
||||
if (t_t - t_r).item() < base_eps * 1000:
|
||||
# we need to ensure the time gap is wider than the epsilon(one step)
|
||||
scaled_eps = base_eps * 1000
|
||||
if t_t.item() + scaled_eps > 1000:
|
||||
t_r = t_r - scaled_eps
|
||||
else:
|
||||
t_t = t_t + scaled_eps
|
||||
timestep_t_list.append(t_t)
|
||||
timestep_r_list.append(t_r)
|
||||
eps = min((t_t - t_r).item(), 1e-3) / num_train_timesteps
|
||||
timesteps_t = torch.stack(timestep_t_list, dim=0).float()
|
||||
timesteps_r = torch.stack(timestep_r_list, dim=0).float()
|
||||
|
||||
# fractions in [0,1]
|
||||
t_frac = timesteps_t / total_steps
|
||||
r_frac = timesteps_r / total_steps
|
||||
|
||||
# 2) construct data points
|
||||
latents_clean = batch.latents.to(dtype)
|
||||
noise_sample = noise.to(dtype)
|
||||
|
||||
lerp_vector = noise_sample * t_frac[:, None, None, None] \
|
||||
+ latents_clean * (1.0 - t_frac[:, None, None, None])
|
||||
|
||||
if hasattr(self.sd, 'get_loss_target'):
|
||||
instantaneous_vector = self.sd.get_loss_target(
|
||||
noise=noise_sample,
|
||||
batch=batch,
|
||||
timesteps=timesteps,
|
||||
).detach()
|
||||
else:
|
||||
instantaneous_vector = noise_sample - latents_clean # v_t (B,C,H,W)
|
||||
|
||||
# 3) finite-difference JVP approximation (bump z **and** t)
|
||||
# eps_base, eps_jitter = 1e-3, 1e-4
|
||||
# eps = (eps_base + torch.randn(1, device=lerp_vector.device) * eps_jitter).clamp_(min=1e-4)
|
||||
jitter = 1e-4
|
||||
eps = value_map(
|
||||
torch.rand_like(t_frac),
|
||||
0.0,
|
||||
1.0,
|
||||
base_eps,
|
||||
base_eps + jitter
|
||||
)
|
||||
|
||||
# eps = 1e-3
|
||||
# primary prediction (needs grad)
|
||||
mean_vec_pred = self.predict_noise(
|
||||
noisy_latents=lerp_vector,
|
||||
timesteps=torch.cat([t_frac, r_frac], dim=0) * total_steps,
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
# secondary prediction: bump both latent and timestep by ε
|
||||
with torch.no_grad():
|
||||
# lerp_perturbed = lerp_vector + eps * instantaneous_vector
|
||||
t_frac_plus_eps = t_frac + eps # bump time fraction
|
||||
lerp_perturbed = noise_sample * t_frac_plus_eps[:, None, None, None] \
|
||||
+ latents_clean * (1.0 - t_frac_plus_eps[:, None, None, None])
|
||||
|
||||
f_x_plus_eps_v = self.predict_noise(
|
||||
noisy_latents=lerp_perturbed,
|
||||
timesteps=torch.cat([t_frac_plus_eps, r_frac], dim=0) * total_steps,
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
# finite-difference JVP: (f(x+εv) – f(x)) / ε
|
||||
mean_vec_grad = (f_x_plus_eps_v - mean_vec_pred) / eps
|
||||
# time_gap = (t_frac - r_frac)[:, None, None, None]
|
||||
# mean_vec_scaler = time_gap / eps
|
||||
# mean_vec_grad = (f_x_plus_eps_v - mean_vec_pred) * mean_vec_scaler
|
||||
# mean_vec_grad = mean_vec_grad.detach() # stop-grad as in Eq. 11
|
||||
|
||||
# 4) regression target for the mean vector
|
||||
time_gap = (t_frac - r_frac)[:, None, None, None]
|
||||
mean_vec_target = instantaneous_vector - time_gap * mean_vec_grad
|
||||
# mean_vec_target = instantaneous_vector - mean_vec_grad
|
||||
|
||||
# # 5) MSE loss
|
||||
# loss = torch.nn.functional.mse_loss(
|
||||
# mean_vec_pred.float(),
|
||||
# mean_vec_target.float()
|
||||
# )
|
||||
# return loss
|
||||
# 5) MSE loss
|
||||
loss = torch.nn.functional.mse_loss(
|
||||
mean_vec_pred.float(),
|
||||
mean_vec_target.float(),
|
||||
reduction='none'
|
||||
)
|
||||
with torch.no_grad():
|
||||
pure_loss = loss.mean().detach()
|
||||
# add grad to pure_loss so it can be backwards without issues
|
||||
pure_loss.requires_grad_(True)
|
||||
# normalize the loss per batch element to 1.0
|
||||
# this method has large loss swings that can hurt the model. This method will prevent that
|
||||
# with torch.no_grad():
|
||||
# loss_mean = loss.mean([1, 2, 3], keepdim=True)
|
||||
# loss = loss / loss_mean
|
||||
loss = loss.mean()
|
||||
|
||||
# backward the pure loss for logging
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
# return the real loss for logging
|
||||
return pure_loss
|
||||
|
||||
|
||||
|
||||
def get_prior_prediction(
|
||||
@@ -1495,6 +1734,52 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
|
||||
|
||||
self.before_unet_predict()
|
||||
|
||||
if unconditional_embeds is not None:
|
||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
||||
with self.timer('condition_noisy_latents'):
|
||||
# do it for the model
|
||||
noisy_latents = self.sd.condition_noisy_latents(noisy_latents, batch)
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||
noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch)
|
||||
|
||||
if self.train_config.timestep_type == 'next_sample':
|
||||
with self.timer('next_sample_step'):
|
||||
with torch.no_grad():
|
||||
|
||||
stepped_timestep_indicies = [self.sd.noise_scheduler.index_for_timestep(t) + 1 for t in timesteps]
|
||||
stepped_timesteps = [self.sd.noise_scheduler.timesteps[x] for x in stepped_timestep_indicies]
|
||||
stepped_timesteps = torch.stack(stepped_timesteps, dim=0)
|
||||
|
||||
# do a sample at the current timestep and step it, then determine new noise
|
||||
next_sample_pred = self.predict_noise(
|
||||
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
timesteps=timesteps,
|
||||
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
stepped_latents = self.sd.step_scheduler(
|
||||
next_sample_pred,
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
self.sd.noise_scheduler
|
||||
)
|
||||
# stepped latents is our new noisy latents. Now we need to determine noise in the current sample
|
||||
noisy_latents = stepped_latents
|
||||
original_samples = batch.latents.to(self.device_torch, dtype=dtype)
|
||||
# todo calc next timestep, for now this may work as it
|
||||
t_01 = (stepped_timesteps / 1000).to(original_samples.device)
|
||||
if len(stepped_latents.shape) == 4:
|
||||
t_01 = t_01.view(-1, 1, 1, 1)
|
||||
elif len(stepped_latents.shape) == 5:
|
||||
t_01 = t_01.view(-1, 1, 1, 1, 1)
|
||||
else:
|
||||
raise ValueError("Unknown stepped latents shape", stepped_latents.shape)
|
||||
next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01
|
||||
noise = next_sample_noise
|
||||
timesteps = stepped_timesteps
|
||||
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
||||
if batch.unconditional_latents is not None or self.do_guided_loss:
|
||||
# do guided loss
|
||||
@@ -1511,54 +1796,21 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_pred,
|
||||
)
|
||||
|
||||
|
||||
elif self.train_config.loss_type == 'mean_flow':
|
||||
loss = self.get_mean_flow_loss(
|
||||
noisy_latents=noisy_latents,
|
||||
conditional_embeds=conditional_embeds,
|
||||
match_adapter_assist=match_adapter_assist,
|
||||
network_weight_list=network_weight_list,
|
||||
timesteps=timesteps,
|
||||
pred_kwargs=pred_kwargs,
|
||||
batch=batch,
|
||||
noise=noise,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
prior_pred=prior_pred,
|
||||
)
|
||||
else:
|
||||
if unconditional_embeds is not None:
|
||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
||||
with self.timer('condition_noisy_latents'):
|
||||
# do it for the model
|
||||
noisy_latents = self.sd.condition_noisy_latents(noisy_latents, batch)
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||
noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch)
|
||||
|
||||
if self.train_config.timestep_type == 'next_sample':
|
||||
with self.timer('next_sample_step'):
|
||||
with torch.no_grad():
|
||||
|
||||
stepped_timestep_indicies = [self.sd.noise_scheduler.index_for_timestep(t) + 1 for t in timesteps]
|
||||
stepped_timesteps = [self.sd.noise_scheduler.timesteps[x] for x in stepped_timestep_indicies]
|
||||
stepped_timesteps = torch.stack(stepped_timesteps, dim=0)
|
||||
|
||||
# do a sample at the current timestep and step it, then determine new noise
|
||||
next_sample_pred = self.predict_noise(
|
||||
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
timesteps=timesteps,
|
||||
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
stepped_latents = self.sd.step_scheduler(
|
||||
next_sample_pred,
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
self.sd.noise_scheduler
|
||||
)
|
||||
# stepped latents is our new noisy latents. Now we need to determine noise in the current sample
|
||||
noisy_latents = stepped_latents
|
||||
original_samples = batch.latents.to(self.device_torch, dtype=dtype)
|
||||
# todo calc next timestep, for now this may work as it
|
||||
t_01 = (stepped_timesteps / 1000).to(original_samples.device)
|
||||
if len(stepped_latents.shape) == 4:
|
||||
t_01 = t_01.view(-1, 1, 1, 1)
|
||||
elif len(stepped_latents.shape) == 5:
|
||||
t_01 = t_01.view(-1, 1, 1, 1, 1)
|
||||
else:
|
||||
raise ValueError("Unknown stepped latents shape", stepped_latents.shape)
|
||||
next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01
|
||||
noise = next_sample_noise
|
||||
timesteps = stepped_timesteps
|
||||
|
||||
with self.timer('predict_unet'):
|
||||
noise_pred = self.predict_noise(
|
||||
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
|
||||
@@ -413,7 +413,7 @@ class TrainConfig:
|
||||
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
|
||||
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
|
||||
|
||||
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm
|
||||
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm, mean_flow
|
||||
|
||||
# scale the prediction by this. Increase for more detail, decrease for less
|
||||
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
|
||||
|
||||
@@ -111,12 +111,12 @@ class CaptionMixin:
|
||||
if not hasattr(self, 'file_list'):
|
||||
raise Exception('file_list not found on class instance')
|
||||
img_path_or_tuple = self.file_list[index]
|
||||
ext = self.dataset_config.caption_ext
|
||||
if isinstance(img_path_or_tuple, tuple):
|
||||
img_path = img_path_or_tuple[0] if isinstance(img_path_or_tuple[0], str) else img_path_or_tuple[0].path
|
||||
# check if either has a prompt file
|
||||
path_no_ext = os.path.splitext(img_path)[0]
|
||||
prompt_path = None
|
||||
ext = self.dataset_config.caption_ext
|
||||
prompt_path = path_no_ext + ext
|
||||
else:
|
||||
img_path = img_path_or_tuple if isinstance(img_path_or_tuple, str) else img_path_or_tuple.path
|
||||
@@ -315,7 +315,7 @@ class CaptionProcessingDTOMixin:
|
||||
# see if prompt file exists
|
||||
path_no_ext = os.path.splitext(self.path)[0]
|
||||
prompt_ext = self.dataset_config.caption_ext
|
||||
prompt_path = f"{path_no_ext}.{prompt_ext}"
|
||||
prompt_path = path_no_ext + prompt_ext
|
||||
short_caption = None
|
||||
|
||||
if os.path.exists(prompt_path):
|
||||
|
||||
@@ -4,6 +4,7 @@ from functools import partial
|
||||
from typing import Optional
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, CombinedTimestepGuidanceTextProjEmbeddings
|
||||
|
||||
|
||||
def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection):
|
||||
@@ -174,3 +175,57 @@ def add_model_gpu_splitter_to_flux(
|
||||
|
||||
transformer._pre_gpu_split_to = transformer.to
|
||||
transformer.to = partial(new_device_to, transformer)
|
||||
|
||||
|
||||
def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection):
|
||||
# make zero timestep ending if none is passed
|
||||
if timestep.shape[0] == pooled_projection.shape[0] // 2:
|
||||
timestep = torch.cat([timestep, timestep], dim=0) # timestep - 0 (final timestep) == same as start timestep
|
||||
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
|
||||
conditioning = timesteps_emb + pooled_projections
|
||||
|
||||
return conditioning
|
||||
|
||||
def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection):
|
||||
# make zero timestep ending if none is passed
|
||||
if timestep.shape[0] == pooled_projection.shape[0] // 2:
|
||||
timestep = torch.cat([timestep, timestep], dim=0)
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
|
||||
|
||||
time_guidance_emb = timesteps_emb_start + timesteps_emb_end + guidance_emb
|
||||
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
conditioning = time_guidance_emb + pooled_projections
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
def convert_flux_to_mean_flow(
|
||||
transformer: FluxTransformer2DModel,
|
||||
):
|
||||
if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings):
|
||||
transformer.time_text_embed.forward = partial(
|
||||
mean_flow_time_text_embed_forward, transformer.time_text_embed
|
||||
)
|
||||
elif isinstance(transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings):
|
||||
transformer.time_text_embed.forward = partial(
|
||||
mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported time_text_embed type: {}".format(
|
||||
type(transformer.time_text_embed)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ from diffusers import (
|
||||
LCMScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from toolkit.samplers.mean_flow_scheduler import MeanFlowScheduler
|
||||
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
|
||||
@@ -159,6 +160,8 @@ def get_sampler(
|
||||
scheduler_cls = LCMScheduler
|
||||
elif sampler == "custom_lcm":
|
||||
scheduler_cls = CustomLCMScheduler
|
||||
elif sampler == "mean_flow":
|
||||
scheduler_cls = MeanFlowScheduler
|
||||
elif sampler == "flowmatch":
|
||||
scheduler_cls = CustomFlowMatchEulerDiscreteScheduler
|
||||
config_to_use = copy.deepcopy(flux_config)
|
||||
|
||||
93
toolkit/samplers/mean_flow_scheduler.py
Normal file
93
toolkit/samplers/mean_flow_scheduler.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from typing import Union
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
import torch
|
||||
from toolkit.timestep_weighing.default_weighing_scheme import default_weighing_scheme
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
|
||||
|
||||
class MeanFlowScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.init_noise_sigma = 1.0
|
||||
self.timestep_type = "linear"
|
||||
|
||||
with torch.no_grad():
|
||||
# create weights for timesteps
|
||||
num_timesteps = 1000
|
||||
|
||||
# Create linear timesteps from 1000 to 1
|
||||
timesteps = torch.linspace(1000, 1, num_timesteps, device="cpu")
|
||||
|
||||
self.linear_timesteps = timesteps
|
||||
pass
|
||||
|
||||
def get_weights_for_timesteps(
|
||||
self, timesteps: torch.Tensor, v2=False, timestep_type="linear"
|
||||
) -> torch.Tensor:
|
||||
# Get the indices of the timesteps
|
||||
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
weights = 1.0
|
||||
|
||||
# Get the weights for the timesteps
|
||||
if timestep_type == "weighted":
|
||||
weights = torch.tensor(
|
||||
[default_weighing_scheme[i] for i in step_indices],
|
||||
device=timesteps.device,
|
||||
dtype=timesteps.dtype,
|
||||
)
|
||||
|
||||
return weights
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
t_01 = (timesteps / 1000).to(original_samples.device)
|
||||
noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise
|
||||
return noisy_model_input
|
||||
|
||||
def scale_model_input(
|
||||
self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
return sample
|
||||
|
||||
def set_train_timesteps(self, num_timesteps, device, **kwargs):
|
||||
timesteps = torch.linspace(1000, 1, num_timesteps, device=device)
|
||||
self.timesteps = timesteps
|
||||
return timesteps
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
return_dict: bool = True,
|
||||
**kwargs: Optional[dict],
|
||||
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
||||
|
||||
# single euler step (Eq. 5 ⇒ x₀ = x₁ − uθ)
|
||||
output = sample - model_output
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=output)
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.2.10"
|
||||
VERSION = "0.3.0"
|
||||
Reference in New Issue
Block a user