From 29122b1a5442e1fd789f0771d0f1bd25083f65a7 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 21 Jan 2025 14:21:34 -0700 Subject: [PATCH] Added code to handle diffusion feature extraction loss --- extensions_built_in/sd_trainer/SDTrainer.py | 23 +++++++- toolkit/config_modules.py | 3 + .../models/diffusion_feature_extraction.py | 55 +++++++++++++++++++ 3 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 toolkit/models/diffusion_feature_extraction.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 2a4d051d..a8e0cde7 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -32,6 +32,7 @@ from torchvision import transforms from diffusers import EMAModel import math from toolkit.train_tools import precondition_model_outputs_flow_match +from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe def flush(): @@ -78,6 +79,8 @@ class SDTrainer(BaseSDTrainProcess): self.cached_blank_embeds: Optional[PromptEmbeds] = None self.cached_trigger_embeds: Optional[PromptEmbeds] = None + + self.dfe: Optional[DiffusionFeatureExtractor] = None def before_model_load(self): @@ -178,6 +181,11 @@ class SDTrainer(BaseSDTrainProcess): # move back to cpu self.sd.text_encoder_to('cpu') flush() + + if self.train_config.diffusion_feature_extractor_path is not None: + self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path) + self.dfe.to(self.device_torch) + self.dfe.eval() def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch): @@ -285,6 +293,7 @@ class SDTrainer(BaseSDTrainProcess): ): loss_target = self.train_config.loss_target is_reg = any(batch.get_is_reg_list()) + additional_loss = 0.0 prior_mask_multiplier = None target_mask_multiplier = None @@ -367,7 +376,17 @@ class SDTrainer(BaseSDTrainProcess): target = (noise - batch.latents).detach() else: target = noise - + + if self.dfe is not None: + # do diffusion feature extraction on target + with torch.no_grad(): + rectified_flow_target = noise.float() - batch.latents.float() + target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1)) + + # do diffusion feature extraction on prediction + pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1)) + additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") + if target is None: target = noise @@ -487,7 +506,7 @@ class SDTrainer(BaseSDTrainProcess): loss = loss + norm_std_loss - return loss + return loss + additional_loss def preprocess_batch(self, batch: 'DataLoaderBatchDTO'): return batch diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 65a3691d..592f6e09 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -400,6 +400,9 @@ class TrainConfig: self.paramiter_swapping_factor = kwargs.get('paramiter_swapping_factor', 0.1) # bypass the guidance embedding for training. For open flux with guidance embedding self.bypass_guidance_embedding = kwargs.get('bypass_guidance_embedding', False) + + # diffusion feature extractor + self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None) class ModelConfig: diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py new file mode 100644 index 00000000..1705bdc2 --- /dev/null +++ b/toolkit/models/diffusion_feature_extraction.py @@ -0,0 +1,55 @@ +import torch +import os +from torch import nn +from safetensors.torch import load_file + + +class DFEBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) + self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) + self.act = nn.GELU() + + def forward(self, x): + x_in = x + x = self.conv1(x) + x = self.conv2(x) + x = self.act(x) + x = x + x_in + return x + + +class DiffusionFeatureExtractor(nn.Module): + def __init__(self, in_channels=32): + super().__init__() + num_blocks = 6 + self.conv_in = nn.Conv2d(in_channels, 512, 1) + self.conv_pool = nn.Conv2d(512, 512, 3, stride=2, padding=1) + self.blocks = nn.ModuleList([DFEBlock(512) for _ in range(num_blocks)]) + self.conv_out = nn.Conv2d(512, 512, 1) + + def forward(self, x): + x = self.conv_in(x) + x = self.conv_pool(x) + for block in self.blocks: + x = block(x) + x = self.conv_out(x) + return x + + +def load_dfe(model_path) -> DiffusionFeatureExtractor: + dfe = DiffusionFeatureExtractor() + if not os.path.exists(model_path): + raise FileNotFoundError(f"Model file not found: {model_path}") + # if it ende with safetensors + if model_path.endswith('.safetensors'): + state_dict = load_file(model_path) + else: + state_dict = torch.load(model_path, weights_only=True) + if 'model_state_dict' in state_dict: + state_dict = state_dict['model_state_dict'] + + dfe.load_state_dict(state_dict) + dfe.eval() + return dfe