diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 89fb0e05..26bc6abf 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -197,7 +197,10 @@ class SDTrainer(BaseSDTrainProcess): flush() if self.train_config.diffusion_feature_extractor_path is not None: - self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path) + vae = None + if self.model_config.arch != "flux": + vae = self.sd.vae + self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae) self.dfe.to(self.device_torch) self.dfe.eval() diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 29becb69..9244a7c1 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -154,11 +154,12 @@ class DiffusionFeatureExtractor(nn.Module): class DiffusionFeatureExtractor3(nn.Module): - def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16): + def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None): super().__init__() self.version = 3 - vae = AutoencoderTiny.from_pretrained( - "madebyollin/taef1", torch_dtype=torch.bfloat16) + if vae is None: + vae = AutoencoderTiny.from_pretrained( + "madebyollin/taef1", torch_dtype=torch.bfloat16) self.vae = vae image_encoder_path = "google/siglip-so400m-patch14-384" try: @@ -342,9 +343,9 @@ class DiffusionFeatureExtractor3(nn.Module): return total_loss -def load_dfe(model_path) -> DiffusionFeatureExtractor: +def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor: if model_path == "v3": - dfe = DiffusionFeatureExtractor3() + dfe = DiffusionFeatureExtractor3(vae=vae) dfe.eval() return dfe if not os.path.exists(model_path):