diff --git a/backend/attention.py b/backend/attention.py index f6e965be..5e891104 100644 --- a/backend/attention.py +++ b/backend/attention.py @@ -54,14 +54,39 @@ def attention_pytorch(q, k, v, heads, mask=None): return out +def attention_xformers_single_head(q, k, v): + B, C, H, W = q.shape + q, k, v = map( + lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + out = out.transpose(1, 2).reshape(B, C, H, W) + return out + + +def attention_pytorch_single_head(q, k, v): + B, C, H, W = q.shape + q, k, v = map( + lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), + (q, k, v), + ) + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = out.transpose(2, 3).reshape(B, C, H, W) + return out + + attention_function = attention_pytorch +attention_function_single_head = attention_pytorch_single_head if args.xformers: print("Using xformers cross attention") attention_function = attention_xformers + attention_function_single_head = attention_xformers_single_head else: print("Using pytorch cross attention") attention_function = attention_pytorch + attention_function_single_head = attention_pytorch_single_head class AttentionProcessorForge: diff --git a/backend/loader.py b/backend/loader.py index 810e355f..f513f391 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -1,11 +1,8 @@ import os import importlib -import diffusers -import transformers from diffusers.loaders.single_file_utils import fetch_diffusers_config from diffusers import DiffusionPipeline -from diffusers import AutoencoderKL from backend.vae import load_vae @@ -13,13 +10,12 @@ dir_path = os.path.dirname(__file__) def load_component(component_name, lib_name, cls_name, repo_path, sd): + config_path = os.path.join(repo_path, component_name) if component_name in ['scheduler', 'tokenizer']: cls = getattr(importlib.import_module(lib_name), cls_name) return cls.from_pretrained(os.path.join(repo_path, component_name)) if cls_name in ['AutoencoderKL']: - config = AutoencoderKL.load_config(os.path.join(repo_path, component_name)) - return load_vae(sd, config) - + return load_vae(sd, config_path) return None diff --git a/backend/nn/autoencoder_kl.py b/backend/nn/autoencoder_kl.py new file mode 100644 index 00000000..8019c065 --- /dev/null +++ b/backend/nn/autoencoder_kl.py @@ -0,0 +1,422 @@ +import torch +import numpy as np + +from backend.attention import attention_function_single_head +from diffusers.configuration_utils import ConfigMixin, register_to_config +from typing import Optional, Tuple +from torch import nn + + +def nonlinearity(x): + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class DiagonalGaussianDistribution: + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def mode(self): + return self.mean + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + try: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + except Exception as e: + b, c, h, w = x.shape + out = torch.empty((b, c, h * 2, w * 2), dtype=x.dtype, layout=x.layout, device=x.device) + split = 8 + l = out.shape[1] // split + for i in range(0, out.shape[1], l): + out[:, i:i + l] = torch.nn.functional.interpolate(x[:, i:i + l].to(torch.float32), scale_factor=2.0, + mode="nearest").to(x.dtype) + del x + x = out + + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.swish = torch.nn.SiLU(inplace=True) + self.norm1 = Normalize(in_channels) + self.conv1 = nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout, inplace=True) + self.conv2 = nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = self.swish(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.swish(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.swish(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + h_ = attention_function_single_head(q, k, v) + h_ = self.proj_out(h_) + return x + h_ + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.conv_in = nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + self.norm_out = Normalize(block_in) + self.conv_out = nn.Conv2d(block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + temb = None + h = self.conv_in(x) + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h, temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = self.down[i_level].downsample(h) + + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + conv_out_op=nn.Conv2d, + resnet_op=ResnetBlock, + **kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + self.conv_in = nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + self.mid = nn.Module() + self.mid.block_1 = resnet_op(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = resnet_op(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append(resnet_op(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = conv_out_op(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z, **kwargs): + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb, **kwargs) + h = self.mid.attn_1(h, **kwargs) + h = self.mid.block_2(h, temb, **kwargs) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb, **kwargs) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, **kwargs) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h, **kwargs) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class IntegratedAutoencoderKL(nn.Module, ConfigMixin): + config_name = 'config.json' + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + scaling_factor: float = 0.18215, + shift_factor: Optional[float] = None, + latents_mean: Optional[Tuple[float]] = None, + latents_std: Optional[Tuple[float]] = None, + force_upcast: float = True, + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + ): + super().__init__() + ch = block_out_channels[0] + ch_mult = [x // ch for x in block_out_channels] + self.encoder = Encoder(double_z=True, z_channels=latent_channels, resolution=256, + in_channels=in_channels, out_ch=out_channels, ch=ch, ch_mult=ch_mult, + num_res_blocks=layers_per_block, attn_resolutions=[], dropout=0.0) + self.decoder = Decoder(double_z=True, z_channels=latent_channels, resolution=256, + in_channels=in_channels, out_ch=out_channels, ch=ch, ch_mult=ch_mult, + num_res_blocks=layers_per_block, attn_resolutions=[], dropout=0.0) + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) + self.embed_dim = latent_channels + + def encode(self, x, regulation=None): + z = self.encoder(x) + z = self.quant_conv(z) + posterior = DiagonalGaussianDistribution(z) + if regulation is not None: + return regulation(posterior) + else: + return posterior.sample() + + def decode(self, z): + z = self.post_quant_conv(z) + x = self.decoder(z) + return x diff --git a/backend/nn/dummy.py b/backend/nn/dummy.py new file mode 100644 index 00000000..6243f60b --- /dev/null +++ b/backend/nn/dummy.py @@ -0,0 +1,12 @@ +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from torch import nn + + +class Dummy(nn.Module, ConfigMixin): + config_name = 'config.json' + + @register_to_config + def __init__(self): + super().__init__() diff --git a/backend/state_dict.py b/backend/state_dict.py index a3dd9102..5e86a854 100644 --- a/backend/state_dict.py +++ b/backend/state_dict.py @@ -1,56 +1,12 @@ import torch -class StateDictItem: - def __init__(self, key, value, advanced_indexing=None): - self.key = key - self.value = value - self.shape = value.shape - self.advanced_indexing = advanced_indexing - - def __getitem__(self, advanced_indexing): - t = self.value[advanced_indexing] - return StateDictItem(self.key, t, advanced_indexing=advanced_indexing) - - -def split_state_dict_with_prefix(sd, prefix): - vae_sd = {} +def filter_state_dict_with_prefix(sd, prefix): + new_sd = {} for k, v in list(sd.items()): if k.startswith(prefix): - vae_sd[k] = StateDictItem(k[len(prefix):], v) + new_sd[k[len(prefix):]] = v del sd[k] - return vae_sd - - -def compile_state_dict(state_dict): - sd = {} - mapping = {} - for k, v in state_dict.items(): - sd[k] = v.value - mapping[v.key] = (k, v.advanced_indexing) - return sd, mapping - - -def map_state_dict(sd, mapping): - new_sd = {} - for k, v in sd.items(): - k, indexing = mapping.get(k, (k, None)) - if indexing is not None: - v = v[indexing] - new_sd[k] = v - return new_sd - - -def map_state_dict_heuristic(sd, mapping): - new_mapping = {} - for k, (v, _) in mapping: - new_mapping[k.rpartition('.')[0]] = v.rpartition('.')[0] - - new_sd = {} - for k, v in sd.items(): - l, m, r = k.rpartition('.') - l = new_mapping.get(l, l) - new_sd[l + m + r] = v return new_sd diff --git a/backend/vae.py b/backend/vae.py index 6fcab0d8..e6649db0 100644 --- a/backend/vae.py +++ b/backend/vae.py @@ -1,38 +1,14 @@ -from diffusers import AutoencoderKL -from backend.state_dict import split_state_dict_with_prefix, compile_state_dict +from backend.state_dict import filter_state_dict_with_prefix from backend.operations import using_forge_operations -from backend.attention import AttentionProcessorForge -from diffusers.loaders.single_file_model import convert_ldm_vae_checkpoint +from backend.nn.autoencoder_kl import IntegratedAutoencoderKL -class BaseAutoencoderKL(AutoencoderKL): - def __init__(self, *args, **kwargs): +def load_vae(state_dict, config_path): + config = IntegratedAutoencoderKL.load_config(config_path) - super().__init__(*args, **kwargs) - self.state_dict_mapping = {} - - def encode(self, x, regulation=None, mode=False): - latent_dist = super().encode(x).latent_dist - if mode: - return latent_dist.mode() - elif regulation is not None: - return regulation(latent_dist) - else: - return latent_dist.sample() - - def decode(self, x): - return super().decode(x).sample - - -def load_vae(state_dict, config): with using_forge_operations(): - model = BaseAutoencoderKL(**config) + model = IntegratedAutoencoderKL.from_config(config) - vae_state_dict = split_state_dict_with_prefix(state_dict, "first_stage_model.") - vae_state_dict = convert_ldm_vae_checkpoint(vae_state_dict, config) - vae_state_dict, mapping = compile_state_dict(vae_state_dict) + vae_state_dict = filter_state_dict_with_prefix(state_dict, "first_stage_model.") model.load_state_dict(vae_state_dict, strict=True) - model.set_attn_processor(AttentionProcessorForge()) - model.state_dict_mapping = mapping - return model diff --git a/ldm_patched/modules/sd.py b/ldm_patched/modules/sd.py index 19550a07..fdf6f777 100644 --- a/ldm_patched/modules/sd.py +++ b/ldm_patched/modules/sd.py @@ -163,10 +163,7 @@ class CLIP: return self.patcher.get_key_patches() class VAE: - def __init__(self, model=None, mapping=None, device=None, dtype=None, no_init=False): - if mapping is None: - mapping = {} - + def __init__(self, model=None, device=None, dtype=None, no_init=False): if no_init: return @@ -176,7 +173,6 @@ class VAE: self.latent_channels = 4 self.first_stage_model = model.eval() - self.state_dict_mapping = mapping if device is None: device = model_management.vae_device() @@ -202,7 +198,6 @@ class VAE: n.downscale_ratio = self.downscale_ratio n.latent_channels = self.latent_channels n.first_stage_model = self.first_stage_model - n.state_dict_mapping = self.state_dict_mapping n.device = self.device n.vae_dtype = self.vae_dtype n.output_device = self.output_device diff --git a/modules/sd_vae.py b/modules/sd_vae.py index e7ef5a77..62fd6524 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -3,7 +3,6 @@ import collections from dataclasses import dataclass from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes -from backend.state_dict import map_state_dict import glob from copy import deepcopy @@ -237,8 +236,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): # don't call this from outside def _load_vae_dict(model, vae_dict_1): - sd_mapped = map_state_dict(vae_dict_1, model.first_stage_model.state_dict_mapping) - model.first_stage_model.load_state_dict(sd_mapped) + model.first_stage_model.load_state_dict(vae_dict_1) def clear_loaded_vae(): diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index b5cd9e79..3c019b94 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -109,7 +109,7 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c if output_vae: vae = huggingface_components['vae'] - vae = VAE(model=vae, mapping=vae.state_dict_mapping) + vae = VAE(model=vae) if output_clip: w = WeightsLoader()