Work on mean flow. Minor bug fixes. Omnigen improvements

This commit is contained in:
Jaret Burkett
2025-06-26 13:46:20 -06:00
parent 84c6edca7e
commit 8d9c47316a
4 changed files with 128 additions and 95 deletions

View File

@@ -1,7 +1,7 @@
import inspect
import weakref
import torch
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple
from toolkit.lora_special import LoRASpecialNetwork
from diffusers import FluxTransformer2DModel
from diffusers.models.embeddings import (
@@ -15,6 +15,7 @@ if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig
from toolkit.custom_adapter import CustomAdapter
from extensions_built_in.diffusion_models.omnigen2.src.models.transformers import OmniGen2Transformer2DModel
def mean_flow_time_text_embed_forward(
@@ -60,7 +61,7 @@ def mean_flow_time_text_guidance_embed_forward(
# make zero timestep ending if none is passed
if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]:
timestep = torch.cat(
[timestep, torch.zeros_like(timestep)], dim=0
[timestep, torch.ones_like(timestep)], dim=0
) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(
@@ -111,6 +112,38 @@ def convert_flux_to_mean_flow(
)
)
def mean_flow_omnigen2_time_text_embed_forward(
self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref()
if mean_flow_adapter.is_active and timestep.shape[0] == text_hidden_states.shape[0]:
timestep = torch.cat(
[timestep, torch.ones_like(timestep)], dim=0 # omnigen does reverse timesteps
)
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
time_embed = self.timestep_embedder(timestep_proj)
# mean flow stuff
if mean_flow_adapter.is_active:
# todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps
orig_dtype = time_embed.dtype
time_embed = time_embed.to(torch.float32)
time_embed_start, time_embed_end = time_embed.chunk(2, dim=0)
time_embed = mean_flow_adapter.mean_flow_timestep_embedder(
torch.cat([time_embed_start, time_embed_end], dim=-1)
)
time_embed = time_embed.to(orig_dtype)
caption_embed = self.caption_embedder(text_hidden_states)
return time_embed, caption_embed
def convert_omnigen2_to_mean_flow(
transformer: 'OmniGen2Transformer2DModel',
):
transformer.time_caption_embed.forward = partial(
mean_flow_omnigen2_time_text_embed_forward, transformer.time_caption_embed
)
class MeanFlowAdapter(torch.nn.Module):
def __init__(
@@ -193,6 +226,13 @@ class MeanFlowAdapter(torch.nn.Module):
* transformer.config.attention_head_dim
)
convert_flux_to_mean_flow(transformer)
elif self.model_config.arch in ["omnigen2"]:
transformer: 'OmniGen2Transformer2DModel' = sd.unet
emb_dim = (
1024
)
convert_omnigen2_to_mean_flow(transformer)
else:
raise ValueError(f"Unsupported architecture: {self.model_config.arch}")
@@ -212,6 +252,8 @@ class MeanFlowAdapter(torch.nn.Module):
# add our adapter as a weak ref
if self.model_config.arch in ["flux", "flex2", "flex2"]:
sd.unet.time_text_embed.mean_flow_adapter_ref = weakref.ref(self)
elif self.model_config.arch in ["omnigen2"]:
sd.unet.time_caption_embed.mean_flow_adapter_ref = weakref.ref(self)
def get_params(self):
if self.lora is not None: