mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-03 06:47:23 +00:00
Intergrate Native AutoEncoderKL
This commit is contained in:
@@ -54,14 +54,39 @@ def attention_pytorch(q, k, v, heads, mask=None):
|
||||
return out
|
||||
|
||||
|
||||
def attention_xformers_single_head(q, k, v):
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||
out = out.transpose(1, 2).reshape(B, C, H, W)
|
||||
return out
|
||||
|
||||
|
||||
def attention_pytorch_single_head(q, k, v):
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||
return out
|
||||
|
||||
|
||||
attention_function = attention_pytorch
|
||||
attention_function_single_head = attention_pytorch_single_head
|
||||
|
||||
if args.xformers:
|
||||
print("Using xformers cross attention")
|
||||
attention_function = attention_xformers
|
||||
attention_function_single_head = attention_xformers_single_head
|
||||
else:
|
||||
print("Using pytorch cross attention")
|
||||
attention_function = attention_pytorch
|
||||
attention_function_single_head = attention_pytorch_single_head
|
||||
|
||||
|
||||
class AttentionProcessorForge:
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import os
|
||||
import importlib
|
||||
import diffusers
|
||||
import transformers
|
||||
|
||||
from diffusers.loaders.single_file_utils import fetch_diffusers_config
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import AutoencoderKL
|
||||
from backend.vae import load_vae
|
||||
|
||||
|
||||
@@ -13,13 +10,12 @@ dir_path = os.path.dirname(__file__)
|
||||
|
||||
|
||||
def load_component(component_name, lib_name, cls_name, repo_path, sd):
|
||||
config_path = os.path.join(repo_path, component_name)
|
||||
if component_name in ['scheduler', 'tokenizer']:
|
||||
cls = getattr(importlib.import_module(lib_name), cls_name)
|
||||
return cls.from_pretrained(os.path.join(repo_path, component_name))
|
||||
if cls_name in ['AutoencoderKL']:
|
||||
config = AutoencoderKL.load_config(os.path.join(repo_path, component_name))
|
||||
return load_vae(sd, config)
|
||||
|
||||
return load_vae(sd, config_path)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
422
backend/nn/autoencoder_kl.py
Normal file
422
backend/nn/autoencoder_kl.py
Normal file
@@ -0,0 +1,422 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from backend.attention import attention_function_single_head
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from typing import Optional, Tuple
|
||||
from torch import nn
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution:
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
try:
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
except Exception as e:
|
||||
b, c, h, w = x.shape
|
||||
out = torch.empty((b, c, h * 2, w * 2), dtype=x.dtype, layout=x.layout, device=x.device)
|
||||
split = 8
|
||||
l = out.shape[1] // split
|
||||
for i in range(0, out.shape[1], l):
|
||||
out[:, i:i + l] = torch.nn.functional.interpolate(x[:, i:i + l].to(torch.float32), scale_factor=2.0,
|
||||
mode="nearest").to(x.dtype)
|
||||
del x
|
||||
x = out
|
||||
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||
dropout, temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = nn.Linear(temb_channels,
|
||||
out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
||||
self.conv2 = nn.Conv2d(out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = self.swish(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.swish(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = self.swish(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
h_ = attention_function_single_head(q, k, v)
|
||||
h_ = self.proj_out(h_)
|
||||
return x + h_
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.conv_in = nn.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = nn.Conv2d(block_in,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
temb = None
|
||||
h = self.conv_in(x)
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](h, temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
||||
conv_out_op=nn.Conv2d,
|
||||
resnet_op=ResnetBlock,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
self.conv_in = nn.Conv2d(z_channels,
|
||||
block_in,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = resnet_op(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = resnet_op(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(resnet_op(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = conv_out_op(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, z, **kwargs):
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, **kwargs)
|
||||
h = self.mid.attn_1(h, **kwargs)
|
||||
h = self.mid.block_2(h, temb, **kwargs)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, **kwargs)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, **kwargs)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
class IntegratedAutoencoderKL(nn.Module, ConfigMixin):
|
||||
config_name = 'config.json'
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int] = (64,),
|
||||
layers_per_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 4,
|
||||
norm_num_groups: int = 32,
|
||||
sample_size: int = 32,
|
||||
scaling_factor: float = 0.18215,
|
||||
shift_factor: Optional[float] = None,
|
||||
latents_mean: Optional[Tuple[float]] = None,
|
||||
latents_std: Optional[Tuple[float]] = None,
|
||||
force_upcast: float = True,
|
||||
use_quant_conv: bool = True,
|
||||
use_post_quant_conv: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
ch = block_out_channels[0]
|
||||
ch_mult = [x // ch for x in block_out_channels]
|
||||
self.encoder = Encoder(double_z=True, z_channels=latent_channels, resolution=256,
|
||||
in_channels=in_channels, out_ch=out_channels, ch=ch, ch_mult=ch_mult,
|
||||
num_res_blocks=layers_per_block, attn_resolutions=[], dropout=0.0)
|
||||
self.decoder = Decoder(double_z=True, z_channels=latent_channels, resolution=256,
|
||||
in_channels=in_channels, out_ch=out_channels, ch=ch, ch_mult=ch_mult,
|
||||
num_res_blocks=layers_per_block, attn_resolutions=[], dropout=0.0)
|
||||
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
self.embed_dim = latent_channels
|
||||
|
||||
def encode(self, x, regulation=None):
|
||||
z = self.encoder(x)
|
||||
z = self.quant_conv(z)
|
||||
posterior = DiagonalGaussianDistribution(z)
|
||||
if regulation is not None:
|
||||
return regulation(posterior)
|
||||
else:
|
||||
return posterior.sample()
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
x = self.decoder(z)
|
||||
return x
|
||||
12
backend/nn/dummy.py
Normal file
12
backend/nn/dummy.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import torch
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Dummy(nn.Module, ConfigMixin):
|
||||
config_name = 'config.json'
|
||||
|
||||
@register_to_config
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -1,56 +1,12 @@
|
||||
import torch
|
||||
|
||||
|
||||
class StateDictItem:
|
||||
def __init__(self, key, value, advanced_indexing=None):
|
||||
self.key = key
|
||||
self.value = value
|
||||
self.shape = value.shape
|
||||
self.advanced_indexing = advanced_indexing
|
||||
|
||||
def __getitem__(self, advanced_indexing):
|
||||
t = self.value[advanced_indexing]
|
||||
return StateDictItem(self.key, t, advanced_indexing=advanced_indexing)
|
||||
|
||||
|
||||
def split_state_dict_with_prefix(sd, prefix):
|
||||
vae_sd = {}
|
||||
def filter_state_dict_with_prefix(sd, prefix):
|
||||
new_sd = {}
|
||||
|
||||
for k, v in list(sd.items()):
|
||||
if k.startswith(prefix):
|
||||
vae_sd[k] = StateDictItem(k[len(prefix):], v)
|
||||
new_sd[k[len(prefix):]] = v
|
||||
del sd[k]
|
||||
|
||||
return vae_sd
|
||||
|
||||
|
||||
def compile_state_dict(state_dict):
|
||||
sd = {}
|
||||
mapping = {}
|
||||
for k, v in state_dict.items():
|
||||
sd[k] = v.value
|
||||
mapping[v.key] = (k, v.advanced_indexing)
|
||||
return sd, mapping
|
||||
|
||||
|
||||
def map_state_dict(sd, mapping):
|
||||
new_sd = {}
|
||||
for k, v in sd.items():
|
||||
k, indexing = mapping.get(k, (k, None))
|
||||
if indexing is not None:
|
||||
v = v[indexing]
|
||||
new_sd[k] = v
|
||||
return new_sd
|
||||
|
||||
|
||||
def map_state_dict_heuristic(sd, mapping):
|
||||
new_mapping = {}
|
||||
for k, (v, _) in mapping:
|
||||
new_mapping[k.rpartition('.')[0]] = v.rpartition('.')[0]
|
||||
|
||||
new_sd = {}
|
||||
for k, v in sd.items():
|
||||
l, m, r = k.rpartition('.')
|
||||
l = new_mapping.get(l, l)
|
||||
new_sd[l + m + r] = v
|
||||
return new_sd
|
||||
|
||||
@@ -1,38 +1,14 @@
|
||||
from diffusers import AutoencoderKL
|
||||
from backend.state_dict import split_state_dict_with_prefix, compile_state_dict
|
||||
from backend.state_dict import filter_state_dict_with_prefix
|
||||
from backend.operations import using_forge_operations
|
||||
from backend.attention import AttentionProcessorForge
|
||||
from diffusers.loaders.single_file_model import convert_ldm_vae_checkpoint
|
||||
from backend.nn.autoencoder_kl import IntegratedAutoencoderKL
|
||||
|
||||
|
||||
class BaseAutoencoderKL(AutoencoderKL):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def load_vae(state_dict, config_path):
|
||||
config = IntegratedAutoencoderKL.load_config(config_path)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.state_dict_mapping = {}
|
||||
|
||||
def encode(self, x, regulation=None, mode=False):
|
||||
latent_dist = super().encode(x).latent_dist
|
||||
if mode:
|
||||
return latent_dist.mode()
|
||||
elif regulation is not None:
|
||||
return regulation(latent_dist)
|
||||
else:
|
||||
return latent_dist.sample()
|
||||
|
||||
def decode(self, x):
|
||||
return super().decode(x).sample
|
||||
|
||||
|
||||
def load_vae(state_dict, config):
|
||||
with using_forge_operations():
|
||||
model = BaseAutoencoderKL(**config)
|
||||
model = IntegratedAutoencoderKL.from_config(config)
|
||||
|
||||
vae_state_dict = split_state_dict_with_prefix(state_dict, "first_stage_model.")
|
||||
vae_state_dict = convert_ldm_vae_checkpoint(vae_state_dict, config)
|
||||
vae_state_dict, mapping = compile_state_dict(vae_state_dict)
|
||||
vae_state_dict = filter_state_dict_with_prefix(state_dict, "first_stage_model.")
|
||||
model.load_state_dict(vae_state_dict, strict=True)
|
||||
model.set_attn_processor(AttentionProcessorForge())
|
||||
model.state_dict_mapping = mapping
|
||||
|
||||
return model
|
||||
|
||||
@@ -163,10 +163,7 @@ class CLIP:
|
||||
return self.patcher.get_key_patches()
|
||||
|
||||
class VAE:
|
||||
def __init__(self, model=None, mapping=None, device=None, dtype=None, no_init=False):
|
||||
if mapping is None:
|
||||
mapping = {}
|
||||
|
||||
def __init__(self, model=None, device=None, dtype=None, no_init=False):
|
||||
if no_init:
|
||||
return
|
||||
|
||||
@@ -176,7 +173,6 @@ class VAE:
|
||||
self.latent_channels = 4
|
||||
|
||||
self.first_stage_model = model.eval()
|
||||
self.state_dict_mapping = mapping
|
||||
|
||||
if device is None:
|
||||
device = model_management.vae_device()
|
||||
@@ -202,7 +198,6 @@ class VAE:
|
||||
n.downscale_ratio = self.downscale_ratio
|
||||
n.latent_channels = self.latent_channels
|
||||
n.first_stage_model = self.first_stage_model
|
||||
n.state_dict_mapping = self.state_dict_mapping
|
||||
n.device = self.device
|
||||
n.vae_dtype = self.vae_dtype
|
||||
n.output_device = self.output_device
|
||||
|
||||
@@ -3,7 +3,6 @@ import collections
|
||||
from dataclasses import dataclass
|
||||
|
||||
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes
|
||||
from backend.state_dict import map_state_dict
|
||||
|
||||
import glob
|
||||
from copy import deepcopy
|
||||
@@ -237,8 +236,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
||||
|
||||
# don't call this from outside
|
||||
def _load_vae_dict(model, vae_dict_1):
|
||||
sd_mapped = map_state_dict(vae_dict_1, model.first_stage_model.state_dict_mapping)
|
||||
model.first_stage_model.load_state_dict(sd_mapped)
|
||||
model.first_stage_model.load_state_dict(vae_dict_1)
|
||||
|
||||
|
||||
def clear_loaded_vae():
|
||||
|
||||
@@ -109,7 +109,7 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
|
||||
if output_vae:
|
||||
vae = huggingface_components['vae']
|
||||
vae = VAE(model=vae, mapping=vae.state_dict_mapping)
|
||||
vae = VAE(model=vae)
|
||||
|
||||
if output_clip:
|
||||
w = WeightsLoader()
|
||||
|
||||
Reference in New Issue
Block a user