diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 4edbf26b..f9a9639e 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -62,6 +62,11 @@ class SDTrainer(BaseSDTrainProcess): embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype) embedding_list.append(embedding) conditional_embeds = concat_prompt_embeds(embedding_list) + if not grad_on_text_encoder: + # detach the embeddings + conditional_embeds = conditional_embeds.detach() + self.optimizer.zero_grad() + flush() noise_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype), diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index f56bd15a..a565dcac 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -86,13 +86,14 @@ class BaseSDTrainProcess(BaseTrainProcess): if embedding_raw is not None: self.embed_config = EmbeddingConfig(**embedding_raw) + if self.embed_config is None and self.network_config is None: + # get the latest checkpoint + # check to see if we have a latest save + latest_save_path = self.get_latest_save_path() - # check to see if we have a latest save - latest_save_path = self.get_latest_save_path() - - if latest_save_path is not None: - print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") - self.model_config.name_or_path = latest_save_path + if latest_save_path is not None: + print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") + self.model_config.name_or_path = latest_save_path self.sd = StableDiffusion( device=self.device, diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index 3ba3e257..d0673f5b 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -38,6 +38,12 @@ class PromptEmbeds: self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs) return self + def detach(self): + self.text_embeds = self.text_embeds.detach() + if self.pooled_embeds is not None: + self.pooled_embeds = self.pooled_embeds.detach() + return self + class EncodedPromptPair: def __init__(