mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-28 10:21:20 +00:00
486 lines
20 KiB
Python
486 lines
20 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import logging
|
|
import os
|
|
from collections import OrderedDict
|
|
|
|
from fairseq import utils
|
|
from fairseq.data import (
|
|
BacktranslationDataset,
|
|
IndexedCachedDataset,
|
|
IndexedDataset,
|
|
IndexedRawTextDataset,
|
|
LanguagePairDataset,
|
|
NoisingDataset,
|
|
RoundRobinZipDatasets,
|
|
data_utils,
|
|
indexed_dataset,
|
|
)
|
|
from fairseq.models import FairseqMultiModel
|
|
from fairseq.sequence_generator import SequenceGenerator
|
|
|
|
from . import register_task
|
|
from .multilingual_translation import MultilingualTranslationTask
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _get_bt_dataset_key(lang_pair):
|
|
return "bt:" + lang_pair
|
|
|
|
|
|
def _get_denoising_dataset_key(lang_pair):
|
|
return "denoising:" + lang_pair
|
|
|
|
|
|
# ported from UnsupervisedMT
|
|
def parse_lambda_config(x):
|
|
"""
|
|
Parse the configuration of lambda coefficient (for scheduling).
|
|
x = "3" # lambda will be a constant equal to x
|
|
x = "0:1,1000:0" # lambda will start from 1 and linearly decrease
|
|
# to 0 during the first 1000 iterations
|
|
x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000
|
|
# iterations, then will linearly increase to 1 until iteration 2000
|
|
"""
|
|
split = x.split(",")
|
|
if len(split) == 1:
|
|
return float(x), None
|
|
else:
|
|
split = [s.split(os.pathsep) for s in split]
|
|
assert all(len(s) == 2 for s in split)
|
|
assert all(k.isdigit() for k, _ in split)
|
|
assert all(
|
|
int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1)
|
|
)
|
|
return float(split[0][1]), [(int(k), float(v)) for k, v in split]
|
|
|
|
|
|
@register_task("semisupervised_translation")
|
|
class SemisupervisedTranslationTask(MultilingualTranslationTask):
|
|
"""A task for training multiple translation models simultaneously.
|
|
|
|
We iterate round-robin over batches from multiple language pairs, ordered
|
|
according to the `--lang-pairs` argument.
|
|
|
|
The training loop is roughly:
|
|
|
|
for i in range(len(epoch)):
|
|
for lang_pair in args.lang_pairs:
|
|
batch = next_batch_for_lang_pair(lang_pair)
|
|
loss = criterion(model_for_lang_pair(lang_pair), batch)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
In practice, `next_batch_for_lang_pair` is abstracted in a FairseqDataset
|
|
(e.g., `RoundRobinZipDatasets`) and `model_for_lang_pair` is a model that
|
|
implements the `FairseqMultiModel` interface.
|
|
|
|
During inference it is required to specify a single `--source-lang` and
|
|
`--target-lang`, instead of `--lang-pairs`.
|
|
"""
|
|
|
|
@staticmethod
|
|
def add_args(parser):
|
|
"""Add task-specific arguments to the parser."""
|
|
# fmt: off
|
|
MultilingualTranslationTask.add_args(parser)
|
|
parser.add_argument('--lambda-parallel-config', default="1.0", type=str, metavar='CONFIG',
|
|
help='cross-entropy reconstruction coefficient (parallel data). '
|
|
'use fixed weight during training if set to floating point number. '
|
|
'use piecewise linear function over number of updates to schedule the '
|
|
'weight with the format: w0:step0,w1:step1,...')
|
|
parser.add_argument('--lambda-denoising-config', default="0.0", type=str, metavar='CONFIG',
|
|
help='Cross-entropy reconstruction coefficient (denoising autoencoding)'
|
|
'use fixed weight during training if set to floating point number. '
|
|
'use piecewise linear function over number of updates to schedule the '
|
|
'weight with the format: w0:step0,w1:step1,...')
|
|
parser.add_argument('--lambda-otf-bt-config', default="0.0", type=str, metavar='CONFIG',
|
|
help='cross-entropy reconstruction coefficient (on-the-fly back-translation parallel data)'
|
|
'use fixed weight during training if set to floating point number. '
|
|
'use piecewise linear function over number of updates to schedule the '
|
|
'weight with the format: w0:step0,w1:step1,...')
|
|
parser.add_argument('--bt-max-len-a', default=1.1, type=float, metavar='N',
|
|
help='generate back-translated sequences of maximum length ax + b, where x is the '
|
|
'source length')
|
|
parser.add_argument('--bt-max-len-b', default=10.0, type=float, metavar='N',
|
|
help='generate back-translated sequences of maximum length ax + b, where x is the '
|
|
'source length')
|
|
parser.add_argument('--bt-beam-size', default=1, type=int, metavar='N',
|
|
help='beam size used in beam search of online back-translation')
|
|
parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N',
|
|
help='maximum word shuffle distance for denoising autoencoding data generation')
|
|
parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N',
|
|
help='word dropout probability for denoising autoencoding data generation')
|
|
parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N',
|
|
help='word blanking probability for denoising autoencoding data generation')
|
|
# fmt: on
|
|
|
|
def __init__(self, args, dicts, training):
|
|
super().__init__(args, dicts, training)
|
|
self.lambda_parallel, self.lambda_parallel_steps = parse_lambda_config(
|
|
args.lambda_parallel_config
|
|
)
|
|
self.lambda_otf_bt, self.lambda_otf_bt_steps = parse_lambda_config(
|
|
args.lambda_otf_bt_config
|
|
)
|
|
self.lambda_denoising, self.lambda_denoising_steps = parse_lambda_config(
|
|
args.lambda_denoising_config
|
|
)
|
|
if self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None:
|
|
denoising_lang_pairs = [
|
|
"%s-%s" % (tgt, tgt)
|
|
for tgt in {lang_pair.split("-")[1] for lang_pair in args.lang_pairs}
|
|
]
|
|
self.model_lang_pairs = self.model_lang_pairs + denoising_lang_pairs
|
|
self.backtranslate_datasets = {}
|
|
self.backtranslators = {}
|
|
|
|
@classmethod
|
|
def setup_task(cls, args, **kwargs):
|
|
dicts, training = MultilingualTranslationTask.prepare(args, **kwargs)
|
|
return cls(args, dicts, training)
|
|
|
|
def load_dataset(self, split, epoch=1, **kwargs):
|
|
"""Load a dataset split."""
|
|
paths = utils.split_paths(self.args.data)
|
|
assert len(paths) > 0
|
|
data_path = paths[(epoch - 1) % len(paths)]
|
|
|
|
def split_exists(split, src, tgt, lang):
|
|
if src is not None:
|
|
filename = os.path.join(
|
|
data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)
|
|
)
|
|
else:
|
|
filename = os.path.join(
|
|
data_path, "{}.{}-None.{}".format(split, src, tgt)
|
|
)
|
|
return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)
|
|
|
|
def load_indexed_dataset(path, dictionary):
|
|
return data_utils.load_indexed_dataset(
|
|
path, dictionary, self.args.dataset_impl
|
|
)
|
|
|
|
# load parallel datasets
|
|
src_datasets, tgt_datasets = {}, {}
|
|
if (
|
|
self.lambda_parallel > 0.0
|
|
or self.lambda_parallel_steps is not None
|
|
or not split.startswith("train")
|
|
):
|
|
for lang_pair in self.lang_pairs:
|
|
src, tgt = lang_pair.split("-")
|
|
if split_exists(split, src, tgt, src):
|
|
prefix = os.path.join(
|
|
data_path, "{}.{}-{}.".format(split, src, tgt)
|
|
)
|
|
elif split_exists(split, tgt, src, src):
|
|
prefix = os.path.join(
|
|
data_path, "{}.{}-{}.".format(split, tgt, src)
|
|
)
|
|
else:
|
|
continue
|
|
src_datasets[lang_pair] = load_indexed_dataset(
|
|
prefix + src, self.dicts[src]
|
|
)
|
|
tgt_datasets[lang_pair] = load_indexed_dataset(
|
|
prefix + tgt, self.dicts[tgt]
|
|
)
|
|
logger.info(
|
|
"parallel-{} {} {} examples".format(
|
|
data_path, split, len(src_datasets[lang_pair])
|
|
)
|
|
)
|
|
if len(src_datasets) == 0:
|
|
raise FileNotFoundError(
|
|
"Dataset not found: {} ({})".format(split, data_path)
|
|
)
|
|
|
|
# back translation datasets
|
|
backtranslate_datasets = {}
|
|
if (
|
|
self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None
|
|
) and split.startswith("train"):
|
|
for lang_pair in self.lang_pairs:
|
|
src, tgt = lang_pair.split("-")
|
|
if not split_exists(split, tgt, None, tgt):
|
|
raise FileNotFoundError(
|
|
"Dataset not found: backtranslation {} ({})".format(
|
|
split, data_path
|
|
)
|
|
)
|
|
filename = os.path.join(
|
|
data_path, "{}.{}-None.{}".format(split, tgt, tgt)
|
|
)
|
|
dataset = load_indexed_dataset(filename, self.dicts[tgt])
|
|
lang_pair_dataset_tgt = LanguagePairDataset(
|
|
dataset,
|
|
dataset.sizes,
|
|
self.dicts[tgt],
|
|
left_pad_source=self.args.left_pad_source,
|
|
left_pad_target=self.args.left_pad_target,
|
|
)
|
|
lang_pair_dataset = LanguagePairDataset(
|
|
dataset,
|
|
dataset.sizes,
|
|
src_dict=self.dicts[src],
|
|
tgt=dataset,
|
|
tgt_sizes=dataset.sizes,
|
|
tgt_dict=self.dicts[tgt],
|
|
left_pad_source=self.args.left_pad_source,
|
|
left_pad_target=self.args.left_pad_target,
|
|
)
|
|
backtranslate_datasets[lang_pair] = BacktranslationDataset(
|
|
tgt_dataset=self.alter_dataset_langtok(
|
|
lang_pair_dataset_tgt,
|
|
src_eos=self.dicts[tgt].eos(),
|
|
src_lang=tgt,
|
|
tgt_lang=src,
|
|
),
|
|
backtranslation_fn=self.backtranslators[lang_pair],
|
|
src_dict=self.dicts[src],
|
|
tgt_dict=self.dicts[tgt],
|
|
output_collater=self.alter_dataset_langtok(
|
|
lang_pair_dataset=lang_pair_dataset,
|
|
src_eos=self.dicts[src].eos(),
|
|
src_lang=src,
|
|
tgt_eos=self.dicts[tgt].eos(),
|
|
tgt_lang=tgt,
|
|
).collater,
|
|
)
|
|
logger.info(
|
|
"backtranslate-{}: {} {} {} examples".format(
|
|
tgt,
|
|
data_path,
|
|
split,
|
|
len(backtranslate_datasets[lang_pair]),
|
|
)
|
|
)
|
|
self.backtranslate_datasets[lang_pair] = backtranslate_datasets[
|
|
lang_pair
|
|
]
|
|
|
|
# denoising autoencoder
|
|
noising_datasets = {}
|
|
if (
|
|
self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None
|
|
) and split.startswith("train"):
|
|
for lang_pair in self.lang_pairs:
|
|
_, tgt = lang_pair.split("-")
|
|
if not split_exists(split, tgt, None, tgt):
|
|
continue
|
|
filename = os.path.join(
|
|
data_path, "{}.{}-None.{}".format(split, tgt, tgt)
|
|
)
|
|
tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt])
|
|
tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt])
|
|
noising_dataset = NoisingDataset(
|
|
tgt_dataset1,
|
|
self.dicts[tgt],
|
|
seed=1,
|
|
max_word_shuffle_distance=self.args.max_word_shuffle_distance,
|
|
word_dropout_prob=self.args.word_dropout_prob,
|
|
word_blanking_prob=self.args.word_blanking_prob,
|
|
)
|
|
noising_datasets[lang_pair] = self.alter_dataset_langtok(
|
|
LanguagePairDataset(
|
|
noising_dataset,
|
|
tgt_dataset1.sizes,
|
|
self.dicts[tgt],
|
|
tgt_dataset2,
|
|
tgt_dataset2.sizes,
|
|
self.dicts[tgt],
|
|
left_pad_source=self.args.left_pad_source,
|
|
left_pad_target=self.args.left_pad_target,
|
|
),
|
|
src_eos=self.dicts[tgt].eos(),
|
|
src_lang=tgt,
|
|
tgt_eos=self.dicts[tgt].eos(),
|
|
tgt_lang=tgt,
|
|
)
|
|
logger.info(
|
|
"denoising-{}: {} {} {} examples".format(
|
|
tgt,
|
|
data_path,
|
|
split,
|
|
len(noising_datasets[lang_pair]),
|
|
)
|
|
)
|
|
|
|
def language_pair_dataset(lang_pair):
|
|
src, tgt = lang_pair.split("-")
|
|
src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
|
|
return self.alter_dataset_langtok(
|
|
LanguagePairDataset(
|
|
src_dataset,
|
|
src_dataset.sizes,
|
|
self.dicts[src],
|
|
tgt_dataset,
|
|
tgt_dataset.sizes,
|
|
self.dicts[tgt],
|
|
left_pad_source=self.args.left_pad_source,
|
|
left_pad_target=self.args.left_pad_target,
|
|
),
|
|
self.dicts[src].eos(),
|
|
src,
|
|
self.dicts[tgt].eos(),
|
|
tgt,
|
|
)
|
|
|
|
self.datasets[split] = RoundRobinZipDatasets(
|
|
OrderedDict(
|
|
[
|
|
(lang_pair, language_pair_dataset(lang_pair))
|
|
for lang_pair in src_datasets.keys()
|
|
]
|
|
+ [
|
|
(_get_bt_dataset_key(lang_pair), dataset)
|
|
for lang_pair, dataset in backtranslate_datasets.items()
|
|
]
|
|
+ [
|
|
(_get_denoising_dataset_key(lang_pair), dataset)
|
|
for lang_pair, dataset in noising_datasets.items()
|
|
]
|
|
),
|
|
eval_key=None
|
|
if self.training
|
|
else "%s-%s" % (self.args.source_lang, self.args.target_lang),
|
|
)
|
|
|
|
def build_model(self, args, from_checkpoint=False):
|
|
from fairseq import models
|
|
|
|
model = models.build_model(args, self, from_checkpoint)
|
|
if not isinstance(model, FairseqMultiModel):
|
|
raise ValueError(
|
|
"SemisupervisedTranslationTask requires a FairseqMultiModel architecture"
|
|
)
|
|
|
|
# create SequenceGenerator for each model that has backtranslation dependency on it
|
|
self.sequence_generators = {}
|
|
if (
|
|
self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None
|
|
) and self.training:
|
|
for lang_pair in self.lang_pairs:
|
|
src, tgt = lang_pair.split("-")
|
|
key = "{}-{}".format(tgt, src)
|
|
self.sequence_generators[key] = SequenceGenerator(
|
|
[model.models[key]],
|
|
tgt_dict=self.dicts[src],
|
|
beam_size=args.bt_beam_size,
|
|
max_len_a=args.bt_max_len_a,
|
|
max_len_b=args.bt_max_len_b,
|
|
)
|
|
decoder_lang_tok_idx = self.get_decoder_langtok(src)
|
|
|
|
def backtranslate_fn(
|
|
sample,
|
|
model=model.models[key],
|
|
bos_token=decoder_lang_tok_idx,
|
|
sequence_generator=self.sequence_generators[key],
|
|
):
|
|
return sequence_generator.generate(
|
|
[model],
|
|
sample,
|
|
bos_token=bos_token,
|
|
)
|
|
|
|
self.backtranslators[lang_pair] = backtranslate_fn
|
|
|
|
return model
|
|
|
|
def train_step(
|
|
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
|
|
):
|
|
model.train()
|
|
|
|
if update_num > 0:
|
|
self.update_step(update_num)
|
|
|
|
agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, {}
|
|
|
|
def forward_backward(model, samples, logging_output_key, weight):
|
|
nonlocal agg_loss, agg_sample_size, agg_logging_output
|
|
if samples is None or len(samples) == 0:
|
|
return
|
|
loss, sample_size, logging_output = criterion(model, samples)
|
|
if ignore_grad:
|
|
loss *= 0
|
|
else:
|
|
loss *= weight
|
|
optimizer.backward(loss)
|
|
agg_loss += loss.detach().item()
|
|
# TODO make summing of the sample sizes configurable
|
|
agg_sample_size += sample_size
|
|
for k in logging_output:
|
|
agg_logging_output[k] += logging_output[k]
|
|
agg_logging_output[logging_output_key] += logging_output[k]
|
|
|
|
if self.lambda_parallel > 0.0:
|
|
for lang_pair in self.lang_pairs:
|
|
forward_backward(
|
|
model.models[lang_pair],
|
|
sample[lang_pair],
|
|
lang_pair,
|
|
self.lambda_parallel,
|
|
)
|
|
|
|
if self.lambda_otf_bt > 0.0:
|
|
for lang_pair in self.lang_pairs:
|
|
sample_key = _get_bt_dataset_key(lang_pair)
|
|
forward_backward(
|
|
model.models[lang_pair],
|
|
sample[sample_key],
|
|
sample_key,
|
|
self.lambda_otf_bt,
|
|
)
|
|
|
|
if self.lambda_denoising > 0.0:
|
|
for lang_pair in self.lang_pairs:
|
|
_, tgt = lang_pair.split("-")
|
|
sample_key = _get_denoising_dataset_key(lang_pair)
|
|
forward_backward(
|
|
model.models["{0}-{0}".format(tgt)],
|
|
sample[sample_key],
|
|
sample_key,
|
|
self.lambda_denoising,
|
|
)
|
|
|
|
return agg_loss, agg_sample_size, agg_logging_output
|
|
|
|
def update_step(self, num_updates):
|
|
def lambda_step_func(config, n_iter):
|
|
"""
|
|
Update a lambda value according to its schedule configuration.
|
|
"""
|
|
ranges = [
|
|
i
|
|
for i in range(len(config) - 1)
|
|
if config[i][0] <= n_iter < config[i + 1][0]
|
|
]
|
|
if len(ranges) == 0:
|
|
assert n_iter >= config[-1][0]
|
|
return config[-1][1]
|
|
assert len(ranges) == 1
|
|
i = ranges[0]
|
|
x_a, y_a = config[i]
|
|
x_b, y_b = config[i + 1]
|
|
return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a)
|
|
|
|
if self.lambda_parallel_steps is not None:
|
|
self.lambda_parallel = lambda_step_func(
|
|
self.lambda_parallel_steps, num_updates
|
|
)
|
|
if self.lambda_denoising_steps is not None:
|
|
self.lambda_denoising = lambda_step_func(
|
|
self.lambda_denoising_steps, num_updates
|
|
)
|
|
if self.lambda_otf_bt_steps is not None:
|
|
self.lambda_otf_bt = lambda_step_func(self.lambda_otf_bt_steps, num_updates)
|