Tied in ant tested TI script

This commit is contained in:
Jaret Burkett
2023-08-23 13:26:28 -06:00
parent 2e6c55c720
commit d298240cec
6 changed files with 89 additions and 58 deletions

View File

@@ -25,7 +25,9 @@ class Embedding:
):
self.name = embed_config.trigger
self.sd = sd
self.trigger = embed_config.trigger
self.embed_config = embed_config
self.step = 0
# setup our embedding
# Add the placeholder token in tokenizer
placeholder_tokens = [self.embed_config.trigger]
@@ -64,10 +66,7 @@ class Embedding:
for initializer_token_id, token_id in zip(init_token_ids, self.placeholder_token_ids):
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
# this doesnt seem to be used again
self.token_embeds = token_embeds
# replace "[name] with this. This triggers it in the text encoder
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
self.embedding_tokens = " ".join(self.sd.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids))
# returns the string to have in the prompt to trigger the embedding
@@ -86,7 +85,7 @@ class Embedding:
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
# stack the tokens along batch axis adding that axis
new_vector = torch.stack(
[token_embeds[token_id].unsqueeze(0) for token_id in self.placeholder_token_ids],
[token_embeds[token_id] for token_id in self.placeholder_token_ids],
dim=0
)
return new_vector
@@ -100,6 +99,39 @@ class Embedding:
token_embeds[self.placeholder_token_ids[i]] = new_vector[i].clone()
x = 1
# diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc
# however, on training we don't use that pipeline, so we have to do it ourselves
def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
output_prompt = prompt
default_replacements = [self.name, self.trigger, "[name]", "[trigger]", self.embedding_tokens]
replace_with = self.embedding_tokens if expand_token else self.trigger
if to_replace_list is None:
to_replace_list = default_replacements
else:
to_replace_list += default_replacements
# remove duplicates
to_replace_list = list(set(to_replace_list))
# replace them all
for to_replace in to_replace_list:
# replace it
output_prompt = output_prompt.replace(to_replace, replace_with)
# see how many times replace_with is in the prompt
num_instances = prompt.count(replace_with)
if num_instances == 0 and add_if_not_present:
# add it to the beginning of the prompt
output_prompt = replace_with + " " + output_prompt
if num_instances > 1:
print(
f"Warning: {self.name} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
return output_prompt
def save(self, filename):
# todo check to see how to get the vector out of the embedding
@@ -107,7 +139,7 @@ class Embedding:
"string_to_token": {"*": 265},
"string_to_param": {"*": self.vec},
"name": self.name,
"step": 0,
"step": self.step,
# todo get these
"sd_checkpoint": None,
"sd_checkpoint_name": None,
@@ -182,4 +214,7 @@ class Embedding:
raise Exception(
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
if 'step' in data:
self.step = int(data['step'])
self.vec = emb.detach().to(device, dtype=torch.float32)