mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added v2 of dfp
This commit is contained in:
@@ -378,15 +378,32 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
target = noise
|
||||
|
||||
if self.dfe is not None:
|
||||
# do diffusion feature extraction on target
|
||||
with torch.no_grad():
|
||||
rectified_flow_target = noise.float() - batch.latents.float()
|
||||
target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1))
|
||||
|
||||
# do diffusion feature extraction on prediction
|
||||
pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
|
||||
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \
|
||||
self.train_config.diffusion_feature_extractor_weight
|
||||
if self.dfe.version == 1:
|
||||
# do diffusion feature extraction on target
|
||||
with torch.no_grad():
|
||||
rectified_flow_target = noise.float() - batch.latents.float()
|
||||
target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1))
|
||||
|
||||
# do diffusion feature extraction on prediction
|
||||
pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
|
||||
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \
|
||||
self.train_config.diffusion_feature_extractor_weight
|
||||
else:
|
||||
# version 2
|
||||
# do diffusion feature extraction on target
|
||||
with torch.no_grad():
|
||||
rectified_flow_target = noise.float() - batch.latents.float()
|
||||
target_feature_list = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1))
|
||||
|
||||
# do diffusion feature extraction on prediction
|
||||
pred_feature_list = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
|
||||
|
||||
dfe_loss = 0.0
|
||||
for i in range(len(target_feature_list)):
|
||||
dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean")
|
||||
|
||||
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
|
||||
|
||||
|
||||
if target is None:
|
||||
target = noise
|
||||
|
||||
Reference in New Issue
Block a user