From 1f541bc5d8e45d28618613a4b0ef20675ceb3df9 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 27 Aug 2025 11:05:16 -0600 Subject: [PATCH] Changes to handle a different DFE arch --- extensions_built_in/sd_trainer/SDTrainer.py | 7 +++++-- toolkit/models/diffusion_feature_extraction.py | 8 +++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index e8b6608f..f1d4e829 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -305,8 +305,11 @@ class SDTrainer(BaseSDTrainProcess): # enable gradient checkpointing on the vae if vae is not None and self.train_config.gradient_checkpointing: - vae.enable_gradient_checkpointing() - vae.train() + try: + vae.enable_gradient_checkpointing() + vae.train() + except: + pass def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch): diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 0f7bff09..5edff00d 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -140,13 +140,13 @@ class DFEBlock(nn.Module): class DiffusionFeatureExtractor(nn.Module): - def __init__(self, in_channels=16): + def __init__(self, in_channels=16, out_channels=512): super().__init__() self.version = 1 num_blocks = 6 self.conv_in = nn.Conv2d(in_channels, 512, 1) self.blocks = nn.ModuleList([DFEBlock(512) for _ in range(num_blocks)]) - self.conv_out = nn.Conv2d(512, 512, 1) + self.conv_out = nn.Conv2d(512, out_channels, 1) def forward(self, x): x = self.conv_in(x) @@ -611,7 +611,9 @@ def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor: state_dict = state_dict['model_state_dict'] if 'conv_in.weight' in state_dict: - dfe = DiffusionFeatureExtractor() + # determine num out channels + out_channels = state_dict['conv_out.weight'].shape[0] + dfe = DiffusionFeatureExtractor(out_channels=out_channels) else: dfe = DiffusionFeatureExtractor2()