More work on mean flow loss. Moved it to an adapter. Still not functioning properly though.

This commit is contained in:
Jaret Burkett
2025-06-16 07:17:35 -06:00
parent c0314ba325
commit 1c2b7298dd
6 changed files with 323 additions and 165 deletions

View File

@@ -176,60 +176,3 @@ def add_model_gpu_splitter_to_flux(
transformer._pre_gpu_split_to = transformer.to
transformer.to = partial(new_device_to, transformer)
def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection):
# make zero timestep ending if none is passed
if timestep.shape[0] == pooled_projection.shape[0]:
timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb_combo = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
timesteps_emb_start, timesteps_emb_end = timesteps_emb_combo.chunk(2, dim=0)
timesteps_emb = timesteps_emb_start + timesteps_emb_end
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
return conditioning
def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection):
# make zero timestep ending if none is passed
if timestep.shape[0] == pooled_projection.shape[0]:
timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0)
time_guidance_emb = timesteps_emb_start + timesteps_emb_end + guidance_emb
pooled_projections = self.text_embedder(pooled_projection)
conditioning = time_guidance_emb + pooled_projections
return conditioning
def convert_flux_to_mean_flow(
transformer: FluxTransformer2DModel,
):
if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings):
transformer.time_text_embed.forward = partial(
mean_flow_time_text_embed_forward, transformer.time_text_embed
)
elif isinstance(transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings):
transformer.time_text_embed.forward = partial(
mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed
)
else:
raise ValueError(
"Unsupported time_text_embed type: {}".format(
type(transformer.time_text_embed)
)
)