From 7ffd124a9e23fc5bedb6377bba5b37925dea79a9 Mon Sep 17 00:00:00 2001 From: DenOfEquity <166248528+DenOfEquity@users.noreply.github.com> Date: Thu, 29 Aug 2024 16:41:47 +0100 Subject: [PATCH 1/2] add the new flux version Ideally there would be an 'is_flux' bool to check; using `is not shared.sd_model.is_webui_legacy_model():` instead. --- modules/sd_vae_taesd.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 0a8bd668..a835548f 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -63,7 +63,12 @@ class TAESDDecoder(nn.Module): super().__init__() if latent_channels is None: - latent_channels = 16 if "taesd3" in str(decoder_path) else 4 + if "taesd3" in str(decoder_path): + latent_channels = 16 + elif "taef1" in str(decoder_path): + latent_channels = 16 + else: + latent_channels = 4 self.decoder = decoder(latent_channels) self.decoder.load_state_dict( @@ -79,7 +84,12 @@ class TAESDEncoder(nn.Module): super().__init__() if latent_channels is None: - latent_channels = 16 if "taesd3" in str(encoder_path) else 4 + if "taesd3" in str(encoder_path): + latent_channels = 16 + elif "taef1" in str(encoder_path): + latent_channels = 16 + else: + latent_channels = 4 self.encoder = encoder(latent_channels) self.encoder.load_state_dict( @@ -95,15 +105,16 @@ def download_model(model_path, model_url): def decoder_model(): - if not shared.sd_model.is_webui_legacy_model(): - return None - if shared.sd_model.is_sd3: model_name = "taesd3_decoder.pth" + elif not shared.sd_model.is_webui_legacy_model(): # ideally would have 'is_flux' + model_name = "taef1_decoder.pth" elif shared.sd_model.is_sdxl: model_name = "taesdxl_decoder.pth" - else: + elif shared.sd_model.is_sd1: model_name = "taesd_decoder.pth" + else: + return None loaded_model = sd_vae_taesd_models.get(model_name) @@ -125,6 +136,8 @@ def decoder_model(): def encoder_model(): if shared.sd_model.is_sd3: model_name = "taesd3_encoder.pth" + elif not shared.sd_model.is_webui_legacy_model(): # ideally would have 'is_flux' + model_name = "taef1_encoder.pth" elif shared.sd_model.is_sdxl: model_name = "taesdxl_encoder.pth" else: From f36fa7dc1a2d0369615992d8fef2699cd53cb240 Mon Sep 17 00:00:00 2001 From: DenOfEquity <166248528+DenOfEquity@users.noreply.github.com> Date: Sat, 31 Aug 2024 13:05:02 +0100 Subject: [PATCH 2/2] fix for sd2 --- modules/sd_vae_taesd.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index a835548f..a63f6e6a 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -111,10 +111,8 @@ def decoder_model(): model_name = "taef1_decoder.pth" elif shared.sd_model.is_sdxl: model_name = "taesdxl_decoder.pth" - elif shared.sd_model.is_sd1: - model_name = "taesd_decoder.pth" else: - return None + model_name = "taesd_decoder.pth" loaded_model = sd_vae_taesd_models.get(model_name)