Intergrate Native AutoEncoderKL

This commit is contained in:
layerdiffusion
2024-07-31 21:08:22 -07:00
parent 2c3afff371
commit 0d079a846d
9 changed files with 473 additions and 93 deletions

View File

@@ -54,14 +54,39 @@ def attention_pytorch(q, k, v, heads, mask=None):
return out
def attention_xformers_single_head(q, k, v):
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
(q, k, v),
)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
out = out.transpose(1, 2).reshape(B, C, H, W)
return out
def attention_pytorch_single_head(q, k, v):
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W)
return out
attention_function = attention_pytorch
attention_function_single_head = attention_pytorch_single_head
if args.xformers:
print("Using xformers cross attention")
attention_function = attention_xformers
attention_function_single_head = attention_xformers_single_head
else:
print("Using pytorch cross attention")
attention_function = attention_pytorch
attention_function_single_head = attention_pytorch_single_head
class AttentionProcessorForge:

View File

@@ -1,11 +1,8 @@
import os
import importlib
import diffusers
import transformers
from diffusers.loaders.single_file_utils import fetch_diffusers_config
from diffusers import DiffusionPipeline
from diffusers import AutoencoderKL
from backend.vae import load_vae
@@ -13,13 +10,12 @@ dir_path = os.path.dirname(__file__)
def load_component(component_name, lib_name, cls_name, repo_path, sd):
config_path = os.path.join(repo_path, component_name)
if component_name in ['scheduler', 'tokenizer']:
cls = getattr(importlib.import_module(lib_name), cls_name)
return cls.from_pretrained(os.path.join(repo_path, component_name))
if cls_name in ['AutoencoderKL']:
config = AutoencoderKL.load_config(os.path.join(repo_path, component_name))
return load_vae(sd, config)
return load_vae(sd, config_path)
return None

View File

