add the new flux version

Ideally there would be an 'is_flux' bool to check; using `is not shared.sd_model.is_webui_legacy_model():` instead.
This commit is contained in:
DenOfEquity
2024-08-29 16:41:47 +01:00
committed by GitHub
parent 948e91458a
commit 7ffd124a9e

View File

@@ -63,7 +63,12 @@ class TAESDDecoder(nn.Module):
super().__init__()
if latent_channels is None:
latent_channels = 16 if "taesd3" in str(decoder_path) else 4
if "taesd3" in str(decoder_path):
latent_channels = 16
elif "taef1" in str(decoder_path):
latent_channels = 16
else:
latent_channels = 4
self.decoder = decoder(latent_channels)
self.decoder.load_state_dict(
@@ -79,7 +84,12 @@ class TAESDEncoder(nn.Module):
super().__init__()
if latent_channels is None:
latent_channels = 16 if "taesd3" in str(encoder_path) else 4
if "taesd3" in str(encoder_path):
latent_channels = 16
elif "taef1" in str(encoder_path):
latent_channels = 16
else:
latent_channels = 4
self.encoder = encoder(latent_channels)
self.encoder.load_state_dict(
@@ -95,15 +105,16 @@ def download_model(model_path, model_url):
def decoder_model():
if not shared.sd_model.is_webui_legacy_model():
return None
if shared.sd_model.is_sd3:
model_name = "taesd3_decoder.pth"
elif not shared.sd_model.is_webui_legacy_model(): # ideally would have 'is_flux'
model_name = "taef1_decoder.pth"
elif shared.sd_model.is_sdxl:
model_name = "taesdxl_decoder.pth"
else:
elif shared.sd_model.is_sd1:
model_name = "taesd_decoder.pth"
else:
return None
loaded_model = sd_vae_taesd_models.get(model_name)
@@ -125,6 +136,8 @@ def decoder_model():
def encoder_model():
if shared.sd_model.is_sd3:
model_name = "taesd3_encoder.pth"
elif not shared.sd_model.is_webui_legacy_model(): # ideally would have 'is_flux'
model_name = "taef1_encoder.pth"
elif shared.sd_model.is_sdxl:
model_name = "taesdxl_encoder.pth"
else: