mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed issue with loadin models after resume function added. Added additional flush if not training text encoder to clear out vram before grad accum
This commit is contained in:
@@ -62,6 +62,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
|
||||
embedding_list.append(embedding)
|
||||
conditional_embeds = concat_prompt_embeds(embedding_list)
|
||||
if not grad_on_text_encoder:
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
self.optimizer.zero_grad()
|
||||
flush()
|
||||
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
|
||||
@@ -86,13 +86,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if embedding_raw is not None:
|
||||
self.embed_config = EmbeddingConfig(**embedding_raw)
|
||||
|
||||
if self.embed_config is None and self.network_config is None:
|
||||
# get the latest checkpoint
|
||||
# check to see if we have a latest save
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
|
||||
# check to see if we have a latest save
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
|
||||
if latest_save_path is not None:
|
||||
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
self.model_config.name_or_path = latest_save_path
|
||||
if latest_save_path is not None:
|
||||
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
self.model_config.name_or_path = latest_save_path
|
||||
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
|
||||
@@ -38,6 +38,12 @@ class PromptEmbeds:
|
||||
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def detach(self):
|
||||
self.text_embeds = self.text_embeds.detach()
|
||||
if self.pooled_embeds is not None:
|
||||
self.pooled_embeds = self.pooled_embeds.detach()
|
||||
return self
|
||||
|
||||
|
||||
class EncodedPromptPair:
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user