mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-22 06:04:26 +00:00
402 lines
13 KiB
Python
402 lines
13 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import os
|
|
from collections import Counter
|
|
from multiprocessing import Pool
|
|
|
|
import torch
|
|
from fairseq import utils
|
|
from fairseq.data import data_utils
|
|
from fairseq.file_chunker_utils import Chunker, find_offsets
|
|
from fairseq.file_io import PathManager
|
|
from fairseq.tokenizer import tokenize_line
|
|
|
|
|
|
class Dictionary:
|
|
"""A mapping from symbols to consecutive integers"""
|
|
|
|
def __init__(
|
|
self,
|
|
*, # begin keyword-only arguments
|
|
bos="<s>",
|
|
pad="<pad>",
|
|
eos="</s>",
|
|
unk="<unk>",
|
|
extra_special_symbols=None,
|
|
):
|
|
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
|
|
self.symbols = []
|
|
self.count = []
|
|
self.indices = {}
|
|
self.bos_index = self.add_symbol(bos)
|
|
self.pad_index = self.add_symbol(pad)
|
|
self.eos_index = self.add_symbol(eos)
|
|
self.unk_index = self.add_symbol(unk)
|
|
if extra_special_symbols:
|
|
for s in extra_special_symbols:
|
|
self.add_symbol(s)
|
|
self.nspecial = len(self.symbols)
|
|
|
|
def __eq__(self, other):
|
|
return self.indices == other.indices
|
|
|
|
def __getitem__(self, idx):
|
|
if idx < len(self.symbols):
|
|
return self.symbols[idx]
|
|
return self.unk_word
|
|
|
|
def get_count(self, idx):
|
|
return self.count[idx]
|
|
|
|
def __len__(self):
|
|
"""Returns the number of symbols in the dictionary"""
|
|
return len(self.symbols)
|
|
|
|
def __contains__(self, sym):
|
|
return sym in self.indices
|
|
|
|
def index(self, sym):
|
|
"""Returns the index of the specified symbol"""
|
|
assert isinstance(sym, str)
|
|
if sym in self.indices:
|
|
return self.indices[sym]
|
|
return self.unk_index
|
|
|
|
def string(
|
|
self,
|
|
tensor,
|
|
bpe_symbol=None,
|
|
escape_unk=False,
|
|
extra_symbols_to_ignore=None,
|
|
unk_string=None,
|
|
include_eos=False,
|
|
separator=" ",
|
|
):
|
|
"""Helper for converting a tensor of token indices to a string.
|
|
|
|
Can optionally remove BPE symbols or escape <unk> words.
|
|
"""
|
|
if torch.is_tensor(tensor) and tensor.dim() == 2:
|
|
return "\n".join(
|
|
self.string(
|
|
t,
|
|
bpe_symbol,
|
|
escape_unk,
|
|
extra_symbols_to_ignore,
|
|
include_eos=include_eos,
|
|
)
|
|
for t in tensor
|
|
)
|
|
|
|
extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
|
|
if not include_eos:
|
|
extra_symbols_to_ignore.add(self.eos())
|
|
|
|
def token_string(i):
|
|
if i == self.unk():
|
|
if unk_string is not None:
|
|
return unk_string
|
|
else:
|
|
return self.unk_string(escape_unk)
|
|
else:
|
|
return self[i]
|
|
|
|
if hasattr(self, "bos_index"):
|
|
extra_symbols_to_ignore.add(self.bos())
|
|
|
|
sent = separator.join(
|
|
token_string(i)
|
|
for i in tensor
|
|
if utils.item(i) not in extra_symbols_to_ignore
|
|
)
|
|
|
|
return data_utils.post_process(sent, bpe_symbol)
|
|
|
|
def unk_string(self, escape=False):
|
|
"""Return unknown string, optionally escaped as: <<unk>>"""
|
|
if escape:
|
|
return "<{}>".format(self.unk_word)
|
|
else:
|
|
return self.unk_word
|
|
|
|
def add_symbol(self, word, n=1, overwrite=False):
|
|
"""Adds a word to the dictionary"""
|
|
if word in self.indices and not overwrite:
|
|
idx = self.indices[word]
|
|
self.count[idx] = self.count[idx] + n
|
|
return idx
|
|
else:
|
|
idx = len(self.symbols)
|
|
self.indices[word] = idx
|
|
self.symbols.append(word)
|
|
self.count.append(n)
|
|
return idx
|
|
|
|
def update(self, new_dict):
|
|
"""Updates counts from new dictionary."""
|
|
for word in new_dict.symbols:
|
|
idx2 = new_dict.indices[word]
|
|
if word in self.indices:
|
|
idx = self.indices[word]
|
|
self.count[idx] = self.count[idx] + new_dict.count[idx2]
|
|
else:
|
|
idx = len(self.symbols)
|
|
self.indices[word] = idx
|
|
self.symbols.append(word)
|
|
self.count.append(new_dict.count[idx2])
|
|
|
|
def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
|
|
"""Sort symbols by frequency in descending order, ignoring special ones.
|
|
|
|
Args:
|
|
- threshold defines the minimum word count
|
|
- nwords defines the total number of words in the final dictionary,
|
|
including special symbols
|
|
- padding_factor can be used to pad the dictionary size to be a
|
|
multiple of 8, which is important on some hardware (e.g., Nvidia
|
|
Tensor Cores).
|
|
"""
|
|
if nwords <= 0:
|
|
nwords = len(self)
|
|
|
|
new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial)))
|
|
new_symbols = self.symbols[: self.nspecial]
|
|
new_count = self.count[: self.nspecial]
|
|
|
|
c = Counter(
|
|
dict(
|
|
sorted(zip(self.symbols[self.nspecial :], self.count[self.nspecial :]))
|
|
)
|
|
)
|
|
for symbol, count in c.most_common(nwords - self.nspecial):
|
|
if count >= threshold:
|
|
new_indices[symbol] = len(new_symbols)
|
|
new_symbols.append(symbol)
|
|
new_count.append(count)
|
|
else:
|
|
break
|
|
|
|
assert len(new_symbols) == len(new_indices)
|
|
|
|
self.count = list(new_count)
|
|
self.symbols = list(new_symbols)
|
|
self.indices = new_indices
|
|
|
|
self.pad_to_multiple_(padding_factor)
|
|
|
|
def pad_to_multiple_(self, padding_factor):
|
|
"""Pad Dictionary size to be a multiple of *padding_factor*."""
|
|
if padding_factor > 1:
|
|
i = 0
|
|
while len(self) % padding_factor != 0:
|
|
symbol = "madeupword{:04d}".format(i)
|
|
self.add_symbol(symbol, n=0)
|
|
i += 1
|
|
|
|
def bos(self):
|
|
"""Helper to get index of beginning-of-sentence symbol"""
|
|
return self.bos_index
|
|
|
|
def pad(self):
|
|
"""Helper to get index of pad symbol"""
|
|
return self.pad_index
|
|
|
|
def eos(self):
|
|
"""Helper to get index of end-of-sentence symbol"""
|
|
return self.eos_index
|
|
|
|
def unk(self):
|
|
"""Helper to get index of unk symbol"""
|
|
return self.unk_index
|
|
|
|
@classmethod
|
|
def load(cls, f):
|
|
"""Loads the dictionary from a text file with the format:
|
|
|
|
```
|
|
<symbol0> <count0>
|
|
<symbol1> <count1>
|
|
...
|
|
```
|
|
"""
|
|
d = cls()
|
|
d.add_from_file(f)
|
|
return d
|
|
|
|
def add_from_file(self, f):
|
|
"""
|
|
Loads a pre-existing dictionary from a text file and adds its symbols
|
|
to this instance.
|
|
"""
|
|
if isinstance(f, str):
|
|
try:
|
|
with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd:
|
|
self.add_from_file(fd)
|
|
except FileNotFoundError as fnfe:
|
|
raise fnfe
|
|
except UnicodeError:
|
|
raise Exception(
|
|
"Incorrect encoding detected in {}, please "
|
|
"rebuild the dataset".format(f)
|
|
)
|
|
return
|
|
|
|
lines = f.readlines()
|
|
indices_start_line = self._load_meta(lines)
|
|
|
|
for line in lines[indices_start_line:]:
|
|
try:
|
|
line, field = line.rstrip().rsplit(" ", 1)
|
|
if field == "#fairseq:overwrite":
|
|
overwrite = True
|
|
line, field = line.rsplit(" ", 1)
|
|
else:
|
|
overwrite = False
|
|
count = int(field)
|
|
word = line
|
|
if word in self and not overwrite:
|
|
raise RuntimeError(
|
|
"Duplicate word found when loading Dictionary: '{}'. "
|
|
"Duplicate words can overwrite earlier ones by adding the "
|
|
"#fairseq:overwrite flag at the end of the corresponding row "
|
|
"in the dictionary file. If using the Camembert model, please "
|
|
"download an updated copy of the model file.".format(word)
|
|
)
|
|
self.add_symbol(word, n=count, overwrite=overwrite)
|
|
except ValueError:
|
|
raise ValueError(
|
|
f"Incorrect dictionary format, expected '<token> <cnt> [flags]': \"{line}\""
|
|
)
|
|
|
|
def _save(self, f, kv_iterator):
|
|
if isinstance(f, str):
|
|
PathManager.mkdirs(os.path.dirname(f))
|
|
with PathManager.open(f, "w", encoding="utf-8") as fd:
|
|
return self.save(fd)
|
|
for k, v in kv_iterator:
|
|
print("{} {}".format(k, v), file=f)
|
|
|
|
def _get_meta(self):
|
|
return [], []
|
|
|
|
def _load_meta(self, lines):
|
|
return 0
|
|
|
|
def save(self, f):
|
|
"""Stores dictionary into a text file"""
|
|
ex_keys, ex_vals = self._get_meta()
|
|
self._save(
|
|
f,
|
|
zip(
|
|
ex_keys + self.symbols[self.nspecial :],
|
|
ex_vals + self.count[self.nspecial :],
|
|
),
|
|
)
|
|
|
|
def dummy_sentence(self, length):
|
|
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
|
|
t[-1] = self.eos()
|
|
return t
|
|
|
|
def encode_line(
|
|
self,
|
|
line,
|
|
line_tokenizer=tokenize_line,
|
|
add_if_not_exist=True,
|
|
consumer=None,
|
|
append_eos=True,
|
|
reverse_order=False,
|
|
) -> torch.IntTensor:
|
|
words = line_tokenizer(line)
|
|
if reverse_order:
|
|
words = list(reversed(words))
|
|
nwords = len(words)
|
|
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
|
|
|
|
for i, word in enumerate(words):
|
|
if add_if_not_exist:
|
|
idx = self.add_symbol(word)
|
|
else:
|
|
idx = self.index(word)
|
|
if consumer is not None:
|
|
consumer(word, idx)
|
|
ids[i] = idx
|
|
if append_eos:
|
|
ids[nwords] = self.eos_index
|
|
return ids
|
|
|
|
@staticmethod
|
|
def _add_file_to_dictionary_single_worker(
|
|
filename,
|
|
tokenize,
|
|
eos_word,
|
|
start_offset,
|
|
end_offset,
|
|
):
|
|
counter = Counter()
|
|
with Chunker(filename, start_offset, end_offset) as line_iterator:
|
|
for line in line_iterator:
|
|
for word in tokenize(line):
|
|
counter.update([word])
|
|
counter.update([eos_word])
|
|
return counter
|
|
|
|
@staticmethod
|
|
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
|
|
def merge_result(counter):
|
|
for w, c in sorted(counter.items()):
|
|
dict.add_symbol(w, c)
|
|
|
|
local_file = PathManager.get_local_path(filename)
|
|
offsets = find_offsets(local_file, num_workers)
|
|
if num_workers > 1:
|
|
chunks = zip(offsets, offsets[1:])
|
|
pool = Pool(processes=num_workers)
|
|
results = []
|
|
for (start_offset, end_offset) in chunks:
|
|
results.append(
|
|
pool.apply_async(
|
|
Dictionary._add_file_to_dictionary_single_worker,
|
|
(
|
|
local_file,
|
|
tokenize,
|
|
dict.eos_word,
|
|
start_offset,
|
|
end_offset,
|
|
),
|
|
)
|
|
)
|
|
pool.close()
|
|
pool.join()
|
|
for r in results:
|
|
merge_result(r.get())
|
|
else:
|
|
merge_result(
|
|
Dictionary._add_file_to_dictionary_single_worker(
|
|
local_file, tokenize, dict.eos_word, offsets[0], offsets[1]
|
|
)
|
|
)
|
|
|
|
|
|
class TruncatedDictionary(object):
|
|
def __init__(self, wrapped_dict, length):
|
|
self.__class__ = type(
|
|
wrapped_dict.__class__.__name__,
|
|
(self.__class__, wrapped_dict.__class__),
|
|
{},
|
|
)
|
|
self.__dict__ = wrapped_dict.__dict__
|
|
self.wrapped_dict = wrapped_dict
|
|
self.length = min(len(self.wrapped_dict), length)
|
|
|
|
def __len__(self):
|
|
return self.length
|
|
|
|
def __getitem__(self, i):
|
|
if i < self.length:
|
|
return self.wrapped_dict[i]
|
|
return self.wrapped_dict.unk()
|