mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-31 05:19:45 +00:00
Previously before this commit, credits are already in entry and licenses are already in root. This commit will make info clearer.
658 lines
24 KiB
Python
658 lines
24 KiB
Python
# 1st edit by https://github.com/CompVis/latent-diffusion
|
|
# 2nd edit by https://github.com/Stability-AI/stablediffusion
|
|
# 3rd edit by https://github.com/Stability-AI/generative-models
|
|
# 4th edit by https://github.com/comfyanonymous/ComfyUI
|
|
# 5th edit by Forge
|
|
|
|
|
|
# pytorch_diffusion + derived encoder decoder
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from einops import rearrange
|
|
from typing import Optional, Any
|
|
|
|
from ldm_patched.modules import model_management
|
|
import ldm_patched.modules.ops
|
|
ops = ldm_patched.modules.ops.disable_weight_init
|
|
|
|
if model_management.xformers_enabled_vae():
|
|
import xformers
|
|
import xformers.ops
|
|
|
|
def get_timestep_embedding(timesteps, embedding_dim):
|
|
"""
|
|
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
|
From Fairseq.
|
|
Build sinusoidal embeddings.
|
|
This matches the implementation in tensor2tensor, but differs slightly
|
|
from the description in Section 3.5 of "Attention Is All You Need".
|
|
"""
|
|
assert len(timesteps.shape) == 1
|
|
|
|
half_dim = embedding_dim // 2
|
|
emb = math.log(10000) / (half_dim - 1)
|
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
|
emb = emb.to(device=timesteps.device)
|
|
emb = timesteps.float()[:, None] * emb[None, :]
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
|
if embedding_dim % 2 == 1: # zero pad
|
|
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
|
return emb
|
|
|
|
|
|
def nonlinearity(x):
|
|
# swish
|
|
return x*torch.sigmoid(x)
|
|
|
|
|
|
def Normalize(in_channels, num_groups=32):
|
|
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
|
|
|
|
|
class Upsample(nn.Module):
|
|
def __init__(self, in_channels, with_conv):
|
|
super().__init__()
|
|
self.with_conv = with_conv
|
|
if self.with_conv:
|
|
self.conv = ops.Conv2d(in_channels,
|
|
in_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
def forward(self, x):
|
|
try:
|
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
|
except: #operation not implemented for bf16
|
|
b, c, h, w = x.shape
|
|
out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
|
|
split = 8
|
|
l = out.shape[1] // split
|
|
for i in range(0, out.shape[1], l):
|
|
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
|
|
del x
|
|
x = out
|
|
|
|
if self.with_conv:
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
class Downsample(nn.Module):
|
|
def __init__(self, in_channels, with_conv):
|
|
super().__init__()
|
|
self.with_conv = with_conv
|
|
if self.with_conv:
|
|
# no asymmetric padding in torch conv, must do it ourselves
|
|
self.conv = ops.Conv2d(in_channels,
|
|
in_channels,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=0)
|
|
|
|
def forward(self, x):
|
|
if self.with_conv:
|
|
pad = (0,1,0,1)
|
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
|
x = self.conv(x)
|
|
else:
|
|
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
|
return x
|
|
|
|
|
|
class ResnetBlock(nn.Module):
|
|
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
|
dropout, temb_channels=512):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
out_channels = in_channels if out_channels is None else out_channels
|
|
self.out_channels = out_channels
|
|
self.use_conv_shortcut = conv_shortcut
|
|
|
|
self.swish = torch.nn.SiLU(inplace=True)
|
|
self.norm1 = Normalize(in_channels)
|
|
self.conv1 = ops.Conv2d(in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
if temb_channels > 0:
|
|
self.temb_proj = ops.Linear(temb_channels,
|
|
out_channels)
|
|
self.norm2 = Normalize(out_channels)
|
|
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
|
self.conv2 = ops.Conv2d(out_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
if self.in_channels != self.out_channels:
|
|
if self.use_conv_shortcut:
|
|
self.conv_shortcut = ops.Conv2d(in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
else:
|
|
self.nin_shortcut = ops.Conv2d(in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0)
|
|
|
|
def forward(self, x, temb):
|
|
h = x
|
|
h = self.norm1(h)
|
|
h = self.swish(h)
|
|
h = self.conv1(h)
|
|
|
|
if temb is not None:
|
|
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
|
|
|
|
h = self.norm2(h)
|
|
h = self.swish(h)
|
|
h = self.dropout(h)
|
|
h = self.conv2(h)
|
|
|
|
if self.in_channels != self.out_channels:
|
|
if self.use_conv_shortcut:
|
|
x = self.conv_shortcut(x)
|
|
else:
|
|
x = self.nin_shortcut(x)
|
|
|
|
return x+h
|
|
|
|
def slice_attention(q, k, v):
|
|
r1 = torch.zeros_like(k, device=q.device)
|
|
scale = (int(q.shape[-1])**(-0.5))
|
|
|
|
mem_free_total = model_management.get_free_memory(q.device)
|
|
|
|
gb = 1024 ** 3
|
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
|
modifier = 3 if q.element_size() == 2 else 2.5
|
|
mem_required = tensor_size * modifier
|
|
steps = 1
|
|
|
|
if mem_required > mem_free_total:
|
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
|
|
while True:
|
|
try:
|
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
for i in range(0, q.shape[1], slice_size):
|
|
end = i + slice_size
|
|
s1 = torch.bmm(q[:, i:end], k) * scale
|
|
|
|
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
|
|
del s1
|
|
|
|
r1[:, :, i:end] = torch.bmm(v, s2)
|
|
del s2
|
|
break
|
|
except model_management.OOM_EXCEPTION as e:
|
|
model_management.soft_empty_cache(True)
|
|
steps *= 2
|
|
if steps > 128:
|
|
raise e
|
|
print("out of memory error, increasing steps and trying again", steps)
|
|
|
|
return r1
|
|
|
|
def normal_attention(q, k, v):
|
|
# compute attention
|
|
b,c,h,w = q.shape
|
|
|
|
q = q.reshape(b,c,h*w)
|
|
q = q.permute(0,2,1) # b,hw,c
|
|
k = k.reshape(b,c,h*w) # b,c,hw
|
|
v = v.reshape(b,c,h*w)
|
|
|
|
r1 = slice_attention(q, k, v)
|
|
h_ = r1.reshape(b,c,h,w)
|
|
del r1
|
|
return h_
|
|
|
|
def xformers_attention(q, k, v):
|
|
# compute attention
|
|
B, C, H, W = q.shape
|
|
q, k, v = map(
|
|
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
|
(q, k, v),
|
|
)
|
|
|
|
try:
|
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
|
out = out.transpose(1, 2).reshape(B, C, H, W)
|
|
except NotImplementedError as e:
|
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
|
return out
|
|
|
|
def pytorch_attention(q, k, v):
|
|
# compute attention
|
|
B, C, H, W = q.shape
|
|
q, k, v = map(
|
|
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
|
(q, k, v),
|
|
)
|
|
|
|
try:
|
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
out = out.transpose(2, 3).reshape(B, C, H, W)
|
|
except model_management.OOM_EXCEPTION as e:
|
|
print("scaled_dot_product_attention OOMed: switched to slice attention")
|
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
|
return out
|
|
|
|
|
|
class AttnBlock(nn.Module):
|
|
def __init__(self, in_channels):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
|
|
self.norm = Normalize(in_channels)
|
|
self.q = ops.Conv2d(in_channels,
|
|
in_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0)
|
|
self.k = ops.Conv2d(in_channels,
|
|
in_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0)
|
|
self.v = ops.Conv2d(in_channels,
|
|
in_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0)
|
|
self.proj_out = ops.Conv2d(in_channels,
|
|
in_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0)
|
|
|
|
if model_management.xformers_enabled_vae():
|
|
print("Using xformers attention in VAE")
|
|
self.optimized_attention = xformers_attention
|
|
elif model_management.pytorch_attention_enabled():
|
|
print("Using pytorch attention in VAE")
|
|
self.optimized_attention = pytorch_attention
|
|
else:
|
|
print("Using split attention in VAE")
|
|
self.optimized_attention = normal_attention
|
|
|
|
def forward(self, x):
|
|
h_ = x
|
|
h_ = self.norm(h_)
|
|
q = self.q(h_)
|
|
k = self.k(h_)
|
|
v = self.v(h_)
|
|
|
|
h_ = self.optimized_attention(q, k, v)
|
|
|
|
h_ = self.proj_out(h_)
|
|
|
|
return x+h_
|
|
|
|
|
|
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
|
return AttnBlock(in_channels)
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
|
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
|
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
|
super().__init__()
|
|
if use_linear_attn: attn_type = "linear"
|
|
self.ch = ch
|
|
self.temb_ch = self.ch*4
|
|
self.num_resolutions = len(ch_mult)
|
|
self.num_res_blocks = num_res_blocks
|
|
self.resolution = resolution
|
|
self.in_channels = in_channels
|
|
|
|
self.use_timestep = use_timestep
|
|
if self.use_timestep:
|
|
# timestep embedding
|
|
self.temb = nn.Module()
|
|
self.temb.dense = nn.ModuleList([
|
|
ops.Linear(self.ch,
|
|
self.temb_ch),
|
|
ops.Linear(self.temb_ch,
|
|
self.temb_ch),
|
|
])
|
|
|
|
# downsampling
|
|
self.conv_in = ops.Conv2d(in_channels,
|
|
self.ch,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
curr_res = resolution
|
|
in_ch_mult = (1,)+tuple(ch_mult)
|
|
self.down = nn.ModuleList()
|
|
for i_level in range(self.num_resolutions):
|
|
block = nn.ModuleList()
|
|
attn = nn.ModuleList()
|
|
block_in = ch*in_ch_mult[i_level]
|
|
block_out = ch*ch_mult[i_level]
|
|
for i_block in range(self.num_res_blocks):
|
|
block.append(ResnetBlock(in_channels=block_in,
|
|
out_channels=block_out,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout))
|
|
block_in = block_out
|
|
if curr_res in attn_resolutions:
|
|
attn.append(make_attn(block_in, attn_type=attn_type))
|
|
down = nn.Module()
|
|
down.block = block
|
|
down.attn = attn
|
|
if i_level != self.num_resolutions-1:
|
|
down.downsample = Downsample(block_in, resamp_with_conv)
|
|
curr_res = curr_res // 2
|
|
self.down.append(down)
|
|
|
|
# middle
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
|
out_channels=block_in,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout)
|
|
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
|
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
|
out_channels=block_in,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout)
|
|
|
|
# upsampling
|
|
self.up = nn.ModuleList()
|
|
for i_level in reversed(range(self.num_resolutions)):
|
|
block = nn.ModuleList()
|
|
attn = nn.ModuleList()
|
|
block_out = ch*ch_mult[i_level]
|
|
skip_in = ch*ch_mult[i_level]
|
|
for i_block in range(self.num_res_blocks+1):
|
|
if i_block == self.num_res_blocks:
|
|
skip_in = ch*in_ch_mult[i_level]
|
|
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
|
out_channels=block_out,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout))
|
|
block_in = block_out
|
|
if curr_res in attn_resolutions:
|
|
attn.append(make_attn(block_in, attn_type=attn_type))
|
|
up = nn.Module()
|
|
up.block = block
|
|
up.attn = attn
|
|
if i_level != 0:
|
|
up.upsample = Upsample(block_in, resamp_with_conv)
|
|
curr_res = curr_res * 2
|
|
self.up.insert(0, up) # prepend to get consistent order
|
|
|
|
# end
|
|
self.norm_out = Normalize(block_in)
|
|
self.conv_out = ops.Conv2d(block_in,
|
|
out_ch,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
def forward(self, x, t=None, context=None):
|
|
#assert x.shape[2] == x.shape[3] == self.resolution
|
|
if context is not None:
|
|
# assume aligned context, cat along channel axis
|
|
x = torch.cat((x, context), dim=1)
|
|
if self.use_timestep:
|
|
# timestep embedding
|
|
assert t is not None
|
|
temb = get_timestep_embedding(t, self.ch)
|
|
temb = self.temb.dense[0](temb)
|
|
temb = nonlinearity(temb)
|
|
temb = self.temb.dense[1](temb)
|
|
else:
|
|
temb = None
|
|
|
|
# downsampling
|
|
hs = [self.conv_in(x)]
|
|
for i_level in range(self.num_resolutions):
|
|
for i_block in range(self.num_res_blocks):
|
|
h = self.down[i_level].block[i_block](hs[-1], temb)
|
|
if len(self.down[i_level].attn) > 0:
|
|
h = self.down[i_level].attn[i_block](h)
|
|
hs.append(h)
|
|
if i_level != self.num_resolutions-1:
|
|
hs.append(self.down[i_level].downsample(hs[-1]))
|
|
|
|
# middle
|
|
h = hs[-1]
|
|
h = self.mid.block_1(h, temb)
|
|
h = self.mid.attn_1(h)
|
|
h = self.mid.block_2(h, temb)
|
|
|
|
# upsampling
|
|
for i_level in reversed(range(self.num_resolutions)):
|
|
for i_block in range(self.num_res_blocks+1):
|
|
h = self.up[i_level].block[i_block](
|
|
torch.cat([h, hs.pop()], dim=1), temb)
|
|
if len(self.up[i_level].attn) > 0:
|
|
h = self.up[i_level].attn[i_block](h)
|
|
if i_level != 0:
|
|
h = self.up[i_level].upsample(h)
|
|
|
|
# end
|
|
h = self.norm_out(h)
|
|
h = nonlinearity(h)
|
|
h = self.conv_out(h)
|
|
return h
|
|
|
|
def get_last_layer(self):
|
|
return self.conv_out.weight
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
|
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
|
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
|
**ignore_kwargs):
|
|
super().__init__()
|
|
if use_linear_attn: attn_type = "linear"
|
|
self.ch = ch
|
|
self.temb_ch = 0
|
|
self.num_resolutions = len(ch_mult)
|
|
self.num_res_blocks = num_res_blocks
|
|
self.resolution = resolution
|
|
self.in_channels = in_channels
|
|
|
|
# downsampling
|
|
self.conv_in = ops.Conv2d(in_channels,
|
|
self.ch,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
curr_res = resolution
|
|
in_ch_mult = (1,)+tuple(ch_mult)
|
|
self.in_ch_mult = in_ch_mult
|
|
self.down = nn.ModuleList()
|
|
for i_level in range(self.num_resolutions):
|
|
block = nn.ModuleList()
|
|
attn = nn.ModuleList()
|
|
block_in = ch*in_ch_mult[i_level]
|
|
block_out = ch*ch_mult[i_level]
|
|
for i_block in range(self.num_res_blocks):
|
|
block.append(ResnetBlock(in_channels=block_in,
|
|
out_channels=block_out,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout))
|
|
block_in = block_out
|
|
if curr_res in attn_resolutions:
|
|
attn.append(make_attn(block_in, attn_type=attn_type))
|
|
down = nn.Module()
|
|
down.block = block
|
|
down.attn = attn
|
|
if i_level != self.num_resolutions-1:
|
|
down.downsample = Downsample(block_in, resamp_with_conv)
|
|
curr_res = curr_res // 2
|
|
self.down.append(down)
|
|
|
|
# middle
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
|
out_channels=block_in,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout)
|
|
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
|
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
|
out_channels=block_in,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout)
|
|
|
|
# end
|
|
self.norm_out = Normalize(block_in)
|
|
self.conv_out = ops.Conv2d(block_in,
|
|
2*z_channels if double_z else z_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
def forward(self, x):
|
|
# timestep embedding
|
|
temb = None
|
|
# downsampling
|
|
h = self.conv_in(x)
|
|
for i_level in range(self.num_resolutions):
|
|
for i_block in range(self.num_res_blocks):
|
|
h = self.down[i_level].block[i_block](h, temb)
|
|
if len(self.down[i_level].attn) > 0:
|
|
h = self.down[i_level].attn[i_block](h)
|
|
if i_level != self.num_resolutions-1:
|
|
h = self.down[i_level].downsample(h)
|
|
|
|
# middle
|
|
h = self.mid.block_1(h, temb)
|
|
h = self.mid.attn_1(h)
|
|
h = self.mid.block_2(h, temb)
|
|
|
|
# end
|
|
h = self.norm_out(h)
|
|
h = nonlinearity(h)
|
|
h = self.conv_out(h)
|
|
return h
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
|
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
|
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
|
conv_out_op=ops.Conv2d,
|
|
resnet_op=ResnetBlock,
|
|
attn_op=AttnBlock,
|
|
**ignorekwargs):
|
|
super().__init__()
|
|
if use_linear_attn: attn_type = "linear"
|
|
self.ch = ch
|
|
self.temb_ch = 0
|
|
self.num_resolutions = len(ch_mult)
|
|
self.num_res_blocks = num_res_blocks
|
|
self.resolution = resolution
|
|
self.in_channels = in_channels
|
|
self.give_pre_end = give_pre_end
|
|
self.tanh_out = tanh_out
|
|
|
|
# compute in_ch_mult, block_in and curr_res at lowest res
|
|
in_ch_mult = (1,)+tuple(ch_mult)
|
|
block_in = ch*ch_mult[self.num_resolutions-1]
|
|
curr_res = resolution // 2**(self.num_resolutions-1)
|
|
self.z_shape = (1,z_channels,curr_res,curr_res)
|
|
print("Working with z of shape {} = {} dimensions.".format(
|
|
self.z_shape, np.prod(self.z_shape)))
|
|
|
|
# z to block_in
|
|
self.conv_in = ops.Conv2d(z_channels,
|
|
block_in,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
# middle
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = resnet_op(in_channels=block_in,
|
|
out_channels=block_in,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout)
|
|
self.mid.attn_1 = attn_op(block_in)
|
|
self.mid.block_2 = resnet_op(in_channels=block_in,
|
|
out_channels=block_in,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout)
|
|
|
|
# upsampling
|
|
self.up = nn.ModuleList()
|
|
for i_level in reversed(range(self.num_resolutions)):
|
|
block = nn.ModuleList()
|
|
attn = nn.ModuleList()
|
|
block_out = ch*ch_mult[i_level]
|
|
for i_block in range(self.num_res_blocks+1):
|
|
block.append(resnet_op(in_channels=block_in,
|
|
out_channels=block_out,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout))
|
|
block_in = block_out
|
|
if curr_res in attn_resolutions:
|
|
attn.append(attn_op(block_in))
|
|
up = nn.Module()
|
|
up.block = block
|
|
up.attn = attn
|
|
if i_level != 0:
|
|
up.upsample = Upsample(block_in, resamp_with_conv)
|
|
curr_res = curr_res * 2
|
|
self.up.insert(0, up) # prepend to get consistent order
|
|
|
|
# end
|
|
self.norm_out = Normalize(block_in)
|
|
self.conv_out = conv_out_op(block_in,
|
|
out_ch,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
def forward(self, z, **kwargs):
|
|
#assert z.shape[1:] == self.z_shape[1:]
|
|
self.last_z_shape = z.shape
|
|
|
|
# timestep embedding
|
|
temb = None
|
|
|
|
# z to block_in
|
|
h = self.conv_in(z)
|
|
|
|
# middle
|
|
h = self.mid.block_1(h, temb, **kwargs)
|
|
h = self.mid.attn_1(h, **kwargs)
|
|
h = self.mid.block_2(h, temb, **kwargs)
|
|
|
|
# upsampling
|
|
for i_level in reversed(range(self.num_resolutions)):
|
|
for i_block in range(self.num_res_blocks+1):
|
|
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
|
if len(self.up[i_level].attn) > 0:
|
|
h = self.up[i_level].attn[i_block](h, **kwargs)
|
|
if i_level != 0:
|
|
h = self.up[i_level].upsample(h)
|
|
|
|
# end
|
|
if self.give_pre_end:
|
|
return h
|
|
|
|
h = self.norm_out(h)
|
|
h = nonlinearity(h)
|
|
h = self.conv_out(h, **kwargs)
|
|
if self.tanh_out:
|
|
h = torch.tanh(h)
|
|
return h
|