diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 0a8bd668..a63f6e6a 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -63,7 +63,12 @@ class TAESDDecoder(nn.Module): super().__init__() if latent_channels is None: - latent_channels = 16 if "taesd3" in str(decoder_path) else 4 + if "taesd3" in str(decoder_path): + latent_channels = 16 + elif "taef1" in str(decoder_path): + latent_channels = 16 + else: + latent_channels = 4 self.decoder = decoder(latent_channels) self.decoder.load_state_dict( @@ -79,7 +84,12 @@ class TAESDEncoder(nn.Module): super().__init__() if latent_channels is None: - latent_channels = 16 if "taesd3" in str(encoder_path) else 4 + if "taesd3" in str(encoder_path): + latent_channels = 16 + elif "taef1" in str(encoder_path): + latent_channels = 16 + else: + latent_channels = 4 self.encoder = encoder(latent_channels) self.encoder.load_state_dict( @@ -95,11 +105,10 @@ def download_model(model_path, model_url): def decoder_model(): - if not shared.sd_model.is_webui_legacy_model(): - return None - if shared.sd_model.is_sd3: model_name = "taesd3_decoder.pth" + elif not shared.sd_model.is_webui_legacy_model(): # ideally would have 'is_flux' + model_name = "taef1_decoder.pth" elif shared.sd_model.is_sdxl: model_name = "taesdxl_decoder.pth" else: @@ -125,6 +134,8 @@ def decoder_model(): def encoder_model(): if shared.sd_model.is_sd3: model_name = "taesd3_encoder.pth" + elif not shared.sd_model.is_webui_legacy_model(): # ideally would have 'is_flux' + model_name = "taef1_encoder.pth" elif shared.sd_model.is_sdxl: model_name = "taesdxl_encoder.pth" else: