diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 1705bdc2..e3b42ff3 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -25,13 +25,11 @@ class DiffusionFeatureExtractor(nn.Module): super().__init__() num_blocks = 6 self.conv_in = nn.Conv2d(in_channels, 512, 1) - self.conv_pool = nn.Conv2d(512, 512, 3, stride=2, padding=1) self.blocks = nn.ModuleList([DFEBlock(512) for _ in range(num_blocks)]) self.conv_out = nn.Conv2d(512, 512, 1) def forward(self, x): x = self.conv_in(x) - x = self.conv_pool(x) for block in self.blocks: x = block(x) x = self.conv_out(x)