mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-23 00:03:57 +00:00
add the new flux version
Ideally there would be an 'is_flux' bool to check; using `is not shared.sd_model.is_webui_legacy_model():` instead.
This commit is contained in:
@@ -63,7 +63,12 @@ class TAESDDecoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if latent_channels is None:
|
||||
latent_channels = 16 if "taesd3" in str(decoder_path) else 4
|
||||
if "taesd3" in str(decoder_path):
|
||||
latent_channels = 16
|
||||
elif "taef1" in str(decoder_path):
|
||||
latent_channels = 16
|
||||
else:
|
||||
latent_channels = 4
|
||||
|
||||
self.decoder = decoder(latent_channels)
|
||||
self.decoder.load_state_dict(
|
||||
@@ -79,7 +84,12 @@ class TAESDEncoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if latent_channels is None:
|
||||
latent_channels = 16 if "taesd3" in str(encoder_path) else 4
|
||||
if "taesd3" in str(encoder_path):
|
||||
latent_channels = 16
|
||||
elif "taef1" in str(encoder_path):
|
||||
latent_channels = 16
|
||||
else:
|
||||
latent_channels = 4
|
||||
|
||||
self.encoder = encoder(latent_channels)
|
||||
self.encoder.load_state_dict(
|
||||
@@ -95,15 +105,16 @@ def download_model(model_path, model_url):
|
||||
|
||||
|
||||
def decoder_model():
|
||||
if not shared.sd_model.is_webui_legacy_model():
|
||||
return None
|
||||
|
||||
if shared.sd_model.is_sd3:
|
||||
model_name = "taesd3_decoder.pth"
|
||||
elif not shared.sd_model.is_webui_legacy_model(): # ideally would have 'is_flux'
|
||||
model_name = "taef1_decoder.pth"
|
||||
elif shared.sd_model.is_sdxl:
|
||||
model_name = "taesdxl_decoder.pth"
|
||||
else:
|
||||
elif shared.sd_model.is_sd1:
|
||||
model_name = "taesd_decoder.pth"
|
||||
else:
|
||||
return None
|
||||
|
||||
loaded_model = sd_vae_taesd_models.get(model_name)
|
||||
|
||||
@@ -125,6 +136,8 @@ def decoder_model():
|
||||
def encoder_model():
|
||||
if shared.sd_model.is_sd3:
|
||||
model_name = "taesd3_encoder.pth"
|
||||
elif not shared.sd_model.is_webui_legacy_model(): # ideally would have 'is_flux'
|
||||
model_name = "taef1_encoder.pth"
|
||||
elif shared.sd_model.is_sdxl:
|
||||
model_name = "taesdxl_encoder.pth"
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user