mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Fixed some issues with training mean flow algo. Still testing WIP
This commit is contained in:
@@ -135,9 +135,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
super().hook_before_train_loop()
|
||||
if self.train_config.timestep_type == "mean_flow":
|
||||
if self.train_config.loss_type == "mean_flow":
|
||||
# todo handle non flux models
|
||||
convert_flux_to_mean_flow(self.sd.transformer)
|
||||
convert_flux_to_mean_flow(self.sd.unet)
|
||||
|
||||
if self.train_config.do_prior_divergence:
|
||||
self.do_prior_prediction = True
|
||||
@@ -811,6 +811,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
base_eps,
|
||||
base_eps + jitter
|
||||
)
|
||||
# eps = (t_frac - r_frac) / 2
|
||||
|
||||
# eps = 1e-3
|
||||
# primary prediction (needs grad)
|
||||
|
||||
Reference in New Issue
Block a user