Add new version of DFE. Kitchen sink

This commit is contained in:
Jaret Burkett
2025-01-31 11:42:27 -07:00
parent 34a1c6947a
commit 15a57bc89f
4 changed files with 203 additions and 2 deletions

View File

@@ -387,7 +387,7 @@ class SDTrainer(BaseSDTrainProcess):
pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \
self.train_config.diffusion_feature_extractor_weight
else:
elif self.dfe.version == 2:
# version 2
# do diffusion feature extraction on target
with torch.no_grad():
@@ -402,6 +402,17 @@ class SDTrainer(BaseSDTrainProcess):
dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean")
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
elif self.dfe.version == 3:
dfe_loss = self.dfe(
noise_pred=noise_pred,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
scheduler=self.sd.noise_scheduler
)
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight
else:
raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}")
if target is None: