diff --git a/backend/nn/autoencoder_kl.py b/backend/nn/autoencoder_kl.py index 8019c065..5e830fc8 100644 --- a/backend/nn/autoencoder_kl.py +++ b/backend/nn/autoencoder_kl.py @@ -273,7 +273,6 @@ class Decoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, - conv_out_op=nn.Conv2d, resnet_op=ResnetBlock, **kwargs): super().__init__() @@ -332,11 +331,11 @@ class Decoder(nn.Module): # end self.norm_out = Normalize(block_in) - self.conv_out = conv_out_op(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) def forward(self, z, **kwargs): temb = None