80 lines
2.0 KiB
Python
Executable File
80 lines
2.0 KiB
Python
Executable File
import re
|
|
|
|
|
|
re_attention = re.compile(r"""
|
|
\\\(|
|
|
\\\)|
|
|
\\\[|
|
|
\\]|
|
|
\\\\|
|
|
\\|
|
|
\(|
|
|
\[|
|
|
:\s*([+-]?[.\d]+)\s*\)|
|
|
\)|
|
|
]|
|
|
[^\\()\[\]:]+|
|
|
:
|
|
""", re.X)
|
|
|
|
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
|
|
|
|
|
|
def parse_prompt_attention(text, emphasis):
|
|
res = []
|
|
round_brackets = []
|
|
square_brackets = []
|
|
|
|
round_bracket_multiplier = 1.1
|
|
square_bracket_multiplier = 1 / 1.1
|
|
|
|
def multiply_range(start_position, multiplier):
|
|
for p in range(start_position, len(res)):
|
|
res[p][1] *= multiplier
|
|
|
|
if emphasis == "None":
|
|
# interpret literally
|
|
res = [[text, 1.0]]
|
|
else:
|
|
for m in re_attention.finditer(text):
|
|
text = m.group(0)
|
|
weight = m.group(1)
|
|
|
|
if text.startswith('\\'):
|
|
res.append([text[1:], 1.0])
|
|
elif text == '(':
|
|
round_brackets.append(len(res))
|
|
elif text == '[':
|
|
square_brackets.append(len(res))
|
|
elif weight is not None and round_brackets:
|
|
multiply_range(round_brackets.pop(), float(weight))
|
|
elif text == ')' and round_brackets:
|
|
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
|
elif text == ']' and square_brackets:
|
|
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
|
else:
|
|
parts = re.split(re_break, text)
|
|
for i, part in enumerate(parts):
|
|
if i > 0:
|
|
res.append(["BREAK", -1])
|
|
res.append([part, 1.0])
|
|
|
|
for pos in round_brackets:
|
|
multiply_range(pos, round_bracket_multiplier)
|
|
|
|
for pos in square_brackets:
|
|
multiply_range(pos, square_bracket_multiplier)
|
|
|
|
if len(res) == 0:
|
|
res = [["", 1.0]]
|
|
|
|
i = 0
|
|
while i + 1 < len(res):
|
|
if res[i][1] == res[i + 1][1]:
|
|
res[i][0] += res[i + 1][0]
|
|
res.pop(i + 1)
|
|
else:
|
|
i += 1
|
|
|
|
return res
|