This commit is contained in:
layerdiffusion
2024-08-01 13:07:42 -07:00
parent 37492057e5
commit fec51bf2dd

View File

@@ -273,7 +273,6 @@ class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
resnet_op=ResnetBlock,
**kwargs):
super().__init__()
self.ch = ch
@@ -298,15 +297,15 @@ class Decoder(nn.Module):
padding=1)
self.mid = nn.Module()
self.mid.block_1 = resnet_op(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = resnet_op(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
@@ -314,10 +313,10 @@ class Decoder(nn.Module):
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(resnet_op(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
@@ -340,15 +339,12 @@ class Decoder(nn.Module):
def forward(self, z, **kwargs):
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb, **kwargs)
h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb, **kwargs)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, **kwargs)
@@ -357,7 +353,6 @@ class Decoder(nn.Module):
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h