Changes to handle a different DFE arch

This commit is contained in:
Jaret Burkett
2025-08-27 11:05:16 -06:00
parent fd13bd73a6
commit 1f541bc5d8
2 changed files with 10 additions and 5 deletions

View File

@@ -305,8 +305,11 @@ class SDTrainer(BaseSDTrainProcess):
# enable gradient checkpointing on the vae
if vae is not None and self.train_config.gradient_checkpointing:
vae.enable_gradient_checkpointing()
vae.train()
try:
vae.enable_gradient_checkpointing()
vae.train()
except:
pass
def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):

View File

@@ -140,13 +140,13 @@ class DFEBlock(nn.Module):
class DiffusionFeatureExtractor(nn.Module):
def __init__(self, in_channels=16):
def __init__(self, in_channels=16, out_channels=512):
super().__init__()
self.version = 1
num_blocks = 6
self.conv_in = nn.Conv2d(in_channels, 512, 1)
self.blocks = nn.ModuleList([DFEBlock(512) for _ in range(num_blocks)])
self.conv_out = nn.Conv2d(512, 512, 1)
self.conv_out = nn.Conv2d(512, out_channels, 1)
def forward(self, x):
x = self.conv_in(x)
@@ -611,7 +611,9 @@ def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
state_dict = state_dict['model_state_dict']
if 'conv_in.weight' in state_dict:
dfe = DiffusionFeatureExtractor()
# determine num out channels
out_channels = state_dict['conv_out.weight'].shape[0]
dfe = DiffusionFeatureExtractor(out_channels=out_channels)
else:
dfe = DiffusionFeatureExtractor2()