diff --git a/backend/text_processing/engine.py b/backend/text_processing/engine.py new file mode 100644 index 00000000..d7a87963 --- /dev/null +++ b/backend/text_processing/engine.py @@ -0,0 +1,269 @@ +import math +from collections import namedtuple + +import torch + +from backend.text_processing import parsing, emphasis + + +PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) + + +class PromptChunk: + def __init__(self): + self.tokens = [] + self.multipliers = [] + self.fixes = [] + + +class ClassicTextProcessingEngine(torch.nn.Module): + def __init__(self, wrapped, hijack): + super().__init__() + self.chunk_length = 75 + + self.is_trainable = False + self.input_key = 'txt' + self.return_pooled = False + + self.comma_token = None + + self.hijack = hijack + + self.wrapped = wrapped + + self.is_trainable = getattr(wrapped, 'is_trainable', False) + self.input_key = getattr(wrapped, 'input_key', 'txt') + self.return_pooled = getattr(self.wrapped, 'return_pooled', False) + + self.legacy_ucg_val = None # for sgm codebase + + self.tokenizer = wrapped.tokenizer + + vocab = self.tokenizer.get_vocab() + + self.comma_token = vocab.get(',', None) + + self.token_mults = {} + + tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k] + for text, ident in tokens_with_parens: + mult = 1.0 + for c in text: + if c == '[': + mult /= 1.1 + if c == ']': + mult *= 1.1 + if c == '(': + mult *= 1.1 + if c == ')': + mult /= 1.1 + + if mult != 1.0: + self.token_mults[ident] = mult + + self.id_start = self.wrapped.tokenizer.bos_token_id + self.id_end = self.wrapped.tokenizer.eos_token_id + self.id_pad = self.id_end + + def empty_chunk(self): + chunk = PromptChunk() + chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1) + chunk.multipliers = [1.0] * (self.chunk_length + 2) + return chunk + + def get_target_prompt_token_count(self, token_count): + return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length + + def tokenize(self, texts): + tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] + + return tokenized + + def encode_with_transformers(self, tokens): + raise NotImplementedError + + def encode_embedding_init_text(self, init_text, nvpt): + embedding_layer = self.wrapped.transformer.text_model.embeddings + ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] + embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0) + return embedded + + def tokenize_line(self, line): + parsed = parsing.parse_prompt_attention(line) + + tokenized = self.tokenize([text for text, _ in parsed]) + + chunks = [] + chunk = PromptChunk() + token_count = 0 + last_comma = -1 + + def next_chunk(is_last=False): + """puts current chunk into the list of results and produces the next one - empty; + if is_last is true, tokens tokens at the end won't add to token_count""" + nonlocal token_count + nonlocal last_comma + nonlocal chunk + + if is_last: + token_count += len(chunk.tokens) + else: + token_count += self.chunk_length + + to_add = self.chunk_length - len(chunk.tokens) + if to_add > 0: + chunk.tokens += [self.id_end] * to_add + chunk.multipliers += [1.0] * to_add + + chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end] + chunk.multipliers = [1.0] + chunk.multipliers + [1.0] + + last_comma = -1 + chunks.append(chunk) + chunk = PromptChunk() + + for tokens, (text, weight) in zip(tokenized, parsed): + if text == 'BREAK' and weight == -1: + next_chunk() + continue + + position = 0 + while position < len(tokens): + token = tokens[position] + + comma_padding_backtrack = 20 + + if token == self.comma_token: + last_comma = len(chunk.tokens) + + elif comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= comma_padding_backtrack: + break_location = last_comma + 1 + + reloc_tokens = chunk.tokens[break_location:] + reloc_mults = chunk.multipliers[break_location:] + + chunk.tokens = chunk.tokens[:break_location] + chunk.multipliers = chunk.multipliers[:break_location] + + next_chunk() + chunk.tokens = reloc_tokens + chunk.multipliers = reloc_mults + + if len(chunk.tokens) == self.chunk_length: + next_chunk() + + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position) + if embedding is None: + chunk.tokens.append(token) + chunk.multipliers.append(weight) + position += 1 + continue + + emb_len = int(embedding.vectors) + if len(chunk.tokens) + emb_len > self.chunk_length: + next_chunk() + + chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding)) + + chunk.tokens += [0] * emb_len + chunk.multipliers += [weight] * emb_len + position += embedding_length_in_tokens + + if chunk.tokens or not chunks: + next_chunk(is_last=True) + + return chunks, token_count + + def process_texts(self, texts): + token_count = 0 + + cache = {} + batch_chunks = [] + for line in texts: + if line in cache: + chunks = cache[line] + else: + chunks, current_token_count = self.tokenize_line(line) + token_count = max(current_token_count, token_count) + + cache[line] = chunks + + batch_chunks.append(chunks) + + return batch_chunks, token_count + + def forward(self, texts): + batch_chunks, token_count = self.process_texts(texts) + + used_embeddings = {} + chunk_count = max([len(x) for x in batch_chunks]) + + zs = [] + for i in range(chunk_count): + batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks] + + tokens = [x.tokens for x in batch_chunk] + multipliers = [x.multipliers for x in batch_chunk] + self.hijack.fixes = [x.fixes for x in batch_chunk] + + for fixes in self.hijack.fixes: + for _position, embedding in fixes: + used_embeddings[embedding.name] = embedding + + z = self.process_tokens(tokens, multipliers) + zs.append(z) + + if used_embeddings: + for name, embedding in used_embeddings.items(): + print(f'Used Embedding: {name}') + + # Todo: + # if opts.textual_inversion_add_hashes_to_infotext and used_embeddings: + # hashes = [] + # for name, embedding in used_embeddings.items(): + # shorthash = embedding.shorthash + # if not shorthash: + # continue + # + # name = name.replace(":", "").replace(",", "") + # hashes.append(f"{name}: {shorthash}") + # + # if hashes: + # if self.hijack.extra_generation_params.get("TI hashes"): + # hashes.append(self.hijack.extra_generation_params.get("TI hashes")) + # self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes) + # + # if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original": + # self.hijack.extra_generation_params["Emphasis"] = opts.emphasis + + if self.return_pooled: + return torch.hstack(zs), zs[0].pooled + else: + return torch.hstack(zs) + + def process_tokens(self, remade_batch_tokens, batch_multipliers): + tokens = torch.asarray(remade_batch_tokens) + + if self.id_end != self.id_pad: + for batch_pos in range(len(remade_batch_tokens)): + index = remade_batch_tokens[batch_pos].index(self.id_end) + tokens[batch_pos, index + 1:tokens.shape[1]] = self.id_pad + + z = self.encode_with_transformers(tokens) + + pooled = getattr(z, 'pooled', None) + + # Todo + # e = emphasis.get_current_option(opts.emphasis)() + + e = emphasis.EmphasisOriginal() + e.tokens = remade_batch_tokens + e.multipliers = torch.asarray(batch_multipliers) + e.z = z + e.after_transformers() + z = e.z + + if pooled is not None: + z.pooled = pooled + + return z diff --git a/backend/text_processing/parsing.py b/backend/text_processing/parsing.py new file mode 100644 index 00000000..07dcdc16 --- /dev/null +++ b/backend/text_processing/parsing.py @@ -0,0 +1,75 @@ +import re + + +re_attention = re.compile(r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:\s*([+-]?[.\d]+)\s*\)| +\)| +]| +[^\\()\[\]:]+| +: +""", re.X) + +re_break = re.compile(r"\s*\bBREAK\b\s*", re.S) + + +def parse_prompt_attention(text): + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith('\\'): + res.append([text[1:], 1.0]) + elif text == '(': + round_brackets.append(len(res)) + elif text == '[': + square_brackets.append(len(res)) + elif weight is not None and round_brackets: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ')' and round_brackets: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == ']' and square_brackets: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + parts = re.split(re_break, text) + for i, part in enumerate(parts): + if i > 0: + res.append(["BREAK", -1]) + res.append([part, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res