mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-27 03:19:47 +00:00
Previously emphasis mode changes applied only on model load. Also added handling for mode 'None' (literal interpretation), which previously gave results identical to mode 'Ignore'.
159 lines
4.7 KiB
Python
159 lines
4.7 KiB
Python
import torch
|
|
|
|
from collections import namedtuple
|
|
from backend.text_processing import parsing, emphasis
|
|
from backend import memory_management
|
|
|
|
from modules.shared import opts
|
|
|
|
|
|
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
|
|
|
|
|
class PromptChunk:
|
|
def __init__(self):
|
|
self.tokens = []
|
|
self.multipliers = []
|
|
|
|
|
|
class T5TextProcessingEngine:
|
|
def __init__(self, text_encoder, tokenizer, emphasis_name="Original", min_length=256):
|
|
super().__init__()
|
|
|
|
self.text_encoder = text_encoder.transformer
|
|
self.tokenizer = tokenizer
|
|
|
|
self.emphasis = emphasis.get_current_option(opts.emphasis)()
|
|
self.min_length = min_length
|
|
self.id_end = 1
|
|
self.id_pad = 0
|
|
|
|
vocab = self.tokenizer.get_vocab()
|
|
|
|
self.comma_token = vocab.get(',</w>', None)
|
|
|
|
self.token_mults = {}
|
|
|
|
tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
|
for text, ident in tokens_with_parens:
|
|
mult = 1.0
|
|
for c in text:
|
|
if c == '[':
|
|
mult /= 1.1
|
|
if c == ']':
|
|
mult *= 1.1
|
|
if c == '(':
|
|
mult *= 1.1
|
|
if c == ')':
|
|
mult /= 1.1
|
|
|
|
if mult != 1.0:
|
|
self.token_mults[ident] = mult
|
|
|
|
def tokenize(self, texts):
|
|
tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
|
return tokenized
|
|
|
|
def encode_with_transformers(self, tokens):
|
|
device = memory_management.text_encoder_device()
|
|
tokens = tokens.to(device)
|
|
self.text_encoder.shared.to(device=device, dtype=torch.float32)
|
|
|
|
z = self.text_encoder(
|
|
input_ids=tokens,
|
|
)
|
|
|
|
return z
|
|
|
|
def tokenize_line(self, line):
|
|
parsed = parsing.parse_prompt_attention(line, self.emphasis.name)
|
|
|
|
tokenized = self.tokenize([text for text, _ in parsed])
|
|
|
|
chunks = []
|
|
chunk = PromptChunk()
|
|
token_count = 0
|
|
|
|
def next_chunk():
|
|
nonlocal token_count
|
|
nonlocal chunk
|
|
|
|
chunk.tokens = chunk.tokens + [self.id_end]
|
|
chunk.multipliers = chunk.multipliers + [1.0]
|
|
current_chunk_length = len(chunk.tokens)
|
|
|
|
token_count += current_chunk_length
|
|
remaining_count = self.min_length - current_chunk_length
|
|
|
|
if remaining_count > 0:
|
|
chunk.tokens += [self.id_pad] * remaining_count
|
|
chunk.multipliers += [1.0] * remaining_count
|
|
|
|
chunks.append(chunk)
|
|
chunk = PromptChunk()
|
|
|
|
for tokens, (text, weight) in zip(tokenized, parsed):
|
|
if text == 'BREAK' and weight == -1:
|
|
next_chunk()
|
|
continue
|
|
|
|
position = 0
|
|
while position < len(tokens):
|
|
token = tokens[position]
|
|
chunk.tokens.append(token)
|
|
chunk.multipliers.append(weight)
|
|
position += 1
|
|
|
|
if chunk.tokens or not chunks:
|
|
next_chunk()
|
|
|
|
return chunks, token_count
|
|
|
|
def __call__(self, texts):
|
|
zs = []
|
|
cache = {}
|
|
|
|
self.emphasis = emphasis.get_current_option(opts.emphasis)()
|
|
|
|
for line in texts:
|
|
if line in cache:
|
|
line_z_values = cache[line]
|
|
else:
|
|
chunks, token_count = self.tokenize_line(line)
|
|
line_z_values = []
|
|
|
|
# pad all chunks to length of longest chunk
|
|
max_tokens = 0
|
|
for chunk in chunks:
|
|
max_tokens = max (len(chunk.tokens), max_tokens)
|
|
|
|
for chunk in chunks:
|
|
tokens = chunk.tokens
|
|
multipliers = chunk.multipliers
|
|
|
|
remaining_count = max_tokens - len(tokens)
|
|
if remaining_count > 0:
|
|
tokens += [self.id_pad] * remaining_count
|
|
multipliers += [1.0] * remaining_count
|
|
|
|
z = self.process_tokens([tokens], [multipliers])[0]
|
|
line_z_values.append(z)
|
|
cache[line] = line_z_values
|
|
|
|
zs.extend(line_z_values)
|
|
|
|
return torch.stack(zs)
|
|
|
|
def process_tokens(self, batch_tokens, batch_multipliers):
|
|
tokens = torch.asarray(batch_tokens)
|
|
|
|
z = self.encode_with_transformers(tokens)
|
|
|
|
self.emphasis.tokens = batch_tokens
|
|
self.emphasis.multipliers = torch.asarray(batch_multipliers).to(z)
|
|
self.emphasis.z = z
|
|
self.emphasis.after_transformers()
|
|
z = self.emphasis.z
|
|
|
|
return z
|