Merge pull request #1582 from DenOfEquity/tae-flux

add the new Tiny AutoEncoder for flux by madebyollin
This commit is contained in:
DenOfEquity
2024-08-31 13:15:10 +01:00
committed by GitHub

View File

@@ -63,7 +63,12 @@ class TAESDDecoder(nn.Module):
super().__init__()
if latent_channels is None:
latent_channels = 16 if "taesd3" in str(decoder_path) else 4
if "taesd3" in str(decoder_path):
latent_channels = 16
elif "taef1" in str(decoder_path):
latent_channels = 16
else:
latent_channels = 4
self.decoder = decoder(latent_channels)
self.decoder.load_state_dict(
@@ -79,7 +84,12 @@ class TAESDEncoder(nn.Module):
super().__init__()
if latent_channels is None:
latent_channels = 16 if "taesd3" in str(encoder_path) else 4
if "taesd3" in str(encoder_path):
latent_channels = 16
elif "taef1" in str(encoder_path):
latent_channels = 16
else:
latent_channels = 4
self.encoder = encoder(latent_channels)
self.encoder.load_state_dict(
@@ -95,11 +105,10 @@ def download_model(model_path, model_url):
def decoder_model():
if not shared.sd_model.is_webui_legacy_model():
return None
if shared.sd_model.is_sd3:
model_name = "taesd3_decoder.pth"
elif not shared.sd_model.is_webui_legacy_model(): # ideally would have 'is_flux'
model_name = "taef1_decoder.pth"
elif shared.sd_model.is_sdxl:
model_name = "taesdxl_decoder.pth"
else:
@@ -125,6 +134,8 @@ def decoder_model():
def encoder_model():
if shared.sd_model.is_sd3:
model_name = "taesd3_encoder.pth"
elif not shared.sd_model.is_webui_legacy_model(): # ideally would have 'is_flux'
model_name = "taef1_encoder.pth"
elif shared.sd_model.is_sdxl:
model_name = "taesdxl_encoder.pth"
else: