Fixed issue with loadin models after resume function added. Added additional flush if not training text encoder to clear out vram before grad accum

This commit is contained in:
Jaret Burkett
2023-08-28 17:56:30 -06:00
parent b79ced3e10
commit a008d9e63b
3 changed files with 18 additions and 6 deletions

View File

@@ -62,6 +62,11 @@ class SDTrainer(BaseSDTrainProcess):
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
embedding_list.append(embedding)
conditional_embeds = concat_prompt_embeds(embedding_list)
if not grad_on_text_encoder:
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
self.optimizer.zero_grad()
flush()
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),