Fixed some issues with training mean flow algo. Still testing WIP

This commit is contained in:
Jaret Burkett
2025-06-14 12:24:00 -06:00
parent 3f0ae99d48
commit c0314ba325
2 changed files with 12 additions and 7 deletions

View File

@@ -135,9 +135,9 @@ class SDTrainer(BaseSDTrainProcess):
def hook_before_train_loop(self):
super().hook_before_train_loop()
if self.train_config.timestep_type == "mean_flow":
if self.train_config.loss_type == "mean_flow":
# todo handle non flux models
convert_flux_to_mean_flow(self.sd.transformer)
convert_flux_to_mean_flow(self.sd.unet)
if self.train_config.do_prior_divergence:
self.do_prior_prediction = True
@@ -811,6 +811,7 @@ class SDTrainer(BaseSDTrainProcess):
base_eps,
base_eps + jitter
)
# eps = (t_frac - r_frac) / 2
# eps = 1e-3
# primary prediction (needs grad)