From 69b1827ed5715fdfc46b30d3f3681a6f22c21eea Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Wed, 7 Aug 2024 18:00:58 -0700 Subject: [PATCH] revise preview logics --- backend/diffusion_engine/base.py | 3 +++ modules/sd_samplers_common.py | 14 +++++++++--- modules/sd_vae_approx.py | 37 ++++---------------------------- modules/sd_vae_taesd.py | 3 +++ 4 files changed, 21 insertions(+), 36 deletions(-) diff --git a/backend/diffusion_engine/base.py b/backend/diffusion_engine/base.py index 5ad5ac2f..1a697232 100644 --- a/backend/diffusion_engine/base.py +++ b/backend/diffusion_engine/base.py @@ -39,6 +39,9 @@ class ForgeDiffusionEngine: self.is_sdxl = False self.is_sd3 = False + def is_webui_legacy_model(self): + return self.is_sd1 or self.is_sd2 or self.is_sdxl or self.is_sd3 + def set_clip_skip(self, clip_skip): pass diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index fee65f83..bdfccee1 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -47,10 +47,18 @@ def samples_to_images_tensor(sample, approximation=None, model=None): if approximation == 2: x_sample = sd_vae_approx.cheap_approximation(sample) elif approximation == 1: - x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach() + m = sd_vae_approx.model() + if m is None: + x_sample = sd_vae_approx.cheap_approximation(sample) + else: + x_sample = m(sample.to(devices.device, devices.dtype)).detach() elif approximation == 3: - x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach() - x_sample = x_sample * 2 - 1 + m = sd_vae_taesd.decoder_model() + if m is None: + x_sample = sd_vae_approx.cheap_approximation(sample) + else: + x_sample = m(sample.to(devices.device, devices.dtype)).detach() + x_sample = x_sample * 2 - 1 else: if model is None: model = shared.sd_model diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index 7f7ff068..91a962b7 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -40,6 +40,9 @@ def download_model(model_path, model_url): def model(): + if not shared.sd_model.is_webui_legacy_model(): + return None + if shared.sd_model.is_sd3: model_name = "vaeapprox-sd3.pt" elif shared.sd_model.is_sdxl: @@ -68,36 +71,4 @@ def model(): def cheap_approximation(sample): - # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2 - - if shared.sd_model.is_sd3: - coeffs = [ - [-0.0645, 0.0177, 0.1052], [ 0.0028, 0.0312, 0.0650], - [ 0.1848, 0.0762, 0.0360], [ 0.0944, 0.0360, 0.0889], - [ 0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284], - [ 0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047], - [-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039], - [ 0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481], - [ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867], - [-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259], - ] - elif shared.sd_model.is_sdxl: - coeffs = [ - [ 0.3448, 0.4168, 0.4395], - [-0.1953, -0.0290, 0.0250], - [ 0.1074, 0.0886, -0.0163], - [-0.3730, -0.2499, -0.2088], - ] - else: - coeffs = [ - [ 0.298, 0.207, 0.208], - [ 0.187, 0.286, 0.173], - [-0.158, 0.189, 0.264], - [-0.184, -0.271, -0.473], - ] - - coefs = torch.tensor(coeffs).to(sample.device) - - x_sample = torch.einsum("...lxy,lr -> ...rxy", sample, coefs) - - return x_sample + return torch.einsum("...lxy,lr -> ...rxy", sample, torch.tensor(shared.sd_model.model_config.latent_format.latent_rgb_factors).to(sample.device)) diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index d06253d2..0a8bd668 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -95,6 +95,9 @@ def download_model(model_path, model_url): def decoder_model(): + if not shared.sd_model.is_webui_legacy_model(): + return None + if shared.sd_model.is_sd3: model_name = "taesd3_decoder.pth" elif shared.sd_model.is_sdxl: