Fixed some issues with training mean flow algo. Still testing WIP

This commit is contained in:
Jaret Burkett
2025-06-14 12:24:00 -06:00
parent 3f0ae99d48
commit c0314ba325
2 changed files with 12 additions and 7 deletions

View File

@@ -179,11 +179,15 @@ def add_model_gpu_splitter_to_flux(
def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection):
# make zero timestep ending if none is passed
if timestep.shape[0] == pooled_projection.shape[0] // 2:
timestep = torch.cat([timestep, timestep], dim=0) # timestep - 0 (final timestep) == same as start timestep
if timestep.shape[0] == pooled_projection.shape[0]:
timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
timesteps_emb_combo = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
timesteps_emb_start, timesteps_emb_end = timesteps_emb_combo.chunk(2, dim=0)
timesteps_emb = timesteps_emb_start + timesteps_emb_end
pooled_projections = self.text_embedder(pooled_projection)
@@ -193,8 +197,8 @@ def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, t
def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection):
# make zero timestep ending if none is passed
if timestep.shape[0] == pooled_projection.shape[0] // 2:
timestep = torch.cat([timestep, timestep], dim=0)
if timestep.shape[0] == pooled_projection.shape[0]:
timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)