mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 18:51:31 +00:00
begin to use new vae impl
This commit is contained in:
@@ -6,11 +6,25 @@ from backend.attention import AttentionProcessorForge
|
|||||||
from diffusers.loaders.single_file_model import convert_ldm_vae_checkpoint
|
from diffusers.loaders.single_file_model import convert_ldm_vae_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
class BaseVAE(AutoencoderKL):
|
||||||
|
def encode(self, x, regulation=None, mode=False):
|
||||||
|
latent_dist = super().encode(x).latent_dist
|
||||||
|
if mode:
|
||||||
|
return latent_dist.mode()
|
||||||
|
elif regulation is not None:
|
||||||
|
return regulation(latent_dist)
|
||||||
|
else:
|
||||||
|
return latent_dist.sample()
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return super().decode(x).sample
|
||||||
|
|
||||||
|
|
||||||
def load_vae_from_state_dict(state_dict):
|
def load_vae_from_state_dict(state_dict):
|
||||||
config = guess_vae_config(state_dict)
|
config = guess_vae_config(state_dict)
|
||||||
|
|
||||||
with using_forge_operations():
|
with using_forge_operations():
|
||||||
model = AutoencoderKL(**config)
|
model = BaseVAE(**config)
|
||||||
|
|
||||||
vae_state_dict = split_state_dict_with_prefix(state_dict, "first_stage_model.")
|
vae_state_dict = split_state_dict_with_prefix(state_dict, "first_stage_model.")
|
||||||
vae_state_dict = convert_ldm_vae_checkpoint(vae_state_dict, config)
|
vae_state_dict = convert_ldm_vae_checkpoint(vae_state_dict, config)
|
||||||
|
|||||||
@@ -163,48 +163,20 @@ class CLIP:
|
|||||||
return self.patcher.get_key_patches()
|
return self.patcher.get_key_patches()
|
||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, sd=None, device=None, config=None, dtype=None, no_init=False):
|
def __init__(self, model=None, mapping=None, device=None, dtype=None, no_init=False):
|
||||||
|
if mapping is None:
|
||||||
|
mapping = {}
|
||||||
|
|
||||||
if no_init:
|
if no_init:
|
||||||
return
|
return
|
||||||
|
|
||||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
|
||||||
|
|
||||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
|
||||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
||||||
self.downscale_ratio = 8
|
self.downscale_ratio = 8
|
||||||
self.latent_channels = 4
|
self.latent_channels = 4
|
||||||
|
|
||||||
if config is None:
|
self.first_stage_model = model.eval()
|
||||||
if "decoder.mid.block_1.mix_factor" in sd:
|
self.state_dict_mapping = mapping
|
||||||
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
|
||||||
decoder_config = encoder_config.copy()
|
|
||||||
decoder_config["video_kernel_size"] = [3, 1, 1]
|
|
||||||
decoder_config["alpha"] = 0.0
|
|
||||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "ldm_patched.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
|
||||||
encoder_config={'target': "ldm_patched.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
|
||||||
decoder_config={'target': "ldm_patched.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
|
||||||
elif "taesd_decoder.1.weight" in sd:
|
|
||||||
self.first_stage_model = ldm_patched.taesd.taesd.TAESD()
|
|
||||||
else:
|
|
||||||
#default SD1.x/SD2.x VAE parameters
|
|
||||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
|
||||||
|
|
||||||
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
|
|
||||||
ddconfig['ch_mult'] = [1, 2, 4]
|
|
||||||
self.downscale_ratio = 4
|
|
||||||
|
|
||||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
|
||||||
else:
|
|
||||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
|
||||||
self.first_stage_model = self.first_stage_model.eval()
|
|
||||||
|
|
||||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
|
||||||
if len(m) > 0:
|
|
||||||
print("Missing VAE keys", m)
|
|
||||||
|
|
||||||
if len(u) > 0:
|
|
||||||
print("Leftover VAE keys", u)
|
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = model_management.vae_device()
|
device = model_management.vae_device()
|
||||||
@@ -216,7 +188,11 @@ class VAE:
|
|||||||
self.first_stage_model.to(self.vae_dtype)
|
self.first_stage_model.to(self.vae_dtype)
|
||||||
self.output_device = model_management.intermediate_device()
|
self.output_device = model_management.intermediate_device()
|
||||||
|
|
||||||
self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(
|
||||||
|
self.first_stage_model,
|
||||||
|
load_device=self.device,
|
||||||
|
offload_device=offload_device
|
||||||
|
)
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = VAE(no_init=True)
|
n = VAE(no_init=True)
|
||||||
@@ -226,6 +202,7 @@ class VAE:
|
|||||||
n.downscale_ratio = self.downscale_ratio
|
n.downscale_ratio = self.downscale_ratio
|
||||||
n.latent_channels = self.latent_channels
|
n.latent_channels = self.latent_channels
|
||||||
n.first_stage_model = self.first_stage_model
|
n.first_stage_model = self.first_stage_model
|
||||||
|
n.state_dict_mapping = self.state_dict_mapping
|
||||||
n.device = self.device
|
n.device = self.device
|
||||||
n.vae_dtype = self.vae_dtype
|
n.vae_dtype = self.vae_dtype
|
||||||
n.output_device = self.output_device
|
n.output_device = self.output_device
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from ldm.util import instantiate_from_config
|
|||||||
from modules_forge import forge_clip
|
from modules_forge import forge_clip
|
||||||
from modules_forge.unet_patcher import UnetPatcher
|
from modules_forge.unet_patcher import UnetPatcher
|
||||||
from ldm_patched.modules.model_base import model_sampling, ModelType
|
from ldm_patched.modules.model_base import model_sampling, ModelType
|
||||||
|
from backend.vae.loader import load_vae_from_state_dict
|
||||||
|
|
||||||
import open_clip
|
import open_clip
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
@@ -105,9 +106,8 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
model.load_model_weights(sd, "model.diffusion_model.")
|
model.load_model_weights(sd, "model.diffusion_model.")
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
|
vae, mapping = load_vae_from_state_dict(sd)
|
||||||
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
vae = VAE(model=vae, mapping=mapping)
|
||||||
vae = VAE(sd=vae_sd)
|
|
||||||
|
|
||||||
if output_clip:
|
if output_clip:
|
||||||
w = WeightsLoader()
|
w = WeightsLoader()
|
||||||
|
|||||||
Reference in New Issue
Block a user