mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed some issues with training mean flow algo. Still testing WIP
This commit is contained in:
@@ -135,9 +135,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
super().hook_before_train_loop()
|
||||
if self.train_config.timestep_type == "mean_flow":
|
||||
if self.train_config.loss_type == "mean_flow":
|
||||
# todo handle non flux models
|
||||
convert_flux_to_mean_flow(self.sd.transformer)
|
||||
convert_flux_to_mean_flow(self.sd.unet)
|
||||
|
||||
if self.train_config.do_prior_divergence:
|
||||
self.do_prior_prediction = True
|
||||
@@ -811,6 +811,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
base_eps,
|
||||
base_eps + jitter
|
||||
)
|
||||
# eps = (t_frac - r_frac) / 2
|
||||
|
||||
# eps = 1e-3
|
||||
# primary prediction (needs grad)
|
||||
|
||||
@@ -179,11 +179,15 @@ def add_model_gpu_splitter_to_flux(
|
||||
|
||||
def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, timestep, pooled_projection):
|
||||
# make zero timestep ending if none is passed
|
||||
if timestep.shape[0] == pooled_projection.shape[0] // 2:
|
||||
timestep = torch.cat([timestep, timestep], dim=0) # timestep - 0 (final timestep) == same as start timestep
|
||||
if timestep.shape[0] == pooled_projection.shape[0]:
|
||||
timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep
|
||||
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
timesteps_emb_combo = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
timesteps_emb_start, timesteps_emb_end = timesteps_emb_combo.chunk(2, dim=0)
|
||||
|
||||
timesteps_emb = timesteps_emb_start + timesteps_emb_end
|
||||
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
|
||||
@@ -193,8 +197,8 @@ def mean_flow_time_text_embed_forward(self:CombinedTimestepTextProjEmbeddings, t
|
||||
|
||||
def mean_flow_time_text_guidance_embed_forward(self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection):
|
||||
# make zero timestep ending if none is passed
|
||||
if timestep.shape[0] == pooled_projection.shape[0] // 2:
|
||||
timestep = torch.cat([timestep, timestep], dim=0)
|
||||
if timestep.shape[0] == pooled_projection.shape[0]:
|
||||
timestep = torch.cat([timestep, torch.zeros_like(timestep)], dim=0) # timestep - 0 (final timestep) == same as start timestep
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user