mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Work on mean flow. Minor bug fixes. Omnigen improvements
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import inspect
|
||||
import weakref
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models.embeddings import (
|
||||
@@ -15,6 +15,7 @@ if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
from extensions_built_in.diffusion_models.omnigen2.src.models.transformers import OmniGen2Transformer2DModel
|
||||
|
||||
|
||||
def mean_flow_time_text_embed_forward(
|
||||
@@ -60,7 +61,7 @@ def mean_flow_time_text_guidance_embed_forward(
|
||||
# make zero timestep ending if none is passed
|
||||
if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]:
|
||||
timestep = torch.cat(
|
||||
[timestep, torch.zeros_like(timestep)], dim=0
|
||||
[timestep, torch.ones_like(timestep)], dim=0
|
||||
) # timestep - 0 (final timestep) == same as start timestep
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(
|
||||
@@ -111,6 +112,38 @@ def convert_flux_to_mean_flow(
|
||||
)
|
||||
)
|
||||
|
||||
def mean_flow_omnigen2_time_text_embed_forward(
|
||||
self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref()
|
||||
if mean_flow_adapter.is_active and timestep.shape[0] == text_hidden_states.shape[0]:
|
||||
timestep = torch.cat(
|
||||
[timestep, torch.ones_like(timestep)], dim=0 # omnigen does reverse timesteps
|
||||
)
|
||||
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
|
||||
time_embed = self.timestep_embedder(timestep_proj)
|
||||
|
||||
# mean flow stuff
|
||||
if mean_flow_adapter.is_active:
|
||||
# todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps
|
||||
orig_dtype = time_embed.dtype
|
||||
time_embed = time_embed.to(torch.float32)
|
||||
time_embed_start, time_embed_end = time_embed.chunk(2, dim=0)
|
||||
time_embed = mean_flow_adapter.mean_flow_timestep_embedder(
|
||||
torch.cat([time_embed_start, time_embed_end], dim=-1)
|
||||
)
|
||||
time_embed = time_embed.to(orig_dtype)
|
||||
|
||||
caption_embed = self.caption_embedder(text_hidden_states)
|
||||
return time_embed, caption_embed
|
||||
|
||||
|
||||
def convert_omnigen2_to_mean_flow(
|
||||
transformer: 'OmniGen2Transformer2DModel',
|
||||
):
|
||||
transformer.time_caption_embed.forward = partial(
|
||||
mean_flow_omnigen2_time_text_embed_forward, transformer.time_caption_embed
|
||||
)
|
||||
|
||||
class MeanFlowAdapter(torch.nn.Module):
|
||||
def __init__(
|
||||
@@ -193,6 +226,13 @@ class MeanFlowAdapter(torch.nn.Module):
|
||||
* transformer.config.attention_head_dim
|
||||
)
|
||||
convert_flux_to_mean_flow(transformer)
|
||||
|
||||
elif self.model_config.arch in ["omnigen2"]:
|
||||
transformer: 'OmniGen2Transformer2DModel' = sd.unet
|
||||
emb_dim = (
|
||||
1024
|
||||
)
|
||||
convert_omnigen2_to_mean_flow(transformer)
|
||||
else:
|
||||
raise ValueError(f"Unsupported architecture: {self.model_config.arch}")
|
||||
|
||||
@@ -212,6 +252,8 @@ class MeanFlowAdapter(torch.nn.Module):
|
||||
# add our adapter as a weak ref
|
||||
if self.model_config.arch in ["flux", "flex2", "flex2"]:
|
||||
sd.unet.time_text_embed.mean_flow_adapter_ref = weakref.ref(self)
|
||||
elif self.model_config.arch in ["omnigen2"]:
|
||||
sd.unet.time_caption_embed.mean_flow_adapter_ref = weakref.ref(self)
|
||||
|
||||
def get_params(self):
|
||||
if self.lora is not None:
|
||||
|
||||
Reference in New Issue
Block a user