mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added weighing to DFE
This commit is contained in:
@@ -385,7 +385,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# do diffusion feature extraction on prediction
|
||||
pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
|
||||
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean")
|
||||
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \
|
||||
self.train_config.diffusion_feature_extractor_weight
|
||||
|
||||
if target is None:
|
||||
target = noise
|
||||
|
||||
@@ -403,6 +403,7 @@ class TrainConfig:
|
||||
|
||||
# diffusion feature extractor
|
||||
self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None)
|
||||
self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', 0.1)
|
||||
|
||||
# optimal noise pairing
|
||||
self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1)
|
||||
|
||||
Reference in New Issue
Block a user