mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-11 14:30:03 +00:00
52 lines
1.6 KiB
Python
52 lines
1.6 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import re
|
|
|
|
|
|
WHITESPACE_NORMALIZER = re.compile(r"\s+")
|
|
SPACE = chr(32)
|
|
SPACE_ESCAPE = chr(9601)
|
|
# excluding non-breaking space (160) here
|
|
PRINTABLE_LATIN = set(
|
|
list(range(32, 126 + 1)) + list(range(161, 172 + 1)) + list(range(174, 255 + 1))
|
|
)
|
|
BYTE_TO_BCHAR = {
|
|
b: chr(b) if b in PRINTABLE_LATIN else chr(256 + b) for b in range(256)
|
|
}
|
|
BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
|
|
|
|
|
|
def byte_encode(x: str) -> str:
|
|
normalized = WHITESPACE_NORMALIZER.sub(SPACE, x)
|
|
return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")])
|
|
|
|
|
|
def byte_decode(x: str) -> str:
|
|
try:
|
|
return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8")
|
|
except ValueError:
|
|
return ""
|
|
|
|
|
|
def smart_byte_decode(x: str) -> str:
|
|
output = byte_decode(x)
|
|
if output == "":
|
|
# DP the best recovery (max valid chars) if it's broken
|
|
n_bytes = len(x)
|
|
f = [0 for _ in range(n_bytes + 1)]
|
|
pt = [0 for _ in range(n_bytes + 1)]
|
|
for i in range(1, n_bytes + 1):
|
|
f[i], pt[i] = f[i - 1], i - 1
|
|
for j in range(1, min(4, i) + 1):
|
|
if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0:
|
|
f[i], pt[i] = f[i - j] + 1, i - j
|
|
cur_pt = n_bytes
|
|
while cur_pt > 0:
|
|
if f[cur_pt] == f[pt[cur_pt]] + 1:
|
|
output = byte_decode(x[pt[cur_pt] : cur_pt]) + output
|
|
cur_pt = pt[cur_pt]
|
|
return output
|