From 04abe57c76e0e2e72002355f7148de19ef4aed1c Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 22 Jan 2025 08:50:57 -0700 Subject: [PATCH] Added weighing to DFE --- extensions_built_in/sd_trainer/SDTrainer.py | 3 ++- toolkit/config_modules.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index a8e0cde7..5e405333 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -385,7 +385,8 @@ class SDTrainer(BaseSDTrainProcess): # do diffusion feature extraction on prediction pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1)) - additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") + additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \ + self.train_config.diffusion_feature_extractor_weight if target is None: target = noise diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index b4417afc..15dca441 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -403,6 +403,7 @@ class TrainConfig: # diffusion feature extractor self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None) + self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', 0.1) # optimal noise pairing self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1)