mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Tied in ant tested TI script
This commit is contained in:
@@ -106,20 +106,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
prompt = sample_config.prompts[i]
|
||||
|
||||
# add embedding if there is one
|
||||
# note: diffusers will automatically expand the trigger to the number of added tokens
|
||||
# ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here
|
||||
if self.embedding is not None:
|
||||
# replace our name with the embedding
|
||||
if self.embed_config.trigger in prompt:
|
||||
# if the trigger is a part of the prompt, replace it with the token ids
|
||||
prompt = prompt.replace(self.embed_config.trigger, self.embedding.get_embedding_string())
|
||||
if self.name in prompt:
|
||||
# if the name is in the prompt, replace it with the trigger
|
||||
prompt = prompt.replace(self.name, self.embedding.get_embedding_string())
|
||||
if "[name]" in prompt:
|
||||
# in [name] in prompt, replace it with the trigger
|
||||
prompt = prompt.replace("[name]", self.embedding.get_embedding_string())
|
||||
if self.embedding.get_embedding_string() not in prompt:
|
||||
# add it to the beginning of the prompt
|
||||
prompt = self.embedding.get_embedding_string() + " " + prompt
|
||||
prompt = self.embedding.inject_embedding_to_prompt(
|
||||
prompt,
|
||||
)
|
||||
|
||||
gen_img_config_list.append(GenerateImageConfig(
|
||||
prompt=prompt, # it will autoparse the prompt
|
||||
@@ -208,6 +200,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
)
|
||||
self.network.multiplier = prev_multiplier
|
||||
elif self.embedding is not None:
|
||||
# set current step
|
||||
self.embedding.step = self.step_num
|
||||
# change filename to pt if that is set
|
||||
if self.embed_config.save_format == "pt":
|
||||
# replace extension
|
||||
file_path = os.path.splitext(file_path)[0] + ".pt"
|
||||
self.embedding.save(file_path)
|
||||
else:
|
||||
self.sd.save(
|
||||
@@ -234,7 +232,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def before_dataset_load(self):
|
||||
pass
|
||||
|
||||
def hook_train_loop(self, batch=None):
|
||||
def hook_train_loop(self, batch):
|
||||
# return loss
|
||||
return 0.0
|
||||
|
||||
@@ -365,6 +363,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if latest_save_path is not None:
|
||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||
|
||||
# resume state from embedding
|
||||
self.step_num = self.embedding.step
|
||||
|
||||
# set trainable params
|
||||
params = self.embedding.get_trainable_params()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user