mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Work on mean flow. Minor bug fixes. Omnigen improvements
This commit is contained in:
@@ -239,7 +239,10 @@ class OmniGen2Model(BaseModel):
|
||||
**kwargs
|
||||
):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
||||
try:
|
||||
timestep = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# optional_kwargs = {}
|
||||
# if 'ref_image_hidden_states' in set(inspect.signature(self.model.forward).parameters.keys()):
|
||||
|
||||
@@ -305,14 +305,14 @@ class OmniGen2Pipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
|
||||
# untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because Gemma can only handle sequences up to"
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
# if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
# removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
# logger.warning(
|
||||
# "The following part of your input was truncated because Gemma can only handle sequences up to"
|
||||
# f" {max_sequence_length} tokens: {removed_text}"
|
||||
# )
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
||||
prompt_embeds = self.mllm(
|
||||
|
||||
@@ -459,15 +459,33 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
stepped_latents = torch.cat(stepped_chunks, dim=0)
|
||||
|
||||
stepped_latents = stepped_latents.to(self.sd.vae.device, dtype=self.sd.vae.dtype)
|
||||
# resize to half the size of the latents
|
||||
stepped_latents_half = torch.nn.functional.interpolate(
|
||||
stepped_latents,
|
||||
size=(stepped_latents.shape[2] // 2, stepped_latents.shape[3] // 2),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
pred_features = self.dfe(stepped_latents.float())
|
||||
pred_features_half = self.dfe(stepped_latents_half.float())
|
||||
with torch.no_grad():
|
||||
target_features = self.dfe(batch.latents.to(self.device_torch, dtype=torch.float32))
|
||||
batch_latents_half = torch.nn.functional.interpolate(
|
||||
batch.latents.to(self.device_torch, dtype=torch.float32),
|
||||
size=(batch.latents.shape[2] // 2, batch.latents.shape[3] // 2),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
target_features_half = self.dfe(batch_latents_half)
|
||||
# scale dfe so it is weaker at higher noise levels
|
||||
dfe_scaler = 1 - (timesteps.float() / 1000.0).view(-1, 1, 1, 1).to(self.device_torch)
|
||||
|
||||
dfe_loss = torch.nn.functional.mse_loss(pred_features, target_features, reduction="none") * \
|
||||
self.train_config.diffusion_feature_extractor_weight * dfe_scaler
|
||||
additional_loss += dfe_loss.mean()
|
||||
|
||||
dfe_loss_half = torch.nn.functional.mse_loss(pred_features_half, target_features_half, reduction="none") * \
|
||||
self.train_config.diffusion_feature_extractor_weight * dfe_scaler
|
||||
additional_loss += dfe_loss.mean() + dfe_loss_half.mean()
|
||||
elif self.dfe.version == 2:
|
||||
# version 2
|
||||
# do diffusion feature extraction on target
|
||||
@@ -735,136 +753,106 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||
**kwargs
|
||||
):
|
||||
# ------------------------------------------------------------------
|
||||
# “Slow” Mean-Flow loss – finite-difference version
|
||||
# (avoids JVP / double-backprop issues with Flash-Attention)
|
||||
# ------------------------------------------------------------------
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
total_steps = float(self.sd.noise_scheduler.config.num_train_timesteps) # 1000
|
||||
base_eps = 1e-3 # this is one step when multiplied by 1000
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
total_steps = float(self.sd.noise_scheduler.config.num_train_timesteps) # e.g. 1000
|
||||
base_eps = 1e-3
|
||||
min_time_gap = 1e-2
|
||||
|
||||
with torch.no_grad():
|
||||
num_train_timesteps = self.sd.noise_scheduler.config.num_train_timesteps
|
||||
batch_size = batch.latents.shape[0]
|
||||
timestep_t_list = []
|
||||
timestep_r_list = []
|
||||
|
||||
for i in range(batch_size):
|
||||
t1 = random.randint(0, num_train_timesteps - 1)
|
||||
t2 = random.randint(0, num_train_timesteps - 1)
|
||||
t_t = self.sd.noise_scheduler.timesteps[min(t1, t2)]
|
||||
t_r = self.sd.noise_scheduler.timesteps[max(t1, t2)]
|
||||
if (t_t - t_r).item() < base_eps * 1000:
|
||||
# we need to ensure the time gap is wider than the epsilon(one step)
|
||||
scaled_eps = base_eps * 1000
|
||||
if t_t.item() + scaled_eps > 1000:
|
||||
t_r = t_r - scaled_eps
|
||||
if (t_t - t_r).item() < min_time_gap * 1000:
|
||||
scaled_time_gap = min_time_gap * 1000
|
||||
if t_t.item() + scaled_time_gap > 1000:
|
||||
t_r = t_r - scaled_time_gap
|
||||
else:
|
||||
t_t = t_t + scaled_eps
|
||||
t_t = t_t + scaled_time_gap
|
||||
timestep_t_list.append(t_t)
|
||||
timestep_r_list.append(t_r)
|
||||
eps = min((t_t - t_r).item(), 1e-3) / num_train_timesteps
|
||||
|
||||
timesteps_t = torch.stack(timestep_t_list, dim=0).float()
|
||||
timesteps_r = torch.stack(timestep_r_list, dim=0).float()
|
||||
|
||||
# fractions in [0,1]
|
||||
t_frac = timesteps_t / total_steps
|
||||
r_frac = timesteps_r / total_steps
|
||||
t_frac = timesteps_t / total_steps # [0,1]
|
||||
r_frac = timesteps_r / total_steps # [0,1]
|
||||
|
||||
# 2) construct data points
|
||||
latents_clean = batch.latents.to(dtype)
|
||||
noise_sample = noise.to(dtype)
|
||||
latents_clean = batch.latents.to(dtype)
|
||||
noise_sample = noise.to(dtype)
|
||||
|
||||
lerp_vector = noise_sample * t_frac[:, None, None, None] \
|
||||
+ latents_clean * (1.0 - t_frac[:, None, None, None])
|
||||
lerp_vector = latents_clean * (1.0 - t_frac[:, None, None, None]) + noise_sample * t_frac[:, None, None, None]
|
||||
|
||||
if hasattr(self.sd, 'get_loss_target'):
|
||||
instantaneous_vector = self.sd.get_loss_target(
|
||||
noise=noise_sample,
|
||||
batch=batch,
|
||||
timesteps=timesteps,
|
||||
).detach()
|
||||
else:
|
||||
instantaneous_vector = noise_sample - latents_clean # v_t (B,C,H,W)
|
||||
eps = base_eps
|
||||
|
||||
# 3) finite-difference JVP approximation (bump z **and** t)
|
||||
# eps_base, eps_jitter = 1e-3, 1e-4
|
||||
# eps = (eps_base + torch.randn(1, device=lerp_vector.device) * eps_jitter).clamp_(min=1e-4)
|
||||
jitter = 1e-4
|
||||
eps = value_map(
|
||||
torch.rand_like(t_frac),
|
||||
0.0,
|
||||
1.0,
|
||||
base_eps,
|
||||
base_eps + jitter
|
||||
)
|
||||
# eps = (t_frac - r_frac) / 2
|
||||
# concatenate timesteps as input for u(z, r, t)
|
||||
timesteps_cat = torch.cat([t_frac, r_frac], dim=0) * total_steps
|
||||
|
||||
# eps = 1e-3
|
||||
# primary prediction (needs grad)
|
||||
mean_vec_pred = self.predict_noise(
|
||||
noisy_latents=lerp_vector,
|
||||
timesteps=torch.cat([t_frac, r_frac], dim=0) * total_steps,
|
||||
# model predicts u(z, r, t)
|
||||
u_pred = self.predict_noise(
|
||||
noisy_latents=lerp_vector.to(dtype),
|
||||
timesteps=timesteps_cat.to(dtype),
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
# secondary prediction: bump both latent and timestep by ε
|
||||
with torch.no_grad():
|
||||
# lerp_perturbed = lerp_vector + eps * instantaneous_vector
|
||||
t_frac_plus_eps = t_frac + eps # bump time fraction
|
||||
lerp_perturbed = noise_sample * t_frac_plus_eps[:, None, None, None] \
|
||||
+ latents_clean * (1.0 - t_frac_plus_eps[:, None, None, None])
|
||||
t_frac_plus_eps = (t_frac + eps).clamp(0.0, 1.0)
|
||||
lerp_perturbed = latents_clean * (1.0 - t_frac_plus_eps[:, None, None, None]) + noise_sample * t_frac_plus_eps[:, None, None, None]
|
||||
timesteps_cat_perturbed = torch.cat([t_frac_plus_eps, r_frac], dim=0) * total_steps
|
||||
|
||||
f_x_plus_eps_v = self.predict_noise(
|
||||
noisy_latents=lerp_perturbed,
|
||||
timesteps=torch.cat([t_frac_plus_eps, r_frac], dim=0) * total_steps,
|
||||
u_perturbed = self.predict_noise(
|
||||
noisy_latents=lerp_perturbed.to(dtype),
|
||||
timesteps=timesteps_cat_perturbed.to(dtype),
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
# finite-difference JVP: (f(x+εv) – f(x)) / ε
|
||||
mean_vec_grad = (f_x_plus_eps_v - mean_vec_pred) / eps
|
||||
# time_gap = (t_frac - r_frac)[:, None, None, None]
|
||||
# mean_vec_scaler = time_gap / eps
|
||||
# mean_vec_grad = (f_x_plus_eps_v - mean_vec_pred) * mean_vec_scaler
|
||||
# mean_vec_grad = mean_vec_grad.detach() # stop-grad as in Eq. 11
|
||||
# compute du/dt via finite difference (detached)
|
||||
du_dt = (u_perturbed - u_pred).detach() / eps
|
||||
# du_dt = (u_perturbed - u_pred).detach()
|
||||
du_dt = du_dt.to(dtype)
|
||||
|
||||
|
||||
time_gap = (t_frac - r_frac)[:, None, None, None].to(dtype)
|
||||
time_gap.clamp(min=1e-4)
|
||||
u_shifted = u_pred + time_gap * du_dt
|
||||
# u_shifted = u_pred + du_dt / time_gap
|
||||
# u_shifted = u_pred
|
||||
|
||||
# 4) regression target for the mean vector
|
||||
time_gap = (t_frac - r_frac)[:, None, None, None]
|
||||
mean_vec_target = instantaneous_vector - time_gap * mean_vec_grad
|
||||
# mean_vec_target = instantaneous_vector - mean_vec_grad
|
||||
# a step is done like this:
|
||||
# stepped_latent = model_input + (timestep_next - timestep) * model_output
|
||||
|
||||
# flow target velocity
|
||||
# v_target = (noise_sample - latents_clean) / time_gap
|
||||
# flux predicts opposite of velocity, so we need to invert it
|
||||
v_target = (latents_clean - noise_sample) / time_gap
|
||||
|
||||
# # 5) MSE loss
|
||||
# loss = torch.nn.functional.mse_loss(
|
||||
# mean_vec_pred.float(),
|
||||
# mean_vec_target.float()
|
||||
# )
|
||||
# return loss
|
||||
# 5) MSE loss
|
||||
# compute loss
|
||||
loss = torch.nn.functional.mse_loss(
|
||||
mean_vec_pred.float(),
|
||||
mean_vec_target.float(),
|
||||
u_shifted.float(),
|
||||
v_target.float(),
|
||||
reduction='none'
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
pure_loss = loss.mean().detach()
|
||||
# add grad to pure_loss so it can be backwards without issues
|
||||
pure_loss.requires_grad_(True)
|
||||
# normalize the loss per batch element to 1.0
|
||||
# this method has large loss swings that can hurt the model. This method will prevent that
|
||||
# with torch.no_grad():
|
||||
# loss_mean = loss.mean([1, 2, 3], keepdim=True)
|
||||
# loss = loss / loss_mean
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
# backward the pure loss for logging
|
||||
if loss.item() > 1e3:
|
||||
pass
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
# return the real loss for logging
|
||||
return pure_loss
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import inspect
|
||||
import weakref
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models.embeddings import (
|
||||
@@ -15,6 +15,7 @@ if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
from extensions_built_in.diffusion_models.omnigen2.src.models.transformers import OmniGen2Transformer2DModel
|
||||
|
||||
|
||||
def mean_flow_time_text_embed_forward(
|
||||
@@ -60,7 +61,7 @@ def mean_flow_time_text_guidance_embed_forward(
|
||||
# make zero timestep ending if none is passed
|
||||
if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]:
|
||||
timestep = torch.cat(
|
||||
[timestep, torch.zeros_like(timestep)], dim=0
|
||||
[timestep, torch.ones_like(timestep)], dim=0
|
||||
) # timestep - 0 (final timestep) == same as start timestep
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(
|
||||
@@ -111,6 +112,38 @@ def convert_flux_to_mean_flow(
|
||||
)
|
||||
)
|
||||
|
||||
def mean_flow_omnigen2_time_text_embed_forward(
|
||||
self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref()
|
||||
if mean_flow_adapter.is_active and timestep.shape[0] == text_hidden_states.shape[0]:
|
||||
timestep = torch.cat(
|
||||
[timestep, torch.ones_like(timestep)], dim=0 # omnigen does reverse timesteps
|
||||
)
|
||||
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
|
||||
time_embed = self.timestep_embedder(timestep_proj)
|
||||
|
||||
# mean flow stuff
|
||||
if mean_flow_adapter.is_active:
|
||||
# todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps
|
||||
orig_dtype = time_embed.dtype
|
||||
time_embed = time_embed.to(torch.float32)
|
||||
time_embed_start, time_embed_end = time_embed.chunk(2, dim=0)
|
||||
time_embed = mean_flow_adapter.mean_flow_timestep_embedder(
|
||||
torch.cat([time_embed_start, time_embed_end], dim=-1)
|
||||
)
|
||||
time_embed = time_embed.to(orig_dtype)
|
||||
|
||||
caption_embed = self.caption_embedder(text_hidden_states)
|
||||
return time_embed, caption_embed
|
||||
|
||||
|
||||
def convert_omnigen2_to_mean_flow(
|
||||
transformer: 'OmniGen2Transformer2DModel',
|
||||
):
|
||||
transformer.time_caption_embed.forward = partial(
|
||||
mean_flow_omnigen2_time_text_embed_forward, transformer.time_caption_embed
|
||||
)
|
||||
|
||||
class MeanFlowAdapter(torch.nn.Module):
|
||||
def __init__(
|
||||
@@ -193,6 +226,13 @@ class MeanFlowAdapter(torch.nn.Module):
|
||||
* transformer.config.attention_head_dim
|
||||
)
|
||||
convert_flux_to_mean_flow(transformer)
|
||||
|
||||
elif self.model_config.arch in ["omnigen2"]:
|
||||
transformer: 'OmniGen2Transformer2DModel' = sd.unet
|
||||
emb_dim = (
|
||||
1024
|
||||
)
|
||||
convert_omnigen2_to_mean_flow(transformer)
|
||||
else:
|
||||
raise ValueError(f"Unsupported architecture: {self.model_config.arch}")
|
||||
|
||||
@@ -212,6 +252,8 @@ class MeanFlowAdapter(torch.nn.Module):
|
||||
# add our adapter as a weak ref
|
||||
if self.model_config.arch in ["flux", "flex2", "flex2"]:
|
||||
sd.unet.time_text_embed.mean_flow_adapter_ref = weakref.ref(self)
|
||||
elif self.model_config.arch in ["omnigen2"]:
|
||||
sd.unet.time_caption_embed.mean_flow_adapter_ref = weakref.ref(self)
|
||||
|
||||
def get_params(self):
|
||||
if self.lora is not None:
|
||||
|
||||
Reference in New Issue
Block a user