mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-28 10:21:20 +00:00
442 lines
18 KiB
Python
442 lines
18 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import datetime
|
|
import logging
|
|
import time
|
|
|
|
import torch
|
|
from fairseq.data import (
|
|
FairseqDataset,
|
|
LanguagePairDataset,
|
|
ListDataset,
|
|
data_utils,
|
|
iterators,
|
|
)
|
|
from fairseq.data.multilingual.multilingual_data_manager import (
|
|
MultilingualDatasetManager,
|
|
)
|
|
from fairseq.data.multilingual.sampling_method import SamplingMethod
|
|
from fairseq.tasks import LegacyFairseqTask, register_task
|
|
from fairseq.utils import FileContentsAction
|
|
|
|
|
|
###
|
|
def get_time_gap(s, e):
|
|
return (
|
|
datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)
|
|
).__str__()
|
|
|
|
|
|
###
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@register_task("translation_multi_simple_epoch")
|
|
class TranslationMultiSimpleEpochTask(LegacyFairseqTask):
|
|
"""
|
|
Translate from one (source) language to another (target) language.
|
|
|
|
Args:
|
|
langs (List[str]): a list of languages that are being supported
|
|
dicts (Dict[str, fairseq.data.Dictionary]): mapping from supported languages to their dictionaries
|
|
training (bool): whether the task should be configured for training or not
|
|
|
|
.. note::
|
|
|
|
The translation task is compatible with :mod:`fairseq-train`,
|
|
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
|
|
|
|
The translation task provides the following additional command-line
|
|
arguments:
|
|
|
|
.. argparse::
|
|
:ref: fairseq.tasks.translation_parser
|
|
:prog:
|
|
"""
|
|
|
|
@staticmethod
|
|
def add_args(parser):
|
|
"""Add task-specific arguments to the parser."""
|
|
# fmt: off
|
|
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
|
|
help='inference source language')
|
|
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
|
|
help='inference target language')
|
|
parser.add_argument('--lang-pairs', default=None, metavar='PAIRS',
|
|
help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr',
|
|
action=FileContentsAction)
|
|
parser.add_argument('--keep-inference-langtok', action='store_true',
|
|
help='keep language tokens in inference output (e.g. for analysis or debugging)')
|
|
|
|
SamplingMethod.add_arguments(parser)
|
|
MultilingualDatasetManager.add_args(parser)
|
|
# fmt: on
|
|
|
|
def __init__(self, args, langs, dicts, training):
|
|
super().__init__(args)
|
|
self.langs = langs
|
|
self.dicts = dicts
|
|
self.training = training
|
|
if training:
|
|
self.lang_pairs = args.lang_pairs
|
|
else:
|
|
self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)]
|
|
# eval_lang_pairs for multilingual translation is usually all of the
|
|
# lang_pairs. However for other multitask settings or when we want to
|
|
# optimize for certain languages we want to use a different subset. Thus
|
|
# the eval_lang_pairs class variable is provided for classes that extend
|
|
# this class.
|
|
self.eval_lang_pairs = self.lang_pairs
|
|
# model_lang_pairs will be used to build encoder-decoder model pairs in
|
|
# models.build_model(). This allows multitask type of sub-class can
|
|
# build models other than the input lang_pairs
|
|
self.model_lang_pairs = self.lang_pairs
|
|
self.source_langs = [d.split("-")[0] for d in self.lang_pairs]
|
|
self.target_langs = [d.split("-")[1] for d in self.lang_pairs]
|
|
self.check_dicts(self.dicts, self.source_langs, self.target_langs)
|
|
|
|
self.sampling_method = SamplingMethod.build_sampler(args, self)
|
|
self.data_manager = MultilingualDatasetManager.setup_data_manager(
|
|
args, self.lang_pairs, langs, dicts, self.sampling_method
|
|
)
|
|
|
|
def check_dicts(self, dicts, source_langs, target_langs):
|
|
if self.args.source_dict is not None or self.args.target_dict is not None:
|
|
# no need to check whether the source side and target side are sharing dictionaries
|
|
return
|
|
src_dict = dicts[source_langs[0]]
|
|
tgt_dict = dicts[target_langs[0]]
|
|
for src_lang in source_langs:
|
|
assert (
|
|
src_dict == dicts[src_lang]
|
|
), "Diffrent dictionary are specified for different source languages; "
|
|
"TranslationMultiSimpleEpochTask only supports one shared dictionary across all source languages"
|
|
for tgt_lang in target_langs:
|
|
assert (
|
|
tgt_dict == dicts[tgt_lang]
|
|
), "Diffrent dictionary are specified for different target languages; "
|
|
"TranslationMultiSimpleEpochTask only supports one shared dictionary across all target languages"
|
|
|
|
@classmethod
|
|
def setup_task(cls, args, **kwargs):
|
|
langs, dicts, training = MultilingualDatasetManager.prepare(
|
|
cls.load_dictionary, args, **kwargs
|
|
)
|
|
return cls(args, langs, dicts, training)
|
|
|
|
def has_sharded_data(self, split):
|
|
return self.data_manager.has_sharded_data(split)
|
|
|
|
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
|
"""Load a given dataset split.
|
|
|
|
Args:
|
|
split (str): name of the split (e.g., train, valid, test)
|
|
"""
|
|
if split in self.datasets:
|
|
dataset = self.datasets[split]
|
|
if self.has_sharded_data(split):
|
|
if self.args.virtual_epoch_size is not None:
|
|
if dataset.load_next_shard:
|
|
shard_epoch = dataset.shard_epoch
|
|
else:
|
|
# no need to load next shard so skip loading
|
|
# also this avoid always loading from beginning of the data
|
|
return
|
|
else:
|
|
shard_epoch = epoch
|
|
else:
|
|
# estimate the shard epoch from virtual data size and virtual epoch size
|
|
shard_epoch = self.data_manager.estimate_global_pass_epoch(epoch)
|
|
logger.info(f"loading data for {split} epoch={epoch}/{shard_epoch}")
|
|
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
|
|
if split in self.datasets:
|
|
del self.datasets[split]
|
|
logger.info("old dataset deleted manually")
|
|
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
|
|
self.datasets[split] = self.data_manager.load_dataset(
|
|
split,
|
|
self.training,
|
|
epoch=epoch,
|
|
combine=combine,
|
|
shard_epoch=shard_epoch,
|
|
**kwargs,
|
|
)
|
|
|
|
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
|
|
if constraints is not None:
|
|
raise NotImplementedError(
|
|
"Constrained decoding with the multilingual_translation task is not supported"
|
|
)
|
|
|
|
src_data = ListDataset(src_tokens, src_lengths)
|
|
dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary)
|
|
src_langtok_spec, tgt_langtok_spec = self.args.langtoks["main"]
|
|
if self.args.lang_tok_replacing_bos_eos:
|
|
dataset = self.data_manager.alter_dataset_langtok(
|
|
dataset,
|
|
src_eos=self.source_dictionary.eos(),
|
|
src_lang=self.args.source_lang,
|
|
tgt_eos=self.target_dictionary.eos(),
|
|
tgt_lang=self.args.target_lang,
|
|
src_langtok_spec=src_langtok_spec,
|
|
tgt_langtok_spec=tgt_langtok_spec,
|
|
)
|
|
else:
|
|
dataset.src = self.data_manager.src_dataset_tranform_func(
|
|
self.args.source_lang,
|
|
self.args.target_lang,
|
|
dataset=dataset.src,
|
|
spec=src_langtok_spec,
|
|
)
|
|
return dataset
|
|
|
|
def build_generator(
|
|
self,
|
|
models,
|
|
args,
|
|
seq_gen_cls=None,
|
|
extra_gen_cls_kwargs=None,
|
|
):
|
|
if not getattr(args, "keep_inference_langtok", False):
|
|
_, tgt_langtok_spec = self.args.langtoks["main"]
|
|
if tgt_langtok_spec:
|
|
tgt_lang_tok = self.data_manager.get_decoder_langtok(
|
|
self.args.target_lang, tgt_langtok_spec
|
|
)
|
|
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
|
|
extra_gen_cls_kwargs["symbols_to_strip_from_output"] = {tgt_lang_tok}
|
|
|
|
return super().build_generator(
|
|
models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs
|
|
)
|
|
|
|
def build_model(self, args, from_checkpoint=False):
|
|
return super().build_model(args, from_checkpoint)
|
|
|
|
def valid_step(self, sample, model, criterion):
|
|
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
|
|
return loss, sample_size, logging_output
|
|
|
|
def inference_step(
|
|
self, generator, models, sample, prefix_tokens=None, constraints=None
|
|
):
|
|
with torch.no_grad():
|
|
_, tgt_langtok_spec = self.args.langtoks["main"]
|
|
if not self.args.lang_tok_replacing_bos_eos:
|
|
if prefix_tokens is None and tgt_langtok_spec:
|
|
tgt_lang_tok = self.data_manager.get_decoder_langtok(
|
|
self.args.target_lang, tgt_langtok_spec
|
|
)
|
|
src_tokens = sample["net_input"]["src_tokens"]
|
|
bsz = src_tokens.size(0)
|
|
prefix_tokens = (
|
|
torch.LongTensor([[tgt_lang_tok]]).expand(bsz, 1).to(src_tokens)
|
|
)
|
|
return generator.generate(
|
|
models,
|
|
sample,
|
|
prefix_tokens=prefix_tokens,
|
|
constraints=constraints,
|
|
)
|
|
else:
|
|
return generator.generate(
|
|
models,
|
|
sample,
|
|
prefix_tokens=prefix_tokens,
|
|
bos_token=self.data_manager.get_decoder_langtok(
|
|
self.args.target_lang, tgt_langtok_spec
|
|
)
|
|
if tgt_langtok_spec
|
|
else self.target_dictionary.eos(),
|
|
)
|
|
|
|
def reduce_metrics(self, logging_outputs, criterion):
|
|
super().reduce_metrics(logging_outputs, criterion)
|
|
|
|
def max_positions(self):
|
|
"""Return the max sentence length allowed by the task."""
|
|
return (self.args.max_source_positions, self.args.max_target_positions)
|
|
|
|
@property
|
|
def source_dictionary(self):
|
|
return self.data_manager.get_source_dictionary(self.source_langs[0])
|
|
|
|
@property
|
|
def target_dictionary(self):
|
|
return self.data_manager.get_target_dictionary(self.target_langs[0])
|
|
|
|
def create_batch_sampler_func(
|
|
self,
|
|
max_positions,
|
|
ignore_invalid_inputs,
|
|
max_tokens,
|
|
max_sentences,
|
|
required_batch_size_multiple=1,
|
|
seed=1,
|
|
):
|
|
def construct_batch_sampler(dataset, epoch):
|
|
splits = [
|
|
s for s, _ in self.datasets.items() if self.datasets[s] == dataset
|
|
]
|
|
split = splits[0] if len(splits) > 0 else None
|
|
# NEW implementation
|
|
if epoch is not None:
|
|
# initialize the dataset with the correct starting epoch
|
|
dataset.set_epoch(epoch)
|
|
|
|
# get indices ordered by example size
|
|
start_time = time.time()
|
|
logger.info(f"start batch sampler: mem usage: {data_utils.get_mem_usage()}")
|
|
|
|
with data_utils.numpy_seed(seed):
|
|
indices = dataset.ordered_indices()
|
|
logger.info(
|
|
f"[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}"
|
|
)
|
|
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
|
|
|
|
# filter examples that are too large
|
|
if max_positions is not None:
|
|
my_time = time.time()
|
|
indices = self.filter_indices_by_size(
|
|
indices, dataset, max_positions, ignore_invalid_inputs
|
|
)
|
|
logger.info(
|
|
f"[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}"
|
|
)
|
|
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
|
|
|
|
# create mini-batches with given size constraints
|
|
my_time = time.time()
|
|
batch_sampler = dataset.batch_by_size(
|
|
indices,
|
|
max_tokens=max_tokens,
|
|
max_sentences=max_sentences,
|
|
required_batch_size_multiple=required_batch_size_multiple,
|
|
)
|
|
|
|
logger.info(
|
|
f"[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}"
|
|
)
|
|
logger.info(
|
|
f"[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}"
|
|
)
|
|
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
|
|
|
|
return batch_sampler
|
|
|
|
return construct_batch_sampler
|
|
|
|
# we need to override get_batch_iterator because we want to reset the epoch iterator each time
|
|
def get_batch_iterator(
|
|
self,
|
|
dataset,
|
|
max_tokens=None,
|
|
max_sentences=None,
|
|
max_positions=None,
|
|
ignore_invalid_inputs=False,
|
|
required_batch_size_multiple=1,
|
|
seed=1,
|
|
num_shards=1,
|
|
shard_id=0,
|
|
num_workers=0,
|
|
epoch=1,
|
|
data_buffer_size=0,
|
|
disable_iterator_cache=False,
|
|
skip_remainder_batch=False,
|
|
grouped_shuffling=False,
|
|
update_epoch_batch_itr=False,
|
|
):
|
|
"""
|
|
Get an iterator that yields batches of data from the given dataset.
|
|
|
|
Args:
|
|
dataset (~fairseq.data.FairseqDataset): dataset to batch
|
|
max_tokens (int, optional): max number of tokens in each batch
|
|
(default: None).
|
|
max_sentences (int, optional): max number of sentences in each
|
|
batch (default: None).
|
|
max_positions (optional): max sentence length supported by the
|
|
model (default: None).
|
|
ignore_invalid_inputs (bool, optional): don't raise Exception for
|
|
sentences that are too long (default: False).
|
|
required_batch_size_multiple (int, optional): require batch size to
|
|
be a multiple of N (default: 1).
|
|
seed (int, optional): seed for random number generator for
|
|
reproducibility (default: 1).
|
|
num_shards (int, optional): shard the data iterator into N
|
|
shards (default: 1).
|
|
shard_id (int, optional): which shard of the data iterator to
|
|
return (default: 0).
|
|
num_workers (int, optional): how many subprocesses to use for data
|
|
loading. 0 means the data will be loaded in the main process
|
|
(default: 0).
|
|
epoch (int, optional): the epoch to start the iterator from
|
|
(default: 0).
|
|
data_buffer_size (int, optional): number of batches to
|
|
preload (default: 0).
|
|
disable_iterator_cache (bool, optional): don't cache the
|
|
EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`)
|
|
(default: False).
|
|
grouped_shuffling (bool, optional): group batches with each groups
|
|
containing num_shards batches and shuffle groups. Reduces difference
|
|
between sequence lengths among workers for batches sorted by length.
|
|
update_epoch_batch_itr (bool optional): if true then donot use the cached
|
|
batch iterator for the epoch
|
|
|
|
Returns:
|
|
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
|
|
given dataset split
|
|
"""
|
|
# initialize the dataset with the correct starting epoch
|
|
assert isinstance(dataset, FairseqDataset)
|
|
if dataset in self.dataset_to_epoch_iter:
|
|
return self.dataset_to_epoch_iter[dataset]
|
|
if self.args.sampling_method == "RoundRobin":
|
|
batch_iter = super().get_batch_iterator(
|
|
dataset,
|
|
max_tokens=max_tokens,
|
|
max_sentences=max_sentences,
|
|
max_positions=max_positions,
|
|
ignore_invalid_inputs=ignore_invalid_inputs,
|
|
required_batch_size_multiple=required_batch_size_multiple,
|
|
seed=seed,
|
|
num_shards=num_shards,
|
|
shard_id=shard_id,
|
|
num_workers=num_workers,
|
|
epoch=epoch,
|
|
data_buffer_size=data_buffer_size,
|
|
disable_iterator_cache=disable_iterator_cache,
|
|
skip_remainder_batch=skip_remainder_batch,
|
|
update_epoch_batch_itr=update_epoch_batch_itr,
|
|
)
|
|
self.dataset_to_epoch_iter[dataset] = batch_iter
|
|
return batch_iter
|
|
|
|
construct_batch_sampler = self.create_batch_sampler_func(
|
|
max_positions,
|
|
ignore_invalid_inputs,
|
|
max_tokens,
|
|
max_sentences,
|
|
required_batch_size_multiple=required_batch_size_multiple,
|
|
seed=seed,
|
|
)
|
|
|
|
epoch_iter = iterators.EpochBatchIterator(
|
|
dataset=dataset,
|
|
collate_fn=dataset.collater,
|
|
batch_sampler=construct_batch_sampler,
|
|
seed=seed,
|
|
num_shards=num_shards,
|
|
shard_id=shard_id,
|
|
num_workers=num_workers,
|
|
epoch=epoch,
|
|
)
|
|
return epoch_iter
|