diff --git a/extensions_built_in/flex2/flex2.py b/extensions_built_in/flex2/flex2.py index e3234fe1..0f966c4b 100644 --- a/extensions_built_in/flex2/flex2.py +++ b/extensions_built_in/flex2/flex2.py @@ -72,6 +72,9 @@ class Flex2(BaseModel): def get_train_scheduler(): return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + def get_bucket_divisibility(self): + return 16 + def load_model(self): dtype = self.torch_dtype self.print_and_status_update("Loading Flux2 model") diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 26bc6abf..cc08cdf6 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -198,7 +198,7 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.diffusion_feature_extractor_path is not None: vae = None - if self.model_config.arch != "flux": + if self.model_config.arch != "flux" or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer": vae = self.sd.vae self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae) self.dfe.to(self.device_torch)