mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-11 22:40:03 +00:00
196 lines
7.2 KiB
Python
196 lines
7.2 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from dataclasses import dataclass, field
|
|
import torch
|
|
from fairseq import utils
|
|
from fairseq.data import LanguagePairDataset
|
|
from fairseq.dataclass import ChoiceEnum
|
|
from fairseq.tasks import register_task
|
|
from fairseq.tasks.translation import (
|
|
TranslationConfig,
|
|
TranslationTask,
|
|
load_langpair_dataset,
|
|
)
|
|
from fairseq.utils import new_arange
|
|
|
|
|
|
NOISE_CHOICES = ChoiceEnum(["random_delete", "random_mask", "no_noise", "full_mask"])
|
|
|
|
|
|
@dataclass
|
|
class TranslationLevenshteinConfig(TranslationConfig):
|
|
noise: NOISE_CHOICES = field(
|
|
default="random_delete",
|
|
metadata={"help": "type of noise"},
|
|
)
|
|
|
|
|
|
@register_task("translation_lev", dataclass=TranslationLevenshteinConfig)
|
|
class TranslationLevenshteinTask(TranslationTask):
|
|
"""
|
|
Translation (Sequence Generation) task for Levenshtein Transformer
|
|
See `"Levenshtein Transformer" <https://arxiv.org/abs/1905.11006>`_.
|
|
"""
|
|
|
|
cfg: TranslationLevenshteinConfig
|
|
|
|
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
|
"""Load a given dataset split.
|
|
|
|
Args:
|
|
split (str): name of the split (e.g., train, valid, test)
|
|
"""
|
|
paths = utils.split_paths(self.cfg.data)
|
|
assert len(paths) > 0
|
|
data_path = paths[(epoch - 1) % len(paths)]
|
|
|
|
# infer langcode
|
|
src, tgt = self.cfg.source_lang, self.cfg.target_lang
|
|
|
|
self.datasets[split] = load_langpair_dataset(
|
|
data_path,
|
|
split,
|
|
src,
|
|
self.src_dict,
|
|
tgt,
|
|
self.tgt_dict,
|
|
combine=combine,
|
|
dataset_impl=self.cfg.dataset_impl,
|
|
upsample_primary=self.cfg.upsample_primary,
|
|
left_pad_source=self.cfg.left_pad_source,
|
|
left_pad_target=self.cfg.left_pad_target,
|
|
max_source_positions=self.cfg.max_source_positions,
|
|
max_target_positions=self.cfg.max_target_positions,
|
|
prepend_bos=True,
|
|
)
|
|
|
|
def inject_noise(self, target_tokens):
|
|
def _random_delete(target_tokens):
|
|
pad = self.tgt_dict.pad()
|
|
bos = self.tgt_dict.bos()
|
|
eos = self.tgt_dict.eos()
|
|
|
|
max_len = target_tokens.size(1)
|
|
target_mask = target_tokens.eq(pad)
|
|
target_score = target_tokens.clone().float().uniform_()
|
|
target_score.masked_fill_(
|
|
target_tokens.eq(bos) | target_tokens.eq(eos), 0.0
|
|
)
|
|
target_score.masked_fill_(target_mask, 1)
|
|
target_score, target_rank = target_score.sort(1)
|
|
target_length = target_mask.size(1) - target_mask.float().sum(
|
|
1, keepdim=True
|
|
)
|
|
|
|
# do not delete <bos> and <eos> (we assign 0 score for them)
|
|
target_cutoff = (
|
|
2
|
|
+ (
|
|
(target_length - 2)
|
|
* target_score.new_zeros(target_score.size(0), 1).uniform_()
|
|
).long()
|
|
)
|
|
target_cutoff = target_score.sort(1)[1] >= target_cutoff
|
|
|
|
prev_target_tokens = (
|
|
target_tokens.gather(1, target_rank)
|
|
.masked_fill_(target_cutoff, pad)
|
|
.gather(1, target_rank.masked_fill_(target_cutoff, max_len).sort(1)[1])
|
|
)
|
|
prev_target_tokens = prev_target_tokens[
|
|
:, : prev_target_tokens.ne(pad).sum(1).max()
|
|
]
|
|
|
|
return prev_target_tokens
|
|
|
|
def _random_mask(target_tokens):
|
|
pad = self.tgt_dict.pad()
|
|
bos = self.tgt_dict.bos()
|
|
eos = self.tgt_dict.eos()
|
|
unk = self.tgt_dict.unk()
|
|
|
|
target_masks = (
|
|
target_tokens.ne(pad) & target_tokens.ne(bos) & target_tokens.ne(eos)
|
|
)
|
|
target_score = target_tokens.clone().float().uniform_()
|
|
target_score.masked_fill_(~target_masks, 2.0)
|
|
target_length = target_masks.sum(1).float()
|
|
target_length = target_length * target_length.clone().uniform_()
|
|
target_length = target_length + 1 # make sure to mask at least one token.
|
|
|
|
_, target_rank = target_score.sort(1)
|
|
target_cutoff = new_arange(target_rank) < target_length[:, None].long()
|
|
prev_target_tokens = target_tokens.masked_fill(
|
|
target_cutoff.scatter(1, target_rank, target_cutoff), unk
|
|
)
|
|
return prev_target_tokens
|
|
|
|
def _full_mask(target_tokens):
|
|
pad = self.tgt_dict.pad()
|
|
bos = self.tgt_dict.bos()
|
|
eos = self.tgt_dict.eos()
|
|
unk = self.tgt_dict.unk()
|
|
|
|
target_mask = (
|
|
target_tokens.eq(bos) | target_tokens.eq(eos) | target_tokens.eq(pad)
|
|
)
|
|
return target_tokens.masked_fill(~target_mask, unk)
|
|
|
|
if self.cfg.noise == "random_delete":
|
|
return _random_delete(target_tokens)
|
|
elif self.cfg.noise == "random_mask":
|
|
return _random_mask(target_tokens)
|
|
elif self.cfg.noise == "full_mask":
|
|
return _full_mask(target_tokens)
|
|
elif self.cfg.noise == "no_noise":
|
|
return target_tokens
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def build_generator(self, models, args, **unused):
|
|
# add models input to match the API for SequenceGenerator
|
|
from fairseq.iterative_refinement_generator import IterativeRefinementGenerator
|
|
|
|
return IterativeRefinementGenerator(
|
|
self.target_dictionary,
|
|
eos_penalty=getattr(args, "iter_decode_eos_penalty", 0.0),
|
|
max_iter=getattr(args, "iter_decode_max_iter", 10),
|
|
beam_size=getattr(args, "iter_decode_with_beam", 1),
|
|
reranking=getattr(args, "iter_decode_with_external_reranker", False),
|
|
decoding_format=getattr(args, "decoding_format", None),
|
|
adaptive=not getattr(args, "iter_decode_force_max_iter", False),
|
|
retain_history=getattr(args, "retain_iter_history", False),
|
|
)
|
|
|
|
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
|
|
if constraints is not None:
|
|
# Though see Susanto et al. (ACL 2020): https://www.aclweb.org/anthology/2020.acl-main.325/
|
|
raise NotImplementedError(
|
|
"Constrained decoding with the translation_lev task is not supported"
|
|
)
|
|
|
|
return LanguagePairDataset(
|
|
src_tokens, src_lengths, self.source_dictionary, append_bos=True
|
|
)
|
|
|
|
def train_step(
|
|
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
|
|
):
|
|
model.train()
|
|
sample["prev_target"] = self.inject_noise(sample["target"])
|
|
loss, sample_size, logging_output = criterion(model, sample)
|
|
if ignore_grad:
|
|
loss *= 0
|
|
optimizer.backward(loss)
|
|
return loss, sample_size, logging_output
|
|
|
|
def valid_step(self, sample, model, criterion):
|
|
model.eval()
|
|
with torch.no_grad():
|
|
sample["prev_target"] = self.inject_noise(sample["target"])
|
|
loss, sample_size, logging_output = criterion(model, sample)
|
|
return loss, sample_size, logging_output
|