Fixed some issues with training mean flow algo. Still testing WIP

This commit is contained in:
Jaret Burkett
2025-06-14 12:24:00 -06:00
parent 0946a66576
commit cbf04b8d53
2 changed files with 12 additions and 7 deletions

View File

@@ -135,9 +135,9 @@ class SDTrainer(BaseSDTrainProcess):
def hook_before_train_loop(self):
super().hook_before_train_loop()
if self.train_config.timestep_type == "mean_flow":
if self.train_config.loss_type == "mean_flow":
# todo handle non flux models
convert_flux_to_mean_flow(self.sd.transformer)
convert_flux_to_mean_flow(self.sd.unet)
if self.train_config.do_prior_divergence:
self.do_prior_prediction = True
@@ -811,6 +811,7 @@ class SDTrainer(BaseSDTrainProcess):
base_eps,
base_eps + jitter
)
# eps = (t_frac - r_frac) / 2
# eps = 1e-3
# primary prediction (needs grad)

View File

@@ -179,11 +179,15 @@ def add_model_gpu_splitter_to_flux(
def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection):
# make zero timestep ending if none is passed
if timestep.shape[0] == pooled_projection.shape[0] // 2:
timestep = torch.cat([timestep, timestep], dim=0) # timestep - 0 (final timestep) == same as start timestep
if timestep.shape[0] == pooled_projection.shape[0]:
timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
timesteps_emb_combo = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
timesteps_emb_start, timesteps_emb_end = timesteps_emb_combo.chunk(2, dim=0)
timesteps_emb = timesteps_emb_start + timesteps_emb_end
pooled_projections = self.text_embedder(pooled_projection)
@@ -193,8 +197,8 @@ def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, t
def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection):
# make zero timestep ending if none is passed
if timestep.shape[0] == pooled_projection.shape[0] // 2:
timestep = torch.cat([timestep, timestep], dim=0)
if timestep.shape[0] == pooled_projection.shape[0]:
timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)