Changes to handle a different DFE arch

This commit is contained in:
Jaret Burkett
2025-08-27 11:05:16 -06:00
parent fd13bd73a6
commit 1f541bc5d8
2 changed files with 10 additions and 5 deletions

View File

@@ -140,13 +140,13 @@ class DFEBlock(nn.Module):
class DiffusionFeatureExtractor(nn.Module):
def __init__(self, in_channels=16):
def __init__(self, in_channels=16, out_channels=512):
super().__init__()
self.version = 1
num_blocks = 6
self.conv_in = nn.Conv2d(in_channels, 512, 1)
self.blocks = nn.ModuleList([DFEBlock(512) for _ in range(num_blocks)])
self.conv_out = nn.Conv2d(512, 512, 1)
self.conv_out = nn.Conv2d(512, out_channels, 1)
def forward(self, x):
x = self.conv_in(x)
@@ -611,7 +611,9 @@ def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
state_dict = state_dict['model_state_dict']
if 'conv_in.weight' in state_dict:
dfe = DiffusionFeatureExtractor()
# determine num out channels
out_channels = state_dict['conv_out.weight'].shape[0]
dfe = DiffusionFeatureExtractor(out_channels=out_channels)
else:
dfe = DiffusionFeatureExtractor2()