mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Update dfe model arch
This commit is contained in:
@@ -25,13 +25,11 @@ class DiffusionFeatureExtractor(nn.Module):
|
||||
super().__init__()
|
||||
num_blocks = 6
|
||||
self.conv_in = nn.Conv2d(in_channels, 512, 1)
|
||||
self.conv_pool = nn.Conv2d(512, 512, 3, stride=2, padding=1)
|
||||
self.blocks = nn.ModuleList([DFEBlock(512) for _ in range(num_blocks)])
|
||||
self.conv_out = nn.Conv2d(512, 512, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
x = self.conv_pool(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.conv_out(x)
|
||||
|
||||
Reference in New Issue
Block a user