Fixed issue with loadin models after resume function added. Added additional flush if not training text encoder to clear out vram before grad accum

This commit is contained in:
Jaret Burkett
2023-08-28 17:56:30 -06:00
parent b79ced3e10
commit a008d9e63b
3 changed files with 18 additions and 6 deletions

View File

@@ -62,6 +62,11 @@ class SDTrainer(BaseSDTrainProcess):
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
embedding_list.append(embedding)
conditional_embeds = concat_prompt_embeds(embedding_list)
if not grad_on_text_encoder:
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
self.optimizer.zero_grad()
flush()
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),

View File

@@ -86,13 +86,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
if embedding_raw is not None:
self.embed_config = EmbeddingConfig(**embedding_raw)
if self.embed_config is None and self.network_config is None:
# get the latest checkpoint
# check to see if we have a latest save
latest_save_path = self.get_latest_save_path()
# check to see if we have a latest save
latest_save_path = self.get_latest_save_path()
if latest_save_path is not None:
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
self.model_config.name_or_path = latest_save_path
if latest_save_path is not None:
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
self.model_config.name_or_path = latest_save_path
self.sd = StableDiffusion(
device=self.device,

View File

@@ -38,6 +38,12 @@ class PromptEmbeds:
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
return self
def detach(self):
self.text_embeds = self.text_embeds.detach()
if self.pooled_embeds is not None:
self.pooled_embeds = self.pooled_embeds.detach()
return self
class EncodedPromptPair:
def __init__(