From fec51bf2dd0567ddd7c0084b9d4cd787bec24765 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Thu, 1 Aug 2024 13:07:42 -0700 Subject: [PATCH] ling --- backend/nn/autoencoder_kl.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/backend/nn/autoencoder_kl.py b/backend/nn/autoencoder_kl.py index 0255d576..d2f9db6d 100644 --- a/backend/nn/autoencoder_kl.py +++ b/backend/nn/autoencoder_kl.py @@ -273,7 +273,6 @@ class Decoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, - resnet_op=ResnetBlock, **kwargs): super().__init__() self.ch = ch @@ -298,15 +297,15 @@ class Decoder(nn.Module): padding=1) self.mid = nn.Module() - self.mid.block_1 = resnet_op(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = resnet_op(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): @@ -314,10 +313,10 @@ class Decoder(nn.Module): attn = nn.ModuleList() block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): - block.append(resnet_op(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(AttnBlock(block_in)) @@ -340,15 +339,12 @@ class Decoder(nn.Module): def forward(self, z, **kwargs): temb = None - # z to block_in h = self.conv_in(z) - # middle h = self.mid.block_1(h, temb, **kwargs) h = self.mid.attn_1(h, **kwargs) h = self.mid.block_2(h, temb, **kwargs) - # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, temb, **kwargs) @@ -357,7 +353,6 @@ class Decoder(nn.Module): if i_level != 0: h = self.up[i_level].upsample(h) - # end if self.give_pre_end: return h