Allow DFE to not have a VAE

This commit is contained in:
Jaret Burkett
2025-03-30 09:23:01 -06:00
parent 860d892214
commit c083a0e5ea
2 changed files with 10 additions and 6 deletions

View File

@@ -197,7 +197,10 @@ class SDTrainer(BaseSDTrainProcess):
flush()
if self.train_config.diffusion_feature_extractor_path is not None:
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path)
vae = None
if self.model_config.arch != "flux":
vae = self.sd.vae
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae)
self.dfe.to(self.device_torch)
self.dfe.eval()