Update dfe model arch

This commit is contained in:
Jaret Burkett
2025-01-22 10:37:23 -07:00
parent 04abe57c76
commit e1549ad54d

View File

@@ -25,13 +25,11 @@ class DiffusionFeatureExtractor(nn.Module):
super().__init__()
num_blocks = 6
self.conv_in = nn.Conv2d(in_channels, 512, 1)
self.conv_pool = nn.Conv2d(512, 512, 3, stride=2, padding=1)
self.blocks = nn.ModuleList([DFEBlock(512) for _ in range(num_blocks)])
self.conv_out = nn.Conv2d(512, 512, 1)
def forward(self, x):
x = self.conv_in(x)
x = self.conv_pool(x)
for block in self.blocks:
x = block(x)
x = self.conv_out(x)