mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-25 07:34:11 +00:00
382 lines
11 KiB
Python
382 lines
11 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import logging
|
|
import os
|
|
import typing as tp
|
|
from abc import ABC, abstractmethod
|
|
from collections import Counter
|
|
from dataclasses import dataclass
|
|
from multiprocessing import Pool
|
|
|
|
import torch
|
|
|
|
from fairseq.data import Dictionary, indexed_dataset
|
|
from fairseq.file_chunker_utils import Chunker, find_offsets
|
|
from fairseq.file_io import PathManager
|
|
from fairseq.tokenizer import tokenize_line
|
|
|
|
logger = logging.getLogger("binarizer")
|
|
|
|
|
|
@dataclass
|
|
class BinarizeSummary:
|
|
"""
|
|
Keep track of what's going on in the binarizer
|
|
"""
|
|
|
|
num_seq: int = 0
|
|
replaced: tp.Optional[Counter] = None
|
|
num_tok: int = 0
|
|
|
|
@property
|
|
def num_replaced(self) -> int:
|
|
if self.replaced is None:
|
|
return 0
|
|
return sum(self.replaced.values())
|
|
|
|
@property
|
|
def replaced_percent(self) -> float:
|
|
return 100 * self.num_replaced / self.num_tok
|
|
|
|
def __str__(self) -> str:
|
|
base = f"{self.num_seq} sents, {self.num_tok} tokens"
|
|
if self.replaced is None:
|
|
return base
|
|
|
|
return f"{base}, {self.replaced_percent:.3}% replaced"
|
|
|
|
def merge(self, other: "BinarizeSummary"):
|
|
replaced = None
|
|
if self.replaced is not None:
|
|
replaced = self.replaced
|
|
if other.replaced is not None:
|
|
if replaced is None:
|
|
replaced = other.replaced
|
|
else:
|
|
replaced += other.replaced
|
|
self.replaced = replaced
|
|
self.num_seq += other.num_seq
|
|
self.num_tok += other.num_tok
|
|
|
|
|
|
class Binarizer(ABC):
|
|
"""
|
|
a binarizer describes how to take a string and build a tensor out of it
|
|
"""
|
|
|
|
@abstractmethod
|
|
def binarize_line(
|
|
self,
|
|
line: str,
|
|
summary: BinarizeSummary,
|
|
) -> torch.IntTensor:
|
|
...
|
|
|
|
|
|
def _worker_prefix(output_prefix: str, worker_id: int):
|
|
return f"{output_prefix}.pt{worker_id}"
|
|
|
|
|
|
class FileBinarizer:
|
|
"""
|
|
An file binarizer can take a file, tokenize it, and binarize each line to a tensor
|
|
"""
|
|
|
|
@classmethod
|
|
def multiprocess_dataset(
|
|
cls,
|
|
input_file: str,
|
|
dataset_impl: str,
|
|
binarizer: Binarizer,
|
|
output_prefix: str,
|
|
vocab_size=None,
|
|
num_workers=1,
|
|
) -> BinarizeSummary:
|
|
final_summary = BinarizeSummary()
|
|
|
|
offsets = find_offsets(input_file, num_workers)
|
|
# find_offsets returns a list of position [pos1, pos2, pos3, pos4] but we would want pairs:
|
|
# [(pos1, pos2), (pos2, pos3), (pos3, pos4)] to process the chunks with start/end info
|
|
# we zip the list with itself shifted by one to get all the pairs.
|
|
(first_chunk, *more_chunks) = zip(offsets, offsets[1:])
|
|
pool = None
|
|
if num_workers > 1:
|
|
pool = Pool(processes=num_workers - 1)
|
|
worker_results = [
|
|
pool.apply_async(
|
|
cls._binarize_chunk_and_finalize,
|
|
args=(
|
|
binarizer,
|
|
input_file,
|
|
start_offset,
|
|
end_offset,
|
|
_worker_prefix(
|
|
output_prefix,
|
|
worker_id,
|
|
),
|
|
dataset_impl,
|
|
),
|
|
kwds={
|
|
"vocab_size": vocab_size,
|
|
}
|
|
if vocab_size is not None
|
|
else {},
|
|
)
|
|
for worker_id, (start_offset, end_offset) in enumerate(
|
|
more_chunks, start=1
|
|
)
|
|
]
|
|
|
|
pool.close()
|
|
pool.join()
|
|
for r in worker_results:
|
|
summ = r.get()
|
|
final_summary.merge(summ)
|
|
|
|
# do not close the bin file as we need to merge the worker results in
|
|
final_ds, summ = cls._binarize_file_chunk(
|
|
binarizer,
|
|
input_file,
|
|
offset_start=first_chunk[0],
|
|
offset_end=first_chunk[1],
|
|
output_prefix=output_prefix,
|
|
dataset_impl=dataset_impl,
|
|
vocab_size=vocab_size if vocab_size is not None else None,
|
|
)
|
|
final_summary.merge(summ)
|
|
|
|
if num_workers > 1:
|
|
for worker_id in range(1, num_workers):
|
|
# merge the worker outputs
|
|
worker_output_prefix = _worker_prefix(
|
|
output_prefix,
|
|
worker_id,
|
|
)
|
|
final_ds.merge_file_(worker_output_prefix)
|
|
try:
|
|
os.remove(indexed_dataset.data_file_path(worker_output_prefix))
|
|
os.remove(indexed_dataset.index_file_path(worker_output_prefix))
|
|
except Exception as e:
|
|
logger.error(
|
|
f"couldn't remove {worker_output_prefix}.*", exc_info=e
|
|
)
|
|
|
|
# now we can close the file
|
|
idx_file = indexed_dataset.index_file_path(output_prefix)
|
|
final_ds.finalize(idx_file)
|
|
return final_summary
|
|
|
|
@staticmethod
|
|
def _binarize_file_chunk(
|
|
binarizer: Binarizer,
|
|
filename: str,
|
|
offset_start: int,
|
|
offset_end: int,
|
|
output_prefix: str,
|
|
dataset_impl: str,
|
|
vocab_size=None,
|
|
) -> tp.Tuple[tp.Any, BinarizeSummary]: # (dataset builder, BinarizeSummary)
|
|
"""
|
|
creates a dataset builder and append binarized items to it. This function does not
|
|
finalize the builder, this is useful if you want to do other things with your bin file
|
|
like appending/merging other files
|
|
"""
|
|
bin_file = indexed_dataset.data_file_path(output_prefix)
|
|
ds = indexed_dataset.make_builder(
|
|
bin_file,
|
|
impl=dataset_impl,
|
|
vocab_size=vocab_size,
|
|
)
|
|
summary = BinarizeSummary()
|
|
|
|
with Chunker(
|
|
PathManager.get_local_path(filename), offset_start, offset_end
|
|
) as line_iterator:
|
|
for line in line_iterator:
|
|
ds.add_item(binarizer.binarize_line(line, summary))
|
|
|
|
return ds, summary
|
|
|
|
@classmethod
|
|
def _binarize_chunk_and_finalize(
|
|
cls,
|
|
binarizer: Binarizer,
|
|
filename: str,
|
|
offset_start: int,
|
|
offset_end: int,
|
|
output_prefix: str,
|
|
dataset_impl: str,
|
|
vocab_size=None,
|
|
):
|
|
"""
|
|
same as above, but also finalizes the builder
|
|
"""
|
|
ds, summ = cls._binarize_file_chunk(
|
|
binarizer,
|
|
filename,
|
|
offset_start,
|
|
offset_end,
|
|
output_prefix,
|
|
dataset_impl,
|
|
vocab_size=vocab_size,
|
|
)
|
|
|
|
idx_file = indexed_dataset.index_file_path(output_prefix)
|
|
ds.finalize(idx_file)
|
|
|
|
return summ
|
|
|
|
|
|
class VocabularyDatasetBinarizer(Binarizer):
|
|
"""
|
|
Takes a Dictionary/Vocabulary, assign ids to each
|
|
token using the dictionary encode_line function.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dict: Dictionary,
|
|
tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
|
|
append_eos: bool = True,
|
|
reverse_order: bool = False,
|
|
already_numberized: bool = False,
|
|
) -> None:
|
|
self.dict = dict
|
|
self.tokenize = tokenize
|
|
self.append_eos = append_eos
|
|
self.reverse_order = reverse_order
|
|
self.already_numberized = already_numberized
|
|
super().__init__()
|
|
|
|
def binarize_line(
|
|
self,
|
|
line: str,
|
|
summary: BinarizeSummary,
|
|
):
|
|
if summary.replaced is None:
|
|
summary.replaced = Counter()
|
|
|
|
def replaced_consumer(word, idx):
|
|
if idx == self.dict.unk_index and word != self.dict.unk_word:
|
|
summary.replaced.update([word])
|
|
|
|
if self.already_numberized:
|
|
id_strings = line.strip().split()
|
|
id_list = [int(id_string) for id_string in id_strings]
|
|
if self.reverse_order:
|
|
id_list.reverse()
|
|
if self.append_eos:
|
|
id_list.append(self.dict.eos())
|
|
ids = torch.IntTensor(id_list)
|
|
else:
|
|
ids = self.dict.encode_line(
|
|
line=line,
|
|
line_tokenizer=self.tokenize,
|
|
add_if_not_exist=False,
|
|
consumer=replaced_consumer,
|
|
append_eos=self.append_eos,
|
|
reverse_order=self.reverse_order,
|
|
)
|
|
|
|
summary.num_seq += 1
|
|
summary.num_tok += len(ids)
|
|
return ids
|
|
|
|
|
|
class AlignmentDatasetBinarizer(Binarizer):
|
|
"""
|
|
binarize by parsing a set of alignments and packing
|
|
them in a tensor (see utils.parse_alignment)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
alignment_parser: tp.Callable[[str], torch.IntTensor],
|
|
) -> None:
|
|
super().__init__()
|
|
self.alignment_parser = alignment_parser
|
|
|
|
def binarize_line(
|
|
self,
|
|
line: str,
|
|
summary: BinarizeSummary,
|
|
):
|
|
ids = self.alignment_parser(line)
|
|
summary.num_seq += 1
|
|
summary.num_tok += len(ids)
|
|
return ids
|
|
|
|
|
|
class LegacyBinarizer:
|
|
@classmethod
|
|
def binarize(
|
|
cls,
|
|
filename: str,
|
|
dico: Dictionary,
|
|
consumer: tp.Callable[[torch.IntTensor], None],
|
|
tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
|
|
append_eos: bool = True,
|
|
reverse_order: bool = False,
|
|
offset: int = 0,
|
|
end: int = -1,
|
|
already_numberized: bool = False,
|
|
) -> tp.Dict[str, int]:
|
|
binarizer = VocabularyDatasetBinarizer(
|
|
dict=dico,
|
|
tokenize=tokenize,
|
|
append_eos=append_eos,
|
|
reverse_order=reverse_order,
|
|
already_numberized=already_numberized,
|
|
)
|
|
return cls._consume_file(
|
|
filename,
|
|
binarizer,
|
|
consumer,
|
|
offset_start=offset,
|
|
offset_end=end,
|
|
)
|
|
|
|
@classmethod
|
|
def binarize_alignments(
|
|
cls,
|
|
filename: str,
|
|
alignment_parser: tp.Callable[[str], torch.IntTensor],
|
|
consumer: tp.Callable[[torch.IntTensor], None],
|
|
offset: int = 0,
|
|
end: int = -1,
|
|
) -> tp.Dict[str, int]:
|
|
binarizer = AlignmentDatasetBinarizer(alignment_parser)
|
|
return cls._consume_file(
|
|
filename,
|
|
binarizer,
|
|
consumer,
|
|
offset_start=offset,
|
|
offset_end=end,
|
|
)
|
|
|
|
@staticmethod
|
|
def _consume_file(
|
|
filename: str,
|
|
binarizer: Binarizer,
|
|
consumer: tp.Callable[[torch.IntTensor], None],
|
|
offset_start: int,
|
|
offset_end: int,
|
|
) -> tp.Dict[str, int]:
|
|
summary = BinarizeSummary()
|
|
|
|
with Chunker(
|
|
PathManager.get_local_path(filename), offset_start, offset_end
|
|
) as line_iterator:
|
|
for line in line_iterator:
|
|
consumer(binarizer.binarize_line(line, summary))
|
|
|
|
return {
|
|
"nseq": summary.num_seq,
|
|
"nunk": summary.num_replaced,
|
|
"ntok": summary.num_tok,
|
|
"replaced": summary.replaced,
|
|
}
|