265 lines
8.9 KiB
Python
Executable File
265 lines
8.9 KiB
Python
Executable File
import os
|
|
import torch
|
|
import base64
|
|
import json
|
|
import zlib
|
|
import numpy as np
|
|
import safetensors.torch
|
|
|
|
from PIL import Image
|
|
|
|
|
|
class EmbeddingEncoder(json.JSONEncoder):
|
|
def default(self, obj):
|
|
if isinstance(obj, torch.Tensor):
|
|
return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
|
|
return json.JSONEncoder.default(self, obj)
|
|
|
|
|
|
class EmbeddingDecoder(json.JSONDecoder):
|
|
def __init__(self, *args, **kwargs):
|
|
json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
|
|
|
|
def object_hook(self, d):
|
|
if 'TORCHTENSOR' in d:
|
|
return torch.from_numpy(np.array(d['TORCHTENSOR']))
|
|
return d
|
|
|
|
|
|
def embedding_to_b64(data):
|
|
d = json.dumps(data, cls=EmbeddingEncoder)
|
|
return base64.b64encode(d.encode())
|
|
|
|
|
|
def embedding_from_b64(data):
|
|
d = base64.b64decode(data)
|
|
return json.loads(d, cls=EmbeddingDecoder)
|
|
|
|
|
|
def lcg(m=2 ** 32, a=1664525, c=1013904223, seed=0):
|
|
while True:
|
|
seed = (a * seed + c) % m
|
|
yield seed % 255
|
|
|
|
|
|
def xor_block(block):
|
|
g = lcg()
|
|
randblock = np.array([next(g) for _ in range(np.prod(block.shape))]).astype(np.uint8).reshape(block.shape)
|
|
return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
|
|
|
|
|
|
def crop_black(img, tol=0):
|
|
mask = (img > tol).all(2)
|
|
mask0, mask1 = mask.any(0), mask.any(1)
|
|
col_start, col_end = mask0.argmax(), mask.shape[1] - mask0[::-1].argmax()
|
|
row_start, row_end = mask1.argmax(), mask.shape[0] - mask1[::-1].argmax()
|
|
return img[row_start:row_end, col_start:col_end]
|
|
|
|
|
|
def extract_image_data_embed(image):
|
|
d = 3
|
|
outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
|
|
black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
|
|
if black_cols[0].shape[0] < 2:
|
|
print(f'{os.path.basename(getattr(image, "filename", "unknown image file"))}: no embedded information found.')
|
|
return None
|
|
|
|
data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
|
|
data_block_upper = outarr[:, black_cols[0].max() + 1:, :].astype(np.uint8)
|
|
|
|
data_block_lower = xor_block(data_block_lower)
|
|
data_block_upper = xor_block(data_block_upper)
|
|
|
|
data_block = (data_block_upper << 4) | (data_block_lower)
|
|
data_block = data_block.flatten().tobytes()
|
|
|
|
data = zlib.decompress(data_block)
|
|
return json.loads(data, cls=EmbeddingDecoder)
|
|
|
|
|
|
class Embedding:
|
|
def __init__(self, vec, name, step=None):
|
|
self.vec = vec
|
|
self.name = name
|
|
self.step = step
|
|
self.shape = None
|
|
self.vectors = 0
|
|
self.sd_checkpoint = None
|
|
self.sd_checkpoint_name = None
|
|
|
|
|
|
class DirWithTextualInversionEmbeddings:
|
|
def __init__(self, path):
|
|
self.path = path
|
|
self.mtime = None
|
|
|
|
def has_changed(self):
|
|
if not os.path.isdir(self.path):
|
|
return False
|
|
|
|
mt = os.path.getmtime(self.path)
|
|
if self.mtime is None or mt > self.mtime:
|
|
return True
|
|
|
|
def update(self):
|
|
if not os.path.isdir(self.path):
|
|
return
|
|
|
|
self.mtime = os.path.getmtime(self.path)
|
|
|
|
|
|
class EmbeddingDatabase:
|
|
def __init__(self, tokenizer, expected_shape=-1):
|
|
self.ids_lookup = {}
|
|
self.word_embeddings = {}
|
|
self.embedding_dirs = {}
|
|
self.skipped_embeddings = {}
|
|
self.expected_shape = expected_shape
|
|
self.tokenizer = tokenizer
|
|
self.fixes = []
|
|
|
|
def add_embedding_dir(self, path):
|
|
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
|
|
|
|
def clear_embedding_dirs(self):
|
|
self.embedding_dirs.clear()
|
|
|
|
def register_embedding(self, embedding):
|
|
return self.register_embedding_by_name(embedding, embedding.name)
|
|
|
|
def register_embedding_by_name(self, embedding, name):
|
|
ids = self.tokenizer([name], truncation=False, add_special_tokens=False)["input_ids"][0]
|
|
first_id = ids[0]
|
|
if first_id not in self.ids_lookup:
|
|
self.ids_lookup[first_id] = []
|
|
if name in self.word_embeddings:
|
|
lookup = [x for x in self.ids_lookup[first_id] if x[1].name != name]
|
|
else:
|
|
lookup = self.ids_lookup[first_id]
|
|
if embedding is not None:
|
|
lookup += [(ids, embedding)]
|
|
self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
|
|
if embedding is None:
|
|
if name in self.word_embeddings:
|
|
del self.word_embeddings[name]
|
|
if len(self.ids_lookup[first_id]) == 0:
|
|
del self.ids_lookup[first_id]
|
|
return None
|
|
self.word_embeddings[name] = embedding
|
|
return embedding
|
|
|
|
def load_from_file(self, path, filename):
|
|
name, ext = os.path.splitext(filename)
|
|
ext = ext.upper()
|
|
|
|
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
|
_, second_ext = os.path.splitext(name)
|
|
if second_ext.upper() == '.PREVIEW':
|
|
return
|
|
|
|
embed_image = Image.open(path)
|
|
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
|
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
|
name = data.get('name', name)
|
|
else:
|
|
data = extract_image_data_embed(embed_image)
|
|
if data:
|
|
name = data.get('name', name)
|
|
else:
|
|
return
|
|
elif ext in ['.BIN', '.PT']:
|
|
data = torch.load(path, map_location="cpu")
|
|
elif ext in ['.SAFETENSORS']:
|
|
data = safetensors.torch.load_file(path, device="cpu")
|
|
else:
|
|
return
|
|
|
|
if data is not None:
|
|
embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)
|
|
|
|
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
|
self.register_embedding(embedding)
|
|
else:
|
|
self.skipped_embeddings[name] = embedding
|
|
else:
|
|
print(f"Unable to load Textual inversion embedding due to data issue: '{name}'.")
|
|
|
|
def load_from_dir(self, embdir):
|
|
if not os.path.isdir(embdir.path):
|
|
return
|
|
|
|
for root, _, fns in os.walk(embdir.path, followlinks=True):
|
|
for fn in fns:
|
|
try:
|
|
fullfn = os.path.join(root, fn)
|
|
|
|
if os.stat(fullfn).st_size == 0:
|
|
continue
|
|
|
|
self.load_from_file(fullfn, fn)
|
|
except Exception:
|
|
print(f"Error loading embedding {fn}")
|
|
continue
|
|
|
|
def load_textual_inversion_embeddings(self):
|
|
self.ids_lookup.clear()
|
|
self.word_embeddings.clear()
|
|
self.skipped_embeddings.clear()
|
|
|
|
for embdir in self.embedding_dirs.values():
|
|
self.load_from_dir(embdir)
|
|
embdir.update()
|
|
|
|
return
|
|
|
|
def find_embedding_at_position(self, tokens, offset):
|
|
token = tokens[offset]
|
|
possible_matches = self.ids_lookup.get(token, None)
|
|
|
|
if possible_matches is None:
|
|
return None, None
|
|
|
|
for ids, embedding in possible_matches:
|
|
if tokens[offset:offset + len(ids)] == ids:
|
|
return embedding, len(ids)
|
|
|
|
return None, None
|
|
|
|
|
|
def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):
|
|
if 'string_to_param' in data: # textual inversion embeddings
|
|
param_dict = data['string_to_param']
|
|
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
|
emb = next(iter(param_dict.items()))[1]
|
|
vec = emb.detach().to(dtype=torch.float32)
|
|
shape = vec.shape[-1]
|
|
vectors = vec.shape[0]
|
|
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
|
|
vec = {k: v.detach().to(dtype=torch.float32) for k, v in data.items()}
|
|
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
|
|
vectors = data['clip_g'].shape[0]
|
|
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
|
|
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
|
|
|
emb = next(iter(data.values()))
|
|
if len(emb.shape) == 1:
|
|
emb = emb.unsqueeze(0)
|
|
vec = emb.detach().to(dtype=torch.float32)
|
|
shape = vec.shape[-1]
|
|
vectors = vec.shape[0]
|
|
else:
|
|
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
|
|
|
embedding = Embedding(vec, name)
|
|
embedding.step = data.get('step', None)
|
|
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
|
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
|
embedding.vectors = vectors
|
|
embedding.shape = shape
|
|
|
|
if filepath:
|
|
embedding.filename = filepath
|
|
|
|
return embedding
|