diff --git a/backend/text_processing/classic_engine.py b/backend/text_processing/classic_engine.py index bd2cf674..3f82fcd6 100644 --- a/backend/text_processing/classic_engine.py +++ b/backend/text_processing/classic_engine.py @@ -6,6 +6,8 @@ from backend.text_processing import parsing, emphasis from backend.text_processing.textual_inversion import EmbeddingDatabase from backend import memory_management +from modules.shared import opts + PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) last_extra_generation_params = {} @@ -67,7 +69,7 @@ class ClassicTextProcessingEngine: self.text_encoder = text_encoder self.tokenizer = tokenizer - self.emphasis = emphasis.get_current_option(emphasis_name)() + self.emphasis = emphasis.get_current_option(opts.emphasis)() self.text_projection = text_projection self.minimal_clip_skip = minimal_clip_skip self.clip_skip = clip_skip @@ -146,7 +148,7 @@ class ClassicTextProcessingEngine: return z def tokenize_line(self, line): - parsed = parsing.parse_prompt_attention(line) + parsed = parsing.parse_prompt_attention(line, self.emphasis.name) tokenized = self.tokenize([text for text, _ in parsed]) @@ -248,6 +250,8 @@ class ClassicTextProcessingEngine: return batch_chunks, token_count def __call__(self, texts): + self.emphasis = emphasis.get_current_option(opts.emphasis)() + batch_chunks, token_count = self.process_texts(texts) used_embeddings = {} diff --git a/backend/text_processing/parsing.py b/backend/text_processing/parsing.py index 07dcdc16..cdc49122 100644 --- a/backend/text_processing/parsing.py +++ b/backend/text_processing/parsing.py @@ -20,7 +20,7 @@ re_attention = re.compile(r""" re_break = re.compile(r"\s*\bBREAK\b\s*", re.S) -def parse_prompt_attention(text): +def parse_prompt_attention(text, emphasis): res = [] round_brackets = [] square_brackets = [] @@ -32,44 +32,48 @@ def parse_prompt_attention(text): for p in range(start_position, len(res)): res[p][1] *= multiplier - for m in re_attention.finditer(text): - text = m.group(0) - weight = m.group(1) + if emphasis == "None": + # interpret literally + res = [[text, 1.0]] + else: + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) - if text.startswith('\\'): - res.append([text[1:], 1.0]) - elif text == '(': - round_brackets.append(len(res)) - elif text == '[': - square_brackets.append(len(res)) - elif weight is not None and round_brackets: - multiply_range(round_brackets.pop(), float(weight)) - elif text == ')' and round_brackets: - multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == ']' and square_brackets: - multiply_range(square_brackets.pop(), square_bracket_multiplier) - else: - parts = re.split(re_break, text) - for i, part in enumerate(parts): - if i > 0: - res.append(["BREAK", -1]) - res.append([part, 1.0]) + if text.startswith('\\'): + res.append([text[1:], 1.0]) + elif text == '(': + round_brackets.append(len(res)) + elif text == '[': + square_brackets.append(len(res)) + elif weight is not None and round_brackets: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ')' and round_brackets: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == ']' and square_brackets: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + parts = re.split(re_break, text) + for i, part in enumerate(parts): + if i > 0: + res.append(["BREAK", -1]) + res.append([part, 1.0]) - for pos in round_brackets: - multiply_range(pos, round_bracket_multiplier) + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) - for pos in square_brackets: - multiply_range(pos, square_bracket_multiplier) + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) - if len(res) == 0: - res = [["", 1.0]] + if len(res) == 0: + res = [["", 1.0]] - i = 0 - while i + 1 < len(res): - if res[i][1] == res[i + 1][1]: - res[i][0] += res[i + 1][0] - res.pop(i + 1) - else: - i += 1 + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 return res diff --git a/backend/text_processing/t5_engine.py b/backend/text_processing/t5_engine.py index a6b3dc47..e00ccccf 100644 --- a/backend/text_processing/t5_engine.py +++ b/backend/text_processing/t5_engine.py @@ -4,6 +4,8 @@ from collections import namedtuple from backend.text_processing import parsing, emphasis from backend import memory_management +from modules.shared import opts + PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) @@ -21,7 +23,7 @@ class T5TextProcessingEngine: self.text_encoder = text_encoder.transformer self.tokenizer = tokenizer - self.emphasis = emphasis.get_current_option(emphasis_name)() + self.emphasis = emphasis.get_current_option(opts.emphasis)() self.min_length = min_length self.id_end = 1 self.id_pad = 0 @@ -64,7 +66,7 @@ class T5TextProcessingEngine: return z def tokenize_line(self, line): - parsed = parsing.parse_prompt_attention(line) + parsed = parsing.parse_prompt_attention(line, self.emphasis.name) tokenized = self.tokenize([text for text, _ in parsed]) @@ -111,6 +113,8 @@ class T5TextProcessingEngine: zs = [] cache = {} + self.emphasis = emphasis.get_current_option(opts.emphasis)() + for line in texts: if line in cache: line_z_values = cache[line]