Added v2 of dfp

This commit is contained in:
Jaret Burkett
2025-01-22 16:32:13 -07:00
parent e1549ad54d
commit bbfba0c188
3 changed files with 175 additions and 38 deletions

View File

@@ -378,15 +378,32 @@ class SDTrainer(BaseSDTrainProcess):
target = noise
if self.dfe is not None:
# do diffusion feature extraction on target
with torch.no_grad():
rectified_flow_target = noise.float() - batch.latents.float()
target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1))
# do diffusion feature extraction on prediction
pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \
self.train_config.diffusion_feature_extractor_weight
if self.dfe.version == 1:
# do diffusion feature extraction on target
with torch.no_grad():
rectified_flow_target = noise.float() - batch.latents.float()
target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1))
# do diffusion feature extraction on prediction
pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \
self.train_config.diffusion_feature_extractor_weight
else:
# version 2
# do diffusion feature extraction on target
with torch.no_grad():
rectified_flow_target = noise.float() - batch.latents.float()
target_feature_list = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1))
# do diffusion feature extraction on prediction
pred_feature_list = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
dfe_loss = 0.0
for i in range(len(target_feature_list)):
dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean")
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
if target is None:
target = noise