mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Fixed some issues with training mean flow algo. Still testing WIP
This commit is contained in:
@@ -179,11 +179,15 @@ def add_model_gpu_splitter_to_flux(
|
||||
|
||||
def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection):
|
||||
# make zero timestep ending if none is passed
|
||||
if timestep.shape[0] == pooled_projection.shape[0] // 2:
|
||||
timestep = torch.cat([timestep, timestep], dim=0) # timestep - 0 (final timestep) == same as start timestep
|
||||
if timestep.shape[0] == pooled_projection.shape[0]:
|
||||
timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep
|
||||
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
timesteps_emb_combo = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
timesteps_emb_start, timesteps_emb_end = timesteps_emb_combo.chunk(2, dim=0)
|
||||
|
||||
timesteps_emb = timesteps_emb_start + timesteps_emb_end
|
||||
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
|
||||
@@ -193,8 +197,8 @@ def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, t
|
||||
|
||||
def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection):
|
||||
# make zero timestep ending if none is passed
|
||||
if timestep.shape[0] == pooled_projection.shape[0] // 2:
|
||||
timestep = torch.cat([timestep, timestep], dim=0)
|
||||
if timestep.shape[0] == pooled_projection.shape[0]:
|
||||
timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user