Tied in ant tested TI script

This commit is contained in:
Jaret Burkett
2023-08-23 13:26:28 -06:00
parent 2e6c55c720
commit d298240cec
6 changed files with 89 additions and 58 deletions

View File

@@ -106,20 +106,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
prompt = sample_config.prompts[i]
# add embedding if there is one
# note: diffusers will automatically expand the trigger to the number of added tokens
# ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here
if self.embedding is not None:
# replace our name with the embedding
if self.embed_config.trigger in prompt:
# if the trigger is a part of the prompt, replace it with the token ids
prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string())
if self.name in prompt:
# if the name is in the prompt, replace it with the trigger
prompt = prompt.replace(self.name, self.embedding.get_embedding_string())
if "[name]" in prompt:
# in [name] in prompt, replace it with the trigger
prompt = prompt.replace("[name]", self.embedding.get_embedding_string())
if self.embedding.get_embedding_string() not in prompt:
# add it to the beginning of the prompt
prompt = self.embedding.get_embedding_string() + " " + prompt
prompt = self.embedding.inject_embedding_to_prompt(
prompt,
)
gen_img_config_list.append(GenerateImageConfig(
prompt=prompt, # it will autoparse the prompt
@@ -208,6 +200,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
)
self.network.multiplier = prev_multiplier
elif self.embedding is not None:
# set current step
self.embedding.step = self.step_num
# change filename to pt if that is set
if self.embed_config.save_format == "pt":
# replace extension
file_path = os.path.splitext(file_path)[0] + ".pt"
self.embedding.save(file_path)
else:
self.sd.save(
@@ -234,7 +232,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
def before_dataset_load(self):
pass
def hook_train_loop(self, batch=None):
def hook_train_loop(self, batch):
# return loss
return 0.0
@@ -365,6 +363,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
if latest_save_path is not None:
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
# resume state from embedding
self.step_num = self.embedding.step
# set trainable params
params = self.embedding.get_trainable_params()