diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 6e1daf39..f3813133 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -135,9 +135,9 @@ class SDTrainer(BaseSDTrainProcess): def hook_before_train_loop(self): super().hook_before_train_loop() - if self.train_config.timestep_type == "mean_flow": + if self.train_config.loss_type == "mean_flow": # todo handle non flux models - convert_flux_to_mean_flow(self.sd.transformer) + convert_flux_to_mean_flow(self.sd.unet) if self.train_config.do_prior_divergence: self.do_prior_prediction = True @@ -811,6 +811,7 @@ class SDTrainer(BaseSDTrainProcess): base_eps, base_eps + jitter ) + # eps = (t_frac - r_frac) / 2 # eps = 1e-3 # primary prediction (needs grad) diff --git a/toolkit/models/flux.py b/toolkit/models/flux.py index 42194179..5d3064f4 100644 --- a/toolkit/models/flux.py +++ b/toolkit/models/flux.py @@ -179,11 +179,15 @@ def add_model_gpu_splitter_to_flux( def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection): # make zero timestep ending if none is passed - if timestep.shape[0] == pooled_projection.shape[0] // 2: - timestep = torch.cat([timestep, timestep], dim=0) # timestep - 0 (final timestep) == same as start timestep + if timestep.shape[0] == pooled_projection.shape[0]: + timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + timesteps_emb_combo = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + + timesteps_emb_start, timesteps_emb_end = timesteps_emb_combo.chunk(2, dim=0) + + timesteps_emb = timesteps_emb_start + timesteps_emb_end pooled_projections = self.text_embedder(pooled_projection) @@ -193,8 +197,8 @@ def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, t def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection): # make zero timestep ending if none is passed - if timestep.shape[0] == pooled_projection.shape[0] // 2: - timestep = torch.cat([timestep, timestep], dim=0) + if timestep.shape[0] == pooled_projection.shape[0]: + timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)