Bug fixes and improvements to token injection

This commit is contained in:
Jaret Burkett
2023-09-08 06:10:59 -06:00
parent 92a086d5a5
commit ce4f9fe02a
5 changed files with 74 additions and 63 deletions

View File

@@ -63,7 +63,7 @@ class SDTrainer(BaseSDTrainProcess):
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
self.optimizer.zero_grad()
flush()
flush()
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
@@ -71,6 +71,7 @@ class SDTrainer(BaseSDTrainProcess):
timestep=timesteps,
guidance_scale=1.0,
)
flush()
# 9.18 gb
noise = noise.to(self.device_torch, dtype=dtype)