mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
WIP on mean flow loss. Still a WIP.
This commit is contained in:
@@ -4,6 +4,7 @@ from functools import partial
|
||||
from typing import Optional
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, CombinedTimestepGuidanceTextProjEmbeddings
|
||||
|
||||
|
||||
def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection):
|
||||
@@ -174,3 +175,57 @@ def add_model_gpu_splitter_to_flux(
|
||||
|
||||
transformer._pre_gpu_split_to = transformer.to
|
||||
transformer.to = partial(new_device_to, transformer)
|
||||
|
||||
|
||||
def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection):
|
||||
# make zero timestep ending if none is passed
|
||||
if timestep.shape[0] == pooled_projection.shape[0] // 2:
|
||||
timestep = torch.cat([timestep, timestep], dim=0) # timestep - 0 (final timestep) == same as start timestep
|
||||
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
|
||||
conditioning = timesteps_emb + pooled_projections
|
||||
|
||||
return conditioning
|
||||
|
||||
def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection):
|
||||
# make zero timestep ending if none is passed
|
||||
if timestep.shape[0] == pooled_projection.shape[0] // 2:
|
||||
timestep = torch.cat([timestep, timestep], dim=0)
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
|
||||
|
||||
time_guidance_emb = timesteps_emb_start + timesteps_emb_end + guidance_emb
|
||||
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
conditioning = time_guidance_emb + pooled_projections
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
def convert_flux_to_mean_flow(
|
||||
transformer: FluxTransformer2DModel,
|
||||
):
|
||||
if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings):
|
||||
transformer.time_text_embed.forward = partial(
|
||||
mean_flow_time_text_embed_forward, transformer.time_text_embed
|
||||
)
|
||||
elif isinstance(transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings):
|
||||
transformer.time_text_embed.forward = partial(
|
||||
mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported time_text_embed type: {}".format(
|
||||
type(transformer.time_text_embed)
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user