@@ -0,0 +1,422 @@
import torch
import numpy as np
from backend.attention import attention_function_single_head
from diffusers.configuration_utils import ConfigMixin, register_to_config
from typing import Optional, Tuple
from torch import nn
def nonlinearity(x):
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class DiagonalGaussianDistribution:
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
return x
def mode(self):
return self.mean
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
try:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
except Exception as e:
b, c, h, w = x.shape
out = torch.empty((b, c, h * 2, w * 2), dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:, i:i + l] = torch.nn.functional.interpolate(x[:, i:i + l].to(torch.float32), scale_factor=2.0,
mode="nearest").to(x.dtype)
del x
x = out
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels)
self.conv1 = nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = self.swish(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(self.swish(temb))[:, :, None, None]
h = self.norm2(h)
h = self.swish(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
h_ = attention_function_single_head(q, k, v)
h_ = self.proj_out(h_)
return x + h_
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
**kwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.conv_in = nn.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.norm_out = Normalize(block_in)
self.conv_out = nn.Conv2d(block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
temb = None
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h)
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
conv_out_op=nn.Conv2d,
resnet_op=ResnetBlock,
**kwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
self.conv_in = nn.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1)
self.mid = nn.Module()
self.mid.block_1 = resnet_op(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = resnet_op(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(resnet_op(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up)
# end
self.norm_out = Normalize(block_in)
self.conv_out = conv_out_op(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, z, **kwargs):
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb, **kwargs)
h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb, **kwargs)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, **kwargs)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, **kwargs)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h, **kwargs)
if self.tanh_out:
h = torch.tanh(h)
return h
class IntegratedAutoencoderKL(nn.Module, ConfigMixin):
config_name = 'config.json'
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
scaling_factor: float = 0.18215,
shift_factor: Optional[float] = None,
latents_mean: Optional[Tuple[float]] = None,
latents_std: Optional[Tuple[float]] = None,
force_upcast: float = True,
use_quant_conv: bool = True,
use_post_quant_conv: bool = True,
):
super().__init__()
ch = block_out_channels[0]
ch_mult = [x // ch for x in block_out_channels]
self.encoder = Encoder(double_z=True, z_channels=latent_channels, resolution=256,
in_channels=in_channels, out_ch=out_channels, ch=ch, ch_mult=ch_mult,
num_res_blocks=layers_per_block, attn_resolutions=[], dropout=0.0)
self.decoder = Decoder(double_z=True, z_channels=latent_channels, resolution=256,
in_channels=in_channels, out_ch=out_channels, ch=ch, ch_mult=ch_mult,
num_res_blocks=layers_per_block, attn_resolutions=[], dropout=0.0)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
self.embed_dim = latent_channels
def encode(self, x, regulation=None):
z = self.encoder(x)
z = self.quant_conv(z)
posterior = DiagonalGaussianDistribution(z)
if regulation is not None:
return regulation(posterior)
else:
return posterior.sample()
def decode(self, z):
z = self.post_quant_conv(z)
x = self.decoder(z)
return x

12
backend/nn/dummy.py Normal file
View File

@@ -0,0 +1,12 @@
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from torch import nn
class Dummy(nn.Module, ConfigMixin):
config_name = 'config.json'
@register_to_config
def __init__(self):
super().__init__()

View File

@@ -1,56 +1,12 @@
import torch
class StateDictItem:
def __init__(self, key, value, advanced_indexing=None):
self.key = key
self.value = value
self.shape = value.shape
self.advanced_indexing = advanced_indexing
def __getitem__(self, advanced_indexing):
t = self.value[advanced_indexing]
return StateDictItem(self.key, t, advanced_indexing=advanced_indexing)
def split_state_dict_with_prefix(sd, prefix):
vae_sd = {}
def filter_state_dict_with_prefix(sd, prefix):
new_sd = {}
for k, v in list(sd.items()):
if k.startswith(prefix):
vae_sd[k] = StateDictItem(k[len(prefix):], v)
new_sd[k[len(prefix):]] = v
del sd[k]
return vae_sd
def compile_state_dict(state_dict):
sd = {}
mapping = {}
for k, v in state_dict.items():
sd[k] = v.value
mapping[v.key] = (k, v.advanced_indexing)
return sd, mapping
def map_state_dict(sd, mapping):
new_sd = {}
for k, v in sd.items():
k, indexing = mapping.get(k, (k, None))
if indexing is not None:
v = v[indexing]
new_sd[k] = v
return new_sd
def map_state_dict_heuristic(sd, mapping):
new_mapping = {}
for k, (v, _) in mapping:
new_mapping[k.rpartition('.')[0]] = v.rpartition('.')[0]
new_sd = {}
for k, v in sd.items():
l, m, r = k.rpartition('.')
l = new_mapping.get(l, l)
new_sd[l + m + r] = v
return new_sd

View File

@@ -1,38 +1,14 @@
from diffusers import AutoencoderKL
from backend.state_dict import split_state_dict_with_prefix, compile_state_dict
from backend.state_dict import filter_state_dict_with_prefix
from backend.operations import using_forge_operations
from backend.attention import AttentionProcessorForge
from diffusers.loaders.single_file_model import convert_ldm_vae_checkpoint
from backend.nn.autoencoder_kl import IntegratedAutoencoderKL
class BaseAutoencoderKL(AutoencoderKL):
def __init__(self, *args, **kwargs):
def load_vae(state_dict, config_path):
config = IntegratedAutoencoderKL.load_config(config_path)
super().__init__(*args, **kwargs)
self.state_dict_mapping = {}
def encode(self, x, regulation=None, mode=False):
latent_dist = super().encode(x).latent_dist
if mode:
return latent_dist.mode()
elif regulation is not None:
return regulation(latent_dist)
else:
return latent_dist.sample()
def decode(self, x):
return super().decode(x).sample
def load_vae(state_dict, config):
with using_forge_operations():
model = BaseAutoencoderKL(**config)
model = IntegratedAutoencoderKL.from_config(config)
vae_state_dict = split_state_dict_with_prefix(state_dict, "first_stage_model.")
vae_state_dict = convert_ldm_vae_checkpoint(vae_state_dict, config)
vae_state_dict, mapping = compile_state_dict(vae_state_dict)
vae_state_dict = filter_state_dict_with_prefix(state_dict, "first_stage_model.")
model.load_state_dict(vae_state_dict, strict=True)
model.set_attn_processor(AttentionProcessorForge())
model.state_dict_mapping = mapping
return model

View File

@@ -163,10 +163,7 @@ class CLIP:
return self.patcher.get_key_patches()
class VAE:
def __init__(self, model=None, mapping=None, device=None, dtype=None, no_init=False):
if mapping is None:
mapping = {}
def __init__(self, model=None, device=None, dtype=None, no_init=False):
if no_init:
return
@@ -176,7 +173,6 @@ class VAE:
self.latent_channels = 4
self.first_stage_model = model.eval()
self.state_dict_mapping = mapping
if device is None:
device = model_management.vae_device()
@@ -202,7 +198,6 @@ class VAE:
n.downscale_ratio = self.downscale_ratio
n.latent_channels = self.latent_channels
n.first_stage_model = self.first_stage_model
n.state_dict_mapping = self.state_dict_mapping
n.device = self.device
n.vae_dtype = self.vae_dtype
n.output_device = self.output_device

View File

@@ -3,7 +3,6 @@ import collections
from dataclasses import dataclass
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes
from backend.state_dict import map_state_dict
import glob
from copy import deepcopy
@@ -237,8 +236,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
# don't call this from outside
def _load_vae_dict(model, vae_dict_1):
sd_mapped = map_state_dict(vae_dict_1, model.first_stage_model.state_dict_mapping)
model.first_stage_model.load_state_dict(sd_mapped)
model.first_stage_model.load_state_dict(vae_dict_1)
def clear_loaded_vae():

View File

@@ -109,7 +109,7 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_vae:
vae = huggingface_components['vae']
vae = VAE(model=vae, mapping=vae.state_dict_mapping)
vae = VAE(model=vae)
if output_clip:
w = WeightsLoader()