mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Adjust buckets for flex2
This commit is contained in:
@@ -72,6 +72,9 @@ class Flex2(BaseModel):
|
|||||||
def get_train_scheduler():
|
def get_train_scheduler():
|
||||||
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||||
|
|
||||||
|
def get_bucket_divisibility(self):
|
||||||
|
return 16
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
dtype = self.torch_dtype
|
dtype = self.torch_dtype
|
||||||
self.print_and_status_update("Loading Flux2 model")
|
self.print_and_status_update("Loading Flux2 model")
|
||||||
|
|||||||
@@ -198,7 +198,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
if self.train_config.diffusion_feature_extractor_path is not None:
|
if self.train_config.diffusion_feature_extractor_path is not None:
|
||||||
vae = None
|
vae = None
|
||||||
if self.model_config.arch != "flux":
|
if self.model_config.arch != "flux" or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
|
||||||
vae = self.sd.vae
|
vae = self.sd.vae
|
||||||
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae)
|
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae)
|
||||||
self.dfe.to(self.device_torch)
|
self.dfe.to(self.device_torch)
|
||||||
|
|||||||
Reference in New Issue
Block a user