mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-22 06:04:26 +00:00
327 lines
12 KiB
Python
327 lines
12 KiB
Python
#!/usr/bin/env python3 -u
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import argparse
|
|
import copy
|
|
import logging
|
|
import os
|
|
from typing import Any, Dict, Iterator, List
|
|
|
|
import torch
|
|
from omegaconf import open_dict
|
|
from torch import nn
|
|
|
|
from fairseq import utils
|
|
from fairseq.data import encoders
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def from_pretrained(
|
|
model_name_or_path,
|
|
checkpoint_file="model.pt",
|
|
data_name_or_path=".",
|
|
archive_map=None,
|
|
**kwargs
|
|
):
|
|
from fairseq import checkpoint_utils, file_utils
|
|
|
|
if archive_map is not None:
|
|
if model_name_or_path in archive_map:
|
|
model_name_or_path = archive_map[model_name_or_path]
|
|
if data_name_or_path is not None and data_name_or_path in archive_map:
|
|
data_name_or_path = archive_map[data_name_or_path]
|
|
|
|
# allow archive_map to set default arg_overrides (e.g., tokenizer, bpe)
|
|
# for each model
|
|
if isinstance(model_name_or_path, dict):
|
|
for k, v in model_name_or_path.items():
|
|
if k == "checkpoint_file":
|
|
checkpoint_file = v
|
|
elif (
|
|
k != "path"
|
|
# only set kwargs that don't already have overrides
|
|
and k not in kwargs
|
|
):
|
|
kwargs[k] = v
|
|
model_name_or_path = model_name_or_path["path"]
|
|
|
|
model_path = file_utils.load_archive_file(model_name_or_path)
|
|
|
|
# convenience hack for loading data and BPE codes from model archive
|
|
if data_name_or_path.startswith("."):
|
|
kwargs["data"] = os.path.abspath(os.path.join(model_path, data_name_or_path))
|
|
else:
|
|
kwargs["data"] = file_utils.load_archive_file(data_name_or_path)
|
|
for file, arg in {
|
|
"code": "bpe_codes",
|
|
"bpecodes": "bpe_codes",
|
|
"sentencepiece.bpe.model": "sentencepiece_model",
|
|
"merges.txt": "bpe_merges",
|
|
"vocab.json": "bpe_vocab",
|
|
}.items():
|
|
path = os.path.join(model_path, file)
|
|
if os.path.exists(path):
|
|
kwargs[arg] = path
|
|
|
|
if "user_dir" in kwargs:
|
|
utils.import_user_module(argparse.Namespace(user_dir=kwargs["user_dir"]))
|
|
|
|
model_path = [
|
|
os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep)
|
|
]
|
|
|
|
if "is_vocoder" in kwargs:
|
|
args = {"data": kwargs["data"], "model_path": model_path}
|
|
task = None
|
|
models = None
|
|
else:
|
|
models, args, task = checkpoint_utils.load_model_ensemble_and_task(
|
|
model_path,
|
|
arg_overrides=kwargs,
|
|
)
|
|
if "generation_args" in kwargs and kwargs["generation_args"]:
|
|
for key in kwargs["generation_args"]:
|
|
setattr(args["generation"], key, kwargs["generation_args"][key])
|
|
|
|
return {
|
|
"args": args,
|
|
"task": task,
|
|
"models": models,
|
|
}
|
|
|
|
|
|
class GeneratorHubInterface(nn.Module):
|
|
"""
|
|
PyTorch Hub interface for generating sequences from a pre-trained
|
|
translation or language model.
|
|
"""
|
|
|
|
def __init__(self, cfg, task, models):
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
self.task = task
|
|
self.models = nn.ModuleList(models)
|
|
self.src_dict = task.source_dictionary
|
|
self.tgt_dict = task.target_dictionary
|
|
|
|
# optimize model for generation
|
|
for model in self.models:
|
|
model.prepare_for_inference_(cfg)
|
|
|
|
# Load alignment dictionary for unknown word replacement
|
|
# (None if no unknown word replacement, empty if no path to align dictionary)
|
|
self.align_dict = utils.load_align_dict(cfg.generation.replace_unk)
|
|
|
|
self.tokenizer = encoders.build_tokenizer(cfg.tokenizer)
|
|
self.bpe = encoders.build_bpe(cfg.bpe)
|
|
|
|
self.max_positions = utils.resolve_max_positions(
|
|
self.task.max_positions(), *[model.max_positions() for model in models]
|
|
)
|
|
|
|
# this is useful for determining the device
|
|
self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))
|
|
|
|
@property
|
|
def device(self):
|
|
return self._float_tensor.device
|
|
|
|
def translate(
|
|
self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs
|
|
) -> List[str]:
|
|
return self.sample(sentences, beam, verbose, **kwargs)
|
|
|
|
def sample(
|
|
self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs
|
|
) -> List[str]:
|
|
if isinstance(sentences, str):
|
|
return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
|
|
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
|
|
batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
|
|
return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos]
|
|
|
|
def score(
|
|
self, sentences: List[str], replace_newline_with_eos: bool = False, **kwargs
|
|
):
|
|
if isinstance(sentences, str):
|
|
return self.score(
|
|
[sentences], replace_newline_with_eos=replace_newline_with_eos, **kwargs
|
|
)[0]
|
|
|
|
def encode(sentence):
|
|
if replace_newline_with_eos:
|
|
return torch.cat([self.encode(line) for line in sentence.splitlines()])
|
|
else:
|
|
return self.encode(sentence)
|
|
|
|
# NOTE: this doesn't support translation tasks currently
|
|
tokenized_sentences = [encode(sentence) for sentence in sentences]
|
|
return [
|
|
hypos[0]
|
|
for hypos in self.generate(
|
|
tokenized_sentences, score_reference=True, **kwargs
|
|
)
|
|
]
|
|
|
|
def generate(
|
|
self,
|
|
tokenized_sentences: List[torch.LongTensor],
|
|
beam: int = 5,
|
|
verbose: bool = False,
|
|
skip_invalid_size_inputs=False,
|
|
inference_step_args=None,
|
|
prefix_allowed_tokens_fn=None,
|
|
**kwargs
|
|
) -> List[List[Dict[str, torch.Tensor]]]:
|
|
if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1:
|
|
return self.generate(
|
|
tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs
|
|
)[0]
|
|
|
|
# build generator using current args as well as any kwargs
|
|
gen_args = copy.deepcopy(self.cfg.generation)
|
|
with open_dict(gen_args):
|
|
gen_args.beam = beam
|
|
for k, v in kwargs.items():
|
|
setattr(gen_args, k, v)
|
|
generator = self.task.build_generator(
|
|
self.models,
|
|
gen_args,
|
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
|
)
|
|
|
|
inference_step_args = inference_step_args or {}
|
|
results = []
|
|
for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
|
|
batch = utils.apply_to_sample(lambda t: t.to(self.device), batch)
|
|
translations = self.task.inference_step(
|
|
generator, self.models, batch, **inference_step_args
|
|
)
|
|
for id, hypos in zip(batch["id"].tolist(), translations):
|
|
results.append((id, hypos))
|
|
|
|
# sort output to match input order
|
|
outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]
|
|
|
|
if verbose:
|
|
|
|
def getarg(name, default):
|
|
return getattr(gen_args, name, getattr(self.cfg, name, default))
|
|
|
|
for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs):
|
|
src_str_with_unk = self.string(source_tokens)
|
|
logger.info("S\t{}".format(src_str_with_unk))
|
|
for hypo in target_hypotheses:
|
|
hypo_str = self.decode(hypo["tokens"])
|
|
logger.info("H\t{}\t{}".format(hypo["score"], hypo_str))
|
|
logger.info(
|
|
"P\t{}".format(
|
|
" ".join(
|
|
map(
|
|
lambda x: "{:.4f}".format(x),
|
|
hypo["positional_scores"].tolist(),
|
|
)
|
|
)
|
|
)
|
|
)
|
|
if hypo["alignment"] is not None and getarg(
|
|
"print_alignment", False
|
|
):
|
|
logger.info(
|
|
"A\t{}".format(
|
|
" ".join(
|
|
[
|
|
"{}-{}".format(src_idx, tgt_idx)
|
|
for src_idx, tgt_idx in hypo["alignment"]
|
|
]
|
|
)
|
|
)
|
|
)
|
|
return outputs
|
|
|
|
def encode(self, sentence: str) -> torch.LongTensor:
|
|
sentence = self.tokenize(sentence)
|
|
sentence = self.apply_bpe(sentence)
|
|
return self.binarize(sentence)
|
|
|
|
def decode(self, tokens: torch.LongTensor) -> str:
|
|
sentence = self.string(tokens)
|
|
sentence = self.remove_bpe(sentence)
|
|
return self.detokenize(sentence)
|
|
|
|
def tokenize(self, sentence: str) -> str:
|
|
if self.tokenizer is not None:
|
|
sentence = self.tokenizer.encode(sentence)
|
|
return sentence
|
|
|
|
def detokenize(self, sentence: str) -> str:
|
|
if self.tokenizer is not None:
|
|
sentence = self.tokenizer.decode(sentence)
|
|
return sentence
|
|
|
|
def apply_bpe(self, sentence: str) -> str:
|
|
if self.bpe is not None:
|
|
sentence = self.bpe.encode(sentence)
|
|
return sentence
|
|
|
|
def remove_bpe(self, sentence: str) -> str:
|
|
if self.bpe is not None:
|
|
sentence = self.bpe.decode(sentence)
|
|
return sentence
|
|
|
|
def binarize(self, sentence: str) -> torch.LongTensor:
|
|
return self.src_dict.encode_line(sentence, add_if_not_exist=False).long()
|
|
|
|
def string(self, tokens: torch.LongTensor) -> str:
|
|
return self.tgt_dict.string(tokens)
|
|
|
|
def _build_batches(
|
|
self, tokens: List[List[int]], skip_invalid_size_inputs: bool
|
|
) -> Iterator[Dict[str, Any]]:
|
|
lengths = torch.LongTensor([t.numel() for t in tokens])
|
|
batch_iterator = self.task.get_batch_iterator(
|
|
dataset=self.task.build_dataset_for_inference(tokens, lengths),
|
|
max_tokens=self.cfg.dataset.max_tokens,
|
|
max_sentences=self.cfg.dataset.batch_size,
|
|
max_positions=self.max_positions,
|
|
ignore_invalid_inputs=skip_invalid_size_inputs,
|
|
disable_iterator_cache=True,
|
|
).next_epoch_itr(shuffle=False)
|
|
return batch_iterator
|
|
|
|
|
|
class BPEHubInterface(object):
|
|
"""PyTorch Hub interface for Byte-Pair Encoding (BPE)."""
|
|
|
|
def __init__(self, bpe, **kwargs):
|
|
super().__init__()
|
|
args = argparse.Namespace(bpe=bpe, **kwargs)
|
|
self.bpe = encoders.build_bpe(args)
|
|
assert self.bpe is not None
|
|
|
|
def encode(self, sentence: str) -> str:
|
|
return self.bpe.encode(sentence)
|
|
|
|
def decode(self, sentence: str) -> str:
|
|
return self.bpe.decode(sentence)
|
|
|
|
|
|
class TokenizerHubInterface(object):
|
|
"""PyTorch Hub interface for tokenization."""
|
|
|
|
def __init__(self, tokenizer, **kwargs):
|
|
super().__init__()
|
|
args = argparse.Namespace(tokenizer=tokenizer, **kwargs)
|
|
self.tokenizer = encoders.build_tokenizer(args)
|
|
assert self.tokenizer is not None
|
|
|
|
def encode(self, sentence: str) -> str:
|
|
return self.tokenizer.encode(sentence)
|
|
|
|
def decode(self, sentence: str) -> str:
|
|
return self.tokenizer.decode(sentence)
|