Adjust buckets for flex2

This commit is contained in:
Jaret Burkett
2025-04-02 06:47:41 -06:00
parent 3d131fb27a
commit a42c5a1de5
2 changed files with 4 additions and 1 deletions

View File

@@ -72,6 +72,9 @@ class Flex2(BaseModel):
def get_train_scheduler():
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
def get_bucket_divisibility(self):
return 16
def load_model(self):
dtype = self.torch_dtype
self.print_and_status_update("Loading Flux2 model")

View File

@@ -198,7 +198,7 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.diffusion_feature_extractor_path is not None:
vae = None
if self.model_config.arch != "flux":
if self.model_config.arch != "flux" or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
vae = self.sd.vae
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae)
self.dfe.to(self.device_torch)