mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add new version of DFE. Kitchen sink
This commit is contained in:
@@ -387,7 +387,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
|
||||
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \
|
||||
self.train_config.diffusion_feature_extractor_weight
|
||||
else:
|
||||
elif self.dfe.version == 2:
|
||||
# version 2
|
||||
# do diffusion feature extraction on target
|
||||
with torch.no_grad():
|
||||
@@ -402,6 +402,17 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean")
|
||||
|
||||
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
|
||||
elif self.dfe.version == 3:
|
||||
dfe_loss = self.dfe(
|
||||
noise_pred=noise_pred,
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
scheduler=self.sd.noise_scheduler
|
||||
)
|
||||
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight
|
||||
else:
|
||||
raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}")
|
||||
|
||||
|
||||
if target is None:
|
||||
|
||||
Reference in New Issue
Block a user