mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Small fixed for DFE, polar guidance, and other things
This commit is contained in:
@@ -404,13 +404,14 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
|
||||
elif self.dfe.version == 3:
|
||||
dfe_loss = self.dfe(
|
||||
noise=noise,
|
||||
noise_pred=noise_pred,
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
scheduler=self.sd.noise_scheduler
|
||||
)
|
||||
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight
|
||||
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight
|
||||
else:
|
||||
raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}")
|
||||
|
||||
@@ -563,6 +564,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
noise=noise,
|
||||
sd=self.sd,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
train_config=self.train_config,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user