mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 09:44:02 +00:00
Tied in ant tested TI script
This commit is contained in:
@@ -25,7 +25,9 @@ class Embedding:
|
||||
):
|
||||
self.name = embed_config.trigger
|
||||
self.sd = sd
|
||||
self.trigger = embed_config.trigger
|
||||
self.embed_config = embed_config
|
||||
self.step = 0
|
||||
# setup our embedding
|
||||
# Add the placeholder token in tokenizer
|
||||
placeholder_tokens = [self.embed_config.trigger]
|
||||
@@ -64,10 +66,7 @@ class Embedding:
|
||||
for initializer_token_id, token_id in zip(init_token_ids, self.placeholder_token_ids):
|
||||
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
|
||||
|
||||
# this doesnt seem to be used again
|
||||
self.token_embeds = token_embeds
|
||||
|
||||
# replace "[name] with this. This triggers it in the text encoder
|
||||
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
|
||||
self.embedding_tokens = " ".join(self.sd.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids))
|
||||
|
||||
# returns the string to have in the prompt to trigger the embedding
|
||||
@@ -86,7 +85,7 @@ class Embedding:
|
||||
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
|
||||
# stack the tokens along batch axis adding that axis
|
||||
new_vector = torch.stack(
|
||||
[token_embeds[token_id].unsqueeze(0) for token_id in self.placeholder_token_ids],
|
||||
[token_embeds[token_id] for token_id in self.placeholder_token_ids],
|
||||
dim=0
|
||||
)
|
||||
return new_vector
|
||||
@@ -100,6 +99,39 @@ class Embedding:
|
||||
token_embeds[self.placeholder_token_ids[i]] = new_vector[i].clone()
|
||||
x = 1
|
||||
|
||||
# diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc
|
||||
# however, on training we don't use that pipeline, so we have to do it ourselves
|
||||
def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
|
||||
output_prompt = prompt
|
||||
default_replacements = [self.name, self.trigger, "[name]", "[trigger]", self.embedding_tokens]
|
||||
|
||||
replace_with = self.embedding_tokens if expand_token else self.trigger
|
||||
if to_replace_list is None:
|
||||
to_replace_list = default_replacements
|
||||
else:
|
||||
to_replace_list += default_replacements
|
||||
|
||||
# remove duplicates
|
||||
to_replace_list = list(set(to_replace_list))
|
||||
|
||||
# replace them all
|
||||
for to_replace in to_replace_list:
|
||||
# replace it
|
||||
output_prompt = output_prompt.replace(to_replace, replace_with)
|
||||
|
||||
# see how many times replace_with is in the prompt
|
||||
num_instances = prompt.count(replace_with)
|
||||
|
||||
if num_instances == 0 and add_if_not_present:
|
||||
# add it to the beginning of the prompt
|
||||
output_prompt = replace_with + " " + output_prompt
|
||||
|
||||
if num_instances > 1:
|
||||
print(
|
||||
f"Warning: {self.name} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
|
||||
return output_prompt
|
||||
|
||||
def save(self, filename):
|
||||
# todo check to see how to get the vector out of the embedding
|
||||
|
||||
@@ -107,7 +139,7 @@ class Embedding:
|
||||
"string_to_token": {"*": 265},
|
||||
"string_to_param": {"*": self.vec},
|
||||
"name": self.name,
|
||||
"step": 0,
|
||||
"step": self.step,
|
||||
# todo get these
|
||||
"sd_checkpoint": None,
|
||||
"sd_checkpoint_name": None,
|
||||
@@ -182,4 +214,7 @@ class Embedding:
|
||||
raise Exception(
|
||||
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||
|
||||
if 'step' in data:
|
||||
self.step = int(data['step'])
|
||||
|
||||
self.vec = emb.detach().to(device, dtype=torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user