From dc62b0d2d705215eeda94a450c2617ffb9eeaa7d Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Tue, 30 Jul 2024 08:43:14 -0600 Subject: [PATCH] begin to use new vae impl --- backend/vae/loader.py | 16 +++++++++++- ldm_patched/modules/sd.py | 49 ++++++++++------------------------- modules_forge/forge_loader.py | 6 ++--- 3 files changed, 31 insertions(+), 40 deletions(-) diff --git a/backend/vae/loader.py b/backend/vae/loader.py index 6d858154..e904ea73 100644 --- a/backend/vae/loader.py +++ b/backend/vae/loader.py @@ -6,11 +6,25 @@ from backend.attention import AttentionProcessorForge from diffusers.loaders.single_file_model import convert_ldm_vae_checkpoint +class BaseVAE(AutoencoderKL): + def encode(self, x, regulation=None, mode=False): + latent_dist = super().encode(x).latent_dist + if mode: + return latent_dist.mode() + elif regulation is not None: + return regulation(latent_dist) + else: + return latent_dist.sample() + + def decode(self, x): + return super().decode(x).sample + + def load_vae_from_state_dict(state_dict): config = guess_vae_config(state_dict) with using_forge_operations(): - model = AutoencoderKL(**config) + model = BaseVAE(**config) vae_state_dict = split_state_dict_with_prefix(state_dict, "first_stage_model.") vae_state_dict = convert_ldm_vae_checkpoint(vae_state_dict, config) diff --git a/ldm_patched/modules/sd.py b/ldm_patched/modules/sd.py index 934c0092..19550a07 100644 --- a/ldm_patched/modules/sd.py +++ b/ldm_patched/modules/sd.py @@ -163,48 +163,20 @@ class CLIP: return self.patcher.get_key_patches() class VAE: - def __init__(self, sd=None, device=None, config=None, dtype=None, no_init=False): + def __init__(self, model=None, mapping=None, device=None, dtype=None, no_init=False): + if mapping is None: + mapping = {} + if no_init: return - if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format - sd = diffusers_convert.convert_vae_state_dict(sd) - - self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) + self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) self.downscale_ratio = 8 self.latent_channels = 4 - if config is None: - if "decoder.mid.block_1.mix_factor" in sd: - encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - decoder_config = encoder_config.copy() - decoder_config["video_kernel_size"] = [3, 1, 1] - decoder_config["alpha"] = 0.0 - self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "ldm_patched.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, - encoder_config={'target': "ldm_patched.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config}, - decoder_config={'target': "ldm_patched.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) - elif "taesd_decoder.1.weight" in sd: - self.first_stage_model = ldm_patched.taesd.taesd.TAESD() - else: - #default SD1.x/SD2.x VAE parameters - ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - - if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE - ddconfig['ch_mult'] = [1, 2, 4] - self.downscale_ratio = 4 - - self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) - else: - self.first_stage_model = AutoencoderKL(**(config['params'])) - self.first_stage_model = self.first_stage_model.eval() - - m, u = self.first_stage_model.load_state_dict(sd, strict=False) - if len(m) > 0: - print("Missing VAE keys", m) - - if len(u) > 0: - print("Leftover VAE keys", u) + self.first_stage_model = model.eval() + self.state_dict_mapping = mapping if device is None: device = model_management.vae_device() @@ -216,7 +188,11 @@ class VAE: self.first_stage_model.to(self.vae_dtype) self.output_device = model_management.intermediate_device() - self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + self.patcher = ldm_patched.modules.model_patcher.ModelPatcher( + self.first_stage_model, + load_device=self.device, + offload_device=offload_device + ) def clone(self): n = VAE(no_init=True) @@ -226,6 +202,7 @@ class VAE: n.downscale_ratio = self.downscale_ratio n.latent_channels = self.latent_channels n.first_stage_model = self.first_stage_model + n.state_dict_mapping = self.state_dict_mapping n.device = self.device n.vae_dtype = self.vae_dtype n.output_device = self.output_device diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index ba206db2..bf342b4e 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -18,6 +18,7 @@ from ldm.util import instantiate_from_config from modules_forge import forge_clip from modules_forge.unet_patcher import UnetPatcher from ldm_patched.modules.model_base import model_sampling, ModelType +from backend.vae.loader import load_vae_from_state_dict import open_clip from transformers import CLIPTextModel, CLIPTokenizer @@ -105,9 +106,8 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c model.load_model_weights(sd, "model.diffusion_model.") if output_vae: - vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) - vae_sd = model_config.process_vae_state_dict(vae_sd) - vae = VAE(sd=vae_sd) + vae, mapping = load_vae_from_state_dict(sd) + vae = VAE(model=vae, mapping=mapping) if output_clip: w = WeightsLoader()