Added additional config options for custom plugins I needed

This commit is contained in:
Jaret Burkett
2024-01-15 08:31:09 -07:00
parent e190fbaeb8
commit 5276975fb0
7 changed files with 37 additions and 31 deletions

View File

@@ -802,7 +802,7 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.train_text_encoder:
grad_on_text_encoder = True
if self.embedding:
if self.embedding is not None:
grad_on_text_encoder = True
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
@@ -1095,13 +1095,14 @@ class SDTrainer(BaseSDTrainProcess):
else:
with self.timer('predict_unet'):
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype)
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeddings=unconditional_embeds,
timestep=timesteps,
guidance_scale=self.train_config.cfg_scale,
detach_unconditional=False,
**pred_kwargs
)
self.after_unet_predict()