mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 09:44:02 +00:00
Varous bug fixes. Finalized targeted guidance algo
This commit is contained in:
@@ -47,12 +47,15 @@ class Embedding:
|
||||
self.placeholder_token_ids = []
|
||||
self.embedding_tokens = []
|
||||
|
||||
print(f"Adding {placeholder_tokens} tokens to tokenizer")
|
||||
print(f"Adding {self.embed_config.tokens} tokens to tokenizer")
|
||||
|
||||
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
|
||||
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
|
||||
if num_added_tokens != self.embed_config.tokens:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}"
|
||||
)
|
||||
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
@@ -115,10 +118,10 @@ class Embedding:
|
||||
|
||||
def _set_vec(self, new_vector, text_encoder_idx=0):
|
||||
# shape is (1, 768) for SD 1.5 for 1 token
|
||||
token_embeds = self.text_encoder_list[0].get_input_embeddings().weight.data
|
||||
token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data
|
||||
for i in range(new_vector.shape[0]):
|
||||
# apply the weights to the placeholder tokens while preserving gradient
|
||||
token_embeds[self.placeholder_token_ids[0][i]] = new_vector[i].clone()
|
||||
token_embeds[self.placeholder_token_ids[text_encoder_idx][i]] = new_vector[i].clone()
|
||||
|
||||
# make setter and getter for vec
|
||||
@property
|
||||
@@ -249,30 +252,32 @@ class Embedding:
|
||||
else:
|
||||
return
|
||||
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
if hasattr(param_dict, '_parameters'):
|
||||
param_dict = getattr(param_dict,
|
||||
'_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||
|
||||
emb = next(iter(data.values()))
|
||||
if len(emb.shape) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||
|
||||
if 'step' in data:
|
||||
self.step = int(data['step'])
|
||||
|
||||
if self.sd.is_xl:
|
||||
self.vec = tensors['clip_l'].detach().to(device, dtype=torch.float32)
|
||||
self.vec2 = tensors['clip_g'].detach().to(device, dtype=torch.float32)
|
||||
if 'step' in data:
|
||||
self.step = int(data['step'])
|
||||
else:
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
if hasattr(param_dict, '_parameters'):
|
||||
param_dict = getattr(param_dict,
|
||||
'_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||
|
||||
emb = next(iter(data.values()))
|
||||
if len(emb.shape) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||
|
||||
if 'step' in data:
|
||||
self.step = int(data['step'])
|
||||
|
||||
self.vec = emb.detach().to(device, dtype=torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user