From a42c5a1de52bdeef8e02f2bb46ac0ab18bf12daf Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 2 Apr 2025 06:47:41 -0600 Subject: [PATCH] Adjust buckets for flex2 --- extensions_built_in/flex2/flex2.py | 3 +++ extensions_built_in/sd_trainer/SDTrainer.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/extensions_built_in/flex2/flex2.py b/extensions_built_in/flex2/flex2.py index e3234fe1..0f966c4b 100644 --- a/extensions_built_in/flex2/flex2.py +++ b/extensions_built_in/flex2/flex2.py @@ -72,6 +72,9 @@ class Flex2(BaseModel): def get_train_scheduler(): return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + def get_bucket_divisibility(self): + return 16 + def load_model(self): dtype = self.torch_dtype self.print_and_status_update("Loading Flux2 model") diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 26bc6abf..cc08cdf6 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -198,7 +198,7 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.diffusion_feature_extractor_path is not None: vae = None - if self.model_config.arch != "flux": + if self.model_config.arch != "flux" or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer": vae = self.sd.vae self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae) self.dfe.to(self.device_torch)