mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added additional config options for custom plugins I needed
This commit is contained in:
@@ -802,7 +802,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.train_text_encoder:
|
||||
grad_on_text_encoder = True
|
||||
|
||||
if self.embedding:
|
||||
if self.embedding is not None:
|
||||
grad_on_text_encoder = True
|
||||
|
||||
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
|
||||
@@ -1095,13 +1095,14 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
else:
|
||||
with self.timer('predict_unet'):
|
||||
if unconditional_embeds is not None:
|
||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype)
|
||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=self.train_config.cfg_scale,
|
||||
detach_unconditional=False,
|
||||
**pred_kwargs
|
||||
)
|
||||
self.after_unet_predict()
|
||||
|
||||
Reference in New Issue
Block a user