Work on mean flow. Minor bug fixes. Omnigen improvements

This commit is contained in:
Jaret Burkett
2025-06-26 13:46:20 -06:00
parent 84c6edca7e
commit 8d9c47316a
4 changed files with 128 additions and 95 deletions

View File

@@ -239,7 +239,10 @@ class OmniGen2Model(BaseModel):
**kwargs
):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
try:
timestep = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
except Exception as e:
pass
# optional_kwargs = {}
# if 'ref_image_hidden_states' in set(inspect.signature(self.model.forward).parameters.keys()):

View File

@@ -305,14 +305,14 @@ class OmniGen2Pipeline(DiffusionPipeline):
)
text_input_ids = text_inputs.input_ids.to(device)
untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
# untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because Gemma can only handle sequences up to"
f" {max_sequence_length} tokens: {removed_text}"
)
# if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
# removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
# logger.warning(
# "The following part of your input was truncated because Gemma can only handle sequences up to"
# f" {max_sequence_length} tokens: {removed_text}"
# )
prompt_attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = self.mllm(

View File

@@ -459,15 +459,33 @@ class SDTrainer(BaseSDTrainProcess):
stepped_latents = torch.cat(stepped_chunks, dim=0)
stepped_latents = stepped_latents.to(self.sd.vae.device, dtype=self.sd.vae.dtype)
# resize to half the size of the latents
stepped_latents_half = torch.nn.functional.interpolate(
stepped_latents,
size=(stepped_latents.shape[2] // 2, stepped_latents.shape[3] // 2),
mode='bilinear',
align_corners=False
)
pred_features = self.dfe(stepped_latents.float())
pred_features_half = self.dfe(stepped_latents_half.float())
with torch.no_grad():
target_features = self.dfe(batch.latents.to(self.device_torch, dtype=torch.float32))
batch_latents_half = torch.nn.functional.interpolate(
batch.latents.to(self.device_torch, dtype=torch.float32),
size=(batch.latents.shape[2] // 2, batch.latents.shape[3] // 2),
mode='bilinear',
align_corners=False
)
target_features_half = self.dfe(batch_latents_half)
# scale dfe so it is weaker at higher noise levels
dfe_scaler = 1 - (timesteps.float() / 1000.0).view(-1, 1, 1, 1).to(self.device_torch)
dfe_loss = torch.nn.functional.mse_loss(pred_features, target_features, reduction="none") * \
self.train_config.diffusion_feature_extractor_weight * dfe_scaler
additional_loss += dfe_loss.mean()
dfe_loss_half = torch.nn.functional.mse_loss(pred_features_half, target_features_half, reduction="none") * \
self.train_config.diffusion_feature_extractor_weight * dfe_scaler
additional_loss += dfe_loss.mean() + dfe_loss_half.mean()
elif self.dfe.version == 2:
# version 2
# do diffusion feature extraction on target
@@ -735,136 +753,106 @@ class SDTrainer(BaseSDTrainProcess):
unconditional_embeds: Optional[PromptEmbeds] = None,
**kwargs
):
# ------------------------------------------------------------------
# “Slow” Mean-Flow loss finite-difference version
# (avoids JVP / double-backprop issues with Flash-Attention)
# ------------------------------------------------------------------
dtype = get_torch_dtype(self.train_config.dtype)
total_steps = float(self.sd.noise_scheduler.config.num_train_timesteps) # 1000
base_eps = 1e-3 # this is one step when multiplied by 1000
dtype = get_torch_dtype(self.train_config.dtype)
total_steps = float(self.sd.noise_scheduler.config.num_train_timesteps) # e.g. 1000
base_eps = 1e-3
min_time_gap = 1e-2
with torch.no_grad():
num_train_timesteps = self.sd.noise_scheduler.config.num_train_timesteps
batch_size = batch.latents.shape[0]
timestep_t_list = []
timestep_r_list = []
for i in range(batch_size):
t1 = random.randint(0, num_train_timesteps - 1)
t2 = random.randint(0, num_train_timesteps - 1)
t_t = self.sd.noise_scheduler.timesteps[min(t1, t2)]
t_r = self.sd.noise_scheduler.timesteps[max(t1, t2)]
if (t_t - t_r).item() < base_eps * 1000:
# we need to ensure the time gap is wider than the epsilon(one step)
scaled_eps = base_eps * 1000
if t_t.item() + scaled_eps > 1000:
t_r = t_r - scaled_eps
if (t_t - t_r).item() < min_time_gap * 1000:
scaled_time_gap = min_time_gap * 1000
if t_t.item() + scaled_time_gap > 1000:
t_r = t_r - scaled_time_gap
else:
t_t = t_t + scaled_eps
t_t = t_t + scaled_time_gap
timestep_t_list.append(t_t)
timestep_r_list.append(t_r)
eps = min((t_t - t_r).item(), 1e-3) / num_train_timesteps
timesteps_t = torch.stack(timestep_t_list, dim=0).float()
timesteps_r = torch.stack(timestep_r_list, dim=0).float()
# fractions in [0,1]
t_frac = timesteps_t / total_steps
r_frac = timesteps_r / total_steps
t_frac = timesteps_t / total_steps # [0,1]
r_frac = timesteps_r / total_steps # [0,1]
# 2) construct data points
latents_clean = batch.latents.to(dtype)
noise_sample = noise.to(dtype)
latents_clean = batch.latents.to(dtype)
noise_sample = noise.to(dtype)
lerp_vector = noise_sample * t_frac[:, None, None, None] \
+ latents_clean * (1.0 - t_frac[:, None, None, None])
lerp_vector = latents_clean * (1.0 - t_frac[:, None, None, None]) + noise_sample * t_frac[:, None, None, None]
if hasattr(self.sd, 'get_loss_target'):
instantaneous_vector = self.sd.get_loss_target(
noise=noise_sample,
batch=batch,
timesteps=timesteps,
).detach()
else:
instantaneous_vector = noise_sample - latents_clean # v_t (B,C,H,W)
eps = base_eps
# 3) finite-difference JVP approximation (bump z **and** t)
# eps_base, eps_jitter = 1e-3, 1e-4
# eps = (eps_base + torch.randn(1, device=lerp_vector.device) * eps_jitter).clamp_(min=1e-4)
jitter = 1e-4
eps = value_map(
torch.rand_like(t_frac),
0.0,
1.0,
base_eps,
base_eps + jitter
)
# eps = (t_frac - r_frac) / 2
# concatenate timesteps as input for u(z, r, t)
timesteps_cat = torch.cat([t_frac, r_frac], dim=0) * total_steps
# eps = 1e-3
# primary prediction (needs grad)
mean_vec_pred = self.predict_noise(
noisy_latents=lerp_vector,
timesteps=torch.cat([t_frac, r_frac], dim=0) * total_steps,
# model predicts u(z, r, t)
u_pred = self.predict_noise(
noisy_latents=lerp_vector.to(dtype),
timesteps=timesteps_cat.to(dtype),
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
# secondary prediction: bump both latent and timestep by ε
with torch.no_grad():
# lerp_perturbed = lerp_vector + eps * instantaneous_vector
t_frac_plus_eps = t_frac + eps # bump time fraction
lerp_perturbed = noise_sample * t_frac_plus_eps[:, None, None, None] \
+ latents_clean * (1.0 - t_frac_plus_eps[:, None, None, None])
t_frac_plus_eps = (t_frac + eps).clamp(0.0, 1.0)
lerp_perturbed = latents_clean * (1.0 - t_frac_plus_eps[:, None, None, None]) + noise_sample * t_frac_plus_eps[:, None, None, None]
timesteps_cat_perturbed = torch.cat([t_frac_plus_eps, r_frac], dim=0) * total_steps
f_x_plus_eps_v = self.predict_noise(
noisy_latents=lerp_perturbed,
timesteps=torch.cat([t_frac_plus_eps, r_frac], dim=0) * total_steps,
u_perturbed = self.predict_noise(
noisy_latents=lerp_perturbed.to(dtype),
timesteps=timesteps_cat_perturbed.to(dtype),
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
# finite-difference JVP: (f(x+εv) f(x)) / ε
mean_vec_grad = (f_x_plus_eps_v - mean_vec_pred) / eps
# time_gap = (t_frac - r_frac)[:, None, None, None]
# mean_vec_scaler = time_gap / eps
# mean_vec_grad = (f_x_plus_eps_v - mean_vec_pred) * mean_vec_scaler
# mean_vec_grad = mean_vec_grad.detach() # stop-grad as in Eq. 11
# compute du/dt via finite difference (detached)
du_dt = (u_perturbed - u_pred).detach() / eps
# du_dt = (u_perturbed - u_pred).detach()
du_dt = du_dt.to(dtype)
time_gap = (t_frac - r_frac)[:, None, None, None].to(dtype)
time_gap.clamp(min=1e-4)
u_shifted = u_pred + time_gap * du_dt
# u_shifted = u_pred + du_dt / time_gap
# u_shifted = u_pred
# 4) regression target for the mean vector
time_gap = (t_frac - r_frac)[:, None, None, None]
mean_vec_target = instantaneous_vector - time_gap * mean_vec_grad
# mean_vec_target = instantaneous_vector - mean_vec_grad
# a step is done like this:
# stepped_latent = model_input + (timestep_next - timestep) * model_output
# flow target velocity
# v_target = (noise_sample - latents_clean) / time_gap
# flux predicts opposite of velocity, so we need to invert it
v_target = (latents_clean - noise_sample) / time_gap
# # 5) MSE loss
# loss = torch.nn.functional.mse_loss(
# mean_vec_pred.float(),
# mean_vec_target.float()
# )
# return loss
# 5) MSE loss
# compute loss
loss = torch.nn.functional.mse_loss(
mean_vec_pred.float(),
mean_vec_target.float(),
u_shifted.float(),
v_target.float(),
reduction='none'
)
with torch.no_grad():
pure_loss = loss.mean().detach()
# add grad to pure_loss so it can be backwards without issues
pure_loss.requires_grad_(True)
# normalize the loss per batch element to 1.0
# this method has large loss swings that can hurt the model. This method will prevent that
# with torch.no_grad():
# loss_mean = loss.mean([1, 2, 3], keepdim=True)
# loss = loss / loss_mean
loss = loss.mean()
# backward the pure loss for logging
if loss.item() > 1e3:
pass
self.accelerator.backward(loss)
# return the real loss for logging
return pure_loss

View File

@@ -1,7 +1,7 @@
import inspect
import weakref
import torch
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple
from toolkit.lora_special import LoRASpecialNetwork
from diffusers import FluxTransformer2DModel
from diffusers.models.embeddings import (
@@ -15,6 +15,7 @@ if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig
from toolkit.custom_adapter import CustomAdapter
from extensions_built_in.diffusion_models.omnigen2.src.models.transformers import OmniGen2Transformer2DModel
def mean_flow_time_text_embed_forward(
@@ -60,7 +61,7 @@ def mean_flow_time_text_guidance_embed_forward(
# make zero timestep ending if none is passed
if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]:
timestep = torch.cat(
[timestep, torch.zeros_like(timestep)], dim=0
[timestep, torch.ones_like(timestep)], dim=0
) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(
@@ -111,6 +112,38 @@ def convert_flux_to_mean_flow(
)
)
def mean_flow_omnigen2_time_text_embed_forward(
self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref()
if mean_flow_adapter.is_active and timestep.shape[0] == text_hidden_states.shape[0]:
timestep = torch.cat(
[timestep, torch.ones_like(timestep)], dim=0 # omnigen does reverse timesteps
)
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
time_embed = self.timestep_embedder(timestep_proj)
# mean flow stuff
if mean_flow_adapter.is_active:
# todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps
orig_dtype = time_embed.dtype
time_embed = time_embed.to(torch.float32)
time_embed_start, time_embed_end = time_embed.chunk(2, dim=0)
time_embed = mean_flow_adapter.mean_flow_timestep_embedder(
torch.cat([time_embed_start, time_embed_end], dim=-1)
)
time_embed = time_embed.to(orig_dtype)
caption_embed = self.caption_embedder(text_hidden_states)
return time_embed, caption_embed
def convert_omnigen2_to_mean_flow(
transformer: 'OmniGen2Transformer2DModel',
):
transformer.time_caption_embed.forward = partial(
mean_flow_omnigen2_time_text_embed_forward, transformer.time_caption_embed
)
class MeanFlowAdapter(torch.nn.Module):
def __init__(
@@ -193,6 +226,13 @@ class MeanFlowAdapter(torch.nn.Module):
* transformer.config.attention_head_dim
)
convert_flux_to_mean_flow(transformer)
elif self.model_config.arch in ["omnigen2"]:
transformer: 'OmniGen2Transformer2DModel' = sd.unet
emb_dim = (
1024
)
convert_omnigen2_to_mean_flow(transformer)
else:
raise ValueError(f"Unsupported architecture: {self.model_config.arch}")
@@ -212,6 +252,8 @@ class MeanFlowAdapter(torch.nn.Module):
# add our adapter as a weak ref
if self.model_config.arch in ["flux", "flex2", "flex2"]:
sd.unet.time_text_embed.mean_flow_adapter_ref = weakref.ref(self)
elif self.model_config.arch in ["omnigen2"]:
sd.unet.time_caption_embed.mean_flow_adapter_ref = weakref.ref(self)
def get_params(self):
if self.lora is not None: