Added weighing to DFE

This commit is contained in:
Jaret Burkett
2025-01-22 08:50:57 -07:00
parent 89dd041b97
commit 04abe57c76
2 changed files with 3 additions and 1 deletions

View File

@@ -385,7 +385,8 @@ class SDTrainer(BaseSDTrainProcess):
# do diffusion feature extraction on prediction
pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean")
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \
self.train_config.diffusion_feature_extractor_weight
if target is None:
target = noise

View File

@@ -403,6 +403,7 @@ class TrainConfig:
# diffusion feature extractor
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None)
self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', 0.1)
# optimal noise pairing
self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1)