Add monkey patched fairseq package to run on python 3.11 (what is needed for our use of RVC at least)

This commit is contained in:
Tony Ribeiro
2023-08-10 02:58:52 +02:00
parent 28024c5649
commit 60a8e5c9c6
465 changed files with 95671 additions and 0 deletions

View File

@@ -0,0 +1,136 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import argparse
import importlib
import os
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import merge_with_parent
from hydra.core.config_store import ConfigStore
from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa
# register dataclass
TASK_DATACLASS_REGISTRY = {}
TASK_REGISTRY = {}
TASK_CLASS_NAMES = set()
def setup_task(cfg: FairseqDataclass, **kwargs):
task = None
task_name = getattr(cfg, "task", None)
if isinstance(task_name, str):
# legacy tasks
task = TASK_REGISTRY[task_name]
if task_name in TASK_DATACLASS_REGISTRY:
dc = TASK_DATACLASS_REGISTRY[task_name]
cfg = dc.from_namespace(cfg)
else:
task_name = getattr(cfg, "_name", None)
if task_name and task_name in TASK_DATACLASS_REGISTRY:
dc = TASK_DATACLASS_REGISTRY[task_name]
cfg = merge_with_parent(dc(), cfg)
task = TASK_REGISTRY[task_name]
assert (
task is not None
), f"Could not infer task type from {cfg}. Available argparse tasks: {TASK_REGISTRY.keys()}. Available hydra tasks: {TASK_DATACLASS_REGISTRY.keys()}"
return task.setup_task(cfg, **kwargs)
def register_task(name, dataclass=None):
"""
New tasks can be added to fairseq with the
:func:`~fairseq.tasks.register_task` function decorator.
For example::
@register_task('classification')
class ClassificationTask(FairseqTask):
(...)
.. note::
All Tasks must implement the :class:`~fairseq.tasks.FairseqTask`
interface.
Args:
name (str): the name of the task
"""
def register_task_cls(cls):
if name in TASK_REGISTRY:
raise ValueError("Cannot register duplicate task ({})".format(name))
if not issubclass(cls, FairseqTask):
raise ValueError(
"Task ({}: {}) must extend FairseqTask".format(name, cls.__name__)
)
if cls.__name__ in TASK_CLASS_NAMES:
raise ValueError(
"Cannot register task with duplicate class name ({})".format(
cls.__name__
)
)
TASK_REGISTRY[name] = cls
TASK_CLASS_NAMES.add(cls.__name__)
if dataclass is not None and not issubclass(dataclass, FairseqDataclass):
raise ValueError(
"Dataclass {} must extend FairseqDataclass".format(dataclass)
)
cls.__dataclass = dataclass
if dataclass is not None:
TASK_DATACLASS_REGISTRY[name] = dataclass
cs = ConfigStore.instance()
node = dataclass()
node._name = name
cs.store(name=name, group="task", node=node, provider="fairseq")
return cls
return register_task_cls
def get_task(name):
return TASK_REGISTRY[name]
def import_tasks(tasks_dir, namespace):
for file in os.listdir(tasks_dir):
path = os.path.join(tasks_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
task_name = file[: file.find(".py")] if file.endswith(".py") else file
importlib.import_module(namespace + "." + task_name)
# expose `task_parser` for sphinx
if task_name in TASK_REGISTRY:
parser = argparse.ArgumentParser(add_help=False)
group_task = parser.add_argument_group("Task name")
# fmt: off
group_task.add_argument('--task', metavar=task_name,
help='Enable this task with: ``--task=' + task_name + '``')
# fmt: on
group_args = parser.add_argument_group(
"Additional command-line arguments"
)
TASK_REGISTRY[task_name].add_args(group_args)
globals()[task_name + "_parser"] = parser
# automatically import any Python files in the tasks/ directory
tasks_dir = os.path.dirname(__file__)
import_tasks(tasks_dir, "fairseq.tasks")

View File

@@ -0,0 +1,343 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
import os
import torch
import json
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Optional, Any
from fairseq.data import AddTargetDataset, Dictionary, encoders
from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.configs import GenerationConfig
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
from . import register_task
from .. import utils
from ..logging import metrics
logger = logging.getLogger(__name__)
class LabelEncoder(object):
def __init__(self, dictionary):
self.dictionary = dictionary
def __call__(self, label):
return self.dictionary.encode_line(
label, append_eos=False, add_if_not_exist=False
)
def label_len_fn(label):
return len(label.split(" "))
@dataclass
class AudioFinetuningConfig(AudioPretrainingConfig):
# Options for reporting WER metrics during validation. Only applicable to
# Seq2Seq models during fine-tuning
eval_wer: bool = field(
default=False, metadata={"help": "compute WER for Seq2Seq models"}
)
eval_wer_config: GenerationConfig = field(
default_factory=lambda: GenerationConfig(),
metadata={"help": "beam search config for evaluating wer during training"},
)
eval_wer_tokenizer: Any = field(
default=None,
metadata={"help": "tokenizer config for evaluating wer during training"},
)
eval_wer_post_process: str = field(
default="letter",
metadata={
"help": "remove BPE tokens before scoring (can be sentencepiece, letter, and more)"
},
)
eval_bleu: bool = field(
default=False, metadata={"help": "evaluation with BLEU scores"}
)
eval_bleu_detok: Optional[str] = field(
default=None,
metadata={
"help": "detokenize before computing BLEU (e.g., 'moses'); "
"required if using --eval-bleu; use 'space' to disable "
"detokenization; see fairseq.data.encoders for other options"
},
)
eval_bleu_detok_args: str = field(
default="{}", metadata={"help": "args for building the tokenizer, if needed"}
)
eval_tokenized_bleu: bool = field(
default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"}
)
eval_bleu_remove_bpe: Optional[str] = field(
default=None, metadata={"help": "remove BPE before computing BLEU"}
)
eval_bleu_args: str = field(
default="{}",
metadata={
"help": "generation args for BLUE scoring, e.g., "
'\'{"beam": 4, "lenpen": 0.6}\''
},
)
eval_bleu_print_samples: bool = field(
default=False, metadata={"help": "print sample generations during validation"}
)
autoregressive: bool = field(
default=False,
metadata={
"help": "required for autoregressive decoders (like seq2seq models); "
"adds 'prev_output_tokens' to input and appends eos to target"
},
)
@register_task("audio_finetuning", dataclass=AudioFinetuningConfig)
class AudioFinetuningTask(AudioPretrainingTask):
""" """
cfg: AudioFinetuningConfig
def __init__(
self,
cfg: AudioFinetuningConfig,
):
super().__init__(cfg)
self.blank_symbol = "<s>"
self.state.add_factory("target_dictionary", self.load_target_dictionary)
def load_target_dictionary(self):
if self.cfg.labels:
dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt")
return Dictionary.load(dict_path)
return None
def load_dataset(
self, split: str, task_cfg: AudioFinetuningConfig = None, **kwargs
):
super().load_dataset(split, task_cfg, **kwargs)
task_cfg = task_cfg or self.cfg
assert task_cfg.labels is not None
text_compression_level = getattr(
TextCompressionLevel, str(self.cfg.text_compression_level)
)
data_path = self.cfg.data
label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
skipped_indices = getattr(self.datasets[split], "skipped_indices", set())
text_compressor = TextCompressor(level=text_compression_level)
with open(label_path, "r") as f:
labels = [
text_compressor.compress(l)
for i, l in enumerate(f)
if i not in skipped_indices
]
assert len(labels) == len(self.datasets[split]), (
f"labels length ({len(labels)}) and dataset length "
f"({len(self.datasets[split])}) do not match"
)
process_label = LabelEncoder(self.target_dictionary)
self.datasets[split] = AddTargetDataset(
self.datasets[split],
labels,
pad=self.target_dictionary.pad(),
eos=self.target_dictionary.eos(),
batch_targets=True,
process_label=process_label,
label_len_fn=label_len_fn,
add_to_input=task_cfg.get("autoregressive", False),
text_compression_level=text_compression_level,
)
@property
def target_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.state.target_dictionary
def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
if self.cfg.eval_wer and self.cfg.autoregressive:
metrics = self._inference_with_wer(self.sequence_generator, sample, model)
logging_output["_num_char_errors"] = metrics["num_char_errors"]
logging_output["_num_chars"] = metrics["num_chars"]
logging_output["_num_word_errors"] = metrics["num_word_errors"]
logging_output["_num_words"] = metrics["num_words"]
if self.cfg.eval_bleu and self.cfg.autoregressive:
metrics = self._inference_with_bleu(self.sequence_generator, sample, model)
logging_output["_bleu_sys_len"] = metrics.sys_len
logging_output["_bleu_ref_len"] = metrics.ref_len
# we split counts into separate entries so that they can be
# summed efficiently across workers using fast-stat-sync
assert len(metrics.counts) == 4
for i in range(4):
logging_output[f"_bleu_counts_{i}"] = metrics.counts[i]
logging_output[f"_bleu_totals_{i}"] = metrics.totals[i]
return loss, sample_size, logging_output
def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False):
model = super().build_model(model_cfg, from_checkpoint)
if self.cfg.eval_wer and self.cfg.autoregressive:
self.sequence_generator = self.build_generator(
[model],
self.cfg.eval_wer_config,
)
if self.cfg.eval_wer_tokenizer:
self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer)
else:
self.tokenizer = None
if self.cfg.eval_bleu and self.cfg.autoregressive:
assert self.cfg.eval_bleu_detok is not None, (
"--eval-bleu-detok is required if using --eval-bleu; "
"try --eval-bleu-detok=moses (or --eval-bleu-detok=space "
"to disable detokenization, e.g., when using sentencepiece)"
)
detok_args = json.loads(self.cfg.eval_bleu_detok_args)
self.tokenizer = encoders.build_tokenizer(
Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args)
)
gen_args = json.loads(self.cfg.eval_bleu_args)
gen_args = Namespace(**gen_args)
self.sequence_generator = self.build_generator([model], gen_args)
return model
def _inference_with_wer(self, generator, sample, model):
import editdistance
def decode(toks):
s = self.target_dictionary.string(
toks.int().cpu(),
self.cfg.eval_wer_post_process,
escape_unk=True,
)
if self.tokenizer:
s = self.tokenizer.decode(s)
return s
num_word_errors, num_char_errors = 0, 0
num_chars, num_words = 0, 0
gen_out = self.inference_step(generator, [model], sample, None)
for i in range(len(gen_out)):
hyp = decode(gen_out[i][0]["tokens"])
ref = decode(
utils.strip_pad(sample["target"][i], self.target_dictionary.pad()),
)
num_char_errors += editdistance.eval(hyp, ref)
num_chars += len(ref)
hyp_words = hyp.split()
ref_words = ref.split()
num_word_errors += editdistance.eval(hyp_words, ref_words)
num_words += len(ref_words)
return {
"num_char_errors": num_char_errors,
"num_chars": num_chars,
"num_word_errors": num_word_errors,
"num_words": num_words,
}
def _inference_with_bleu(self, generator, sample, model):
import sacrebleu
def decode(toks, is_ref):
s = self.target_dictionary.string(
toks.int().cpu(),
self.cfg.eval_bleu_remove_bpe,
# The default unknown string in fairseq is `<unk>`, but
# this is tokenized by sacrebleu as `< unk >`, inflating
# BLEU scores. Instead, we use a somewhat more verbose
# alternative that is unlikely to appear in the real
# reference, but doesn't get split into multiple tokens.
unk_string=("UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP"),
)
if self.tokenizer:
s = self.tokenizer.decode(s)
return s
gen_out = self.inference_step(generator, [model], sample)
hyps, refs = [], []
for i in range(len(gen_out)):
hyps.append(decode(gen_out[i][0]["tokens"], is_ref=False))
refs.append(
decode(
utils.strip_pad(sample["target"][i], self.target_dictionary.pad()),
is_ref=True, # don't count <unk> as matches to the hypo
)
)
if self.cfg.eval_bleu_print_samples:
logger.info("H-{} {}".format(sample["id"][0], hyps[0]))
logger.info("T-{} {}".format(sample["id"][0], refs[0]))
eval_tokenization = "none" if self.cfg.eval_tokenized_bleu else "13a"
return sacrebleu.corpus_bleu(hyps, [refs], tokenize=eval_tokenization)
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
if self.cfg.eval_wer:
zero = torch.scalar_tensor(0.0)
num_char_errors = sum(
log.get("_num_char_errors", zero) for log in logging_outputs
)
num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs)
num_word_errors = sum(
log.get("_num_word_errors", zero) for log in logging_outputs
)
num_words = sum(log.get("_num_words", zero) for log in logging_outputs)
metrics.log_scalar("_num_char_errors", num_char_errors)
metrics.log_scalar("_num_chars", num_chars)
metrics.log_scalar("_num_word_errors", num_word_errors)
metrics.log_scalar("_num_words", num_words)
if num_chars > 0:
metrics.log_derived(
"uer",
lambda meters: meters["_num_char_errors"].sum
* 100.0
/ meters["_num_chars"].sum
if meters["_num_chars"].sum > 0
else float("nan"),
)
if num_words > 0:
metrics.log_derived(
"wer",
lambda meters: meters["_num_word_errors"].sum
* 100.0
/ meters["_num_words"].sum
if meters["_num_words"].sum > 0
else float("nan"),
)
if self.cfg.eval_bleu:
len_keys = ["_bleu_sys_len", "_bleu_ref_len"]
count_keys = [f"_bleu_counts_{i}" for i in range(4)]
total_keys = [f"_bleu_totals_{i}" for i in range(4)]
for k in len_keys + count_keys + total_keys:
metrics.log_scalar(k, sum(log.get(k, 0) for log in logging_outputs))
import sacrebleu
metrics.log_derived(
"bleu",
lambda meters: sacrebleu.compute_bleu(
correct=[meters[k].sum for k in count_keys],
total=[meters[k].sum for k in total_keys],
sys_len=meters["_bleu_sys_len"].sum,
ref_len=meters["_bleu_ref_len"].sum,
smooth_method="exp",
).score,
)

View File

@@ -0,0 +1,205 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
import os
import sys
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Optional
from omegaconf import MISSING, II, OmegaConf
from fairseq.data import BinarizedAudioDataset, FileAudioDataset
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
from fairseq.data.text_compressor import TextCompressionLevel
from . import FairseqTask, register_task
logger = logging.getLogger(__name__)
@dataclass
class InferredW2vConfig:
# The following are needed to precompute mask and mask channel indices
# before model's forward.
mask_length: Optional[int] = II("model.mask_length")
mask_prob: Optional[float] = II("model.mask_prob")
mask_selection: Optional[str] = II("model.mask_selection")
mask_other: Optional[float] = II("model.mask_other")
no_mask_overlap: Optional[bool] = II("model.no_mask_overlap")
mask_min_space: Optional[int] = II("model.mask_min_space")
mask_channel_length: Optional[int] = II("model.mask_channel_length")
mask_channel_prob: Optional[float] = II("model.mask_channel_prob")
mask_channel_selection: Optional[str] = II("model.mask_channel_selection")
mask_channel_other: Optional[float] = II("model.mask_channel_other")
no_mask_channel_overlap: Optional[bool] = II("model.no_mask_channel_overlap")
mask_channel_min_space: Optional[int] = II("model.mask_channel_min_space")
conv_feature_layers: Optional[str] = II("model.conv_feature_layers")
encoder_embed_dim: Optional[int] = II("model.encoder_embed_dim")
@dataclass
class AudioPretrainingConfig(FairseqDataclass):
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
labels: Optional[str] = field(
default=None,
metadata={"help": "extension of the label file to load, used for fine-tuning"},
)
binarized_dataset: bool = field(
default=False,
metadata={
"help": "if true, loads binarized dataset (useful for very large datasets). "
"See examples/wav2vec/scripts/binarize_manifest.sh"
},
)
sample_rate: int = field(
default=16_000,
metadata={
"help": "target sample rate. audio files will be up/down sampled to this rate"
},
)
normalize: bool = field(
default=False,
metadata={"help": "if set, normalizes input to have 0 mean and unit variance"},
)
enable_padding: bool = field(
default=False, metadata={"help": "pad shorter samples instead of cropping"}
)
max_sample_size: Optional[int] = field(
default=None, metadata={"help": "max sample size to crop to for batching"}
)
min_sample_size: Optional[int] = field(
default=None, metadata={"help": "min sample size to skip small examples"}
)
num_batch_buckets: int = field(
default=0,
metadata={"help": "number of buckets"},
)
precompute_mask_indices: bool = field(
default=False,
metadata={
"help": "flag to compute mask indices in data preparation.",
},
)
inferred_w2v_config: Optional[InferredW2vConfig] = field(
default=None,
metadata={
"help": "wav2vec 2.0 masking arguments used to pre-compute masks (required for TPU)",
},
)
tpu: bool = II("common.tpu")
text_compression_level: ChoiceEnum([x.name for x in TextCompressionLevel]) = field(
default="none",
metadata={
"help": "compression level for texts (e.g. audio filenames, "
"target texts): none/low/high (default: none). "
},
)
@register_task("audio_pretraining", dataclass=AudioPretrainingConfig)
class AudioPretrainingTask(FairseqTask):
""" """
cfg: AudioPretrainingConfig
@classmethod
def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
cfg (AudioPretrainingConfig): configuration of this task
"""
return cls(cfg)
def _get_mask_precompute_kwargs(self, cfg):
if self.cfg.precompute_mask_indices or self.cfg.tpu:
assert (
cfg.inferred_w2v_config is not None
), "inferred_w2v_config must be set"
return OmegaConf.to_container(
cfg.inferred_w2v_config, resolve=True, enum_to_str=True
)
else:
return {}
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
data_path = self.cfg.data
task_cfg = task_cfg or self.cfg
# upgrade old task
if isinstance(task_cfg, Namespace):
if not hasattr(task_cfg, "autoregressive"):
task_cfg.autoregressive = not task_cfg.criterion == "ctc"
text_compression_level = getattr(
TextCompressionLevel, str(self.cfg.text_compression_level)
)
if getattr(task_cfg, "binarized_dataset", False):
self.datasets[split] = BinarizedAudioDataset(
data_path,
split=split,
sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate),
max_sample_size=self.cfg.max_sample_size,
min_sample_size=self.cfg.min_sample_size,
pad=task_cfg.labels is not None or task_cfg.enable_padding,
normalize=task_cfg.normalize,
num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu),
compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu),
**self._get_mask_precompute_kwargs(task_cfg),
)
else:
manifest_path = os.path.join(data_path, "{}.tsv".format(split))
self.datasets[split] = FileAudioDataset(
manifest_path=manifest_path,
sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate),
max_sample_size=self.cfg.max_sample_size,
min_sample_size=self.cfg.min_sample_size,
pad=task_cfg.labels is not None or task_cfg.enable_padding,
normalize=task_cfg.normalize,
num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu),
compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu),
text_compression_level=text_compression_level,
**self._get_mask_precompute_kwargs(task_cfg),
)
if self.cfg.tpu and task_cfg.inferred_w2v_config.mask_channel_prob == 0.0:
logger.info(
"Pretraining on TPUs may suffer convergence "
"issues when training with `mask_channel_prob` value of "
"0. You may want to set this to a low value close to 0."
)
@property
def source_dictionary(self):
return None
@property
def target_dictionary(self):
return None
def max_positions(self):
"""Maximum input length supported by the encoder."""
return sys.maxsize, sys.maxsize
def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False):
model = super().build_model(model_cfg, from_checkpoint)
actualized_cfg = getattr(model, "cfg", None)
if actualized_cfg is not None:
# if "w2v_args" in actualized_cfg:
if hasattr(actualized_cfg, "w2v_args"):
model_cfg.w2v_args = actualized_cfg.w2v_args
return model

View File

@@ -0,0 +1,191 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import logging
import os
from collections import OrderedDict
import numpy as np
from fairseq import tokenizer, utils
from fairseq.data import ConcatDataset, Dictionary, TokenBlockDataset, data_utils
from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset
from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset
from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task("cross_lingual_lm")
class CrossLingualLMTask(LegacyFairseqTask):
"""
Task for training cross-lingual language models.
For more details look at: https://arxiv.org/pdf/1901.07291.pdf
Args:
dictionary (Dictionary): the dictionary for the input of the task
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument(
"data",
help="colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner",
)
parser.add_argument(
"--tokens-per-sample",
default=512,
type=int,
help="max number of total tokens over all segments" " per sample",
)
parser.add_argument(
"--monolingual-langs",
default="en",
type=str,
help="comma separated list of languages for which we"
" want to train XLM on",
)
parser.add_argument(
"--shuffle",
action="store_true",
help="shuffle each monolingual dataset while" " training",
)
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
self.seed = args.seed
self.distributed_world_size = args.distributed_world_size
self.langs2id = self._lang_to_id(args.monolingual_langs)
def _lang_to_id(self, languages: str):
"""
Build a map from languages to ids. These ids are used as segment labels
for cross-lingual LM training.
"""
lang2id = {}
langs = [l.strip() for l in languages.split(",")]
for id, lang in enumerate(langs):
lang2id[lang] = id
return lang2id
@classmethod
def load_dictionary(cls, filename):
return MaskedLMDictionary.load(filename)
@classmethod
def build_dictionary(
cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8
):
d = MaskedLMDictionary()
for filename in filenames:
Dictionary.add_file_to_dictionary(
filename, d, tokenizer.tokenize_line, workers
)
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
return d
@property
def target_dictionary(self):
return self.dictionary
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task."""
dictionary = MaskedLMDictionary.load(os.path.join(args.data, "dict.txt"))
logger.info("dictionary: {} types".format(len(dictionary)))
return cls(args, dictionary)
def _load_single_lang_dataset(self, split, epoch):
loaded_datasets = []
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
for k in itertools.count():
split_k = split + (str(k) if k > 0 else "")
path = os.path.join(data_path, split_k)
ds = data_utils.load_indexed_dataset(
path, self.dictionary, self.args.dataset_impl
)
if ds is None:
if k > 0:
break
else:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, data_path)
)
# Since we append each block with the classification_token,
# we need to effectively create blocks of length
# tokens_per_sample-1
loaded_datasets.append(
TokenBlockDataset(
ds,
ds.sizes,
self.args.tokens_per_sample - 1,
pad=self.dictionary.pad(),
eos=self.dictionary.eos(),
)
)
logger.info(
"{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1]))
)
if len(loaded_datasets) == 1:
dataset = loaded_datasets[0]
sizes = dataset.sizes
else:
dataset = ConcatDataset(loaded_datasets)
sizes = np.concatenate([ds.sizes for ds in loaded_datasets])
return dataset, sizes
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
dataset_map = OrderedDict()
for lang in self.langs2id.keys():
# Datasets are expected to be in "split.lang" format (Eg: train.en)
language_split = "{}.{}".format(split, lang)
block_dataset, sizes = self._load_single_lang_dataset(
split=language_split, epoch=epoch
)
dataset_map[lang] = MaskedLMDataset(
dataset=block_dataset,
sizes=sizes,
vocab=self.dictionary,
pad_idx=self.dictionary.pad(),
mask_idx=self.dictionary.mask(),
classif_token_idx=self.dictionary.eos(),
sep_token_idx=self.dictionary.eos(),
shuffle=getattr(self.args, "shuffle", False),
has_pairs=False,
segment_id=self.langs2id[lang],
seed=self.seed,
)
self.datasets[split] = MultiCorpusSampledDataset(dataset_map)
logger.info(
"{} {} {} examples".format(
utils.split_paths(self.args.data)[epoch - 1],
split,
len(self.datasets[split]),
)
)

View File

@@ -0,0 +1,296 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from dataclasses import dataclass, field
from typing import Any, Optional
import numpy as np
from omegaconf import II, MISSING
from fairseq import utils
from fairseq.data import (
AppendTokenDataset,
DenoisingDataset,
Dictionary,
IdDataset,
NestedDictionaryDataset,
NumelDataset,
PadDataset,
PrependTokenDataset,
StripTokenDataset,
TokenBlockDataset,
data_utils,
)
from fairseq.data.encoders.utils import get_whole_word_mask
from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from ..data.indexed_dataset import get_available_dataset_impl
logger = logging.getLogger(__name__)
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
MASK_LENGTH_CHOICES = ChoiceEnum(["subword", "word", "span-poisson"])
@dataclass
class DenoisingConfig(FairseqDataclass):
data: str = field(
default=MISSING,
metadata={"help": "path to data directory"},
)
bpe: Optional[str] = field(
default=None,
metadata={"help": "TODO"},
)
tokens_per_sample: int = field(
default=512,
metadata={
"help": "max number of total tokens over all segments "
"per sample for dataset"
},
)
sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
default="complete_doc",
metadata={
"help": 'If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end '
"of sentence, but may include multiple sentences per sample. "
'"complete_doc" is similar but respects doc boundaries. '
'If set to "eos", includes only one sentence per sample.'
},
)
replace_length: int = field(
default=0,
metadata={"help": "TODO, should only allow -1, 0 and 1"},
)
mask: float = field(
default=0.0,
metadata={"help": "fraction of words/subwords that will be masked"},
)
mask_random: float = field(
default=0.0,
metadata={"help": "instead of using [MASK], use random token this often"},
)
insert: float = field(
default=0.0,
metadata={"help": "insert this percentage of additional random tokens"},
)
permute: float = field(
default=0.0,
metadata={"help": "take this proportion of subwords and permute them"},
)
rotate: float = field(
default=0.5,
metadata={"help": "rotate this proportion of inputs"},
)
poisson_lambda: float = field(
default=3.0,
metadata={"help": "randomly shuffle sentences for this proportion of inputs"},
)
shuffle_instance: float = field(
default=0.0,
metadata={"help": "shuffle this proportion of sentences in all inputs"},
)
mask_length: MASK_LENGTH_CHOICES = field(
default="subword",
metadata={"help": "mask length to choose"},
)
permute_sentences: int = field(
default=-1,
metadata={
"help": "when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)"
},
)
seed: int = II("common.seed")
shorten_method: SHORTEN_METHOD_CHOICES = field(
default="none",
metadata={
"help": "if not none, shorten sequences that exceed --tokens-per-sample"
},
)
shorten_data_split_list: str = field(
default="",
metadata={
"help": "comma-separated list of dataset splits to apply shortening to, "
'e.g., "train,valid" (default: all dataset splits)'
},
)
max_source_positions: int = field(
default=1024,
metadata={"help": "max number of tokens in the source sequence"},
)
max_target_positions: int = field(
default=1024,
metadata={"help": "max number of tokens in the target sequence"},
)
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
"dataset.dataset_impl"
)
@register_task("denoising", dataclass=DenoisingConfig)
class DenoisingTask(FairseqTask):
"""
Denoising task for applying sequence to sequence denoising. (ie. BART)
"""
cfg: DenoisingConfig
def __init__(self, cfg, dictionary):
super().__init__(cfg)
self.dictionary = dictionary
# add mask token
self.mask_idx = self.dictionary.add_symbol("<mask>")
@classmethod
def setup_task(cls, cfg: DenoisingConfig, **kwargs):
"""Setup the task."""
paths = utils.split_paths(cfg.data)
assert len(paths) > 0
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
logger.info("dictionary: {} types".format(len(dictionary)))
if not hasattr(cfg, "shuffle_instance"):
cfg.shuffle_instance = False
return cls(cfg, dictionary)
def _load_dataset_split(self, split, epoch, combine):
paths = utils.split_paths(self.cfg.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)
dataset = data_utils.load_indexed_dataset(
split_path,
self.dictionary,
self.cfg.dataset_impl,
combine=combine,
)
if dataset is None:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, split_path)
)
dataset = StripTokenDataset(dataset, self.dictionary.eos())
dataset = maybe_shorten_dataset(
dataset,
split,
self.cfg.shorten_data_split_list,
self.cfg.shorten_method,
self.cfg.tokens_per_sample,
self.cfg.seed,
)
# create continuous blocks of tokens
dataset = TokenBlockDataset(
dataset,
dataset.sizes,
self.cfg.tokens_per_sample - 2,
# one less for <s> and one for </s>
pad=self.dictionary.pad(),
eos=self.dictionary.eos(),
break_mode=self.cfg.sample_break_mode,
document_sep_len=0,
)
logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
dataset = AppendTokenDataset(dataset, self.source_dictionary.eos())
return dataset
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
dataset = self._load_dataset_split(split, epoch, combine)
mask_whole_words = (
get_whole_word_mask(self.cfg.bpe, self.source_dictionary)
if self.cfg.mask_length != "subword"
else None
)
self.datasets[split] = DenoisingDataset(
dataset,
dataset.sizes,
self.dictionary,
self.mask_idx,
mask_whole_words,
shuffle=self.cfg.shuffle_instance,
seed=self.cfg.seed,
mask=self.cfg.mask,
mask_random=self.cfg.mask_random,
insert=self.cfg.insert,
rotate=self.cfg.rotate,
permute_sentences=self.cfg.permute_sentences,
bpe=self.cfg.bpe,
replace_length=self.cfg.replace_length,
mask_length=self.cfg.mask_length,
poisson_lambda=self.cfg.poisson_lambda,
)
logger.info(
"Split: {0}, Loaded {1} samples of denoising_dataset".format(
split,
len(self.datasets[split]),
)
)
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
"""
Generate batches for inference. We assume that the input begins with a
bos symbol (`<s>`) and ends with an eos symbol (`</s>`).
"""
pad = self.source_dictionary.pad()
eos = self.source_dictionary.eos()
src_dataset = TokenBlockDataset(
src_tokens,
src_lengths,
block_size=self.cfg.tokens_per_sample - 2, # for <s> and </s>
pad=pad,
eos=eos,
break_mode=self.cfg.sample_break_mode,
document_sep_len=0,
)
prev_output_tokens = PrependTokenDataset(
StripTokenDataset(src_dataset, eos), eos
)
src_dataset = PadDataset(src_dataset, pad_idx=pad, left_pad=False)
return NestedDictionaryDataset(
{
"id": IdDataset(),
"net_input": {
"src_tokens": src_dataset,
"src_lengths": NumelDataset(src_dataset, reduce=False),
"prev_output_tokens": PadDataset(
prev_output_tokens, pad_idx=pad, left_pad=False
),
},
"target": src_dataset,
},
sizes=[np.array(src_lengths)],
)
def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.cfg.max_source_positions, self.cfg.max_target_positions)
@property
def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary`."""
return self.dictionary
@property
def target_dictionary(self):
"""Return the target :class:`~fairseq.data.Dictionary`."""
return self.dictionary

View File

@@ -0,0 +1,693 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import warnings
from argparse import Namespace
from typing import Any, Callable, Dict, List
import torch
from fairseq import metrics, search, tokenizer, utils
from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import gen_parser_from_dataclass
from fairseq.optim.amp_optimizer import AMPOptimizer
from omegaconf import DictConfig
logger = logging.getLogger(__name__)
class StatefulContainer(object):
def __init__(self):
self._state = dict()
self._factories = dict()
def add_factory(self, name, factory: Callable[[], Any]):
self._factories[name] = factory
def merge_state_dict(self, state_dict: Dict[str, Any]):
self._state.update(state_dict)
@property
def state_dict(self) -> Dict[str, Any]:
return self._state
def __getattr__(self, name):
if name not in self._state and name in self._factories:
self._state[name] = self._factories[name]()
if name in self._state:
return self._state[name]
raise AttributeError(f"Task state has no factory for attribute {name}")
class FairseqTask(object):
"""
Tasks store dictionaries and provide helpers for loading/iterating over
Datasets, initializing the Model/Criterion and calculating the loss.
Tasks have limited statefulness. In particular, state that needs to be
saved to/loaded from checkpoints needs to be stored in the `self.state`
:class:`StatefulContainer` object. For example::
self.state.add_factory("dictionary", self.load_dictionary)
print(self.state.dictionary) # calls self.load_dictionary()
This is necessary so that when loading checkpoints, we can properly
recreate the task state after initializing the task instance.
"""
@classmethod
def add_args(cls, parser):
"""Add task-specific arguments to the parser."""
dc = getattr(cls, "__dataclass", None)
if dc is not None:
gen_parser_from_dataclass(parser, dc())
@staticmethod
def logging_outputs_can_be_summed(criterion) -> bool:
"""
Whether the logging outputs returned by `train_step` and `valid_step` can
be summed across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
"""
return criterion.logging_outputs_can_be_summed()
def __init__(self, cfg: FairseqDataclass, **kwargs):
self.cfg = cfg
self.datasets = dict()
self.dataset_to_epoch_iter = dict()
self.state = StatefulContainer()
@classmethod
def load_dictionary(cls, filename):
"""Load the dictionary from the filename
Args:
filename (str): the filename
"""
return Dictionary.load(filename)
@classmethod
def build_dictionary(
cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8
):
"""Build the dictionary
Args:
filenames (list): list of filenames
workers (int): number of concurrent workers
threshold (int): defines the minimum word count
nwords (int): defines the total number of words in the final dictionary,
including special symbols
padding_factor (int): can be used to pad the dictionary size to be a
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
"""
d = Dictionary()
for filename in filenames:
Dictionary.add_file_to_dictionary(
filename, d, tokenizer.tokenize_line, workers
)
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
return d
@classmethod
def setup_task(cls, cfg: DictConfig, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
cfg (omegaconf.DictConfig): parsed command-line arguments
"""
return cls(cfg, **kwargs)
def has_sharded_data(self, split):
return os.pathsep in getattr(self.cfg, "data", "")
def load_dataset(
self,
split: str,
combine: bool = False,
task_cfg: FairseqDataclass = None,
**kwargs,
):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
combine (bool): combines a split segmented into pieces into one dataset
task_cfg (FairseqDataclass): optional task configuration stored in the checkpoint that can be used
to load datasets
"""
raise NotImplementedError
def dataset(self, split):
"""
Return a loaded dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
Returns:
a :class:`~fairseq.data.FairseqDataset` corresponding to *split*
"""
from fairseq.data import FairseqDataset
if split not in self.datasets:
raise KeyError("Dataset not loaded: " + split)
if not isinstance(self.datasets[split], FairseqDataset):
raise TypeError("Datasets are expected to be of type FairseqDataset")
return self.datasets[split]
def filter_indices_by_size(
self, indices, dataset, max_positions=None, ignore_invalid_inputs=False
):
"""
Filter examples that are too large
Args:
indices (np.array): original array of sample indices
dataset (~fairseq.data.FairseqDataset): dataset to batch
max_positions (optional): max sentence length supported by the
model (default: None).
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long (default: False).
Returns:
np.array: array of filtered sample indices
"""
indices, ignored = dataset.filter_indices_by_size(indices, max_positions)
if len(ignored) > 0:
if not ignore_invalid_inputs:
raise Exception(
(
"Size of sample #{} is invalid (={}) since max_positions={}, "
"skip this example with --skip-invalid-size-inputs-valid-test"
).format(ignored[0], dataset.size(ignored[0]), max_positions)
)
logger.warning(
(
"{:,} samples have invalid sizes and will be skipped, "
"max_positions={}, first few sample ids={}"
).format(len(ignored), max_positions, ignored[:10])
)
return indices
def can_reuse_epoch_itr(self, dataset):
# We can reuse the epoch iterator across epochs as long as the dataset
# hasn't disabled it. We default to ``False`` here, although in practice
# this will be ``True`` for most datasets that inherit from
# ``FairseqDataset`` due to the base implementation there.
return getattr(dataset, "can_reuse_epoch_itr_across_epochs", False)
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
skip_remainder_batch=False,
grouped_shuffling=False,
update_epoch_batch_itr=False,
):
"""
Get an iterator that yields batches of data from the given dataset.
Args:
dataset (~fairseq.data.FairseqDataset): dataset to batch
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch (default: None).
max_positions (optional): max sentence length supported by the
model (default: None).
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long (default: False).
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N (default: 1).
seed (int, optional): seed for random number generator for
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
shards (default: 1).
shard_id (int, optional): which shard of the data iterator to
return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 1).
data_buffer_size (int, optional): number of batches to
preload (default: 0).
disable_iterator_cache (bool, optional): don't cache the
EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`)
(default: False).
skip_remainder_batch (bool, optional): if set, discard the last
batch in each training epoch, as the last batch is often smaller than
local_batch_size * distributed_word_size (default: ``True``).
grouped_shuffling (bool, optional): group batches with each groups
containing num_shards batches and shuffle groups. Reduces difference
between sequence lengths among workers for batches sorted by length.
update_epoch_batch_itr (bool optional): if true then donot use the cached
batch iterator for the epoch
Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
given dataset split
"""
can_reuse_epoch_itr = (
not disable_iterator_cache
and not update_epoch_batch_itr
and self.can_reuse_epoch_itr(dataset)
)
if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter:
logger.debug("reusing EpochBatchIterator for epoch {}".format(epoch))
return self.dataset_to_epoch_iter[dataset]
assert isinstance(dataset, FairseqDataset)
# initialize the dataset with the correct starting epoch
dataset.set_epoch(epoch)
# get indices ordered by example size
with data_utils.numpy_seed(seed):
indices = dataset.ordered_indices()
# filter examples that are too large
if max_positions is not None:
indices = self.filter_indices_by_size(
indices, dataset, max_positions, ignore_invalid_inputs
)
# create mini-batches with given size constraints
batch_sampler = dataset.batch_by_size(
indices,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
reuse_dataloader = getattr(self.cfg, "reuse_dataloader", True)
persistent_workers = getattr(self.cfg, "persistent_workers", False)
# return a reusable, sharded iterator
epoch_iter = iterators.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=batch_sampler,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
buffer_size=data_buffer_size,
skip_remainder_batch=skip_remainder_batch,
grouped_shuffling=grouped_shuffling,
reuse_dataloader=reuse_dataloader,
persistent_workers=persistent_workers,
)
if can_reuse_epoch_itr:
self.dataset_to_epoch_iter[dataset] = epoch_iter
return epoch_iter
def build_model(self, cfg: FairseqDataclass, from_checkpoint=False):
"""
Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
task.
Args:
cfg (FairseqDataclass): configuration object
Returns:
a :class:`~fairseq.models.BaseFairseqModel` instance
"""
from fairseq import models, quantization_utils
model = models.build_model(cfg, self, from_checkpoint)
model = quantization_utils.quantize_model_scalar(model, cfg)
return model
def build_criterion(self, cfg: DictConfig):
"""
Build the :class:`~fairseq.criterions.FairseqCriterion` instance for
this task.
Args:
cfg (omegaconf.DictConfig): configration object
Returns:
a :class:`~fairseq.criterions.FairseqCriterion` instance
"""
from fairseq import criterions
return criterions.build_criterion(cfg, self)
def build_generator(
self,
models,
args,
seq_gen_cls=None,
extra_gen_cls_kwargs=None,
prefix_allowed_tokens_fn=None,
):
"""
Build a :class:`~fairseq.SequenceGenerator` instance for this
task.
Args:
models (List[~fairseq.models.FairseqModel]): ensemble of models
args (fairseq.dataclass.configs.GenerationConfig):
configuration object (dataclass) for generation
extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass
through to SequenceGenerator
prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]):
If provided, this function constrains the beam search to
allowed tokens only at each step. The provided function
should take 2 arguments: the batch ID (`batch_id: int`)
and a unidimensional tensor of token ids (`inputs_ids:
torch.Tensor`). It has to return a `List[int]` with the
allowed tokens for the next generation step conditioned
on the previously generated tokens (`inputs_ids`) and
the batch ID (`batch_id`). This argument is useful for
constrained generation conditioned on the prefix, as
described in "Autoregressive Entity Retrieval"
(https://arxiv.org/abs/2010.00904) and
https://github.com/facebookresearch/GENRE.
"""
if getattr(args, "score_reference", False):
from fairseq.sequence_scorer import SequenceScorer
return SequenceScorer(
self.target_dictionary,
compute_alignment=getattr(args, "print_alignment", False),
)
from fairseq.sequence_generator import (
SequenceGenerator,
SequenceGeneratorWithAlignment,
)
# Choose search strategy. Defaults to Beam Search.
sampling = getattr(args, "sampling", False)
sampling_topk = getattr(args, "sampling_topk", -1)
sampling_topp = getattr(args, "sampling_topp", -1.0)
diverse_beam_groups = getattr(args, "diverse_beam_groups", -1)
diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
match_source_len = getattr(args, "match_source_len", False)
diversity_rate = getattr(args, "diversity_rate", -1)
constrained = getattr(args, "constraints", False)
if prefix_allowed_tokens_fn is None:
prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None)
if (
sum(
int(cond)
for cond in [
sampling,
diverse_beam_groups > 0,
match_source_len,
diversity_rate > 0,
]
)
> 1
):
raise ValueError("Provided Search parameters are mutually exclusive.")
assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"
if sampling:
search_strategy = search.Sampling(
self.target_dictionary, sampling_topk, sampling_topp
)
elif diverse_beam_groups > 0:
search_strategy = search.DiverseBeamSearch(
self.target_dictionary, diverse_beam_groups, diverse_beam_strength
)
elif match_source_len:
# this is useful for tagging applications where the output
# length should match the input length, so we hardcode the
# length constraints for simplicity
search_strategy = search.LengthConstrainedBeamSearch(
self.target_dictionary,
min_len_a=1,
min_len_b=0,
max_len_a=1,
max_len_b=0,
)
elif diversity_rate > -1:
search_strategy = search.DiverseSiblingsSearch(
self.target_dictionary, diversity_rate
)
elif constrained:
search_strategy = search.LexicallyConstrainedBeamSearch(
self.target_dictionary, args.constraints
)
elif prefix_allowed_tokens_fn:
search_strategy = search.PrefixConstrainedBeamSearch(
self.target_dictionary, prefix_allowed_tokens_fn
)
else:
search_strategy = search.BeamSearch(self.target_dictionary)
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
if seq_gen_cls is None:
if getattr(args, "print_alignment", False):
seq_gen_cls = SequenceGeneratorWithAlignment
extra_gen_cls_kwargs["print_alignment"] = args.print_alignment
else:
seq_gen_cls = SequenceGenerator
return seq_gen_cls(
models,
self.target_dictionary,
beam_size=getattr(args, "beam", 5),
max_len_a=getattr(args, "max_len_a", 0),
max_len_b=getattr(args, "max_len_b", 200),
min_len=getattr(args, "min_len", 1),
normalize_scores=(not getattr(args, "unnormalized", False)),
len_penalty=getattr(args, "lenpen", 1),
unk_penalty=getattr(args, "unkpen", 0),
temperature=getattr(args, "temperature", 1.0),
match_source_len=getattr(args, "match_source_len", False),
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
search_strategy=search_strategy,
**extra_gen_cls_kwargs,
)
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
"""
Do forward and backward, and return the loss as computed by *criterion*
for the given *model* and *sample*.
Args:
sample (dict): the mini-batch. The format is defined by the
:class:`~fairseq.data.FairseqDataset`.
model (~fairseq.models.BaseFairseqModel): the model
criterion (~fairseq.criterions.FairseqCriterion): the criterion
optimizer (~fairseq.optim.FairseqOptimizer): the optimizer
update_num (int): the current update
ignore_grad (bool): multiply loss by 0 if this is set to True
Returns:
tuple:
- the loss
- the sample size, which is used as the denominator for the
gradient
- logging outputs to display while training
"""
model.train()
model.set_num_updates(update_num)
with torch.autograd.profiler.record_function("forward"):
with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))):
loss, sample_size, logging_output = criterion(model, sample)
if ignore_grad:
loss *= 0
with torch.autograd.profiler.record_function("backward"):
optimizer.backward(loss)
return loss, sample_size, logging_output
def valid_step(self, sample, model, criterion):
model.eval()
with torch.no_grad():
loss, sample_size, logging_output = criterion(model, sample)
return loss, sample_size, logging_output
def optimizer_step(self, optimizer, model, update_num):
optimizer.step()
def build_dataset_for_inference(
self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs
) -> torch.utils.data.Dataset:
raise NotImplementedError
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
with torch.no_grad():
return generator.generate(
models, sample, prefix_tokens=prefix_tokens, constraints=constraints
)
def begin_epoch(self, epoch, model):
"""Hook function called before the start of each epoch."""
pass
def begin_valid_epoch(self, epoch, model):
"""Hook function called before the start of each validation epoch."""
pass
def aggregate_logging_outputs(self, logging_outputs, criterion):
"""[deprecated] Aggregate logging outputs from data parallel training."""
utils.deprecation_warning(
"The aggregate_logging_outputs API is deprecated. "
"Please use the reduce_metrics API instead."
)
with metrics.aggregate() as agg:
self.reduce_metrics(logging_outputs, criterion)
return agg.get_smoothed_values()
def reduce_metrics(self, logging_outputs, criterion):
"""Aggregate logging outputs from data parallel training."""
# backward compatibility for tasks that override aggregate_logging_outputs
base_func = FairseqTask.aggregate_logging_outputs
self_func = getattr(self, "aggregate_logging_outputs").__func__
if self_func is not base_func:
utils.deprecation_warning(
"Tasks should implement the reduce_metrics API. "
"Falling back to deprecated aggregate_logging_outputs API."
)
agg_logging_outputs = self.aggregate_logging_outputs(
logging_outputs, criterion
)
for k, v in agg_logging_outputs.items():
metrics.log_scalar(k, v)
return
if not any("ntokens" in log for log in logging_outputs):
warnings.warn(
"ntokens not found in Criterion logging outputs, cannot log wpb or wps"
)
else:
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
metrics.log_scalar("wpb", ntokens, priority=180, round=1)
metrics.log_speed("wps", ntokens, priority=90, round=1)
if not any("nsentences" in log for log in logging_outputs):
warnings.warn(
"nsentences not found in Criterion logging outputs, cannot log bsz"
)
else:
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
metrics.log_scalar("bsz", nsentences, priority=190, round=1)
criterion.__class__.reduce_metrics(logging_outputs)
def state_dict(self):
if self.state is not None:
return self.state.state_dict
return {}
def load_state_dict(self, state_dict: Dict[str, Any]):
if self.state is not None:
self.state.merge_state_dict(state_dict)
def max_positions(self):
"""Return the max input length allowed by the task."""
return None
@property
def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary` (if applicable
for this task)."""
raise NotImplementedError
@property
def target_dictionary(self):
"""Return the target :class:`~fairseq.data.Dictionary` (if applicable
for this task)."""
raise NotImplementedError
def build_tokenizer(self, args):
"""Build the pre-tokenizer for this task."""
return encoders.build_tokenizer(args)
def build_bpe(self, args):
"""Build the tokenizer for this task."""
return encoders.build_bpe(args)
def get_interactive_tokens_and_lengths(self, lines, encode_fn):
tokens = [
self.source_dictionary.encode_line(
encode_fn(src_str), add_if_not_exist=False
).long()
for src_str in lines
]
lengths = [t.numel() for t in tokens]
return tokens, lengths
class LegacyFairseqTask(FairseqTask):
def __init__(self, args: Namespace):
super().__init__(None)
self.args = args
self.datasets = {}
self.dataset_to_epoch_iter = {}
@classmethod
def setup_task(cls, args: Namespace, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
args (argparse.Namespace): parsed command-line arguments
"""
return cls(args, **kwargs)
def has_sharded_data(self, split):
return os.pathsep in getattr(self.args, "data", "")
def build_model(self, args: Namespace, from_checkpoint=False):
"""
Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
task.
Args:
args (argparse.Namespace): parsed command-line arguments
Returns:
a :class:`~fairseq.models.BaseFairseqModel` instance
"""
from fairseq import models, quantization_utils
model = models.build_model(args, self, from_checkpoint)
model = quantization_utils.quantize_model_scalar(model, args)
return model
def build_criterion(self, args: Namespace):
"""
Build the :class:`~fairseq.criterions.FairseqCriterion` instance for
this task.
Args:
args (argparse.Namespace): parsed command-line arguments
Returns:
a :class:`~fairseq.criterions.FairseqCriterion` instance
"""
from fairseq import criterions
return criterions.build_criterion(args, self)

View File

@@ -0,0 +1,55 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from fairseq.data.audio.frm_text_to_speech_dataset import FrmTextToSpeechDatasetCreator
from fairseq.tasks import register_task
from fairseq.tasks.text_to_speech import TextToSpeechTask
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
@register_task("frm_text_to_speech")
class FrmTextToSpeechTask(TextToSpeechTask):
@staticmethod
def add_args(parser):
TextToSpeechTask.add_args(parser)
parser.add_argument("--do_chunk", action="store_true", help="train on chunks")
parser.add_argument("--chunk_bound", default=-1, type=int)
parser.add_argument("--chunk_init", default=50, type=int)
parser.add_argument("--chunk_incr", default=5, type=int)
parser.add_argument("--add_eos", action="store_true")
parser.add_argument("--dedup", action="store_true")
parser.add_argument("--ref_fpu", default=-1, type=float)
def load_dataset(self, split, **unused_kwargs):
is_train_split = split.startswith("train")
pre_tokenizer = self.build_tokenizer(self.args)
bpe_tokenizer = self.build_bpe(self.args)
self.datasets[split] = FrmTextToSpeechDatasetCreator.from_tsv(
self.args.data,
self.data_cfg,
split,
self.src_dict,
pre_tokenizer,
bpe_tokenizer,
is_train_split=is_train_split,
n_frames_per_step=self.args.n_frames_per_step,
speaker_to_id=self.speaker_to_id,
do_chunk=self.args.do_chunk,
chunk_bound=self.args.chunk_bound,
chunk_init=self.args.chunk_init,
chunk_incr=self.args.chunk_incr,
add_eos=self.args.add_eos,
dedup=self.args.dedup,
ref_fpu=self.args.ref_fpu,
)

View File

@@ -0,0 +1,191 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
import os
import sys
from typing import Dict, List, Optional, Tuple
import numpy as np
from dataclasses import dataclass, field
from fairseq.data import Dictionary, HubertDataset
from fairseq.dataclass.configs import FairseqDataclass
from fairseq.tasks import register_task
from fairseq.tasks.fairseq_task import FairseqTask
from omegaconf import MISSING
logger = logging.getLogger(__name__)
class LabelEncoder(object):
def __init__(self, dictionary: Dictionary) -> None:
self.dictionary = dictionary
def __call__(self, label: str) -> List[str]:
return self.dictionary.encode_line(
label,
append_eos=False,
add_if_not_exist=False,
)
@dataclass
class HubertPretrainingConfig(FairseqDataclass):
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
fine_tuning: bool = field(
default=False, metadata={"help": "set to true if fine-tuning Hubert"}
)
labels: List[str] = field(
default_factory=lambda: ["ltr"],
metadata={
"help": (
"extension of the label files to load, frame-level labels for"
" pre-training, and sequence-level label for fine-tuning"
)
},
)
label_dir: Optional[str] = field(
default=None,
metadata={
"help": "if set, looks for labels in this directory instead",
},
)
label_rate: float = field(
default=-1.0,
metadata={"help": "label frame rate. -1.0 for sequence label"},
)
sample_rate: int = field(
default=16_000,
metadata={
"help": "target sample rate. audio files will be up/down "
"sampled to this rate"
},
)
normalize: bool = field(
default=False,
metadata={"help": "if set, normalizes input to have 0 mean and unit variance"},
)
enable_padding: bool = field(
default=False,
metadata={"help": "pad shorter samples instead of cropping"},
)
max_keep_size: Optional[int] = field(
default=None,
metadata={"help": "exclude sample longer than this"},
)
max_sample_size: Optional[int] = field(
default=None,
metadata={"help": "max sample size to crop to for batching"},
)
min_sample_size: Optional[int] = field(
default=None,
metadata={"help": "min sample size to crop to for batching"},
)
single_target: Optional[bool] = field(
default=False,
metadata={
"help": "if set, AddTargetDatasets outputs same keys " "as AddTargetDataset"
},
)
random_crop: Optional[bool] = field(
default=True,
metadata={"help": "always crop from the beginning if false"},
)
pad_audio: Optional[bool] = field(
default=False,
metadata={"help": "pad audio to the longest one in the batch if true"},
)
@register_task("hubert_pretraining", dataclass=HubertPretrainingConfig)
class HubertPretrainingTask(FairseqTask):
cfg: HubertPretrainingConfig
def __init__(
self,
cfg: HubertPretrainingConfig,
) -> None:
super().__init__(cfg)
logger.info(f"current directory is {os.getcwd()}")
logger.info(f"HubertPretrainingTask Config {cfg}")
self.cfg = cfg
self.fine_tuning = cfg.fine_tuning
if cfg.fine_tuning:
self.state.add_factory("target_dictionary", self.load_dictionaries)
else:
self.state.add_factory("dictionaries", self.load_dictionaries)
self.blank_symbol = "<s>"
@property
def source_dictionary(self) -> Optional[Dictionary]:
return None
@property
def target_dictionary(self) -> Optional[Dictionary]:
return self.state.target_dictionary
@property
def dictionaries(self) -> List[Dictionary]:
return self.state.dictionaries
@classmethod
def setup_task(
cls, cfg: HubertPretrainingConfig, **kwargs
) -> "HubertPretrainingTask":
return cls(cfg)
def load_dictionaries(self):
label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir
dictionaries = [
Dictionary.load(f"{label_dir}/dict.{label}.txt")
for label in self.cfg.labels
]
return dictionaries[0] if self.cfg.fine_tuning else dictionaries
def get_label_dir(self) -> str:
if self.cfg.label_dir is None:
return self.cfg.data
return self.cfg.label_dir
def load_dataset(self, split: str, **kwargs) -> None:
manifest = f"{self.cfg.data}/{split}.tsv"
dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries
pad_list = [dict.pad() for dict in dicts]
eos_list = [dict.eos() for dict in dicts]
procs = [LabelEncoder(dict) for dict in dicts]
paths = [f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels]
# hubert v1: pad_audio=True, random_crop=False;
self.datasets[split] = HubertDataset(
manifest,
sample_rate=self.cfg.sample_rate,
label_paths=paths,
label_rates=self.cfg.label_rate,
pad_list=pad_list,
eos_list=eos_list,
label_processors=procs,
max_keep_sample_size=self.cfg.max_keep_size,
min_keep_sample_size=self.cfg.min_sample_size,
max_sample_size=self.cfg.max_sample_size,
pad_audio=self.cfg.pad_audio,
normalize=self.cfg.normalize,
store_labels=False,
random_crop=self.cfg.random_crop,
single_target=self.cfg.single_target,
)
def max_positions(self) -> Tuple[int, int]:
return (sys.maxsize, sys.maxsize)
def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array:
return indices

View File

@@ -0,0 +1,383 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import torch
from fairseq import utils
from fairseq.data import (
AppendTokenDataset,
Dictionary,
IdDataset,
LMContextWindowDataset,
MonolingualDataset,
NestedDictionaryDataset,
NumelDataset,
PadDataset,
PrependTokenDataset,
StripTokenDataset,
TokenBlockDataset,
TruncatedDictionary,
data_utils,
)
from fairseq.data.indexed_dataset import get_available_dataset_impl
from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.tasks import LegacyFairseqTask, register_task
from omegaconf import II
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
logger = logging.getLogger(__name__)
@dataclass
class LanguageModelingConfig(FairseqDataclass):
data: Optional[str] = field(
default=None, metadata={"help": "path to data directory"}
)
sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
default="none",
metadata={
"help": 'If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end '
"of sentence, but may include multiple sentences per sample. "
'"complete_doc" is similar but respects doc boundaries. '
'If set to "eos", includes only one sentence per sample.'
},
)
tokens_per_sample: int = field(
default=1024,
metadata={"help": "max number of tokens per sample for LM dataset"},
)
output_dictionary_size: int = field(
default=-1, metadata={"help": "limit the size of output dictionary"}
)
self_target: bool = field(default=False, metadata={"help": "include self target"})
future_target: bool = field(
default=False, metadata={"help": "include future target"}
)
past_target: bool = field(default=False, metadata={"help": "include past target"})
add_bos_token: bool = field(
default=False, metadata={"help": "prepend beginning of sentence token (<s>)"}
)
max_target_positions: Optional[int] = field(
default=None, metadata={"help": "max number of tokens in the target sequence"}
)
shorten_method: SHORTEN_METHOD_CHOICES = field(
default="none",
metadata={
"help": "if not none, shorten sequences that exceed --tokens-per-sample"
},
)
shorten_data_split_list: str = field(
default="",
metadata={
"help": "comma-separated list of dataset splits to apply shortening to, "
'e.g., "train,valid" (default: all dataset splits)'
},
)
pad_to_fixed_length: Optional[bool] = field(
default=False,
metadata={"help": "pad to fixed length"},
)
pad_to_fixed_bsz: Optional[bool] = field(
default=False,
metadata={"help": "boolean to pad to fixed batch size"},
)
# TODO common vars below add to parent
seed: int = II("common.seed")
batch_size: Optional[int] = II("dataset.batch_size")
batch_size_valid: Optional[int] = II("dataset.batch_size_valid")
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
"dataset.dataset_impl"
)
data_buffer_size: int = II("dataset.data_buffer_size")
tpu: bool = II("common.tpu")
use_plasma_view: bool = II("common.use_plasma_view")
plasma_path: str = II("common.plasma_path")
@register_task("language_modeling", dataclass=LanguageModelingConfig)
class LanguageModelingTask(LegacyFairseqTask):
"""
Train a language model.
Args:
dictionary (~fairseq.data.Dictionary): the dictionary for the input of
the language model
output_dictionary (~fairseq.data.Dictionary): the dictionary for the
output of the language model. In most cases it will be the same as
*dictionary*, but could possibly be a more limited version of the
dictionary (if ``--output-dictionary-size`` is used).
targets (List[str]): list of the target types that the language model
should predict. Can be one of "self", "future", and "past".
Defaults to "future".
.. note::
The language modeling task is compatible with :mod:`fairseq-train`,
:mod:`fairseq-generate`, :mod:`fairseq-interactive` and
:mod:`fairseq-eval-lm`.
The language modeling task provides the following additional command-line
arguments:
.. argparse::
:ref: fairseq.tasks.language_modeling_parser
:prog:
"""
def __init__(self, args, dictionary, output_dictionary=None, targets=None):
super().__init__(args)
self.dictionary = dictionary
self.output_dictionary = output_dictionary or dictionary
if targets is None:
targets = ["future"]
self.targets = targets
@classmethod
def setup_dictionary(cls, args, **kwargs):
dictionary = None
output_dictionary = None
if args.data:
paths = utils.split_paths(args.data)
assert len(paths) > 0
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
logger.info("dictionary: {} types".format(len(dictionary)))
output_dictionary = dictionary
if args.output_dictionary_size >= 0:
output_dictionary = TruncatedDictionary(
dictionary, args.output_dictionary_size
)
return (dictionary, output_dictionary)
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
args (argparse.Namespace): parsed command-line arguments
"""
dictionary, output_dictionary = cls.setup_dictionary(args, **kwargs)
# upgrade old checkpoints
if getattr(args, "exclude_self_target", False):
args.self_target = False
targets = []
if getattr(args, "self_target", False):
targets.append("self")
if getattr(args, "future_target", False):
targets.append("future")
if getattr(args, "past_target", False):
targets.append("past")
if len(targets) == 0:
# standard language modeling
targets = ["future"]
return cls(args, dictionary, output_dictionary, targets=targets)
def build_model(self, args, from_checkpoint=False):
model = super().build_model(args, from_checkpoint)
for target in self.targets:
if target not in model.supported_targets:
raise ValueError(
"Unsupported language modeling target: {}".format(target)
)
return model
def load_dataset(
self, split: str, epoch=1, combine=False, **kwargs
) -> MonolingualDataset:
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, valid1, test)
"""
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)
# each process has its own copy of the raw data (likely to be an np.memmap)
dataset = data_utils.load_indexed_dataset(
split_path, self.dictionary, self.args.dataset_impl, combine=combine
)
if dataset is None:
raise FileNotFoundError(f"Dataset not found: {split} ({split_path})")
dataset = maybe_shorten_dataset(
dataset,
split,
self.args.shorten_data_split_list,
self.args.shorten_method,
self.args.tokens_per_sample,
self.args.seed,
)
dataset = TokenBlockDataset(
dataset,
dataset.sizes,
self.args.tokens_per_sample,
pad=self.dictionary.pad(),
eos=self.dictionary.eos(),
break_mode=self.args.sample_break_mode,
include_targets=True,
use_plasma_view=self.args.use_plasma_view,
split_path=split_path,
plasma_path=self.args.plasma_path,
)
add_eos_for_other_targets = (
self.args.sample_break_mode is not None
and self.args.sample_break_mode != "none"
)
fixed_pad_length = None
if self.args.pad_to_fixed_length:
fixed_pad_length = self.args.tokens_per_sample
pad_to_bsz = None
if self.args.pad_to_fixed_bsz:
pad_to_bsz = (
self.args.batch_size_valid if "valid" in split else self.args.batch_size
)
self.datasets[split] = MonolingualDataset(
dataset=dataset,
sizes=dataset.sizes,
src_vocab=self.dictionary,
tgt_vocab=self.output_dictionary,
add_eos_for_other_targets=add_eos_for_other_targets,
shuffle=True,
targets=self.targets,
add_bos_token=self.args.add_bos_token,
fixed_pad_length=fixed_pad_length,
pad_to_bsz=pad_to_bsz,
)
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
"""
Generate batches for inference. We prepend an eos token to src_tokens
(or bos if `--add-bos-token` is set) and we append a <pad> to target.
This is convenient both for generation with a prefix and LM scoring.
"""
dataset = StripTokenDataset(
TokenBlockDataset(
src_tokens,
src_lengths,
block_size=None, # ignored for "eos" break mode
pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(),
break_mode="eos",
),
# remove eos from (end of) target sequence
self.source_dictionary.eos(),
)
src_dataset = PrependTokenDataset(
dataset,
token=(
self.source_dictionary.bos()
if getattr(self.args, "add_bos_token", False)
else self.source_dictionary.eos()
),
)
tgt_dataset = AppendTokenDataset(dataset, token=self.source_dictionary.pad())
return NestedDictionaryDataset(
{
"id": IdDataset(),
"net_input": {
"src_tokens": PadDataset(
src_dataset,
pad_idx=self.source_dictionary.pad(),
left_pad=False,
),
"src_lengths": NumelDataset(src_dataset, reduce=False),
},
"target": PadDataset(
tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False
),
},
sizes=[np.array(src_lengths)],
)
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
with torch.no_grad():
# Generation will always be conditioned on bos_token
if getattr(self.args, "add_bos_token", False):
bos_token = self.source_dictionary.bos()
else:
bos_token = self.source_dictionary.eos()
if constraints is not None:
raise NotImplementedError(
"Constrained decoding with the language_modeling task is not supported"
)
# SequenceGenerator doesn't use src_tokens directly, we need to
# pass the `prefix_tokens` argument instead
if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():
prefix_tokens = sample["net_input"]["src_tokens"]
if prefix_tokens[:, 0].eq(bos_token).all():
prefix_tokens = prefix_tokens[:, 1:]
return generator.generate(
models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token
)
def eval_lm_dataloader(
self,
dataset,
max_tokens: Optional[int] = 36000,
batch_size: Optional[int] = None,
max_positions: Optional[int] = None,
num_shards: int = 1,
shard_id: int = 0,
num_workers: int = 1,
data_buffer_size: int = 10,
# ensures that every evaluated token has access to a context of at least
# this size, if possible
context_window: int = 0,
):
if context_window > 0:
dataset = LMContextWindowDataset(
dataset=dataset,
tokens_per_sample=self.args.tokens_per_sample,
context_window=context_window,
pad_idx=self.source_dictionary.pad(),
)
return self.get_batch_iterator(
dataset=dataset,
max_tokens=max_tokens,
max_sentences=batch_size,
max_positions=max_positions,
ignore_invalid_inputs=True,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
data_buffer_size=data_buffer_size,
).next_epoch_itr(shuffle=False)
@property
def source_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.dictionary
@property
def target_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.output_dictionary

View File

@@ -0,0 +1,152 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import logging
import os
import numpy as np
from fairseq import tokenizer, utils
from fairseq.data import ConcatDataset, Dictionary, data_utils, indexed_dataset
from fairseq.data.legacy.block_pair_dataset import BlockPairDataset
from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset
from fairseq.data.legacy.masked_lm_dictionary import BertDictionary
from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task("legacy_masked_lm")
class LegacyMaskedLMTask(LegacyFairseqTask):
"""
Task for training Masked LM (BERT) model.
Args:
dictionary (Dictionary): the dictionary for the input of the task
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument(
"data",
help="colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner",
)
parser.add_argument(
"--tokens-per-sample",
default=512,
type=int,
help="max number of total tokens over all segments"
" per sample for BERT dataset",
)
parser.add_argument(
"--break-mode", default="doc", type=str, help="mode for breaking sentence"
)
parser.add_argument("--shuffle-dataset", action="store_true", default=False)
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
self.seed = args.seed
@classmethod
def load_dictionary(cls, filename):
return BertDictionary.load(filename)
@classmethod
def build_dictionary(
cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8
):
d = BertDictionary()
for filename in filenames:
Dictionary.add_file_to_dictionary(
filename, d, tokenizer.tokenize_line, workers
)
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
return d
@property
def target_dictionary(self):
return self.dictionary
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task."""
paths = utils.split_paths(args.data)
assert len(paths) > 0
dictionary = BertDictionary.load(os.path.join(paths[0], "dict.txt"))
logger.info("dictionary: {} types".format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split, epoch=1, combine=False):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
loaded_datasets = []
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
logger.info("data_path", data_path)
for k in itertools.count():
split_k = split + (str(k) if k > 0 else "")
path = os.path.join(data_path, split_k)
ds = indexed_dataset.make_dataset(
path,
impl=self.args.dataset_impl,
fix_lua_indexing=True,
dictionary=self.dictionary,
)
if ds is None:
if k > 0:
break
else:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, data_path)
)
with data_utils.numpy_seed(self.seed + k):
loaded_datasets.append(
BlockPairDataset(
ds,
self.dictionary,
ds.sizes,
self.args.tokens_per_sample,
break_mode=self.args.break_mode,
doc_break_size=1,
)
)
logger.info(
"{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1]))
)
if not combine:
break
if len(loaded_datasets) == 1:
dataset = loaded_datasets[0]
sizes = dataset.sizes
else:
dataset = ConcatDataset(loaded_datasets)
sizes = np.concatenate([ds.sizes for ds in loaded_datasets])
self.datasets[split] = MaskedLMDataset(
dataset=dataset,
sizes=sizes,
vocab=self.dictionary,
pad_idx=self.dictionary.pad(),
mask_idx=self.dictionary.mask(),
classif_token_idx=self.dictionary.cls(),
sep_token_idx=self.dictionary.sep(),
shuffle=self.args.shuffle_dataset,
seed=self.seed,
)

View File

@@ -0,0 +1,270 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from dataclasses import dataclass, field
import numpy as np
from omegaconf import II, MISSING, OmegaConf
from fairseq import utils
from fairseq.data import (
Dictionary,
IdDataset,
MaskTokensDataset,
NestedDictionaryDataset,
NumelDataset,
NumSamplesDataset,
PrependTokenDataset,
RightPadDataset,
SortDataset,
TokenBlockDataset,
data_utils,
)
from fairseq.data.encoders.utils import get_whole_word_mask
from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from .language_modeling import SAMPLE_BREAK_MODE_CHOICES, SHORTEN_METHOD_CHOICES
logger = logging.getLogger(__name__)
@dataclass
class MaskedLMConfig(FairseqDataclass):
data: str = field(
default=MISSING,
metadata={
"help": "colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner"
},
)
sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
default="none",
metadata={
"help": 'If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end '
"of sentence, but may include multiple sentences per sample. "
'"complete_doc" is similar but respects doc boundaries. '
'If set to "eos", includes only one sentence per sample.'
},
)
tokens_per_sample: int = field(
default=1024,
metadata={"help": "max number of tokens per sample for LM dataset"},
)
mask_prob: float = field(
default=0.15,
metadata={"help": "probability of replacing a token with mask"},
)
leave_unmasked_prob: float = field(
default=0.1,
metadata={"help": "probability that a masked token is unmasked"},
)
random_token_prob: float = field(
default=0.1,
metadata={"help": "probability of replacing a token with a random token"},
)
freq_weighted_replacement: bool = field(
default=False,
metadata={"help": "sample random replacement words based on word frequencies"},
)
mask_whole_words: bool = field(
default=False,
metadata={"help": "mask whole words; you may also want to set --bpe"},
)
mask_multiple_length: int = field(
default=1,
metadata={"help": "repeat the mask indices multiple times"},
)
mask_stdev: float = field(
default=0.0,
metadata={"help": "stdev of the mask length"},
)
shorten_method: SHORTEN_METHOD_CHOICES = field(
default="none",
metadata={
"help": "if not none, shorten sequences that exceed --tokens-per-sample"
},
)
shorten_data_split_list: str = field(
default="",
metadata={
"help": "comma-separated list of dataset splits to apply shortening to, "
'e.g., "train,valid" (default: all dataset splits)'
},
)
seed: int = II("common.seed")
include_target_tokens: bool = field(
default=False,
metadata={
"help": "include target tokens in model input. this is used for data2vec"
},
)
@register_task("masked_lm", dataclass=MaskedLMConfig)
class MaskedLMTask(FairseqTask):
cfg: MaskedLMConfig
"""Task for training masked language models (e.g., BERT, RoBERTa)."""
def __init__(self, cfg: MaskedLMConfig, dictionary):
super().__init__(cfg)
self.dictionary = dictionary
# add mask token
self.mask_idx = dictionary.add_symbol("<mask>")
@classmethod
def setup_task(cls, cfg: MaskedLMConfig, **kwargs):
paths = utils.split_paths(cfg.data)
assert len(paths) > 0
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
logger.info("dictionary: {} types".format(len(dictionary)))
return cls(cfg, dictionary)
def _load_dataset_split(self, split, epoch, combine):
paths = utils.split_paths(self.cfg.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)
dataset = data_utils.load_indexed_dataset(
split_path,
self.source_dictionary,
combine=combine,
)
if dataset is None:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, split_path)
)
dataset = maybe_shorten_dataset(
dataset,
split,
self.cfg.shorten_data_split_list,
self.cfg.shorten_method,
self.cfg.tokens_per_sample,
self.cfg.seed,
)
# create continuous blocks of tokens
dataset = TokenBlockDataset(
dataset,
dataset.sizes,
self.cfg.tokens_per_sample - 1, # one less for <s>
pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(),
break_mode=self.cfg.sample_break_mode,
)
logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
return PrependTokenDataset(dataset, self.source_dictionary.bos())
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
dataset = self._load_dataset_split(split, epoch, combine)
# create masked input and targets
mask_whole_words = (
get_whole_word_mask(self.args, self.source_dictionary)
if self.cfg.mask_whole_words
else None
)
src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
dataset,
self.source_dictionary,
pad_idx=self.source_dictionary.pad(),
mask_idx=self.mask_idx,
seed=self.cfg.seed,
mask_prob=self.cfg.mask_prob,
leave_unmasked_prob=self.cfg.leave_unmasked_prob,
random_token_prob=self.cfg.random_token_prob,
freq_weighted_replacement=self.cfg.freq_weighted_replacement,
mask_whole_words=mask_whole_words,
mask_multiple_length=self.cfg.mask_multiple_length,
mask_stdev=self.cfg.mask_stdev,
)
with data_utils.numpy_seed(self.cfg.seed):
shuffle = np.random.permutation(len(src_dataset))
target_dataset = RightPadDataset(
tgt_dataset,
pad_idx=self.source_dictionary.pad(),
)
input_dict = {
"src_tokens": RightPadDataset(
src_dataset,
pad_idx=self.source_dictionary.pad(),
),
"src_lengths": NumelDataset(src_dataset, reduce=False),
}
if self.cfg.include_target_tokens:
input_dict["target_tokens"] = target_dataset
self.datasets[split] = SortDataset(
NestedDictionaryDataset(
{
"id": IdDataset(),
"net_input": input_dict,
"target": target_dataset,
"nsentences": NumSamplesDataset(),
"ntokens": NumelDataset(src_dataset, reduce=True),
},
sizes=[src_dataset.sizes],
),
sort_order=[
shuffle,
src_dataset.sizes,
],
)
def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True):
src_dataset = RightPadDataset(
TokenBlockDataset(
src_tokens,
src_lengths,
self.cfg.tokens_per_sample - 1, # one less for <s>
pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(),
break_mode="eos",
),
pad_idx=self.source_dictionary.pad(),
)
src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos())
src_dataset = NestedDictionaryDataset(
{
"id": IdDataset(),
"net_input": {
"src_tokens": src_dataset,
"src_lengths": NumelDataset(src_dataset, reduce=False),
},
},
sizes=src_lengths,
)
if sort:
src_dataset = SortDataset(src_dataset, sort_order=[src_lengths])
return src_dataset
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary

View File

@@ -0,0 +1,268 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
from omegaconf import II
from fairseq.data import (
AppendTokenDataset,
ConcatDataset,
DenoisingDataset,
Dictionary,
PrependTokenDataset,
ResamplingDataset,
SortDataset,
TokenBlockDataset,
data_utils,
)
from fairseq.data.encoders.utils import get_whole_word_mask
from fairseq.tasks import register_task
from .denoising import DenoisingConfig, DenoisingTask
logger = logging.getLogger(__name__)
@dataclass
class MultilingualDenoisingConfig(DenoisingConfig):
multilang_sampling_alpha: float = field(
default=1.0,
metadata={"help": "smoothing alpha for sample ratios across multiple datasets"},
)
add_lang_token: bool = field(
default=False,
metadata={"help": ""},
)
langs: Optional[str] = field(
default=None,
metadata={"help": "language ids we are considering"},
)
no_whole_word_mask_langs: str = field(
default="",
metadata={
"help": "languages without spacing between words don't support whole word masking"
},
)
train_subset: str = II("common.train_subset")
valid_subset: str = II("common.valid_subset")
@register_task("multilingual_denoising", dataclass=MultilingualDenoisingConfig)
class MultilingualDenoisingTask(DenoisingTask):
cfg: MultilingualDenoisingConfig
@classmethod
def setup_task(cls, cfg: MultilingualDenoisingConfig, **kwargs):
"""Setup the task."""
paths = cfg.data.split(":")
assert len(paths) > 0
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
data_path = paths[0]
if cfg.langs is None:
languages = sorted(
[
name
for name in os.listdir(data_path)
if os.path.isdir(os.path.join(data_path, name))
]
)
else:
languages = cfg.langs.split(",")
if cfg.add_lang_token:
for lang in languages:
dictionary.add_symbol("[{}]".format(lang))
logger.info("dictionary: {} types".format(len(dictionary)))
if not hasattr(cfg, "shuffle_instance"):
cfg.shuffle_instance = False
return cls(cfg, dictionary)
def __init__(self, cfg: MultilingualDenoisingConfig, dictionary):
super().__init__(cfg, dictionary)
self.dictionary = dictionary
# add mask token
self.mask_idx = self.dictionary.add_symbol("<mask>")
self.cfg = cfg
def _get_sample_prob(self, dataset_lens):
"""
Get smoothed sampling probability by languages. This helps low resource
languages by upsampling them.
"""
prob = dataset_lens / dataset_lens.sum()
smoothed_prob = prob**self.cfg.multilang_sampling_alpha
smoothed_prob = smoothed_prob / smoothed_prob.sum()
return smoothed_prob
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = self.cfg.data.split(":")
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)
if self.cfg.langs is None:
languages = sorted(
[
name
for name in os.listdir(data_path)
if os.path.isdir(os.path.join(data_path, name))
]
)
else:
languages = self.cfg.langs.split(",")
for name in languages:
p = os.path.join(data_path, name)
assert os.path.exists(p), "data not found: {}".format(p)
logger.info("Training on {0} languages: {1}".format(len(languages), languages))
logger.info(
"Language to id mapping: ", {lang: id for id, lang in enumerate(languages)}
)
mask_whole_words = get_whole_word_mask(self.cfg.bpe, self.dictionary)
language_without_segmentations = self.cfg.no_whole_word_mask_langs.split(",")
lang_datasets = []
for language in languages:
split_path = os.path.join(data_path, language, split)
dataset = data_utils.load_indexed_dataset(
split_path,
self.source_dictionary,
self.cfg.dataset_impl,
combine=combine,
)
if dataset is None:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, split_path)
)
end_token = (
self.source_dictionary.index("[{}]".format(language))
if self.cfg.add_lang_token
else self.source_dictionary.eos()
)
# create continuous blocks of tokens
dataset = TokenBlockDataset(
dataset,
dataset.sizes,
self.cfg.tokens_per_sample - 2, # one less for <s>
pad=self.source_dictionary.pad(),
eos=end_token,
break_mode=self.cfg.sample_break_mode,
)
logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
dataset = AppendTokenDataset(dataset, end_token)
lang_mask_whole_words = (
mask_whole_words
if language not in language_without_segmentations
else None
)
lang_dataset = DenoisingDataset(
dataset,
dataset.sizes,
self.dictionary,
self.mask_idx,
lang_mask_whole_words,
shuffle=self.cfg.shuffle_instance,
seed=self.cfg.seed,
mask=self.cfg.mask,
mask_random=self.cfg.mask_random,
insert=self.cfg.insert,
rotate=self.cfg.rotate,
permute_sentences=self.cfg.permute_sentences,
bpe=self.cfg.bpe,
replace_length=self.cfg.replace_length,
mask_length=self.cfg.mask_length,
poisson_lambda=self.cfg.poisson_lambda,
eos=None
if not self.cfg.add_lang_token
else self.source_dictionary.index("[{}]".format(language)),
)
lang_datasets.append(lang_dataset)
dataset_lengths = np.array(
[len(d) for d in lang_datasets],
dtype=float,
)
logger.info(
"loaded total {} blocks for all languages".format(
int(dataset_lengths.sum()),
)
)
if split == self.cfg.train_subset:
# For train subset, additionally up or down sample languages.
sample_probs = self._get_sample_prob(dataset_lengths)
logger.info(
"Sample probability by language: {}".format(
{
lang: "{0:.4f}".format(sample_probs[id])
for id, lang in enumerate(languages)
}
)
)
size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths
logger.info(
"Up/Down Sampling ratio by language: {}".format(
{
lang: "{0:.2f}".format(size_ratio[id])
for id, lang in enumerate(languages)
}
)
)
resampled_lang_datasets = [
ResamplingDataset(
lang_datasets[i],
size_ratio=size_ratio[i],
seed=self.cfg.seed,
epoch=epoch,
replace=size_ratio[i] >= 1.0,
)
for i, d in enumerate(lang_datasets)
]
dataset = ConcatDataset(
resampled_lang_datasets,
)
else:
dataset = ConcatDataset(lang_datasets)
lang_splits = [split]
for lang_id, lang_dataset in enumerate(lang_datasets):
split_name = split + "_" + languages[lang_id]
lang_splits.append(split_name)
self.datasets[split_name] = lang_dataset
if split in self.cfg.valid_subset:
self.cfg.valid_subset = self.cfg.valid_subset.replace(
split, ",".join(lang_splits)
)
with data_utils.numpy_seed(self.cfg.seed + epoch):
shuffle = np.random.permutation(len(dataset))
self.datasets[split] = SortDataset(
dataset,
sort_order=[
shuffle,
dataset.sizes,
],
)

View File

@@ -0,0 +1,627 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import torch
from omegaconf import II
from fairseq import utils
from fairseq.data import (
AppendTokenDataset,
ConcatDataset,
Dictionary,
IdDataset,
LMContextWindowDataset,
MonolingualDataset,
NestedDictionaryDataset,
NumelDataset,
PadDataset,
PrependTokenDataset,
ResamplingDataset,
SortDataset,
StripTokenDataset,
TokenBlockDataset,
TruncatedDictionary,
data_utils,
)
from fairseq.data.indexed_dataset import get_available_dataset_impl
from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.tasks import LegacyFairseqTask, register_task
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
logger = logging.getLogger(__name__)
def lang_token(lang):
return f"<{lang}>"
@dataclass
class MultilingualLanguageModelingConfig(FairseqDataclass):
# TODO common var add to parent
data: Optional[str] = field(
default=None, metadata={"help": "path to data directory"}
)
sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
default="none",
metadata={
"help": 'If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end '
"of sentence, but may include multiple sentences per sample. "
'"complete_doc" is similar but respects doc boundaries. '
'If set to "eos", includes only one sentence per sample.'
},
)
tokens_per_sample: int = field(
default=1024,
metadata={"help": "max number of tokens per sample for LM dataset"},
)
output_dictionary_size: int = field(
default=-1, metadata={"help": "limit the size of output dictionary"}
)
self_target: bool = field(default=False, metadata={"help": "include self target"})
future_target: bool = field(
default=False, metadata={"help": "include future target"}
)
past_target: bool = field(default=False, metadata={"help": "include past target"})
add_bos_token: bool = field(
default=False, metadata={"help": "prepend lang id token <dialect>"}
)
max_source_positions: Optional[int] = field(
default=None, metadata={"help": "max number of tokens in the source sequence"}
)
max_target_positions: Optional[int] = field(
default=None, metadata={"help": "max number of tokens in the target sequence"}
)
pad_to_fixed_length: Optional[bool] = field(
default=False, metadata={"help": "pad to fixed length"}
)
pad_to_fixed_bsz: Optional[bool] = field(
default=False, metadata={"help": "boolean to pad to fixed batch size"}
)
multilang_sampling_alpha: Optional[float] = field(
default=1.0,
metadata={
"help": "smoothing alpha for sample rations across multiple datasets"
},
)
shorten_method: SHORTEN_METHOD_CHOICES = field(
default="none",
metadata={
"help": "if not none, shorten sequences that exceed --tokens-per-sample"
},
)
shorten_data_split_list: str = field(
default="",
metadata={
"help": "comma-separated list of dataset splits to apply shortening to, "
'e.g., "train,valid" (default: all dataset splits)'
},
)
langs: str = field(
default="",
metadata={
"help": "comma-separated list of languages (default: all directories in data path)"
},
)
baseline_model_langs: str = field(
default="",
metadata={
"help": "comma-separated list of languages in the baseline model (default: none)"
},
)
# TODO: legacy parameter kept for compatibility
baseline_model: str = field(
default="",
metadata={"help": "path to the baseline model (default: none)"},
)
lang_to_offline_shard_ratio: str = field(
default="",
metadata={
"help": "absolute path of tsv file location to indicate lang to offline shard ratio.",
},
)
# TODO common vars below add to parent
seed: int = II("common.seed")
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
"dataset.dataset_impl"
)
data_buffer_size: int = II("dataset.data_buffer_size")
tpu: bool = II("common.tpu")
batch_size: Optional[int] = II("dataset.batch_size")
batch_size_valid: Optional[int] = II("dataset.batch_size_valid")
train_subset: str = II("common.train_subset")
valid_subset: str = II("common.valid_subset")
@register_task(
"multilingual_language_modeling", dataclass=MultilingualLanguageModelingConfig
)
class MultilingualLanguageModelingTask(LegacyFairseqTask):
"""
Train a language model.
Args:
dictionary (~fairseq.data.Dictionary): the dictionary for the input of
the language model
output_dictionary (~fairseq.data.Dictionary): the dictionary for the
output of the language model. In most cases it will be the same as
*dictionary*, but could possibly be a more limited version of the
dictionary (if ``--output-dictionary-size`` is used).
targets (List[str]): list of the target types that the language model
should predict. Can be one of "self", "future", and "past".
Defaults to "future".
.. note::
The language modeling task is compatible with :mod:`fairseq-train`,
:mod:`fairseq-generate`, :mod:`fairseq-interactive` and
:mod:`fairseq-eval-lm`.
The language modeling task provides the following additional command-line
arguments:
.. argparse::
:ref: fairseq.tasks.language_modeling_parser
:prog:
"""
def __init__(self, args, dictionary, output_dictionary=None, targets=None):
super().__init__(args)
self.dictionary = dictionary
self.output_dictionary = output_dictionary or dictionary
if targets is None:
targets = ["future"]
self.targets = targets
@staticmethod
def _get_langs(args, epoch=1):
paths = utils.split_paths(args.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
languages = sorted(
name
for name in os.listdir(data_path)
if os.path.isdir(os.path.join(data_path, name))
)
if args.langs:
keep_langs = set(args.langs.split(","))
languages = [lang for lang in languages if lang in keep_langs]
assert len(languages) == len(keep_langs)
return languages, data_path
@classmethod
def setup_dictionary(cls, args, **kwargs):
dictionary = None
output_dictionary = None
if args.data:
paths = utils.split_paths(args.data)
assert len(paths) > 0
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
if args.add_bos_token:
languages, _ = cls._get_langs(args)
logger.info("----------------")
for lang in languages:
dictionary.add_symbol(lang_token(lang))
logger.info(f"add language token: {lang_token(lang)}")
logger.info("----------------")
logger.info("dictionary: {} types".format(len(dictionary)))
output_dictionary = dictionary
if args.output_dictionary_size >= 0:
output_dictionary = TruncatedDictionary(
dictionary, args.output_dictionary_size
)
return (dictionary, output_dictionary)
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
args (argparse.Namespace): parsed command-line arguments
"""
dictionary, output_dictionary = cls.setup_dictionary(args, **kwargs)
# upgrade old checkpoints
if hasattr(args, "exclude_self_target"):
args.self_target = not args.exclude_self_target
targets = []
if getattr(args, "self_target", False):
targets.append("self")
if getattr(args, "future_target", False):
targets.append("future")
if getattr(args, "past_target", False):
targets.append("past")
if len(targets) == 0:
# standard language modeling
targets = ["future"]
return cls(args, dictionary, output_dictionary, targets=targets)
def build_model(self, args, from_checkpoint=False):
model = super().build_model(args, from_checkpoint)
for target in self.targets:
if target not in model.supported_targets:
raise ValueError(
f"Unsupported language modeling target: {target} not in {model.supported_targets}"
)
return model
def _get_sample_prob(self, dataset_lens):
"""
Get smoothed sampling porbability by languages. This helps low resource
languages by upsampling them.
"""
prob = dataset_lens / dataset_lens.sum()
smoothed_prob = prob**self.args.multilang_sampling_alpha
smoothed_prob = smoothed_prob / smoothed_prob.sum()
return smoothed_prob
def load_dataset(self, split: str, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
languages, data_path = MultilingualLanguageModelingTask._get_langs(
self.args, epoch
)
lang_to_offline_shard_ratio = None
if self.args.lang_to_offline_shard_ratio != "":
lang_to_offline_shard_ratio = {}
assert os.path.exists(
self.args.lang_to_offline_shard_ratio
), "provided offline shard ratio file doesn't exist: {0}".format(
self.args.lang_to_offline_shard_ratio
)
with open(self.args.lang_to_offline_shard_ratio) as fin:
for line in fin:
lang, ratio = line.strip().split("\t")
ratio = float(ratio)
lang_to_offline_shard_ratio[lang] = ratio
logger.info(
"Found offline sharded ratio: %s",
lang_to_offline_shard_ratio,
)
if split == self.args.train_subset:
logger.info(
"Training on {0} languages: {1}".format(len(languages), languages)
)
else:
logger.info(
"Evaluating on {0} languages: {1}".format(len(languages), languages)
)
tokens_per_sample = self.args.tokens_per_sample - int(self.args.add_bos_token)
fixed_pad_length = None
if self.args.pad_to_fixed_length:
fixed_pad_length = self.args.tokens_per_sample
pad_to_bsz = None
if self.args.pad_to_fixed_bsz:
pad_to_bsz = (
self.args.batch_size_valid if "valid" in split else self.args.batch_size
)
lang_datasets = []
for lang_id, language in enumerate(languages):
split_path = os.path.join(data_path, language, split)
dataset = data_utils.load_indexed_dataset(
split_path, self.dictionary, self.args.dataset_impl, combine=combine
)
# print('len(dataset) =', len(dataset))
if dataset is None:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, split_path)
)
dataset = maybe_shorten_dataset(
dataset,
split,
self.args.shorten_data_split_list,
self.args.shorten_method,
tokens_per_sample,
self.args.seed,
)
dataset = TokenBlockDataset(
dataset,
dataset.sizes,
tokens_per_sample,
pad=self.dictionary.pad(),
eos=self.dictionary.eos(),
break_mode=self.args.sample_break_mode,
include_targets=True,
)
add_eos_for_other_targets = (
self.args.sample_break_mode is not None
and self.args.sample_break_mode != "none"
)
src_lang_idx, tgt_lang_idx = None, None
if self.args.add_bos_token:
src_lang_idx = self.dictionary.index(lang_token(language))
tgt_lang_idx = self.output_dictionary.index(lang_token(language))
lang_datasets.append(
MonolingualDataset(
dataset=dataset,
sizes=dataset.sizes,
src_vocab=self.dictionary,
tgt_vocab=self.output_dictionary,
add_eos_for_other_targets=add_eos_for_other_targets,
shuffle=True,
targets=self.targets,
fixed_pad_length=fixed_pad_length,
pad_to_bsz=pad_to_bsz,
add_bos_token=self.args.add_bos_token,
src_lang_idx=src_lang_idx,
tgt_lang_idx=tgt_lang_idx,
)
)
dataset_lengths = np.array(
[len(d) for d in lang_datasets],
dtype=float,
)
logger.info(
"loaded total {} blocks for all languages".format(
dataset_lengths.sum(),
)
)
if split == self.args.train_subset:
dataset_lengths_ratio_multiplier = np.ones(len(dataset_lengths))
if lang_to_offline_shard_ratio is not None:
dataset_lengths_ratio_multiplier = []
for lang in languages:
assert (
lang in lang_to_offline_shard_ratio
), "Lang: {0} missing in offline shard ratio file: {1}".format(
lang,
self.args.lang_to_offline_shard_ratio,
)
dataset_lengths_ratio_multiplier.append(
lang_to_offline_shard_ratio[lang]
)
dataset_lengths_ratio_multiplier = np.array(
dataset_lengths_ratio_multiplier
)
true_dataset_lengths = (
dataset_lengths * dataset_lengths_ratio_multiplier
)
else:
true_dataset_lengths = dataset_lengths
# For train subset, additionally up or down sample languages.
sample_probs = self._get_sample_prob(true_dataset_lengths)
logger.info(
"Sample probability by language: %s",
{
lang: "{0:.4f}".format(sample_probs[id])
for id, lang in enumerate(languages)
},
)
size_ratio = (sample_probs * true_dataset_lengths.sum()) / dataset_lengths
# TODO: add an option for shrinking all size ratios to below 1
# if self.args.multilang_sampling_alpha != 1:
# size_ratio /= size_ratio.max()
# Fix numeric errors in size ratio computation
# 0.999999999999999999 -> 1
# 1.000000000000000002 -> 1
for i in range(len(size_ratio)):
size_ratio[i] = round(size_ratio[i], 8)
logger.info(
"Up/Down Sampling ratio by language: %s",
{
lang: "{0:.2f}".format(size_ratio[id])
for id, lang in enumerate(languages)
},
)
logger.info(
"Actual dataset size by language: %s",
{
lang: "{0:.2f}".format(len(lang_datasets[id]))
for id, lang in enumerate(languages)
},
)
resampled_lang_datasets = [
ResamplingDataset(
lang_datasets[i],
size_ratio=size_ratio[i],
seed=self.args.seed,
epoch=epoch,
replace=size_ratio[i] > 1.0,
)
for i, d in enumerate(lang_datasets)
]
logger.info(
"Resampled dataset size by language: %s",
{
lang: "{0:.2f}".format(len(resampled_lang_datasets[id]))
for id, lang in enumerate(languages)
},
)
dataset = ConcatDataset(resampled_lang_datasets)
else:
dataset = ConcatDataset(lang_datasets)
lang_splits = [split]
for lang_id, lang_dataset in enumerate(lang_datasets):
split_name = split + "_" + languages[lang_id]
lang_splits.append(split_name)
self.datasets[split_name] = lang_dataset
# [TODO]: This is hacky for now to print validation ppl for each
# language individually. Maybe need task API changes to allow it
# in more generic ways.
if split in self.args.valid_subset:
self.args.valid_subset = self.args.valid_subset.replace(
split, ",".join(lang_splits)
)
with data_utils.numpy_seed(self.args.seed + epoch):
shuffle = np.random.permutation(len(dataset))
self.datasets[split] = SortDataset(
dataset,
sort_order=[
shuffle,
dataset.sizes,
],
)
def build_dataset_for_inference(
self, src_tokens, src_lengths, language="en_XX", **kwargs
):
"""
Generate batches for inference. We prepend an eos token to src_tokens
(or bos if `--add-bos-token` is set) and we append a <pad> to target.
This is convenient both for generation with a prefix and LM scoring.
"""
dataset = StripTokenDataset(
TokenBlockDataset(
src_tokens,
src_lengths,
block_size=None, # ignored for "eos" break mode
pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(),
break_mode="eos",
),
# remove eos from (end of) target sequence
self.source_dictionary.eos(),
)
src_lang_idx = self.dictionary.index(lang_token(language))
src_dataset = PrependTokenDataset(
dataset,
token=(
(src_lang_idx or self.source_dictionary.bos())
if getattr(self.args, "add_bos_token", False)
else self.source_dictionary.eos()
),
)
max_seq_len = max(src_lengths) + 1
tgt_dataset = AppendTokenDataset(dataset, token=self.source_dictionary.pad())
return NestedDictionaryDataset(
{
"id": IdDataset(),
"net_input": {
"src_tokens": PadDataset(
src_dataset,
pad_idx=self.source_dictionary.pad(),
left_pad=False,
pad_length=max_seq_len,
),
"src_lengths": NumelDataset(src_dataset, reduce=False),
},
"target": PadDataset(
tgt_dataset,
pad_idx=self.source_dictionary.pad(),
left_pad=False,
pad_length=max_seq_len,
),
},
sizes=[np.array(src_lengths)],
)
@torch.no_grad()
def inference_step(
self,
generator,
models,
sample,
language="en_XX",
prefix_tokens=None,
constraints=None,
):
# Generation will always be conditioned on bos_token
if getattr(self.args, "add_bos_token", False):
src_lang_idx = self.dictionary.index(lang_token(language))
bos_token = src_lang_idx or self.source_dictionary.bos()
else:
bos_token = self.source_dictionary.eos()
if constraints is not None:
raise NotImplementedError(
"Constrained decoding with the language_modeling task is not supported"
)
# SequenceGenerator doesn't use src_tokens directly, we need to
# pass the `prefix_tokens` argument instead
if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():
prefix_tokens = sample["net_input"]["src_tokens"]
if prefix_tokens[:, 0].eq(bos_token).all():
prefix_tokens = prefix_tokens[:, 1:]
return generator.generate(
models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token
)
def eval_lm_dataloader(
self,
dataset,
max_tokens: Optional[int] = 36000,
batch_size: Optional[int] = None,
max_positions: Optional[int] = None,
num_shards: int = 1,
shard_id: int = 0,
num_workers: int = 1,
data_buffer_size: int = 10,
# ensures that every evaluated token has access to a context of at least
# this size, if possible
context_window: int = 0,
):
if context_window > 0:
dataset = LMContextWindowDataset(
dataset=dataset,
tokens_per_sample=self.args.tokens_per_sample,
context_window=context_window,
pad_idx=self.source_dictionary.pad(),
)
return self.get_batch_iterator(
dataset=dataset,
max_tokens=max_tokens,
max_sentences=batch_size,
max_positions=max_positions,
ignore_invalid_inputs=True,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
data_buffer_size=data_buffer_size,
)
@property
def source_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.dictionary
@property
def target_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.output_dictionary

View File

@@ -0,0 +1,338 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import numpy as np
import torch
from fairseq import utils
from fairseq.data import (
ConcatDataset,
Dictionary,
IdDataset,
MaskTokensDataset,
NestedDictionaryDataset,
NumelDataset,
NumSamplesDataset,
PadDataset,
PrependTokenDataset,
RawLabelDataset,
ResamplingDataset,
SortDataset,
TokenBlockDataset,
data_utils,
encoders,
)
from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task("multilingual_masked_lm")
class MultiLingualMaskedLMTask(LegacyFairseqTask):
"""Task for training masked language models (e.g., BERT, RoBERTa)."""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument(
"data",
help="colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner",
)
parser.add_argument(
"--sample-break-mode",
default="complete",
choices=["none", "complete", "complete_doc", "eos"],
help='If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end '
"of sentence, but may include multiple sentences per sample. "
'"complete_doc" is similar but respects doc boundaries. '
'If set to "eos", includes only one sentence per sample.',
)
parser.add_argument(
"--tokens-per-sample",
default=512,
type=int,
help="max number of total tokens over all segments "
"per sample for BERT dataset",
)
parser.add_argument(
"--mask-prob",
default=0.15,
type=float,
help="probability of replacing a token with mask",
)
parser.add_argument(
"--leave-unmasked-prob",
default=0.1,
type=float,
help="probability that a masked token is unmasked",
)
parser.add_argument(
"--random-token-prob",
default=0.1,
type=float,
help="probability of replacing a token with a random token",
)
parser.add_argument(
"--freq-weighted-replacement",
action="store_true",
help="sample random replacement words based on word frequencies",
)
parser.add_argument(
"--mask-whole-words",
default=False,
action="store_true",
help="mask whole words; you may also want to set --bpe",
)
parser.add_argument(
"--multilang-sampling-alpha",
type=float,
default=1.0,
help="smoothing alpha for sample rations across multiple datasets",
)
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
self.seed = args.seed
# add mask token
self.mask_idx = dictionary.add_symbol("<mask>")
@classmethod
def setup_task(cls, args, **kwargs):
paths = utils.split_paths(args.data)
assert len(paths) > 0
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
logger.info("dictionary: {} types".format(len(dictionary)))
return cls(args, dictionary)
def _get_whole_word_mask(self):
# create masked input and targets
if self.args.mask_whole_words:
bpe = encoders.build_bpe(self.args)
if bpe is not None:
def is_beginning_of_word(i):
if i < self.source_dictionary.nspecial:
# special elements are always considered beginnings
return True
tok = self.source_dictionary[i]
if tok.startswith("madeupword"):
return True
try:
return bpe.is_beginning_of_word(tok)
except ValueError:
return True
mask_whole_words = torch.ByteTensor(
list(map(is_beginning_of_word, range(len(self.source_dictionary))))
)
else:
mask_whole_words = None
return mask_whole_words
def _get_sample_prob(self, dataset_lens):
"""
Get smoothed sampling porbability by languages. This helps low resource
languages by upsampling them.
"""
prob = dataset_lens / dataset_lens.sum()
smoothed_prob = prob**self.args.multilang_sampling_alpha
smoothed_prob = smoothed_prob / smoothed_prob.sum()
return smoothed_prob
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
languages = sorted(
name
for name in os.listdir(data_path)
if os.path.isdir(os.path.join(data_path, name))
)
logger.info("Training on {0} languages: {1}".format(len(languages), languages))
logger.info(
"Language to id mapping: ", {lang: id for id, lang in enumerate(languages)}
)
mask_whole_words = self._get_whole_word_mask()
lang_datasets = []
for lang_id, language in enumerate(languages):
split_path = os.path.join(data_path, language, split)
dataset = data_utils.load_indexed_dataset(
split_path,
self.source_dictionary,
self.args.dataset_impl,
combine=combine,
)
if dataset is None:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, split_path)
)
# create continuous blocks of tokens
dataset = TokenBlockDataset(
dataset,
dataset.sizes,
self.args.tokens_per_sample - 1, # one less for <s>
pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(),
break_mode=self.args.sample_break_mode,
)
logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
dataset,
self.source_dictionary,
pad_idx=self.source_dictionary.pad(),
mask_idx=self.mask_idx,
seed=self.args.seed,
mask_prob=self.args.mask_prob,
leave_unmasked_prob=self.args.leave_unmasked_prob,
random_token_prob=self.args.random_token_prob,
freq_weighted_replacement=self.args.freq_weighted_replacement,
mask_whole_words=mask_whole_words,
)
lang_dataset = NestedDictionaryDataset(
{
"net_input": {
"src_tokens": PadDataset(
src_dataset,
pad_idx=self.source_dictionary.pad(),
left_pad=False,
),
"src_lengths": NumelDataset(src_dataset, reduce=False),
},
"target": PadDataset(
tgt_dataset,
pad_idx=self.source_dictionary.pad(),
left_pad=False,
),
"nsentences": NumSamplesDataset(),
"ntokens": NumelDataset(src_dataset, reduce=True),
"lang_id": RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]),
},
sizes=[src_dataset.sizes],
)
lang_datasets.append(lang_dataset)
dataset_lengths = np.array(
[len(d) for d in lang_datasets],
dtype=float,
)
logger.info(
"loaded total {} blocks for all languages".format(
dataset_lengths.sum(),
)
)
if split == self.args.train_subset:
# For train subset, additionally up or down sample languages.
sample_probs = self._get_sample_prob(dataset_lengths)
logger.info(
"Sample probability by language: ",
{
lang: "{0:.4f}".format(sample_probs[id])
for id, lang in enumerate(languages)
},
)
size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths
logger.info(
"Up/Down Sampling ratio by language: ",
{
lang: "{0:.2f}".format(size_ratio[id])
for id, lang in enumerate(languages)
},
)
resampled_lang_datasets = [
ResamplingDataset(
lang_datasets[i],
size_ratio=size_ratio[i],
seed=self.args.seed,
epoch=epoch,
replace=size_ratio[i] >= 1.0,
)
for i, d in enumerate(lang_datasets)
]
dataset = ConcatDataset(resampled_lang_datasets)
else:
dataset = ConcatDataset(lang_datasets)
lang_splits = [split]
for lang_id, lang_dataset in enumerate(lang_datasets):
split_name = split + "_" + languages[lang_id]
lang_splits.append(split_name)
self.datasets[split_name] = lang_dataset
# [TODO]: This is hacky for now to print validation ppl for each
# language individually. Maybe need task API changes to allow it
# in more generic ways.
if split in self.args.valid_subset:
self.args.valid_subset = self.args.valid_subset.replace(
split, ",".join(lang_splits)
)
with data_utils.numpy_seed(self.args.seed + epoch):
shuffle = np.random.permutation(len(dataset))
self.datasets[split] = SortDataset(
dataset,
sort_order=[
shuffle,
dataset.sizes,
],
)
def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True):
src_dataset = PadDataset(
TokenBlockDataset(
src_tokens,
src_lengths,
self.args.tokens_per_sample - 1, # one less for <s>
pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(),
break_mode="eos",
),
pad_idx=self.source_dictionary.pad(),
left_pad=False,
)
src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos())
src_dataset = NestedDictionaryDataset(
{
"id": IdDataset(),
"net_input": {
"src_tokens": src_dataset,
"src_lengths": NumelDataset(src_dataset, reduce=False),
},
},
sizes=src_lengths,
)
if sort:
src_dataset = SortDataset(src_dataset, sort_order=[src_lengths])
return src_dataset
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary

View File

@@ -0,0 +1,462 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import logging
import os
from collections import OrderedDict
from argparse import ArgumentError
import torch
from fairseq import metrics, options, utils
from fairseq.data import (
Dictionary,
LanguagePairDataset,
RoundRobinZipDatasets,
TransformEosLangPairDataset,
)
from fairseq.models import FairseqMultiModel
from fairseq.tasks.translation import load_langpair_dataset
from . import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
def _lang_token(lang: str):
return "__{}__".format(lang)
def _lang_token_index(dic: Dictionary, lang: str):
"""Return language token index."""
idx = dic.index(_lang_token(lang))
assert idx != dic.unk_index, "cannot find language token for lang {}".format(lang)
return idx
@register_task("multilingual_translation")
class MultilingualTranslationTask(LegacyFairseqTask):
"""A task for training multiple translation models simultaneously.
We iterate round-robin over batches from multiple language pairs, ordered
according to the `--lang-pairs` argument.
The training loop is roughly:
for i in range(len(epoch)):
for lang_pair in args.lang_pairs:
batch = next_batch_for_lang_pair(lang_pair)
loss = criterion(model_for_lang_pair(lang_pair), batch)
loss.backward()
optimizer.step()
In practice, `next_batch_for_lang_pair` is abstracted in a FairseqDataset
(e.g., `RoundRobinZipDatasets`) and `model_for_lang_pair` is a model that
implements the `FairseqMultiModel` interface.
During inference it is required to specify a single `--source-lang` and
`--target-lang`, which indicates the inference langauge direction.
`--lang-pairs`, `--encoder-langtok`, `--decoder-langtok` have to be set to
the same value as training.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
parser.add_argument('data', metavar='DIR', help='path to data directory')
parser.add_argument('--lang-pairs', default=None, metavar='PAIRS',
help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language (only needed for inference)')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language (only needed for inference)')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
help='pad the source on the left (default: True)')
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
help='pad the target on the left (default: False)')
try:
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
except ArgumentError:
# this might have already been defined. Once we transition this to hydra it should be fine to add it here.
pass
parser.add_argument('--upsample-primary', default=1, type=int,
help='amount to upsample primary dataset')
parser.add_argument('--encoder-langtok', default=None, type=str, choices=['src', 'tgt'],
metavar='SRCTGT',
help='replace beginning-of-sentence in source sentence with source or target '
'language token. (src/tgt)')
parser.add_argument('--decoder-langtok', action='store_true',
help='replace beginning-of-sentence in target sentence with target language token')
# fmt: on
def __init__(self, args, dicts, training):
super().__init__(args)
self.dicts = dicts
self.training = training
if training:
self.lang_pairs = args.lang_pairs
else:
self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)]
# eval_lang_pairs for multilingual translation is usually all of the
# lang_pairs. However for other multitask settings or when we want to
# optimize for certain languages we want to use a different subset. Thus
# the eval_lang_pairs class variable is provided for classes that extend
# this class.
self.eval_lang_pairs = self.lang_pairs
# model_lang_pairs will be used to build encoder-decoder model pairs in
# models.build_model(). This allows multitask type of sub-class can
# build models other than the input lang_pairs
self.model_lang_pairs = self.lang_pairs
self.langs = list(dicts.keys())
@classmethod
def setup_task(cls, args, **kwargs):
dicts, training = cls.prepare(args, **kwargs)
return cls(args, dicts, training)
@classmethod
def update_args(cls, args):
args.left_pad_source = utils.eval_bool(args.left_pad_source)
args.left_pad_target = utils.eval_bool(args.left_pad_target)
if args.lang_pairs is None:
raise ValueError(
"--lang-pairs is required. List all the language pairs in the training objective."
)
if isinstance(args.lang_pairs, str):
args.lang_pairs = args.lang_pairs.split(",")
@classmethod
def prepare(cls, args, **kargs):
cls.update_args(args)
sorted_langs = sorted(
list({x for lang_pair in args.lang_pairs for x in lang_pair.split("-")})
)
if args.source_lang is not None or args.target_lang is not None:
training = False
else:
training = True
# load dictionaries
dicts = OrderedDict()
for lang in sorted_langs:
paths = utils.split_paths(args.data)
assert len(paths) > 0
dicts[lang] = cls.load_dictionary(
os.path.join(paths[0], "dict.{}.txt".format(lang))
)
if len(dicts) > 0:
assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
assert dicts[lang].unk() == dicts[sorted_langs[0]].unk()
if args.encoder_langtok is not None or args.decoder_langtok:
for lang_to_add in sorted_langs:
dicts[lang].add_symbol(_lang_token(lang_to_add))
logger.info("[{}] dictionary: {} types".format(lang, len(dicts[lang])))
return dicts, training
def get_encoder_langtok(self, src_lang, tgt_lang):
if self.args.encoder_langtok is None:
return self.dicts[src_lang].eos()
if self.args.encoder_langtok == "src":
return _lang_token_index(self.dicts[src_lang], src_lang)
else:
return _lang_token_index(self.dicts[src_lang], tgt_lang)
def get_decoder_langtok(self, tgt_lang):
if not self.args.decoder_langtok:
return self.dicts[tgt_lang].eos()
return _lang_token_index(self.dicts[tgt_lang], tgt_lang)
def alter_dataset_langtok(
self,
lang_pair_dataset,
src_eos=None,
src_lang=None,
tgt_eos=None,
tgt_lang=None,
):
if self.args.encoder_langtok is None and not self.args.decoder_langtok:
return lang_pair_dataset
new_src_eos = None
if (
self.args.encoder_langtok is not None
and src_eos is not None
and src_lang is not None
and tgt_lang is not None
):
new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang)
else:
src_eos = None
new_tgt_bos = None
if self.args.decoder_langtok and tgt_eos is not None and tgt_lang is not None:
new_tgt_bos = self.get_decoder_langtok(tgt_lang)
else:
tgt_eos = None
return TransformEosLangPairDataset(
lang_pair_dataset,
src_eos=src_eos,
new_src_eos=new_src_eos,
tgt_bos=tgt_eos,
new_tgt_bos=new_tgt_bos,
)
def load_dataset(self, split, epoch=1, **kwargs):
"""Load a dataset split."""
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
def language_pair_dataset(lang_pair):
src, tgt = lang_pair.split("-")
langpair_dataset = load_langpair_dataset(
data_path,
split,
src,
self.dicts[src],
tgt,
self.dicts[tgt],
combine=True,
dataset_impl=self.args.dataset_impl,
upsample_primary=self.args.upsample_primary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
)
return self.alter_dataset_langtok(
langpair_dataset,
src_eos=self.dicts[src].eos(),
src_lang=src,
tgt_eos=self.dicts[tgt].eos(),
tgt_lang=tgt,
)
self.datasets[split] = RoundRobinZipDatasets(
OrderedDict(
[
(lang_pair, language_pair_dataset(lang_pair))
for lang_pair in self.lang_pairs
]
),
eval_key=None
if self.training
else "%s-%s" % (self.args.source_lang, self.args.target_lang),
)
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
if constraints is not None:
raise NotImplementedError(
"Constrained decoding with the multilingual_translation task is not supported"
)
lang_pair = "%s-%s" % (self.args.source_lang, self.args.target_lang)
return RoundRobinZipDatasets(
OrderedDict(
[
(
lang_pair,
self.alter_dataset_langtok(
LanguagePairDataset(
src_tokens, src_lengths, self.source_dictionary
),
src_eos=self.source_dictionary.eos(),
src_lang=self.args.source_lang,
tgt_eos=self.target_dictionary.eos(),
tgt_lang=self.args.target_lang,
),
)
]
),
eval_key=lang_pair,
)
def build_model(self, args, from_checkpoint=False):
def check_args():
messages = []
if (
len(set(self.args.lang_pairs).symmetric_difference(args.lang_pairs))
!= 0
):
messages.append(
"--lang-pairs should include all the language pairs {}.".format(
args.lang_pairs
)
)
if self.args.encoder_langtok != args.encoder_langtok:
messages.append(
"--encoder-langtok should be {}.".format(args.encoder_langtok)
)
if self.args.decoder_langtok != args.decoder_langtok:
messages.append(
"--decoder-langtok should {} be set.".format(
"" if args.decoder_langtok else "not"
)
)
if len(messages) > 0:
raise ValueError(" ".join(messages))
# Update args -> the fact that the constructor here
# changes the args object doesn't mean you get the same one here
self.update_args(args)
# Check if task args are consistant with model args
check_args()
from fairseq import models
model = models.build_model(args, self, from_checkpoint)
if not isinstance(model, FairseqMultiModel):
raise ValueError(
"MultilingualTranslationTask requires a FairseqMultiModel architecture"
)
return model
def _per_lang_pair_train_loss(
self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad
):
loss, sample_size, logging_output = criterion(
model.models[lang_pair], sample[lang_pair]
)
if ignore_grad:
loss *= 0
optimizer.backward(loss)
return loss, sample_size, logging_output
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
model.train()
from collections import defaultdict
agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, defaultdict(float)
curr_lang_pairs = [
lang_pair
for lang_pair in self.model_lang_pairs
if sample[lang_pair] is not None and len(sample[lang_pair]) != 0
]
for idx, lang_pair in enumerate(curr_lang_pairs):
def maybe_no_sync():
if (
self.args.distributed_world_size > 1
and hasattr(model, "no_sync")
and idx < len(curr_lang_pairs) - 1
):
return model.no_sync()
else:
return contextlib.ExitStack() # dummy contextmanager
with maybe_no_sync():
loss, sample_size, logging_output = self._per_lang_pair_train_loss(
lang_pair,
model,
update_num,
criterion,
sample,
optimizer,
ignore_grad,
)
agg_loss += loss.detach().item()
# TODO make summing of the sample sizes configurable
agg_sample_size += sample_size
for k in logging_output:
agg_logging_output[k] += logging_output[k]
agg_logging_output[f"{lang_pair}:{k}"] += logging_output[k]
return agg_loss, agg_sample_size, agg_logging_output
def _per_lang_pair_valid_loss(self, lang_pair, model, criterion, sample):
return criterion(model.models[lang_pair], sample[lang_pair])
def valid_step(self, sample, model, criterion):
model.eval()
with torch.no_grad():
from collections import defaultdict
agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, defaultdict(float)
for lang_pair in self.eval_lang_pairs:
if (
lang_pair not in sample
or sample[lang_pair] is None
or len(sample[lang_pair]) == 0
):
continue
loss, sample_size, logging_output = self._per_lang_pair_valid_loss(
lang_pair, model, criterion, sample
)
agg_loss += loss.data.item()
# TODO make summing of the sample sizes configurable
agg_sample_size += sample_size
for k in logging_output:
agg_logging_output[k] += logging_output[k]
agg_logging_output[f"{lang_pair}:{k}"] += logging_output[k]
return agg_loss, agg_sample_size, agg_logging_output
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
with torch.no_grad():
if self.args.decoder_langtok:
bos_token = _lang_token_index(
self.target_dictionary, self.args.target_lang
)
else:
bos_token = self.target_dictionary.eos()
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
constraints=constraints,
bos_token=bos_token,
)
def reduce_metrics(self, logging_outputs, criterion):
with metrics.aggregate():
# pass 'sample_size', 'nsentences', 'ntokens' stats to fairseq_task
super().reduce_metrics(logging_outputs, criterion)
for k in ["sample_size", "nsentences", "ntokens"]:
metrics.log_scalar(k, sum(l[k] for l in logging_outputs))
@property
def source_dictionary(self):
if self.training:
return next(iter(self.dicts.values()))
else:
return self.dicts[self.args.source_lang]
@property
def target_dictionary(self):
if self.training:
return next(iter(self.dicts.values()))
else:
return self.dicts[self.args.target_lang]
def max_positions(self):
"""Return the max sentence length allowed by the task."""
if len(self.datasets.values()) == 0:
return {
"%s-%s"
% (self.args.source_lang, self.args.target_lang): (
self.args.max_source_positions,
self.args.max_target_positions,
)
}
return OrderedDict(
[
(key, (self.args.max_source_positions, self.args.max_target_positions))
for split in self.datasets.keys()
for key in self.datasets[split].datasets.keys()
]
)

View File

@@ -0,0 +1,477 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
import os
import torch
import json
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Optional, Any
from fairseq.data import AddTargetDataset, Dictionary, encoders
from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.configs import GenerationConfig
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
from . import register_task
from .. import utils
from ..logging import metrics
logger = logging.getLogger(__name__)
class LabelEncoder(object):
def __init__(self, dictionary):
self.dictionary = dictionary
def __call__(self, label):
return self.dictionary.encode_line(
label, append_eos=False, add_if_not_exist=False
)
def label_len_fn(label):
return len(label.split(" "))
@dataclass
class NLUFinetuningConfig(AudioPretrainingConfig):
# Options for reporting WER metrics during validation. Only applicable to
# Seq2Seq models during fine-tuning
eval_wer: bool = field(
default=False, metadata={"help": "compute WER for Seq2Seq models"}
)
eval_wer_parse: bool = field(
default=False, metadata={"help": "compute WER for Seq2Seq models"}
)
eval_wer_config: GenerationConfig = field(
default_factory=lambda: GenerationConfig(),
metadata={"help": "beam search config for evaluating wer during training"},
)
eval_wer_tokenizer: Any = field(
default=None,
metadata={"help": "tokenizer config for evaluating wer during training"},
)
eval_wer_post_process: str = field(
default="letter",
metadata={
"help": "remove BPE tokens before scoring (can be sentencepiece, letter, and more)"
},
)
eval_bleu: bool = field(
default=False, metadata={"help": "evaluation with BLEU scores"}
)
eval_bleu_detok: Optional[str] = field(
default=None,
metadata={
"help": "detokenize before computing BLEU (e.g., 'moses'); "
"required if using --eval-bleu; use 'space' to disable "
"detokenization; see fairseq.data.encoders for other options"
},
)
eval_bleu_detok_args: str = field(
default="{}", metadata={"help": "args for building the tokenizer, if needed"}
)
eval_tokenized_bleu: bool = field(
default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"}
)
eval_bleu_remove_bpe: Optional[str] = field(
default=None, metadata={"help": "remove BPE before computing BLEU"}
)
eval_bleu_args: str = field(
default="{}",
metadata={
"help": "generation args for BLUE scoring, e.g., "
'\'{"beam": 4, "lenpen": 0.6}\''
},
)
eval_bleu_print_samples: bool = field(
default=False, metadata={"help": "print sample generations during validation"}
)
autoregressive: bool = field(
default=False,
metadata={
"help": "required for autoregressive decoders (like seq2seq models); "
"adds 'prev_output_tokens' to input and appends eos to target"
},
)
@register_task("nlu_finetuning", dataclass=NLUFinetuningConfig)
class NLUFinetuningTask(AudioPretrainingTask):
""" """
cfg: NLUFinetuningConfig
def __init__(
self,
cfg: NLUFinetuningConfig,
):
super().__init__(cfg)
self.blank_symbol = "<s>"
self.state.add_factory("target_dictionary", self.load_target_dictionary)
def load_target_dictionary(self):
if self.cfg.labels:
dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt")
return Dictionary.load(dict_path)
return None
def load_dataset(self, split: str, task_cfg: NLUFinetuningConfig = None, **kwargs):
super().load_dataset(split, task_cfg, **kwargs)
task_cfg = task_cfg or self.cfg
assert task_cfg.labels is not None
text_compression_level = getattr(
TextCompressionLevel, str(self.cfg.text_compression_level)
)
data_path = self.cfg.data
label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
skipped_indices = getattr(self.datasets[split], "skipped_indices", set())
text_compressor = TextCompressor(level=text_compression_level)
with open(label_path, "r") as f:
labels = [
text_compressor.compress(l)
for i, l in enumerate(f)
if i not in skipped_indices
]
assert len(labels) == len(self.datasets[split]), (
f"labels length ({len(labels)}) and dataset length "
f"({len(self.datasets[split])}) do not match"
)
process_label = LabelEncoder(self.target_dictionary)
self.datasets[split] = AddTargetDataset(
self.datasets[split],
labels,
pad=self.target_dictionary.pad(),
eos=self.target_dictionary.eos(),
batch_targets=True,
process_label=process_label,
label_len_fn=label_len_fn,
add_to_input=task_cfg.get("autoregressive", False),
text_compression_level=text_compression_level,
)
@property
def target_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.state.target_dictionary
def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
if self.cfg.eval_wer_parse and self.cfg.autoregressive:
metrics = self._inference_with_wer_parse(
self.sequence_generator, sample, model
)
logging_output["_num_char_errors"] = metrics["num_char_errors"]
logging_output["_num_chars"] = metrics["num_chars"]
logging_output["_num_word_errors"] = metrics["num_word_errors"]
logging_output["_num_words"] = metrics["num_words"]
logging_output["_num_em_errors"] = metrics["num_em_errors"]
logging_output["_num_ems"] = metrics["num_ems"]
logging_output["_num_tree_errors"] = metrics["num_tree_errors"]
logging_output["_num_trees"] = metrics["num_trees"]
if self.cfg.eval_wer and self.cfg.autoregressive:
metrics = self._inference_with_wer(self.sequence_generator, sample, model)
logging_output["_num_char_errors"] = metrics["num_char_errors"]
logging_output["_num_chars"] = metrics["num_chars"]
logging_output["_num_word_errors"] = metrics["num_word_errors"]
logging_output["_num_words"] = metrics["num_words"]
if self.cfg.eval_bleu and self.cfg.autoregressive:
metrics = self._inference_with_bleu(self.sequence_generator, sample, model)
logging_output["_bleu_sys_len"] = metrics.sys_len
logging_output["_bleu_ref_len"] = metrics.ref_len
# we split counts into separate entries so that they can be
# summed efficiently across workers using fast-stat-sync
assert len(metrics.counts) == 4
for i in range(4):
logging_output[f"_bleu_counts_{i}"] = metrics.counts[i]
logging_output[f"_bleu_totals_{i}"] = metrics.totals[i]
return loss, sample_size, logging_output
def build_model(self, model_cfg: FairseqDataclass):
model = super().build_model(model_cfg)
if (self.cfg.eval_wer or self.cfg.eval_wer_parse) and self.cfg.autoregressive:
self.sequence_generator = self.build_generator(
[model],
self.cfg.eval_wer_config,
)
if self.cfg.eval_wer_tokenizer:
self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer)
else:
self.tokenizer = None
if self.cfg.eval_bleu and self.cfg.autoregressive:
assert self.cfg.eval_bleu_detok is not None, (
"--eval-bleu-detok is required if using --eval-bleu; "
"try --eval-bleu-detok=moses (or --eval-bleu-detok=space "
"to disable detokenization, e.g., when using sentencepiece)"
)
detok_args = json.loads(self.cfg.eval_bleu_detok_args)
self.tokenizer = encoders.build_tokenizer(
Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args)
)
gen_args = json.loads(self.cfg.eval_bleu_args)
gen_args = Namespace(**gen_args)
self.sequence_generator = self.build_generator([model], gen_args)
return model
def _inference_with_wer_parse(self, generator, sample, model):
import editdistance
def decode(toks):
s = self.target_dictionary.string(
toks.int().cpu(),
self.cfg.eval_wer_post_process,
escape_unk=True,
)
if self.tokenizer:
s = self.tokenizer.decode(s)
return s
def decode_to_list(toks):
def token_string(i):
if i == self.target_dictionary.unk():
return self.target_dictionary.unk_string(False)
else:
return self.target_dictionary[i]
return [token_string(i) for i in toks]
def is_ont_token(token):
return "[" in token or "]" in token
def post_process(l):
o = []
for w in l:
if w == self.target_dictionary.eos_word or w == "|":
continue
if w == "_":
o.append(" ")
else:
o.append(w)
if is_ont_token(w):
o.append(" ")
return o
num_word_errors, num_char_errors = 0, 0
num_chars, num_words = 0, 0
num_em_errors, num_ems = 0, 0
num_tree_errors, num_trees = 0, 0
gen_out = self.inference_step(generator, [model], sample, None)
for i in range(len(gen_out)):
hyp_tokens = gen_out[i][0]["tokens"]
# hyp = decode(hyp_tokens)
ref_tokens = utils.strip_pad(
sample["target"][i], self.target_dictionary.pad()
)
# ref = decode(ref_tokens)
hyp_list = decode_to_list(hyp_tokens)
ref_list = decode_to_list(ref_tokens)
hyp_list = post_process(hyp_list)
ref_list = post_process(ref_list)
hyp = "".join(hyp_list).strip()
ref = "".join(ref_list).strip()
num_chars += len(ref)
num_char_errors += editdistance.eval(hyp, ref)
hyp_words = hyp.split()
ref_words = ref.split()
hyp_tree = [word for word in hyp_list if ("[" in word or "]" in word)]
ref_tree = [word for word in ref_list if ("[" in word or "]" in word)]
# num_word_errors += editdistance.eval(hyp_words, ref_words)
hyp_before = decode(hyp_tokens).split()
ref_before = decode(ref_tokens).split()
num_word_errors += editdistance.eval(hyp_before, ref_before)
num_words += len(ref_before)
if hyp != ref:
num_em_errors += 1
if hyp_tree != ref_tree:
num_tree_errors += 1
num_ems += 1
num_trees += 1
return {
"num_char_errors": num_char_errors,
"num_chars": num_chars,
"num_word_errors": num_word_errors,
"num_words": num_words,
"num_ems": num_ems,
"num_em_errors": num_em_errors,
"num_trees": num_trees,
"num_tree_errors": num_tree_errors,
}
def _inference_with_wer(self, generator, sample, model):
import editdistance
def decode(toks):
s = self.target_dictionary.string(
toks.int().cpu(),
self.cfg.eval_wer_post_process,
escape_unk=True,
)
if self.tokenizer:
s = self.tokenizer.decode(s)
return s
num_word_errors, num_char_errors = 0, 0
num_chars, num_words = 0, 0
gen_out = self.inference_step(generator, [model], sample, None)
for i in range(len(gen_out)):
hyp = decode(gen_out[i][0]["tokens"])
ref = decode(
utils.strip_pad(sample["target"][i], self.target_dictionary.pad()),
)
num_char_errors += editdistance.eval(hyp, ref)
num_chars += len(ref)
hyp_words = hyp.split()
ref_words = ref.split()
num_word_errors += editdistance.eval(hyp_words, ref_words)
num_words += len(ref_words)
return {
"num_char_errors": num_char_errors,
"num_chars": num_chars,
"num_word_errors": num_word_errors,
"num_words": num_words,
}
def _inference_with_bleu(self, generator, sample, model):
import sacrebleu
def decode(toks, is_ref):
s = self.target_dictionary.string(
toks.int().cpu(),
self.cfg.eval_bleu_remove_bpe,
# The default unknown string in fairseq is `<unk>`, but
# this is tokenized by sacrebleu as `< unk >`, inflating
# BLEU scores. Instead, we use a somewhat more verbose
# alternative that is unlikely to appear in the real
# reference, but doesn't get split into multiple tokens.
unk_string=("UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP"),
)
if self.tokenizer:
s = self.tokenizer.decode(s)
return s
gen_out = self.inference_step(generator, [model], sample)
hyps, refs = [], []
for i in range(len(gen_out)):
hyps.append(decode(gen_out[i][0]["tokens"], is_ref=False))
refs.append(
decode(
utils.strip_pad(sample["target"][i], self.target_dictionary.pad()),
is_ref=True, # don't count <unk> as matches to the hypo
)
)
if self.cfg.eval_bleu_print_samples:
logger.info("H-{} {}".format(sample["id"][0], hyps[0]))
logger.info("T-{} {}".format(sample["id"][0], refs[0]))
eval_tokenization = "none" if self.cfg.eval_tokenized_bleu else "13a"
return sacrebleu.corpus_bleu(hyps, [refs], tokenize=eval_tokenization)
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
if self.cfg.eval_wer or self.cfg.eval_wer_parse:
zero = torch.scalar_tensor(0.0)
num_char_errors = sum(
log.get("_num_char_errors", zero) for log in logging_outputs
)
num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs)
num_word_errors = sum(
log.get("_num_word_errors", zero) for log in logging_outputs
)
num_words = sum(log.get("_num_words", zero) for log in logging_outputs)
metrics.log_scalar("_num_char_errors", num_char_errors)
metrics.log_scalar("_num_chars", num_chars)
metrics.log_scalar("_num_word_errors", num_word_errors)
metrics.log_scalar("_num_words", num_words)
if num_chars > 0:
metrics.log_derived(
"uer",
lambda meters: meters["_num_char_errors"].sum
* 100.0
/ meters["_num_chars"].sum
if meters["_num_chars"].sum > 0
else float("nan"),
)
if num_words > 0:
metrics.log_derived(
"wer",
lambda meters: meters["_num_word_errors"].sum
* 100.0
/ meters["_num_words"].sum
if meters["_num_words"].sum > 0
else float("nan"),
)
if self.cfg.eval_wer_parse:
num_em_errors = sum(
log.get("_num_em_errors", zero) for log in logging_outputs
)
num_ems = sum(log.get("_num_ems", zero) for log in logging_outputs)
metrics.log_scalar("_num_em_errors", num_em_errors)
metrics.log_scalar("_num_ems", num_ems)
num_tree_errors = sum(
log.get("_num_tree_errors", zero) for log in logging_outputs
)
num_trees = sum(log.get("_num_trees", zero) for log in logging_outputs)
metrics.log_scalar("_num_tree_errors", num_tree_errors)
metrics.log_scalar("_num_trees", num_trees)
if num_ems > 0:
metrics.log_derived(
"em_error",
lambda meters: meters["_num_em_errors"].sum
* 100.0
/ meters["_num_ems"].sum
if meters["_num_ems"].sum > 0
else float("nan"),
)
if num_trees > 0:
metrics.log_derived(
"tree_error",
lambda meters: meters["_num_tree_errors"].sum
* 100.0
/ meters["_num_trees"].sum
if meters["_num_trees"].sum > 0
else float("nan"),
)
if self.cfg.eval_bleu:
len_keys = ["_bleu_sys_len", "_bleu_ref_len"]
count_keys = [f"_bleu_counts_{i}" for i in range(4)]
total_keys = [f"_bleu_totals_{i}" for i in range(4)]
for k in len_keys + count_keys + total_keys:
metrics.log_scalar(k, sum(log.get(k, 0) for log in logging_outputs))
import sacrebleu
metrics.log_derived(
"bleu",
lambda meters: sacrebleu.compute_bleu(
correct=[meters[k].sum for k in count_keys],
total=[meters[k].sum for k in total_keys],
sys_len=meters["_bleu_sys_len"].sum,
ref_len=meters["_bleu_ref_len"].sum,
smooth_method="exp",
).score,
)

View File

@@ -0,0 +1,682 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import json
import logging
import math
import os
from argparse import Namespace
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import Dict, Sequence, Tuple
from argparse import ArgumentError
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import fairseq
from fairseq import metrics, options, utils
from fairseq.data import (
FairseqDataset,
LanguagePairDataset,
NoisingDataset,
PrependTokenDataset,
RoundRobinZipDatasets,
TransformEosLangPairDataset,
data_utils,
encoders,
)
from fairseq.sequence_generator import SequenceGenerator
from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask, load_langpair_dataset
logger = logging.getLogger(__name__)
class PiecewiseLinearFn:
"""Piecewise linear function. Can be configured with a string."""
def __init__(self, pieces: Sequence[Tuple[int, float]]):
assert pieces == sorted(
pieces
), f"PiecewiseLinearFn configuration should be sorted, received: {pieces}"
self.pieces = pieces
def __call__(self, x: int) -> float:
for i, (x_a, y_a) in enumerate(self.pieces[:-1]):
x_b, y_b = self.pieces[i + 1]
if x_a <= x <= x_b:
return y_a + (x - x_a) * (y_b - y_a) / (x_b - x_a)
return self.pieces[-1][1]
@staticmethod
def from_string(configuration: str) -> "PiecewiseLinearFn":
"""
Parse the configuration of lambda coefficient (for scheduling).
x = "3" # lambda will be a constant equal to x
x = "0:1,1000:0" # lambda will start from 1 and linearly decrease
# to 0 during the first 1000 iterations
x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000
# iterations, then will linearly increase to 1 until iteration 2000
"""
if isinstance(configuration, float):
return PiecewiseLinearFn([(0, configuration)])
try:
parts = configuration.split(",")
if len(parts) == 1:
v = float(configuration)
return PiecewiseLinearFn([(0, v)])
split = [s.split(":") for s in parts]
pieces = [(int(t), float(v)) for t, v in split]
return PiecewiseLinearFn(pieces)
except Exception:
raise ValueError(
f"Invalid PiecewiseLinearFn configuration: {configuration!r}"
)
@staticmethod
def one() -> "PiecewiseLinearFn":
return PiecewiseLinearFn([(0, 1.0)])
@register_task("online_backtranslation")
class OnlineBackTranslationTask(TranslationTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
# Generic translation args
parser.add_argument('data', help='colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner; \
however, valid and test data are always in the first directory to \
avoid the need for repeating them in all directories')
parser.add_argument('--mono-langs', metavar='MONO_LANGS',
help='monolingual languages for training')
parser.add_argument('--valid-lang-pairs', default=None, metavar='VALID_LANG_PAIRS',
help='language pairs for validation')
parser.add_argument('--load-alignments', action='store_true',
help='load the binarized alignments')
parser.add_argument('--left-pad-source', default='False', type=str, metavar='BOOL',
help='pad the source on the left')
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
help='pad the target on the left')
parser.add_argument('--upsample-primary', default=1, type=int,
help='amount to upsample primary dataset')
try:
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
except ArgumentError:
# this might have already been defined. Once we transition this to hydra it should be fine to add it here.
pass
parser.add_argument('--truncate-source', action='store_true', default=False,
help='truncate source to max-source-positions')
parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N',
help='if >0, then bucket source and target lengths into N '
'buckets and pad accordingly; this is useful on TPUs '
'to minimize the number of compilations')
# Denoising args
parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N',
help='maximum word shuffle distance for denoising autoencoding data generation')
parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N',
help='word dropout probability for denoising autoencoding data generation')
parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N',
help='word blanking probability for denoising autoencoding data generation')
# Backtranslation args
parser.add_argument('--lambda-bt', default="1.0", type=str, metavar='N',
help='back-translation weight')
parser.add_argument('--lambda-dae', default="1.0", type=str, metavar='N',
help='denoising auto-encoder weight')
# Evaluation args
parser.add_argument('--generate-one-by-one', action='store_true',
help='generate one sentence at a time for backtranslation')
parser.add_argument('--eval-bleu', action='store_true',
help='evaluation with BLEU scores')
parser.add_argument('--eval-bleu-detok', type=str, default="space",
help='detokenize before computing BLEU (e.g., "moses"); '
'required if using --eval-bleu; use "space" to '
'disable detokenization; see fairseq.data.encoders '
'for other options')
parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON',
help='args for building the tokenizer, if needed')
parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False,
help='compute tokenized BLEU instead of sacrebleu')
parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE before computing BLEU')
parser.add_argument('--eval-bleu-args', type=str, metavar='JSON',
help='generation args for BLUE scoring, '
'e.g., \'{"beam": 4, "lenpen": 0.6}\'')
parser.add_argument('--eval-bleu-print-samples', action='store_true',
help='print sample generations during validation')
# fmt: on
def __init__(self, args, common_dict, mono_langs, valid_lang_pairs):
super().__init__(args, common_dict, common_dict)
self.common_dict = common_dict
self.mono_langs = mono_langs
self.valid_lang_pairs = valid_lang_pairs
self.SHOW_SAMPLES_INTERVAL = 1000
# Start by showing samples
self._show_samples_ctr = self.SHOW_SAMPLES_INTERVAL
self.SHOW_SAMPLES_NUMBER = 5
self.lambda_bt = PiecewiseLinearFn.from_string(args.lambda_bt)
self.lambda_dae = PiecewiseLinearFn.from_string(args.lambda_dae)
self.args = args
self.data = utils.split_paths(self.args.data)
if len(self.data) == 1:
shards = list(Path(self.data[0]).glob("shard*"))
if len(shards) > 0:
# keep this as strings, since it can also be a manifold path
old_data = self.data
self.data = [str(shard) for shard in shards]
logging.warning(f"Expanded data directory {old_data} to {self.data}")
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
args (argparse.Namespace): parsed command-line arguments
"""
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
paths = utils.split_paths(args.data)
assert len(paths) > 0
assert args.mono_langs is not None
mono_langs = args.mono_langs.split(",")
valid_lang_pairs = args.valid_lang_pairs.split(",")
# load dictionary
dict_path = os.path.join(paths[0], "dict.txt")
common_dict = cls.load_dictionary(dict_path)
return cls(args, common_dict, mono_langs, valid_lang_pairs)
def load_dataset(self, split, epoch=1, combine=False, **kwargs) -> FairseqDataset:
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if split == "train":
data_path = self.data[(epoch - 1) % len(self.data)]
dataset = self.load_train_dataset(data_path)
else:
# valid/test should always be the same.
dataset = self.load_translation_dataset(split, self.data[0])
self.datasets[split] = dataset
return dataset
def load_train_dataset(self, data_path: str) -> FairseqDataset:
"""The training dataset is made of backtranslation dataset and denoising dataset."""
data = []
for lang in self.mono_langs:
train_path = os.path.join(data_path, lang, "train")
# TODO: could we do the BT using denoise sample ?
# this would half the data loading work
data.append((f"{lang}-BT", self.load_bt_dataset(train_path, lang)))
data.append(
(f"{lang}-DENOISE", self.load_denoise_dataset(train_path, lang))
)
return RoundRobinZipDatasets(OrderedDict(data))
def _langpair_dataset(
self, src: FairseqDataset, tgt: FairseqDataset
) -> LanguagePairDataset:
return LanguagePairDataset(
src,
src.sizes,
self.dictionary,
tgt=tgt,
tgt_sizes=tgt.sizes,
tgt_dict=self.dictionary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
# TODO: should we shuffle ? we are already sorting batch by sizes so ?
# shuffle=True,
)
def _prepend_lang_bos_to_target(
self, dataset: LanguagePairDataset, lang: str
) -> LanguagePairDataset:
bos = _lang_token_index(self.dictionary, lang)
return TransformEosLangPairDataset(
dataset,
src_eos=self.dictionary.eos(),
new_src_eos=self.dictionary.eos(),
tgt_bos=self.dictionary.eos(),
new_tgt_bos=bos,
)
def load_bt_dataset(self, data_path: str, lang: str) -> FairseqDataset:
"""The BT dataset is generated with (tgt, tgt) pairs.
The actual translation to a (generated_src, tgt) pair
is done on the fly during training.
"""
mono_dataset = data_utils.load_indexed_dataset(
data_path, self.common_dict, self.args.dataset_impl
)
assert mono_dataset is not None, f"No dataset found for {lang}"
mono_dataset_src = PrependTokenDataset(
mono_dataset, _lang_token_index(self.dictionary, lang)
)
mono_dataset_bt = self._langpair_dataset(mono_dataset_src, mono_dataset)
logger.info(
f"mono_lang = {lang} "
f"lang token index = {_lang_token_index(self.dictionary, lang)} "
f"lang token = {_lang_token(lang)}"
)
mono_dataset_bt = self._prepend_lang_bos_to_target(mono_dataset_bt, lang)
return mono_dataset_bt
def load_denoise_dataset(self, data_path: str, lang: str) -> FairseqDataset:
"""Classic denoising dataset"""
dataset = data_utils.load_indexed_dataset(
data_path, self.common_dict, self.args.dataset_impl
)
noisy_dataset = NoisingDataset(
dataset,
self.dictionary,
seed=1,
max_word_shuffle_distance=self.args.max_word_shuffle_distance,
word_dropout_prob=self.args.word_dropout_prob,
word_blanking_prob=self.args.word_blanking_prob,
)
noisy_dataset = PrependTokenDataset(
noisy_dataset, _lang_token_index(self.dictionary, lang)
)
clean_dataset = data_utils.load_indexed_dataset(
data_path, self.common_dict, self.args.dataset_impl
)
denoising_dataset = self._langpair_dataset(noisy_dataset, clean_dataset)
denoising_dataset = self._prepend_lang_bos_to_target(denoising_dataset, lang)
return denoising_dataset
def load_translation_dataset(
self, split: str, data_path: str, combine: bool = False
):
# only judging with one language pair for the moment,
# since ConcatDataset doesn't work as expected
assert len(self.valid_lang_pairs) == 1, "For now..."
valid_lang_pair = self.valid_lang_pairs[0]
src, tgt = valid_lang_pair.split("-")
# use the same function than TranslationTask
src_tgt_dt = load_langpair_dataset(
data_path,
split,
src,
self.common_dict,
tgt,
self.common_dict,
combine=combine,
dataset_impl=self.args.dataset_impl,
upsample_primary=self.args.upsample_primary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
load_alignments=self.args.load_alignments,
truncate_source=self.args.truncate_source,
num_buckets=self.args.num_batch_buckets,
shuffle=(split != "test"),
prepend_bos_src=_lang_token_index(self.dictionary, src),
)
src_tgt_eos_dt = self._prepend_lang_bos_to_target(src_tgt_dt, tgt)
src_tgt_eos_dt.args = self.args
return src_tgt_eos_dt
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
raise NotImplementedError
def build_model(self, args, from_checkpoint=False):
# torch.autograd.set_detect_anomaly(True)
model = super().build_model(args, from_checkpoint)
add_secial_tokens_to_dict_and_model(self.common_dict, model, self.mono_langs)
self.sequence_generators = {}
for mono_lang in self.mono_langs:
self.sequence_generators[mono_lang] = SequenceGenerator(
[model],
tgt_dict=self.dictionary,
beam_size=1,
max_len_a=1.3,
max_len_b=5,
min_len=5,
# keep 1 to be able to prepend bos
max_len=model.max_decoder_positions() - 1,
)
if getattr(args, "eval_bleu", False):
assert getattr(args, "eval_bleu_detok", None) is not None, (
"--eval-bleu-detok is required if using --eval-bleu; "
"try --eval-bleu-detok=moses (or --eval-bleu-detok=space "
"to disable detokenization, e.g., when using sentencepiece)"
)
detok_args = json.loads(getattr(args, "eval_bleu_detok_args", "{}") or "{}")
self.tokenizer = encoders.build_tokenizer(
Namespace(
tokenizer=getattr(args, "eval_bleu_detok", None), **detok_args
)
)
gen_args = json.loads(getattr(args, "eval_bleu_args", "{}") or "{}")
self.bleu_sequence_generator = self.build_generator(
[model], Namespace(**gen_args)
)
return model
def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions)
@property
def dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary`."""
return self.common_dict
def display_samples_once_in_a_while(self, smp, mono_lang, other_lang):
self._show_samples_ctr += 1
if self._show_samples_ctr < self.SHOW_SAMPLES_INTERVAL:
return
self._show_samples_ctr = 0
ln = smp["net_input"]["src_tokens"].shape[0]
logger.info(
f"(r:{self.args.distributed_rank}) : "
f"{other_lang} ---> {mono_lang} "
f"({other_lang} was generated by back-translation.) {ln} samples"
)
for i in range(min(ln, self.SHOW_SAMPLES_NUMBER)):
src_tokens = smp["net_input"]["src_tokens"][i]
tgt_tokens = smp["target"][i]
src_str = self.dictionary.string(src_tokens, "sentencepiece")
tgt_str = self.dictionary.string(tgt_tokens, "sentencepiece")
logger.info(
f"\n{i}\t\t[{other_lang} generated] {src_str}\n"
f"\t\t[{mono_lang} original ] {tgt_str}\n"
f"\t\t[ src tokens] {src_tokens}\n"
)
def backtranslate_sample(self, smp, orig_lang, other_lang) -> None:
"""
* WARNING: smp is modified in place.
* At the start of this function, `smp` has the same input and target:
|--------------------------------------------------------|
| smp['net_input']['src_tokens'] | smp['target'] |
| (from data) __en__ hello world | __en__ hello world |
|--------------------------------------------------------|
* We call generator.generate(smp, bos_token = token("ro")),
and copy the result as input
* At the end, `smp` has the translation to other language.
|--------------------------------------------------------|
| smp['net_input']['src_tokens'] | smp['target'] |
| (generated) __ro__ salut lume | __en__ hello world |
|--------------------------------------------------------|
"""
bos_token = _lang_token_index(self.dictionary, other_lang)
generated = self.sequence_generators[orig_lang].generate(
models=[], sample=smp, bos_token=bos_token
)
max_lngth = max([gn[0]["tokens"].size(0) for gn in generated])
net_input = smp["net_input"]
n_src_tokens = torch.empty(
size=(len(generated), max_lngth + 1), dtype=net_input["src_tokens"].dtype
)
n_src_lengths = torch.empty(
len(generated), dtype=net_input["src_lengths"].dtype
)
for i, gn in enumerate(generated):
tokens = gn[0]["tokens"]
tokens_size = tokens.size(0)
padding_needed = max_lngth - tokens_size
tokens = torch.cat([tokens.new([bos_token]), tokens])
tokens = F.pad(tokens, (0, padding_needed), value=self.dictionary.pad())
n_src_tokens[i] = tokens
n_src_lengths[i] = tokens_size + 1
device = net_input["src_tokens"].device
# This seems to be important
del net_input["src_tokens"]
del net_input["src_lengths"]
net_input["src_tokens"] = n_src_tokens.to(device)
net_input["src_lengths"] = n_src_lengths.to(device)
def generate(self, smp, model):
model.eval()
orig_lang = (
self.dictionary[smp["net_input"]["src_tokens"][0][0]]
.replace(" ", "")
.replace("_", "")
)
bos_token = smp["net_input"]["prev_output_tokens"][0][0]
with torch.no_grad():
generated = self.sequence_generators[orig_lang].generate(
models=[model], sample=smp, bos_token=bos_token
)
return generated
def get_other_lang(self, lang):
# TODO: allow more complex mapping
if lang != self.mono_langs[0]:
return self.mono_langs[0]
if len(self.mono_langs) == 2:
return self.mono_langs[1]
return self.mono_langs[np.random.randint(1, len(self.mono_langs))]
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
model.train()
model.set_num_updates(update_num)
agg_loss, agg_sample_size = 0.0, 0.0
agg_logging_output: Dict[str, float] = defaultdict(float)
dataset_keys = self.datasets["train"].datasets.keys()
weights = {
"BT": self.lambda_bt(update_num),
"DENOISE": self.lambda_dae(update_num),
}
log_keys = {"BT": "bt_", "DENOISE": "dae_"}
for dataset_key in dataset_keys:
smp = sample[dataset_key]
mono_lang, task_subtype = dataset_key.split("-")
if weights[task_subtype] == 0:
continue
if task_subtype == "BT":
with torch.autograd.profiler.record_function("backtranslation"):
model.eval()
# TODO: Could we translate to several language at once ?
# this would allow to share encoder_out and maximize GPU usage.
other_lang = self.get_other_lang(mono_lang)
self.backtranslate_sample(smp, mono_lang, other_lang)
self.display_samples_once_in_a_while(smp, mono_lang, other_lang)
model.train()
# Like in FairseqTask.train_step
with torch.autograd.profiler.record_function("forward"):
loss, sample_size, logging_output = criterion(model, smp)
loss *= weights[task_subtype]
if ignore_grad:
loss *= 0
with torch.autograd.profiler.record_function("backward"):
optimizer.backward(loss)
agg_loss += loss.item()
agg_sample_size += sample_size
for k in logging_output:
agg_logging_output[log_keys[task_subtype] + k] += logging_output[k]
agg_logging_output[k] += logging_output[k]
return agg_loss, agg_sample_size, agg_logging_output
def get_bos_token_from_sample(self, sample):
net_input = sample["net_input"]
source_lang_token_id = torch.unique(net_input["src_tokens"][:, 0]).item()
source_lang_token = self.dictionary[source_lang_token_id].replace("_", "")
target_lang_token_id = _lang_token_index(
self.dictionary, self.get_other_lang(source_lang_token)
)
return target_lang_token_id
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
bt_sample_size = sum(x.get("bt_sample_size", 0) for x in logging_outputs)
if bt_sample_size:
bt_loss_sum = sum(x.get("bt_loss", 0) for x in logging_outputs)
bt_loss_sum *= 1 / bt_sample_size / math.log(2)
metrics.log_scalar("bt_loss", bt_loss_sum, bt_sample_size, round=3)
bt_nll_loss_sum = sum(x.get("bt_nll_loss", 0) for x in logging_outputs)
bt_ntokens = sum(x.get("bt_ntokens", 0) for x in logging_outputs)
bt_nll_loss_sum *= 1 / bt_ntokens / math.log(2)
metrics.log_scalar("bt_nll_loss", bt_nll_loss_sum, bt_ntokens, round=3)
metrics.log_derived(
"bt_ppl", lambda meters: utils.get_perplexity(meters["bt_nll_loss"].avg)
)
dae_sample_size = sum(x.get("dae_sample_size", 0) for x in logging_outputs)
if dae_sample_size:
dae_loss_sum = sum(x.get("dae_loss", 0) for x in logging_outputs)
dae_loss_sum *= 1 / dae_sample_size / math.log(2)
metrics.log_scalar("dae_loss", dae_loss_sum, dae_sample_size, round=3)
dae_nll_loss_sum = sum(x.get("dae_nll_loss", 0) for x in logging_outputs)
dae_ntokens = sum(x.get("dae_ntokens", 0) for x in logging_outputs)
dae_nll_loss_sum *= 1 / dae_ntokens / math.log(2)
metrics.log_scalar("dae_nll_loss", dae_nll_loss_sum, dae_ntokens, round=3)
metrics.log_derived(
"dae_ppl",
lambda meters: utils.get_perplexity(meters["dae_nll_loss"].avg),
)
@torch.no_grad()
def extend_embedding(
emb: nn.Module, new_vocab_size: int, copy_from_token_id: int
) -> None:
old_emb_data = emb.weight.data
(old_vocab_size, dim) = old_emb_data.shape
assert new_vocab_size >= old_vocab_size
if new_vocab_size > old_vocab_size:
emb.weight.data = torch.zeros((new_vocab_size, dim))
emb.weight.data[:old_vocab_size, :] = old_emb_data
# initialize new embeddings
emb.weight.data[old_vocab_size:, :] = old_emb_data[copy_from_token_id]
if hasattr(emb, "num_embeddings"):
emb.num_embeddings = new_vocab_size
if hasattr(emb, "out_features"):
emb.out_features = new_vocab_size
if getattr(emb, "bias", None) is None:
return
# Fix the bias.
# Bias shape can be different from the previous vocab size
# if the weight matrix was shared and alread extended but not the bias.
(old_vocab_size,) = emb.bias.shape
assert new_vocab_size >= old_vocab_size
if new_vocab_size > old_vocab_size:
old_bias = emb.bias.data
new_bias = torch.zeros(
(new_vocab_size,), dtype=old_bias.dtype, device=old_bias.device
)
new_bias[:old_vocab_size] = old_bias
emb.bias.data = new_bias
def add_secial_tokens_to_dict_and_model(
dictionary: "fairseq.data.Dictionary",
model: nn.Module,
mono_langs: Sequence[str],
) -> None:
embs = model.encoder.embed_tokens
vocab_size, embedding_dim = embs.weight.shape
# The model may or may not have a '<mask>' embedding yet
assert (
len(dictionary) <= vocab_size <= len(dictionary) + 1
), f"Dictionary len ({len(dictionary)}) doesn't match embs shape ({embs.weight.shape})"
# TODO: we should reuse the pretrained model dict which already has <mask>
dictionary.add_symbol("<mask>")
for lang in mono_langs:
lang_token = _lang_token(lang)
dictionary.add_symbol(lang_token)
logger.info(
f"dictionary: {len(dictionary)} -> {vocab_size} tokens "
f"after adding {len(mono_langs)} lang tokens."
)
if len(dictionary) <= vocab_size:
return
extend_embedding(embs, len(dictionary), dictionary.bos())
dec_embs = model.decoder.embed_tokens
extend_embedding(dec_embs, len(dictionary), dictionary.bos())
lm_head = model.decoder.output_projection
extend_embedding(lm_head, len(dictionary), dictionary.bos())
assert lm_head.weight.shape == (len(dictionary), embedding_dim)
def _lang_token(lang: str) -> str:
return f"__{lang}__"
def _lang_token_index(dictionary, lang: str) -> int:
return dictionary.index(_lang_token(lang))
@contextlib.contextmanager
def assert_weights_have_changed(model: nn.Module):
def checksum(model: nn.Module) -> float:
return sum(p.sum().item() for p in model.parameters())
initial_checksum = checksum(model)
yield model
final_checksum = checksum(model)
logger.info(
f"initial_checksum={initial_checksum} -> final_checksum={final_checksum}"
)
assert initial_checksum != final_checksum, "Model hasn't changed !"

View File

@@ -0,0 +1,485 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from collections import OrderedDict
from fairseq import utils
from fairseq.data import (
BacktranslationDataset,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
NoisingDataset,
RoundRobinZipDatasets,
data_utils,
indexed_dataset,
)
from fairseq.models import FairseqMultiModel
from fairseq.sequence_generator import SequenceGenerator
from . import register_task
from .multilingual_translation import MultilingualTranslationTask
logger = logging.getLogger(__name__)
def _get_bt_dataset_key(lang_pair):
return "bt:" + lang_pair
def _get_denoising_dataset_key(lang_pair):
return "denoising:" + lang_pair
# ported from UnsupervisedMT
def parse_lambda_config(x):
"""
Parse the configuration of lambda coefficient (for scheduling).
x = "3" # lambda will be a constant equal to x
x = "0:1,1000:0" # lambda will start from 1 and linearly decrease
# to 0 during the first 1000 iterations
x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000
# iterations, then will linearly increase to 1 until iteration 2000
"""
split = x.split(",")
if len(split) == 1:
return float(x), None
else:
split = [s.split(os.pathsep) for s in split]
assert all(len(s) == 2 for s in split)
assert all(k.isdigit() for k, _ in split)
assert all(
int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1)
)
return float(split[0][1]), [(int(k), float(v)) for k, v in split]
@register_task("semisupervised_translation")
class SemisupervisedTranslationTask(MultilingualTranslationTask):
"""A task for training multiple translation models simultaneously.
We iterate round-robin over batches from multiple language pairs, ordered
according to the `--lang-pairs` argument.
The training loop is roughly:
for i in range(len(epoch)):
for lang_pair in args.lang_pairs:
batch = next_batch_for_lang_pair(lang_pair)
loss = criterion(model_for_lang_pair(lang_pair), batch)
loss.backward()
optimizer.step()
In practice, `next_batch_for_lang_pair` is abstracted in a FairseqDataset
(e.g., `RoundRobinZipDatasets`) and `model_for_lang_pair` is a model that
implements the `FairseqMultiModel` interface.
During inference it is required to specify a single `--source-lang` and
`--target-lang`, instead of `--lang-pairs`.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
MultilingualTranslationTask.add_args(parser)
parser.add_argument('--lambda-parallel-config', default="1.0", type=str, metavar='CONFIG',
help='cross-entropy reconstruction coefficient (parallel data). '
'use fixed weight during training if set to floating point number. '
'use piecewise linear function over number of updates to schedule the '
'weight with the format: w0:step0,w1:step1,...')
parser.add_argument('--lambda-denoising-config', default="0.0", type=str, metavar='CONFIG',
help='Cross-entropy reconstruction coefficient (denoising autoencoding)'
'use fixed weight during training if set to floating point number. '
'use piecewise linear function over number of updates to schedule the '
'weight with the format: w0:step0,w1:step1,...')
parser.add_argument('--lambda-otf-bt-config', default="0.0", type=str, metavar='CONFIG',
help='cross-entropy reconstruction coefficient (on-the-fly back-translation parallel data)'
'use fixed weight during training if set to floating point number. '
'use piecewise linear function over number of updates to schedule the '
'weight with the format: w0:step0,w1:step1,...')
parser.add_argument('--bt-max-len-a', default=1.1, type=float, metavar='N',
help='generate back-translated sequences of maximum length ax + b, where x is the '
'source length')
parser.add_argument('--bt-max-len-b', default=10.0, type=float, metavar='N',
help='generate back-translated sequences of maximum length ax + b, where x is the '
'source length')
parser.add_argument('--bt-beam-size', default=1, type=int, metavar='N',
help='beam size used in beam search of online back-translation')
parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N',
help='maximum word shuffle distance for denoising autoencoding data generation')
parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N',
help='word dropout probability for denoising autoencoding data generation')
parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N',
help='word blanking probability for denoising autoencoding data generation')
# fmt: on
def __init__(self, args, dicts, training):
super().__init__(args, dicts, training)
self.lambda_parallel, self.lambda_parallel_steps = parse_lambda_config(
args.lambda_parallel_config
)
self.lambda_otf_bt, self.lambda_otf_bt_steps = parse_lambda_config(
args.lambda_otf_bt_config
)
self.lambda_denoising, self.lambda_denoising_steps = parse_lambda_config(
args.lambda_denoising_config
)
if self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None:
denoising_lang_pairs = [
"%s-%s" % (tgt, tgt)
for tgt in {lang_pair.split("-")[1] for lang_pair in args.lang_pairs}
]
self.model_lang_pairs = self.model_lang_pairs + denoising_lang_pairs
self.backtranslate_datasets = {}
self.backtranslators = {}
@classmethod
def setup_task(cls, args, **kwargs):
dicts, training = MultilingualTranslationTask.prepare(args, **kwargs)
return cls(args, dicts, training)
def load_dataset(self, split, epoch=1, **kwargs):
"""Load a dataset split."""
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
def split_exists(split, src, tgt, lang):
if src is not None:
filename = os.path.join(
data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)
)
else:
filename = os.path.join(
data_path, "{}.{}-None.{}".format(split, src, tgt)
)
return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)
def load_indexed_dataset(path, dictionary):
return data_utils.load_indexed_dataset(
path, dictionary, self.args.dataset_impl
)
# load parallel datasets
src_datasets, tgt_datasets = {}, {}
if (
self.lambda_parallel > 0.0
or self.lambda_parallel_steps is not None
or not split.startswith("train")
):
for lang_pair in self.lang_pairs:
src, tgt = lang_pair.split("-")
if split_exists(split, src, tgt, src):
prefix = os.path.join(
data_path, "{}.{}-{}.".format(split, src, tgt)
)
elif split_exists(split, tgt, src, src):
prefix = os.path.join(
data_path, "{}.{}-{}.".format(split, tgt, src)
)
else:
continue
src_datasets[lang_pair] = load_indexed_dataset(
prefix + src, self.dicts[src]
)
tgt_datasets[lang_pair] = load_indexed_dataset(
prefix + tgt, self.dicts[tgt]
)
logger.info(
"parallel-{} {} {} examples".format(
data_path, split, len(src_datasets[lang_pair])
)
)
if len(src_datasets) == 0:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, data_path)
)
# back translation datasets
backtranslate_datasets = {}
if (
self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None
) and split.startswith("train"):
for lang_pair in self.lang_pairs:
src, tgt = lang_pair.split("-")
if not split_exists(split, tgt, None, tgt):
raise FileNotFoundError(
"Dataset not found: backtranslation {} ({})".format(
split, data_path
)
)
filename = os.path.join(
data_path, "{}.{}-None.{}".format(split, tgt, tgt)
)
dataset = load_indexed_dataset(filename, self.dicts[tgt])
lang_pair_dataset_tgt = LanguagePairDataset(
dataset,
dataset.sizes,
self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
)
lang_pair_dataset = LanguagePairDataset(
dataset,
dataset.sizes,
src_dict=self.dicts[src],
tgt=dataset,
tgt_sizes=dataset.sizes,
tgt_dict=self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
)
backtranslate_datasets[lang_pair] = BacktranslationDataset(
tgt_dataset=self.alter_dataset_langtok(
lang_pair_dataset_tgt,
src_eos=self.dicts[tgt].eos(),
src_lang=tgt,
tgt_lang=src,
),
backtranslation_fn=self.backtranslators[lang_pair],
src_dict=self.dicts[src],
tgt_dict=self.dicts[tgt],
output_collater=self.alter_dataset_langtok(
lang_pair_dataset=lang_pair_dataset,
src_eos=self.dicts[src].eos(),
src_lang=src,
tgt_eos=self.dicts[tgt].eos(),
tgt_lang=tgt,
).collater,
)
logger.info(
"backtranslate-{}: {} {} {} examples".format(
tgt,
data_path,
split,
len(backtranslate_datasets[lang_pair]),
)
)
self.backtranslate_datasets[lang_pair] = backtranslate_datasets[
lang_pair
]
# denoising autoencoder
noising_datasets = {}
if (
self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None
) and split.startswith("train"):
for lang_pair in self.lang_pairs:
_, tgt = lang_pair.split("-")
if not split_exists(split, tgt, None, tgt):
continue
filename = os.path.join(
data_path, "{}.{}-None.{}".format(split, tgt, tgt)
)
tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt])
tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt])
noising_dataset = NoisingDataset(
tgt_dataset1,
self.dicts[tgt],
seed=1,
max_word_shuffle_distance=self.args.max_word_shuffle_distance,
word_dropout_prob=self.args.word_dropout_prob,
word_blanking_prob=self.args.word_blanking_prob,
)
noising_datasets[lang_pair] = self.alter_dataset_langtok(
LanguagePairDataset(
noising_dataset,
tgt_dataset1.sizes,
self.dicts[tgt],
tgt_dataset2,
tgt_dataset2.sizes,
self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
),
src_eos=self.dicts[tgt].eos(),
src_lang=tgt,
tgt_eos=self.dicts[tgt].eos(),
tgt_lang=tgt,
)
logger.info(
"denoising-{}: {} {} {} examples".format(
tgt,
data_path,
split,
len(noising_datasets[lang_pair]),
)
)
def language_pair_dataset(lang_pair):
src, tgt = lang_pair.split("-")
src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
return self.alter_dataset_langtok(
LanguagePairDataset(
src_dataset,
src_dataset.sizes,
self.dicts[src],
tgt_dataset,
tgt_dataset.sizes,
self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
),
self.dicts[src].eos(),
src,
self.dicts[tgt].eos(),
tgt,
)
self.datasets[split] = RoundRobinZipDatasets(
OrderedDict(
[
(lang_pair, language_pair_dataset(lang_pair))
for lang_pair in src_datasets.keys()
]
+ [
(_get_bt_dataset_key(lang_pair), dataset)
for lang_pair, dataset in backtranslate_datasets.items()
]
+ [
(_get_denoising_dataset_key(lang_pair), dataset)
for lang_pair, dataset in noising_datasets.items()
]
),
eval_key=None
if self.training
else "%s-%s" % (self.args.source_lang, self.args.target_lang),
)
def build_model(self, args, from_checkpoint=False):
from fairseq import models
model = models.build_model(args, self, from_checkpoint)
if not isinstance(model, FairseqMultiModel):
raise ValueError(
"SemisupervisedTranslationTask requires a FairseqMultiModel architecture"
)
# create SequenceGenerator for each model that has backtranslation dependency on it
self.sequence_generators = {}
if (
self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None
) and self.training:
for lang_pair in self.lang_pairs:
src, tgt = lang_pair.split("-")
key = "{}-{}".format(tgt, src)
self.sequence_generators[key] = SequenceGenerator(
[model.models[key]],
tgt_dict=self.dicts[src],
beam_size=args.bt_beam_size,
max_len_a=args.bt_max_len_a,
max_len_b=args.bt_max_len_b,
)
decoder_lang_tok_idx = self.get_decoder_langtok(src)
def backtranslate_fn(
sample,
model=model.models[key],
bos_token=decoder_lang_tok_idx,
sequence_generator=self.sequence_generators[key],
):
return sequence_generator.generate(
[model],
sample,
bos_token=bos_token,
)
self.backtranslators[lang_pair] = backtranslate_fn
return model
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
model.train()
if update_num > 0:
self.update_step(update_num)
agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, {}
def forward_backward(model, samples, logging_output_key, weight):
nonlocal agg_loss, agg_sample_size, agg_logging_output
if samples is None or len(samples) == 0:
return
loss, sample_size, logging_output = criterion(model, samples)
if ignore_grad:
loss *= 0
else:
loss *= weight
optimizer.backward(loss)
agg_loss += loss.detach().item()
# TODO make summing of the sample sizes configurable
agg_sample_size += sample_size
for k in logging_output:
agg_logging_output[k] += logging_output[k]
agg_logging_output[logging_output_key] += logging_output[k]
if self.lambda_parallel > 0.0:
for lang_pair in self.lang_pairs:
forward_backward(
model.models[lang_pair],
sample[lang_pair],
lang_pair,
self.lambda_parallel,
)
if self.lambda_otf_bt > 0.0:
for lang_pair in self.lang_pairs:
sample_key = _get_bt_dataset_key(lang_pair)
forward_backward(
model.models[lang_pair],
sample[sample_key],
sample_key,
self.lambda_otf_bt,
)
if self.lambda_denoising > 0.0:
for lang_pair in self.lang_pairs:
_, tgt = lang_pair.split("-")
sample_key = _get_denoising_dataset_key(lang_pair)
forward_backward(
model.models["{0}-{0}".format(tgt)],
sample[sample_key],
sample_key,
self.lambda_denoising,
)
return agg_loss, agg_sample_size, agg_logging_output
def update_step(self, num_updates):
def lambda_step_func(config, n_iter):
"""
Update a lambda value according to its schedule configuration.
"""
ranges = [
i
for i in range(len(config) - 1)
if config[i][0] <= n_iter < config[i + 1][0]
]
if len(ranges) == 0:
assert n_iter >= config[-1][0]
return config[-1][1]
assert len(ranges) == 1
i = ranges[0]
x_a, y_a = config[i]
x_b, y_b = config[i + 1]
return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a)
if self.lambda_parallel_steps is not None:
self.lambda_parallel = lambda_step_func(
self.lambda_parallel_steps, num_updates
)
if self.lambda_denoising_steps is not None:
self.lambda_denoising = lambda_step_func(
self.lambda_denoising_steps, num_updates
)
if self.lambda_otf_bt_steps is not None:
self.lambda_otf_bt = lambda_step_func(self.lambda_otf_bt_steps, num_updates)

View File

@@ -0,0 +1,286 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import contextlib
from dataclasses import dataclass, field
from typing import Optional
from omegaconf import MISSING, II, open_dict, OmegaConf
import numpy as np
from fairseq.data import (
ConcatSentencesDataset,
Dictionary,
IdDataset,
NestedDictionaryDataset,
NumelDataset,
NumSamplesDataset,
OffsetTokensDataset,
PrependTokenDataset,
RawLabelDataset,
RightPadDataset,
RollDataset,
SortDataset,
StripTokenDataset,
data_utils,
)
from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.tasks import FairseqDataclass, FairseqTask, register_task
from fairseq.dataclass import ChoiceEnum
logger = logging.getLogger(__name__)
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
@dataclass
class SentencePredictionConfig(FairseqDataclass):
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
num_classes: int = field(
default=-1,
metadata={"help": "number of classes or regression targets"},
)
init_token: Optional[int] = field(
default=None,
metadata={"help": "add token at the beginning of each batch item"},
)
separator_token: Optional[int] = field(
default=None,
metadata={"help": "add separator token between inputs"},
)
no_shuffle: bool = field(
default=False,
)
shorten_method: SHORTEN_METHOD_CHOICES = field(
default="none",
metadata={
"help": "if not none, shorten sequences that exceed tokens_per_sample"
},
)
shorten_data_split_list: str = field(
default="",
metadata={
"help": "comma-separated list of dataset splits to apply shortening to, "
'e.g., "train,valid" (default: all dataset splits)'
},
)
add_prev_output_tokens: bool = field(
default=False,
metadata={
"help": "add prev_output_tokens to sample, used for encoder-decoder arch"
},
)
max_positions: int = field(
default=512,
metadata={"help": "max tokens per example"},
)
regression_target: bool = II("criterion.regression_target")
classification_head_name: str = II("criterion.classification_head_name")
seed: int = II("common.seed")
@register_task("sentence_prediction", dataclass=SentencePredictionConfig)
class SentencePredictionTask(FairseqTask):
"""
Sentence (or sentence pair) prediction (classification or regression) task.
Args:
dictionary (Dictionary): the dictionary for the input of the task
"""
def __init__(self, cfg, data_dictionary, label_dictionary):
super().__init__(cfg)
self.dictionary = data_dictionary
self._label_dictionary = label_dictionary
@classmethod
def load_dictionary(cls, filename):
"""Load the dictionary from the filename
Args:
filename (str): the filename
"""
dictionary = Dictionary.load(filename)
dictionary.add_symbol("<mask>")
return dictionary
@classmethod
def setup_task(cls, cfg, **kwargs):
assert cfg.num_classes > 0, "Must set task.num_classes"
# load data dictionary
data_dict = cls.load_dictionary(
os.path.join(cfg.data, "input0", "dict.txt"),
)
logger.info("[input] dictionary: {} types".format(len(data_dict)))
# load label dictionary
if not cfg.regression_target:
label_dict = cls.load_dictionary(
os.path.join(cfg.data, "label", "dict.txt"),
)
logger.info("[label] dictionary: {} types".format(len(label_dict)))
else:
label_dict = data_dict
return cls(cfg, data_dict, label_dict)
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split (e.g., train, valid, test)."""
def get_path(key, split):
return os.path.join(self.cfg.data, key, split)
def make_dataset(key, dictionary):
split_path = get_path(key, split)
try:
dataset = data_utils.load_indexed_dataset(
split_path,
dictionary,
combine=combine,
)
except Exception as e:
if "StorageException: [404] Path not found" in str(e):
logger.warning(f"dataset {e} not found")
dataset = None
else:
raise e
return dataset
input0 = make_dataset("input0", self.source_dictionary)
assert input0 is not None, "could not find dataset: {}".format(
get_path("input0", split)
)
input1 = make_dataset("input1", self.source_dictionary)
if self.cfg.init_token is not None:
input0 = PrependTokenDataset(input0, self.cfg.init_token)
if input1 is None:
src_tokens = input0
else:
if self.cfg.separator_token is not None:
input1 = PrependTokenDataset(input1, self.cfg.separator_token)
src_tokens = ConcatSentencesDataset(input0, input1)
with data_utils.numpy_seed(self.cfg.seed):
shuffle = np.random.permutation(len(src_tokens))
src_tokens = maybe_shorten_dataset(
src_tokens,
split,
self.cfg.shorten_data_split_list,
self.cfg.shorten_method,
self.max_positions(),
self.cfg.seed,
)
dataset = {
"id": IdDataset(),
"net_input": {
"src_tokens": RightPadDataset(
src_tokens,
pad_idx=self.source_dictionary.pad(),
),
"src_lengths": NumelDataset(src_tokens, reduce=False),
},
"nsentences": NumSamplesDataset(),
"ntokens": NumelDataset(src_tokens, reduce=True),
}
if self.cfg.add_prev_output_tokens:
prev_tokens_dataset = RightPadDataset(
RollDataset(src_tokens, 1),
pad_idx=self.dictionary.pad(),
)
dataset["net_input"].update(
prev_output_tokens=prev_tokens_dataset,
)
if not self.cfg.regression_target:
label_dataset = make_dataset("label", self.label_dictionary)
if label_dataset is not None:
dataset.update(
target=OffsetTokensDataset(
StripTokenDataset(
label_dataset,
id_to_strip=self.label_dictionary.eos(),
),
offset=-self.label_dictionary.nspecial,
)
)
else:
label_path = "{0}.label".format(get_path("label", split))
if os.path.exists(label_path):
def parse_regression_target(i, line):
values = line.split()
assert (
len(values) == self.cfg.num_classes
), f'expected num_classes={self.cfg.num_classes} regression target values on line {i}, found: "{line}"'
return [float(x) for x in values]
with open(label_path) as h:
dataset.update(
target=RawLabelDataset(
[
parse_regression_target(i, line.strip())
for i, line in enumerate(h.readlines())
]
)
)
nested_dataset = NestedDictionaryDataset(
dataset,
sizes=[src_tokens.sizes],
)
if self.cfg.no_shuffle:
dataset = nested_dataset
else:
dataset = SortDataset(
nested_dataset,
# shuffle
sort_order=[shuffle],
)
logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset)))
self.datasets[split] = dataset
return self.datasets[split]
def build_model(self, cfg, from_checkpoint=False):
from fairseq import models
with open_dict(cfg) if OmegaConf.is_config(cfg) else contextlib.ExitStack():
cfg.max_positions = self.cfg.max_positions
model = models.build_model(cfg, self, from_checkpoint)
model.register_classification_head(
self.cfg.classification_head_name,
num_classes=self.cfg.num_classes,
)
return model
def max_positions(self):
return self.cfg.max_positions
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary
@property
def label_dictionary(self):
return self._label_dictionary

View File

@@ -0,0 +1,56 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import contextlib
from omegaconf import open_dict, OmegaConf
from fairseq.tasks import register_task
from fairseq.tasks.sentence_prediction import (
SentencePredictionTask,
SentencePredictionConfig,
)
logger = logging.getLogger(__name__)
@register_task("sentence_prediction_adapters", dataclass=SentencePredictionConfig)
class SentencePredictionAdapterTask(SentencePredictionTask):
def build_model(self, cfg):
from fairseq import models
with open_dict(cfg) if OmegaConf.is_config(cfg) else contextlib.ExitStack():
cfg.max_positions = self.cfg.max_positions
model = models.build_model(cfg, self)
model.register_classification_head(
self.cfg.classification_head_name,
num_classes=self.cfg.num_classes,
)
logger.info("Freezing Embedding Parameters")
for parameter in model.encoder.sentence_encoder.embed_positions.parameters():
parameter.requires_grad = False
for (
parameter
) in model.encoder.sentence_encoder.layernorm_embedding.parameters():
parameter.requires_grad = False
for parameter in model.encoder.sentence_encoder.embed_tokens.parameters():
parameter.requires_grad = False
logger.info("Freezing Adapters")
for k, v in model.encoder.sentence_encoder.layers._modules.items():
logger.info("Freezing Adapters in Layer " + str(k))
if hasattr(v, "adapter_layer_norm"):
logger.info("Freezing Adapter LN")
for parameter in v.adapter_layer_norm.parameters():
parameter.requires_grad = False
for parameter in v.adapter_modules.parameters():
parameter.requires_grad = False
return model

View File

@@ -0,0 +1,219 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import numpy as np
from fairseq import utils
from fairseq.data import (
ConcatSentencesDataset,
Dictionary,
IdDataset,
NestedDictionaryDataset,
NumelDataset,
NumSamplesDataset,
PrependTokenDataset,
RawLabelDataset,
RightPadDataset,
SortDataset,
TruncateDataset,
data_utils,
)
from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task("sentence_ranking")
class SentenceRankingTask(LegacyFairseqTask):
"""
Ranking task on multiple sentences.
Args:
dictionary (Dictionary): the dictionary for the input of the task
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument("data", metavar="FILE", help="file prefix for data")
parser.add_argument(
"--num-classes", type=int, help="number of sentences to be ranked"
)
parser.add_argument(
"--init-token",
type=int,
help="add token at the beginning of each batch item",
)
parser.add_argument(
"--separator-token", type=int, help="add separator token between inputs"
)
parser.add_argument("--no-shuffle", action="store_true")
parser.add_argument(
"--shorten-method",
default="none",
choices=["none", "truncate", "random_crop"],
help="if not none, shorten sequences that exceed --tokens-per-sample",
)
parser.add_argument(
"--shorten-data-split-list",
default="",
help="comma-separated list of dataset splits to apply shortening to, "
'e.g., "train,valid" (default: all dataset splits)',
)
parser.add_argument(
"--max-option-length", type=int, help="max length for each option"
)
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
@classmethod
def load_dictionary(cls, args, filename, source=True):
"""Load the dictionary from the filename
Args:
filename (str): the filename
"""
dictionary = Dictionary.load(filename)
dictionary.add_symbol("<mask>")
return dictionary
@classmethod
def setup_task(cls, args, **kwargs):
assert (
args.criterion == "sentence_ranking"
), "Must set --criterion=sentence_ranking"
# load data dictionary
data_dict = cls.load_dictionary(
args,
os.path.join(args.data, "input0", "dict.txt"),
source=True,
)
logger.info("[input] dictionary: {} types".format(len(data_dict)))
return SentenceRankingTask(args, data_dict)
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split (e.g., train, valid, test)."""
def get_path(type, split):
return os.path.join(self.args.data, type, split)
def make_dataset(type, dictionary):
split_path = get_path(type, split)
dataset = data_utils.load_indexed_dataset(
split_path,
self.source_dictionary,
self.args.dataset_impl,
combine=combine,
)
return dataset
input0 = make_dataset("input0", self.source_dictionary)
input_options = [
make_dataset("input{idx}".format(idx=idx + 1), self.source_dictionary)
for idx in range(self.args.num_classes)
]
if self.args.separator_token is not None:
input0 = PrependTokenDataset(input0, self.args.separator_token)
src_tokens = []
for input_option in input_options:
if self.args.init_token is not None:
input_option = PrependTokenDataset(input_option, self.args.init_token)
if self.args.max_option_length is not None:
input_option = TruncateDataset(
input_option, self.args.max_option_length
)
src_token = ConcatSentencesDataset(input_option, input0)
src_token = maybe_shorten_dataset(
src_token,
split,
self.args.shorten_data_split_list,
self.args.shorten_method,
self.args.max_positions,
self.args.seed,
)
src_tokens.append(src_token)
with data_utils.numpy_seed(self.args.seed):
shuffle = np.random.permutation(len(src_tokens[0]))
dataset = {
"id": IdDataset(),
"nsentences": NumSamplesDataset(),
"ntokens": NumelDataset(src_tokens[0], reduce=True),
}
for src_token_idx in range(len(src_tokens)):
dataset.update(
{
"net_input{idx}".format(idx=src_token_idx + 1): {
"src_tokens": RightPadDataset(
src_tokens[src_token_idx],
pad_idx=self.source_dictionary.pad(),
),
"src_lengths": NumelDataset(
src_tokens[src_token_idx], reduce=False
),
}
}
)
label_path = "{}.label".format(get_path("label", split))
if os.path.exists(label_path):
with open(label_path) as h:
dataset.update(
target=RawLabelDataset([int(x.strip()) for x in h.readlines()])
)
nested_dataset = NestedDictionaryDataset(
dataset,
sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])],
)
if self.args.no_shuffle:
dataset = nested_dataset
else:
dataset = SortDataset(
nested_dataset,
# shuffle
sort_order=[shuffle],
)
logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset)))
self.datasets[split] = dataset
return self.datasets[split]
def build_model(self, args, from_checkpoint=False):
from fairseq import models
model = models.build_model(args, self, from_checkpoint)
model.register_classification_head(
getattr(args, "ranking_head_name", "sentence_classification_head"),
num_classes=1,
)
return model
def max_positions(self):
return self.args.max_positions
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary

View File

@@ -0,0 +1,41 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from fairseq.tasks import register_task
from fairseq.tasks.speech_to_text import SpeechToTextTask
from fairseq.tasks.translation import TranslationTask, TranslationConfig
try:
import examples.simultaneous_translation # noqa
import_successful = True
except BaseException:
import_successful = False
logger = logging.getLogger(__name__)
def check_import(flag):
if not flag:
raise ImportError(
"'examples.simultaneous_translation' is not correctly imported. "
"Please considering `pip install -e $FAIRSEQ_DIR`."
)
@register_task("simul_speech_to_text")
class SimulSpeechToTextTask(SpeechToTextTask):
def __init__(self, args, tgt_dict):
check_import(import_successful)
super().__init__(args, tgt_dict)
@register_task("simul_text_to_text", dataclass=TranslationConfig)
class SimulTextToTextTask(TranslationTask):
def __init__(self, cfg, src_dict, tgt_dict):
check_import(import_successful)
super().__init__(cfg, src_dict, tgt_dict)

View File

@@ -0,0 +1,243 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
from omegaconf import II, MISSING
from fairseq import utils
from fairseq.data import (
AppendTokenDataset,
Dictionary,
IdDataset,
NestedDictionaryDataset,
NumelDataset,
PadDataset,
PrependTokenDataset,
StripTokenDataset,
TokenBlockDataset,
data_utils,
)
from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.data.span_mask_tokens_dataset import SpanMaskedTokensDataset
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from ..data.indexed_dataset import get_available_dataset_impl
logger = logging.getLogger(__name__)
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
@dataclass
class SpanMaskedLMConfig(FairseqDataclass):
shuffle: bool = field(
default=False,
)
noise_density: float = field(
default=0.15,
metadata={"help": "What fraction of the tokens to select as noise"},
)
mean_noise_span_length: float = field(
default=3,
metadata={"help": "Mean noise span length, must be >= 1"},
)
data: str = field(
default=MISSING,
metadata={
"help": "colon separated path to data directories list, "
"will be iterated upon during epochs in round-robin manner"
},
)
sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
default="none",
metadata={
"help": 'If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end '
"of sentence, but may include multiple sentences per sample. "
'"complete_doc" is similar but respects doc boundaries. '
'If set to "eos", includes only one sentence per sample.'
},
)
tokens_per_sample: int = field(
default=1024,
metadata={"help": "max number of tokens per sample for LM dataset"},
)
shorten_method: SHORTEN_METHOD_CHOICES = field(
default="none",
metadata={
"help": "if not none, shorten sequences that exceed --tokens-per-sample"
},
)
shorten_data_split_list: str = field(
default="",
metadata={
"help": "comma-separated list of dataset splits to apply shortening to, "
'e.g., "train,valid" (default: all dataset splits)'
},
)
seed: int = II("common.seed")
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
"dataset.dataset_impl"
)
max_source_positions: int = field(
default=1024, metadata={"help": "max number of tokens in the source sequence"}
)
max_target_positions: int = field(
default=1024, metadata={"help": "max number of tokens in the target sequence"}
)
include_target_tokens: bool = field(
default=False,
metadata={
"help": "include target tokens in model input. this is used for data2vec"
},
)
@register_task("span_masked_lm", dataclass=SpanMaskedLMConfig)
class SpanMaskedLMTask(FairseqTask):
"""
Span masked language modeling task. (ie. T5)
"""
cfg: SpanMaskedLMConfig
def __init__(self, cfg, dictionary):
super().__init__(cfg)
self.dictionary = dictionary
@classmethod
def setup_task(cls, cfg: SpanMaskedLMConfig, **kwargs):
"""Setup the task."""
paths = utils.split_paths(cfg.data)
assert len(paths) > 0
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
logger.info("dictionary: {} types".format(len(dictionary)))
if not hasattr(cfg, "shuffle"):
cfg.shuffle = False
return cls(cfg, dictionary)
def _load_dataset_split(self, split, epoch, combine):
paths = utils.split_paths(self.cfg.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)
dataset = data_utils.load_indexed_dataset(
split_path,
self.dictionary,
self.cfg.dataset_impl,
combine=combine,
)
if dataset is None:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, split_path)
)
dataset = StripTokenDataset(dataset, self.dictionary.eos())
dataset = maybe_shorten_dataset(
dataset,
split,
self.cfg.shorten_data_split_list,
self.cfg.shorten_method,
self.cfg.tokens_per_sample,
self.cfg.seed,
)
# create continuous blocks of tokens
dataset = TokenBlockDataset(
dataset,
dataset.sizes,
self.cfg.tokens_per_sample - 2, # one less for <s> and one for </s>
pad=self.dictionary.pad(),
eos=self.dictionary.eos(),
break_mode=self.cfg.sample_break_mode,
document_sep_len=0,
)
logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
dataset = AppendTokenDataset(dataset, self.source_dictionary.eos())
return dataset
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
dataset = self._load_dataset_split(split, epoch, combine)
self.datasets[split] = SpanMaskedTokensDataset(
dataset,
self.dictionary,
noise_density=self.cfg.noise_density,
mean_noise_span_length=self.cfg.mean_noise_span_length,
shuffle=self.cfg.shuffle,
seed=self.cfg.seed,
)
logger.info(
"Split: {0}, Loaded {1} samples of span_masked_tokens_dataset".format(
split,
len(self.datasets[split]),
)
)
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
"""
Generate batches for inference. We assume that the input begins with a
bos symbol (`<s>`) and ends with an eos symbol (`</s>`).
"""
pad = self.source_dictionary.pad()
eos = self.source_dictionary.eos()
src_dataset = TokenBlockDataset(
src_tokens,
src_lengths,
block_size=self.cfg.tokens_per_sample - 2, # for <s> and </s>
pad=pad,
eos=eos,
break_mode=self.cfg.sample_break_mode,
document_sep_len=0,
)
prev_output_tokens = PrependTokenDataset(
StripTokenDataset(src_dataset, eos), eos
)
src_dataset = PadDataset(src_dataset, pad_idx=pad, left_pad=False)
return NestedDictionaryDataset(
{
"id": IdDataset(),
"net_input": {
"src_tokens": src_dataset,
"src_lengths": NumelDataset(src_dataset, reduce=False),
"prev_output_tokens": PadDataset(
prev_output_tokens, pad_idx=pad, left_pad=False
),
},
"target": src_dataset,
},
sizes=[np.array(src_lengths)],
)
def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.cfg.max_source_positions, self.cfg.max_target_positions)
@property
def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary`."""
return self.dictionary
@property
def target_dictionary(self):
"""Return the target :class:`~fairseq.data.Dictionary`."""
return self.dictionary

View File

@@ -0,0 +1,597 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import logging
import math
from argparse import Namespace
from pathlib import Path
from typing import List
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.data import Dictionary
from fairseq.data.audio.data_cfg import MultitaskConfig, S2SDataConfig
from fairseq.data.audio.speech_to_speech_dataset import SpeechToSpeechDatasetCreator
from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDataset,
TextTargetMultitaskData,
)
from fairseq.tasks import LegacyFairseqTask, register_task
from fairseq.tasks.speech_to_text import DummyMultiTask
from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion
logger = logging.getLogger(__name__)
class StackUnitSequenceGenerator(nn.Module):
def __init__(self, tgt_dict, vocab_size):
super().__init__()
self.pad = tgt_dict.pad()
self.eos = tgt_dict.eos()
self.unk = tgt_dict.unk()
self.offset = len(tgt_dict) - vocab_size
self.vocab_size = vocab_size
def pack_units(self, input: torch.Tensor, n_frames_per_step) -> torch.Tensor:
if n_frames_per_step <= 1:
return input
bsz, _, n = input.shape
assert n == n_frames_per_step
scale = [
pow(self.vocab_size, n_frames_per_step - 1 - i)
for i in range(n_frames_per_step)
]
scale = torch.LongTensor(scale).squeeze(0).to(input.device)
mask = input >= self.offset
res = ((input - self.offset) * scale * mask).sum(dim=2) + self.offset
return res
@torch.no_grad()
def generate(self, models, sample, **kwargs):
# currently only support viterbi search for stacked units
model = models[0]
model.eval()
max_len = model.max_decoder_positions()
# TODO: incorporate max_len_a and max_len_b
src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"]
bsz, src_len, _ = src_tokens.size()
n_frames_per_step = model.decoder.n_frames_per_step
# initialize
encoder_out = model.forward_encoder(
src_tokens, src_lengths, speaker=sample["speaker"]
)
incremental_state = {}
pred_out, attn, scores = [], [], []
finished = src_tokens.new_zeros((bsz,)).bool()
prev_output_tokens = src_lengths.new_zeros((bsz, 1)).long().fill_(self.eos)
for _ in range(max_len):
cur_out, cur_extra = model.forward_decoder(
prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
)
lprobs = model.get_normalized_probs([cur_out], log_probs=True)
# never select pad, unk
lprobs[:, :, self.pad] = -math.inf
lprobs[:, :, self.unk] = -math.inf
cur_pred_lprob, cur_pred_out = torch.max(lprobs, dim=2)
scores.append(cur_pred_lprob)
pred_out.append(cur_pred_out)
prev_output_tokens = torch.cat(
(
prev_output_tokens,
self.pack_units(
cur_pred_out.view(bsz, 1, n_frames_per_step), n_frames_per_step
),
),
dim=1,
)
attn.append(cur_extra["attn"][0])
cur_finished = torch.any(cur_pred_out.squeeze(1) == self.eos, dim=1)
finished = finished | cur_finished
if finished.sum().item() == bsz:
break
pred_out = torch.cat(pred_out, dim=1).view(bsz, -1)
attn = torch.cat(attn, dim=2)
alignment = attn.max(dim=1)[1]
attn = attn.repeat_interleave(n_frames_per_step, dim=2)
alignment = alignment.repeat_interleave(n_frames_per_step, dim=1)
scores = torch.cat(scores, dim=1)
eos_idx = (pred_out == self.eos).nonzero(as_tuple=True)
out_lens = src_lengths.new_zeros((bsz,)).long().fill_(max_len)
for b, l in zip(eos_idx[0], eos_idx[1]):
out_lens[b] = min(l, out_lens[b])
hypos = [
[
{
"tokens": pred_out[b, :out_len],
"attn": attn[b, :, :out_len],
"alignment": alignment[b, :out_len],
"positional_scores": scores[b, :out_len],
"score": utils.item(scores[b, :out_len].sum().data),
}
]
for b, out_len in zip(range(bsz), out_lens)
]
return hypos
@register_task("speech_to_speech")
class SpeechToSpeechTask(LegacyFairseqTask):
@classmethod
def add_args(cls, parser):
parser.add_argument("data", help="manifest root path")
parser.add_argument(
"--config-yaml",
type=str,
default="config.yaml",
help="Configuration YAML filename (under manifest root)",
)
parser.add_argument(
"--multitask-config-yaml",
type=str,
default=None,
help="Configuration YAML filename for the multitasks (under manifest root)",
)
parser.add_argument(
"--max-source-positions",
default=6000,
type=int,
metavar="N",
help="max number of tokens in the source sequence",
)
parser.add_argument(
"--max-target-positions",
default=1024,
type=int,
metavar="N",
help="max number of tokens in the target sequence",
)
parser.add_argument(
"--target-is-code",
action="store_true",
help="set if target is discrete unit instead of spectrogram",
)
parser.add_argument(
"--target-code-size", type=int, default=None, help="# discrete units"
)
parser.add_argument(
"--n-frames-per-step",
type=int,
default=1,
help="# stacked frames, use 0 for reduced discrete unit sequence",
)
parser.add_argument("--eval-inference", action="store_true")
parser.add_argument(
"--eval-args",
type=str,
default="{}",
help='generation args for speech-to-unit model , e.g., \'{"beam": 5, "max_len_a": 1}\', as JSON string',
)
parser.add_argument("--eos-prob-threshold", type=float, default=0.5)
parser.add_argument(
"--mcd-normalize-type",
type=str,
default="targ",
choices=["targ", "pred", "path"],
)
parser.add_argument(
"--vocoder",
type=str,
default="griffin_lim",
choices=["griffin_lim", "hifigan", "code_hifigan"],
)
parser.add_argument("--spec-bwd-max-iter", type=int, default=8)
parser.add_argument(
"--infer-target-lang",
type=str,
default="",
help="target language for inference",
)
def __init__(self, args, tgt_dict, infer_tgt_lang_id=None):
super().__init__(args)
self.tgt_dict = tgt_dict
self.data_cfg = S2SDataConfig(Path(args.data) / args.config_yaml)
self.multitask_tasks = {}
self.tgt_dict_mt = None
self.eos_token_mt = None
if getattr(args, "multitask_config_yaml", None) is not None:
multitask_cfg = MultitaskConfig(
Path(args.data) / args.multitask_config_yaml
)
first_pass_task_idx = multitask_cfg.first_pass_decoder_task_index
for i, (task_name, task_config) in enumerate(
multitask_cfg.get_all_tasks().items()
):
task_obj = DummyMultiTask(
task_config,
task_config.tgt_dict,
first_pass=i == first_pass_task_idx,
)
self.multitask_tasks[task_name] = task_obj
if task_obj.is_first_pass_decoder:
self.tgt_dict_mt = task_obj.target_dictionary
if task_config.prepend_bos_and_append_tgt_lang_tag:
self.eos_token_mt = task_config.eos_token
assert not isinstance(self.eos_token_mt, List)
if not self.eos_token_mt:
raise Warning(
"Please provide eos_token in --multitask-config-yaml to replace eos in sequence generator"
)
self._infer_tgt_lang_id = infer_tgt_lang_id
@classmethod
def setup_task(cls, args, **kwargs):
data_cfg = data_cfg = S2SDataConfig(Path(args.data) / args.config_yaml)
tgt_dict = None
infer_tgt_lang_id = None
if args.target_is_code:
if data_cfg.prepend_tgt_lang_tag_as_bos:
# dictionary with language tags
dict_path = Path(args.data) / data_cfg.vocab_filename
if not dict_path.is_file():
raise FileNotFoundError(
f"Dict has to be provided when setting prepend_tgt_lang_tag_as_bos: true, but dict not found: {dict_path}"
)
tgt_dict = Dictionary.load(dict_path.as_posix())
# target langauge for inference
if args.infer_target_lang != "":
tgt_lang_tag = SpeechToTextDataset.LANG_TAG_TEMPLATE.format(
args.infer_target_lang
)
infer_tgt_lang_id = tgt_dict.index(tgt_lang_tag)
assert infer_tgt_lang_id != tgt_dict.unk()
else:
assert args.target_code_size is not None
tgt_dict = Dictionary()
for i in range(args.target_code_size):
tgt_dict.add_symbol(str(i))
logger.info(f"dictionary size: " f"{len(tgt_dict):,}")
if getattr(args, "train_subset", None) is not None:
if not all(s.startswith("train") for s in args.train_subset.split(",")):
raise ValueError('Train splits should be named like "train*".')
assert args.n_frames_per_step >= 1
assert (
not args.eval_inference
or (args.target_is_code and args.vocoder == "code_hifigan")
or (not args.target_is_code and args.vocoder != "code_hifigan")
)
return cls(args, tgt_dict, infer_tgt_lang_id=infer_tgt_lang_id)
def build_criterion(self, args):
from fairseq import criterions
if len(self.multitask_tasks) > 0:
if self.args.target_is_code and not args._name.startswith("speech_to_unit"):
raise ValueError(
"set --criterion speech_to_unit for speech-to-unit loss with multitask"
)
elif not self.args.target_is_code and not args._name.startswith(
"speech_to_spectrogram"
):
raise ValueError(
"set --criterion speech_to_spectrogram for speech-to-spectrogram loss with multitask"
)
return criterions.build_criterion(args, self)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
self.datasets[split] = SpeechToSpeechDatasetCreator.from_tsv(
root=self.args.data,
data_cfg=self.data_cfg,
splits=split,
is_train_split=split.startswith("train"),
epoch=epoch,
seed=self.args.seed,
target_is_code=self.args.target_is_code,
tgt_dict=self.target_dictionary,
n_frames_per_step=self.args.n_frames_per_step,
multitask=self.multitask_tasks,
)
@property
def target_dictionary(self):
return self.tgt_dict
@property
def target_dictionary_mt(self):
return self.tgt_dict_mt
@property
def source_dictionary(self):
return None
def max_positions(self):
return self.args.max_source_positions, self.args.max_target_positions
def build_model(self, args, from_checkpoint=False):
args.input_feat_per_channel = self.data_cfg.input_feat_per_channel
args.input_channels = self.data_cfg.input_transformed_channels
args.target_speaker_embed = self.data_cfg.target_speaker_embed is not None
args.n_frames_per_step = self.args.n_frames_per_step
model = super().build_model(args, from_checkpoint)
if len(self.multitask_tasks) > 0:
from fairseq.models.speech_to_speech.s2s_transformer import (
S2STransformerMultitaskModelBase,
)
assert isinstance(model, S2STransformerMultitaskModelBase)
if self.args.eval_inference:
self.eval_gen_args = json.loads(self.args.eval_args)
self.generator = self.build_generator(
[model], Namespace(**self.eval_gen_args)
)
return model
def build_generator_dual_decoder(
self,
models,
args,
extra_gen_cls_kwargs=None,
):
from examples.speech_to_speech.unity.sequence_generator_multi_decoder import (
MultiDecoderSequenceGenerator,
)
return MultiDecoderSequenceGenerator(
models,
self.target_dictionary,
self.target_dictionary_mt,
beam_size=max(1, getattr(args, "beam", 1)),
beam_size_mt=max(1, getattr(args, "beam_mt", 1)),
max_len_a=getattr(args, "max_len_a", 0),
max_len_b=getattr(args, "max_len_b", 200),
max_len_a_mt=getattr(args, "max_len_a_mt", 0),
max_len_b_mt=getattr(args, "max_len_b_mt", 200),
min_len=getattr(args, "min_len", 1),
normalize_scores=(not getattr(args, "unnormalized", False)),
len_penalty=getattr(args, "lenpen", 1),
unk_penalty=getattr(args, "unkpen", 0),
temperature=getattr(args, "temperature", 1.0),
match_source_len=getattr(args, "match_source_len", False),
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
**extra_gen_cls_kwargs,
)
def build_generator(
self,
models,
args,
seq_gen_cls=None,
extra_gen_cls_kwargs=None,
):
if not self.args.target_is_code or self.args.eval_inference:
from fairseq.models.text_to_speech.vocoder import get_vocoder
self.vocoder = get_vocoder(self.args, self.data_cfg)
self.vocoder = (
self.vocoder.cuda()
if torch.cuda.is_available() and not self.args.cpu
else self.vocoder.cpu()
)
has_dual_decoder = getattr(models[0], "mt_task_name", None) is not None
if self.args.target_is_code:
if self.args.n_frames_per_step == 1:
if has_dual_decoder:
seq_generator = self.build_generator_dual_decoder(
models,
args,
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
)
else:
seq_generator = super().build_generator(
models,
args,
seq_gen_cls=None,
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
)
else:
assert (
getattr(args, "beam", 1) == 1 and getattr(args, "nbest", 1) == 1
), "only support viterbi search for stacked units"
seq_generator = StackUnitSequenceGenerator(
self.tgt_dict,
self.args.target_code_size,
)
else:
if has_dual_decoder:
if getattr(args, "teacher_forcing", False):
raise NotImplementedError
else:
from fairseq.speech_generator import MultiDecoderSpeechGenerator
generator = MultiDecoderSpeechGenerator
lang_token_ids_aux = {
i
for s, i in self.tgt_dict_mt.indices.items()
if TextTargetMultitaskData.is_lang_tag(s)
}
if extra_gen_cls_kwargs is None:
extra_gen_cls_kwargs = {}
extra_gen_cls_kwargs[
"symbols_to_strip_from_output"
] = lang_token_ids_aux
eos_id_mt = (
self.tgt_dict_mt.index(self.eos_token_mt)
if self.eos_token_mt
else None
)
assert eos_id_mt != self.tgt_dict_mt.unk()
extra_gen_cls_kwargs["eos_mt"] = eos_id_mt
seq_generator = generator(
models,
args,
self.vocoder,
self.data_cfg,
self.target_dictionary_mt,
max_iter=self.args.max_target_positions,
eos_prob_threshold=self.args.eos_prob_threshold,
**extra_gen_cls_kwargs,
)
else:
if getattr(args, "teacher_forcing", False):
from fairseq.speech_generator import (
TeacherForcingAutoRegressiveSpeechGenerator,
)
generator = TeacherForcingAutoRegressiveSpeechGenerator
logger.info("Teacher forcing mode for generation")
else:
from fairseq.speech_generator import AutoRegressiveSpeechGenerator
generator = AutoRegressiveSpeechGenerator
seq_generator = generator(
models[0],
self.vocoder,
self.data_cfg,
max_iter=self.args.max_target_positions,
eos_prob_threshold=self.args.eos_prob_threshold,
)
return seq_generator
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
for task_name, task_obj in self.multitask_tasks.items():
criterion.set_multitask_loss_weight(
task_name, task_obj.args.get_loss_weight(update_num)
)
if task_name in model.multitask_decoders:
model.multitask_decoders[task_name].train()
loss, sample_size, logging_output = super().train_step(
sample, model, criterion, optimizer, update_num, ignore_grad
)
return loss, sample_size, logging_output
def valid_step(self, sample, model, criterion):
for task_name in self.multitask_tasks.keys():
if task_name in model.multitask_decoders:
model.multitask_decoders[task_name].eval()
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
if self.args.eval_inference:
hypos, inference_losses = self.valid_step_with_inference(
sample, model, self.generator
)
for k, v in inference_losses.items():
assert k not in logging_output
logging_output[k] = v
return loss, sample_size, logging_output
def valid_step_with_inference(self, sample, model, generator):
if self.args.target_is_code:
hypos = generator.generate([model], sample)
tgt_lens = (
sample["target_lengths"] - 1
) * self.args.n_frames_per_step # strip <eos>
for b, (f, l) in enumerate(zip(sample["target"], tgt_lens)):
hypos[b][0]["targ_waveform"] = self.vocoder(
{"code": f[:l] - 4}, # remove <bos>, <pad>, <eos>, <unk>
dur_prediction=self.eval_gen_args.get("dur_prediction", False),
)
if len(hypos[b][0]["tokens"]) > 0:
hypos[b][0]["waveform"] = self.vocoder(
{"code": hypos[b][0]["tokens"] - 4},
dur_prediction=self.eval_gen_args.get("dur_prediction", False),
)
else:
hypos[b][0]["waveform"] = torch.flip(
hypos[b][0]["targ_waveform"], dims=[0]
)
else:
hypos = [
[hypo] for hypo in generator.generate(model, sample, has_targ=True)
]
losses = {
"mcd_loss": 0.0,
"targ_frames": 0.0,
"pred_frames": 0.0,
"path_frames": 0.0,
"nins": 0.0,
"ndel": 0.0,
}
rets = batch_mel_cepstral_distortion(
[hypo[0]["targ_waveform"] for hypo in hypos],
[hypo[0]["waveform"] for hypo in hypos],
self.data_cfg.output_sample_rate,
normalize_type=None,
)
for d, extra in rets:
pathmap = extra[-1]
losses["mcd_loss"] += d.item()
losses["targ_frames"] += pathmap.size(0)
losses["pred_frames"] += pathmap.size(1)
losses["path_frames"] += pathmap.sum().item()
losses["nins"] += (pathmap.sum(dim=1) - 1).sum().item()
losses["ndel"] += (pathmap.sum(dim=0) - 1).sum().item()
losses["norm_frames"] = losses[
f"{getattr(self.args, 'mcd_normalize_type', 'targ')}_frames"
]
return hypos, losses
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
with torch.no_grad():
if self._infer_tgt_lang_id is not None:
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
constraints=constraints,
bos_token=self._infer_tgt_lang_id,
)
else:
return super().inference_step(
generator,
models,
sample,
prefix_tokens=prefix_tokens,
constraints=constraints,
)

View File

@@ -0,0 +1,350 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from argparse import Namespace
from pathlib import Path
from typing import List
from fairseq.data import Dictionary, encoders
from fairseq.data.audio.audio_utils import get_features_or_waveform
from fairseq.data.audio.data_cfg import MultitaskConfig
from fairseq.data.audio.speech_to_text_dataset import (
S2TDataConfig,
SpeechToTextDataset,
SpeechToTextDatasetCreator,
TextTargetMultitaskData,
)
from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task("speech_to_text")
class SpeechToTextTask(LegacyFairseqTask):
@classmethod
def add_args(cls, parser):
parser.add_argument("data", help="manifest root path")
parser.add_argument(
"--config-yaml",
type=str,
default="config.yaml",
help="Configuration YAML filename (under manifest root)",
)
parser.add_argument(
"--multitask-config-yaml",
type=str,
default=None,
help="Configuration YAML filename for the multitasks (under manifest root)",
)
parser.add_argument(
"--max-source-positions",
default=6000,
type=int,
metavar="N",
help="max number of tokens in the source sequence",
)
parser.add_argument(
"--max-target-positions",
default=1024,
type=int,
metavar="N",
help="max number of tokens in the target sequence",
)
def __init__(self, args, tgt_dict):
super().__init__(args)
self.tgt_dict = tgt_dict
self.data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml)
self.speaker_to_id = self._get_speaker_to_id()
if (
self.data_cfg.prepend_tgt_lang_tag
and self.data_cfg.prepend_bos_and_append_tgt_lang_tag
):
raise ValueError(
"Please set only one of the two options to avoid adding target token multiple times"
)
self.multitask_tasks = {}
self.tgt_dict_mt = None
self.eos_token_mt = None
if getattr(args, "multitask_config_yaml", None) is not None:
multitask_cfg = MultitaskConfig(
Path(args.data) / args.multitask_config_yaml
)
first_pass_task_idx = multitask_cfg.first_pass_decoder_task_index
for i, (task_name, task_config) in enumerate(
multitask_cfg.get_all_tasks().items()
):
task_obj = DummyMultiTask(
task_config,
task_config.tgt_dict,
first_pass=i == first_pass_task_idx,
)
self.multitask_tasks[task_name] = task_obj
if task_obj.is_first_pass_decoder:
self.tgt_dict_mt = task_obj.target_dictionary
if task_config.prepend_bos_and_append_tgt_lang_tag:
self.eos_token_mt = task_config.eos_token
assert not isinstance(self.eos_token_mt, List)
if not self.eos_token_mt:
raise Warning(
"Please provide eos_token in --multitask-config-yaml to replace eos in sequence generator"
)
def _get_speaker_to_id(self):
speaker_to_id = None
speaker_set_filename = self.data_cfg.config.get("speaker_set_filename")
if speaker_set_filename is not None:
speaker_set_path = Path(self.args.data) / speaker_set_filename
with open(speaker_set_path) as f:
speaker_to_id = {r.strip(): i for i, r in enumerate(f)}
return speaker_to_id
@classmethod
def setup_task(cls, args, **kwargs):
data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml)
dict_path = Path(args.data) / data_cfg.vocab_filename
if not dict_path.is_file():
raise FileNotFoundError(f"Dict not found: {dict_path.as_posix()}")
tgt_dict = Dictionary.load(dict_path.as_posix())
logger.info(
f"dictionary size ({data_cfg.vocab_filename}): " f"{len(tgt_dict):,}"
)
if getattr(args, "train_subset", None) is not None:
if not all(s.startswith("train") for s in args.train_subset.split(",")):
raise ValueError('Train splits should be named like "train*".')
return cls(args, tgt_dict)
def build_criterion(self, args):
from fairseq import criterions
if self.data_cfg.prepend_tgt_lang_tag and args.ignore_prefix_size != 1:
raise ValueError(
'Please set "--ignore-prefix-size 1" since '
"target language ID token is prepended as BOS."
)
return criterions.build_criterion(args, self)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
is_train_split = split.startswith("train")
pre_tokenizer = self.build_tokenizer(self.args)
bpe_tokenizer = self.build_bpe(self.args)
self.datasets[split] = SpeechToTextDatasetCreator.from_tsv(
root=self.args.data,
cfg=self.data_cfg,
splits=split,
tgt_dict=self.tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
is_train_split=is_train_split,
epoch=epoch,
seed=self.args.seed,
speaker_to_id=self.speaker_to_id,
multitask=self.multitask_tasks,
)
@property
def target_dictionary(self):
return self.tgt_dict
@property
def target_dictionary_mt(self):
return self.tgt_dict_mt
@property
def source_dictionary(self):
return None
def max_positions(self):
return self.args.max_source_positions, self.args.max_target_positions
def build_model(self, args, from_checkpoint=False):
args.input_feat_per_channel = self.data_cfg.input_feat_per_channel
args.input_channels = self.data_cfg.input_channels
args.speaker_to_id = self.speaker_to_id
return super(SpeechToTextTask, self).build_model(args, from_checkpoint)
def build_generator_dual_decoder(
self,
models,
args,
extra_gen_cls_kwargs,
):
from examples.speech_to_speech.unity.sequence_generator_multi_decoder import (
MultiDecoderSequenceGenerator,
)
lang_token_ids_aux = {
i
for s, i in self.tgt_dict_mt.indices.items()
if TextTargetMultitaskData.is_lang_tag(s)
}
extra_gen_cls_kwargs["symbols_to_strip_from_output"].update(lang_token_ids_aux)
eos_id_mt = (
self.tgt_dict_mt.index(self.eos_token_mt) if self.eos_token_mt else None
)
assert eos_id_mt != self.tgt_dict_mt.unk()
extra_gen_cls_kwargs["eos_mt"] = eos_id_mt
return MultiDecoderSequenceGenerator(
models,
self.target_dictionary,
self.target_dictionary_mt,
beam_size=max(1, getattr(args, "beam", 1)),
beam_size_mt=max(1, getattr(args, "beam_mt", 1)),
max_len_a=getattr(args, "max_len_a", 0),
max_len_b=getattr(args, "max_len_b", 200),
max_len_a_mt=getattr(args, "max_len_a_mt", 0),
max_len_b_mt=getattr(args, "max_len_b_mt", 0),
min_len=getattr(args, "min_len", 1),
normalize_scores=(not getattr(args, "unnormalized", False)),
len_penalty=getattr(args, "lenpen", 1),
len_penalty_mt=getattr(args, "lenpen_mt", 1),
unk_penalty=getattr(args, "unkpen", 0),
temperature=getattr(args, "temperature", 1.0),
match_source_len=getattr(args, "match_source_len", False),
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
**extra_gen_cls_kwargs,
)
def build_generator(
self,
models,
args,
seq_gen_cls=None,
extra_gen_cls_kwargs=None,
):
if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1:
raise ValueError(
'Please set "--prefix-size 1" since '
"target language ID token is prepended as BOS."
)
lang_token_ids = {
i
for s, i in self.tgt_dict.indices.items()
if SpeechToTextDataset.is_lang_tag(s)
}
if extra_gen_cls_kwargs is None:
extra_gen_cls_kwargs = {}
extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids
eos_token = (
args.eos_token
if "eos_token" in args and args.eos_token is not None
else self.data_cfg.config.get("eos_token", None)
)
if self.data_cfg.prepend_bos_and_append_tgt_lang_tag and not eos_token:
raise Warning(
"Please provide --eos_token to replace eos in sequence generator"
)
eos_id = self.tgt_dict.index(eos_token) if eos_token else None
extra_gen_cls_kwargs["eos"] = eos_id
has_dual_decoder = getattr(models[0], "mt_task_name", None) is not None
if has_dual_decoder:
return self.build_generator_dual_decoder(
models,
args,
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
)
else:
return super().build_generator(
models,
args,
seq_gen_cls=None,
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
)
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
for task_name, task_obj in self.multitask_tasks.items():
criterion.set_multitask_loss_weight(
task_name, task_obj.args.get_loss_weight(update_num)
)
if task_name in model.multitask_decoders:
model.multitask_decoders[task_name].train()
loss, sample_size, logging_output = super().train_step(
sample, model, criterion, optimizer, update_num, ignore_grad
)
return loss, sample_size, logging_output
def valid_step(self, sample, model, criterion):
for task_name, task_obj in self.multitask_tasks.items():
if task_name in model.multitask_decoders:
model.multitask_decoders[task_name].eval()
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
return loss, sample_size, logging_output
def build_tokenizer(self, args):
logger.info(f"pre-tokenizer: {self.data_cfg.pre_tokenizer}")
return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer))
def build_bpe(self, args):
logger.info(f"tokenizer: {self.data_cfg.bpe_tokenizer}")
return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer))
def get_interactive_tokens_and_lengths(self, lines, encode_fn):
n_frames = [get_features_or_waveform(p).shape[0] for p in lines]
return lines, n_frames
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
return SpeechToTextDataset(
"interactive", False, self.data_cfg, src_tokens, src_lengths
)
class DummyMultiTask(LegacyFairseqTask):
def __init__(self, args, tgt_dict, first_pass=False):
super().__init__(args)
self.tgt_dict = tgt_dict
self.first_pass = first_pass
@property
def target_dictionary(self):
return self.tgt_dict
@property
def is_first_pass_decoder(self):
return self.first_pass
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
if self.args.decoder_type == "ctc":
model = models[0] # only support single model
encoder_out = model(**sample)
if hasattr(model, "get_logits"):
emissions = model.get_logits(
encoder_out
) # no need to normalize emissions
else:
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
return generator.decode(
emissions.transpose(0, 1).float().cpu().contiguous()
)
else:
raise NotImplementedError("only ctc decoder is supported at the moment")
def build_generator(
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None
):
if self.args.decoder_type == "ctc":
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
return W2lViterbiDecoder(args, self.tgt_dict)
else:
raise NotImplementedError("only ctc decoder is supported at the moment")

View File

@@ -0,0 +1,224 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
import sys
import torch
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
from fairseq.data import Dictionary
from fairseq.data.codedataset import ExpressiveCodeDataConfig, CodeDataset
from fairseq.dataclass.configs import FairseqDataclass
from fairseq.tasks import register_task
from fairseq.tasks.fairseq_task import FairseqTask
from omegaconf import MISSING, DictConfig
logger = logging.getLogger(__name__)
class UnitDictionary(Dictionary):
"""
A fixed-sized Dictionary that operates on integer-valued tokens
wth a trivial (identity) token <-> id mapping.
Special symbols (bos, eos, ...) have ids above n_units.
"""
def __init__(
self,
*, # begin keyword-only arguments
n_units,
bos="<s>",
pad="<pad>",
eos="</s>",
unk="<unk>",
extra_special_symbols=None,
clip=False,
):
self.n_units = n_units
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
self.clip = clip
self.symbols = []
self.count = []
self.indices = {}
for i in range(n_units):
self.add_symbol(str(i))
self.bos_index = self.add_symbol(bos)
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
if extra_special_symbols:
for s in extra_special_symbols:
self.add_symbol(s)
self.nspecial = len(self.symbols)
def encode_line(self, line, append_eos=True, prepend_bos=False) -> torch.IntTensor:
words = [int(x) for x in line.split()]
if self.clip:
words = [min(self.n_units - 1, word) for word in words]
if prepend_bos:
words = [self.bos_index] + words
if append_eos:
words.append(self.eos_index)
ids = torch.IntTensor(words)
return ids
@dataclass
class SpeechUnitModelingConfig(FairseqDataclass):
data: str = field(default=MISSING, metadata={"help": "Path to data config.json"})
max_token_duration: int = field(
default=20, metadata={"help": "all token durations are capped to this value"}
)
tokens_per_sample: int = field(
default=1024, metadata={"help": "tokens in a sample"}
)
max_target_positions: int = field(
default=1024, metadata={"help": "max target positions"}
)
# duration modeling
ignore_duration_input: bool = field(
default=False, metadata={"help": "whether token durations should be zeroed out"}
)
discrete_duration: bool = field(
default=False, metadata={"help": "treat duration as discrete variable"}
)
# F0 modeling
ignore_f0_input: bool = field(
default=False, metadata={"help": "whether F0 should be zeroed out"}
)
discrete_f0: bool = field(
default=False, metadata={"help": "load quantized f0. get bin from config"}
)
log_f0: bool = field(
default=False, metadata={"help": "whether f0 should be modeled in log space"}
)
normalize_f0_mean: bool = field(
default=False, metadata={"help": "whether normalize f0 by speaker mean"}
)
normalize_f0_std: bool = field(
default=False, metadata={"help": "whether normalize f0 by speaker stddev"}
)
interpolate_f0: bool = field(
default=False,
metadata={"help": "whether interpolate f0 for non-voiced segments"},
)
# input/output streams
stream_shifts: str = field(
default="0,0",
metadata={
"help": (
"comma-separated integer list denoting right-shift for "
"duration and pitch streams"
)
},
)
@register_task("speech_unit_modeling", dataclass=SpeechUnitModelingConfig)
class SpeechUnitLanguageModelingTask(FairseqTask):
def __init__(self, cfg: SpeechUnitModelingConfig) -> None:
super().__init__(cfg)
assert not self.cfg.normalize_f0_std or self.cfg.normalize_f0_mean
self.data_config = ExpressiveCodeDataConfig(cfg.data)
self._source_dictionary = self._target_dictionary = UnitDictionary(
n_units=self.data_config.n_units
)
self._source_duration_dictionary = self._target_duration_dictionary = (
UnitDictionary(n_units=self.cfg.max_token_duration + 1, clip=True)
if self.cfg.discrete_duration
else None
)
self._source_f0_dictionary = self._target_f0_dictionary = (
UnitDictionary(n_units=self.data_config.f0_vq_n_units)
if self.cfg.discrete_f0
else None
)
self._channel_names = ["token", "duration", "f0"]
self._channel_sizes = [
len(self.target_dictionary),
len(self.target_duration_dictionary) if self.cfg.discrete_duration else 1,
len(self.target_f0_dictionary) if self.cfg.discrete_f0 else 1,
]
@property
def source_dictionary(self) -> Optional[Dictionary]:
return self._source_dictionary
@property
def source_duration_dictionary(self) -> Optional[Dictionary]:
return self._source_duration_dictionary
@property
def source_f0_dictionary(self) -> Optional[Dictionary]:
return self._source_f0_dictionary
@property
def channel_names(self) -> List[str]:
return self._channel_names
@property
def channel_sizes(self) -> List[int]:
return self._channel_sizes
@property
def dictionary(self) -> Optional[Dictionary]:
return self._source_dictionary
@property
def target_dictionary(self) -> Optional[Dictionary]:
return self._target_dictionary
@property
def target_duration_dictionary(self) -> Optional[Dictionary]:
return self._target_duration_dictionary
@property
def target_f0_dictionary(self) -> Optional[Dictionary]:
return self._target_f0_dictionary
@property
def dictionaries(self) -> List[Dictionary]:
return [self._dictionaries[l] for l in self.cfg.labels]
@classmethod
def setup_task(
cls, cfg: SpeechUnitModelingConfig, **kwargs
) -> "SpeechUnitLanguageModelingTask":
return cls(cfg)
def load_dataset(self, split: str, **kwargs) -> None:
self.datasets[split] = CodeDataset(
manifest=self.data_config.manifests[split],
dictionary=self.source_dictionary,
dur_dictionary=self.source_duration_dictionary,
f0_dictionary=self.source_f0_dictionary,
config=self.data_config,
discrete_dur=self.cfg.discrete_duration,
discrete_f0=self.cfg.discrete_f0,
log_f0=self.cfg.log_f0,
normalize_f0_mean=self.cfg.normalize_f0_mean,
normalize_f0_std=self.cfg.normalize_f0_std,
interpolate_f0=self.cfg.interpolate_f0,
shifts=self.cfg.stream_shifts,
)
def max_positions(self) -> Tuple[int, int]:
return (sys.maxsize, sys.maxsize)
def build_criterion(self, cfg: DictConfig):
import fairseq.criterions
return fairseq.criterions.build_criterion(cfg, self)

View File

@@ -0,0 +1,501 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import os.path as op
import torch
import torch.nn.functional as F
import numpy as np
from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDatasetCreator
from fairseq.tasks import register_task
from fairseq.tasks.speech_to_text import SpeechToTextTask
from fairseq.speech_generator import (
AutoRegressiveSpeechGenerator,
NonAutoregressiveSpeechGenerator,
TeacherForcingAutoRegressiveSpeechGenerator,
)
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
try:
from tensorboardX import SummaryWriter
except ImportError:
logger.info("Please install tensorboardX: pip install tensorboardX")
SummaryWriter = None
@register_task("text_to_speech")
class TextToSpeechTask(SpeechToTextTask):
@staticmethod
def add_args(parser):
parser.add_argument("data", help="manifest root path")
parser.add_argument(
"--config-yaml",
type=str,
default="config.yaml",
help="Configuration YAML filename (under manifest root)",
)
parser.add_argument(
"--max-source-positions",
default=1024,
type=int,
metavar="N",
help="max number of tokens in the source sequence",
)
parser.add_argument(
"--max-target-positions",
default=1200,
type=int,
metavar="N",
help="max number of tokens in the target sequence",
)
parser.add_argument("--n-frames-per-step", type=int, default=1)
parser.add_argument("--eos-prob-threshold", type=float, default=0.5)
parser.add_argument("--eval-inference", action="store_true")
parser.add_argument("--eval-tb-nsample", type=int, default=8)
parser.add_argument("--vocoder", type=str, default="griffin_lim")
parser.add_argument("--spec-bwd-max-iter", type=int, default=8)
def __init__(self, args, src_dict):
super().__init__(args, src_dict)
self.src_dict = src_dict
self.sr = self.data_cfg.config.get("features").get("sample_rate")
self.tensorboard_writer = None
self.tensorboard_dir = ""
if args.tensorboard_logdir and SummaryWriter is not None:
self.tensorboard_dir = os.path.join(args.tensorboard_logdir, "valid_extra")
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
is_train_split = split.startswith("train")
pre_tokenizer = self.build_tokenizer(self.args)
bpe_tokenizer = self.build_bpe(self.args)
self.datasets[split] = TextToSpeechDatasetCreator.from_tsv(
self.args.data,
self.data_cfg,
split,
self.src_dict,
pre_tokenizer,
bpe_tokenizer,
is_train_split=is_train_split,
epoch=epoch,
seed=self.args.seed,
n_frames_per_step=self.args.n_frames_per_step,
speaker_to_id=self.speaker_to_id,
)
@property
def target_dictionary(self):
return None
@property
def source_dictionary(self):
return self.src_dict
def get_speaker_embeddings_path(self):
speaker_emb_path = None
if self.data_cfg.config.get("speaker_emb_filename") is not None:
speaker_emb_path = op.join(
self.args.data, self.data_cfg.config.get("speaker_emb_filename")
)
return speaker_emb_path
@classmethod
def get_speaker_embeddings(cls, args):
embed_speaker = None
if args.speaker_to_id is not None:
if args.speaker_emb_path is None:
embed_speaker = torch.nn.Embedding(
len(args.speaker_to_id), args.speaker_embed_dim
)
else:
speaker_emb_mat = np.load(args.speaker_emb_path)
assert speaker_emb_mat.shape[1] == args.speaker_embed_dim
embed_speaker = torch.nn.Embedding.from_pretrained(
torch.from_numpy(speaker_emb_mat),
freeze=True,
)
logger.info(
f"load speaker embeddings from {args.speaker_emb_path}. "
f"train embedding? {embed_speaker.weight.requires_grad}\n"
f"embeddings:\n{speaker_emb_mat}"
)
return embed_speaker
def build_model(self, cfg, from_checkpoint=False):
cfg.pitch_min = self.data_cfg.config["features"].get("pitch_min", None)
cfg.pitch_max = self.data_cfg.config["features"].get("pitch_max", None)
cfg.energy_min = self.data_cfg.config["features"].get("energy_min", None)
cfg.energy_max = self.data_cfg.config["features"].get("energy_max", None)
cfg.speaker_emb_path = self.get_speaker_embeddings_path()
model = super().build_model(cfg, from_checkpoint)
self.generator = None
if getattr(cfg, "eval_inference", False):
self.generator = self.build_generator([model], cfg)
return model
def build_generator(self, models, cfg, vocoder=None, **unused):
if vocoder is None:
vocoder = self.build_default_vocoder()
model = models[0]
if getattr(model, "NON_AUTOREGRESSIVE", False):
return NonAutoregressiveSpeechGenerator(model, vocoder, self.data_cfg)
else:
generator = AutoRegressiveSpeechGenerator
if getattr(cfg, "teacher_forcing", False):
generator = TeacherForcingAutoRegressiveSpeechGenerator
logger.info("Teacher forcing mode for generation")
return generator(
model,
vocoder,
self.data_cfg,
max_iter=self.args.max_target_positions,
eos_prob_threshold=self.args.eos_prob_threshold,
)
def build_default_vocoder(self):
from fairseq.models.text_to_speech.vocoder import get_vocoder
vocoder = get_vocoder(self.args, self.data_cfg)
if torch.cuda.is_available() and not self.args.cpu:
vocoder = vocoder.cuda()
else:
vocoder = vocoder.cpu()
return vocoder
def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
if getattr(self.args, "eval_inference", False):
hypos, inference_losses = self.valid_step_with_inference(
sample, model, self.generator
)
for k, v in inference_losses.items():
assert k not in logging_output
logging_output[k] = v
picked_id = 0
if self.tensorboard_dir and (sample["id"] == picked_id).any():
self.log_tensorboard(
sample,
hypos[: self.args.eval_tb_nsample],
model._num_updates,
is_na_model=getattr(model, "NON_AUTOREGRESSIVE", False),
)
return loss, sample_size, logging_output
def valid_step_with_inference(self, sample, model, generator):
hypos = generator.generate(model, sample, has_targ=True)
losses = {
"mcd_loss": 0.0,
"targ_frames": 0.0,
"pred_frames": 0.0,
"nins": 0.0,
"ndel": 0.0,
}
rets = batch_mel_cepstral_distortion(
[hypo["targ_waveform"] for hypo in hypos],
[hypo["waveform"] for hypo in hypos],
self.sr,
normalize_type=None,
)
for d, extra in rets:
pathmap = extra[-1]
losses["mcd_loss"] += d.item()
losses["targ_frames"] += pathmap.size(0)
losses["pred_frames"] += pathmap.size(1)
losses["nins"] += (pathmap.sum(dim=1) - 1).sum().item()
losses["ndel"] += (pathmap.sum(dim=0) - 1).sum().item()
return hypos, losses
def log_tensorboard(self, sample, hypos, num_updates, is_na_model=False):
if self.tensorboard_writer is None:
self.tensorboard_writer = SummaryWriter(self.tensorboard_dir)
tb_writer = self.tensorboard_writer
for b in range(len(hypos)):
idx = sample["id"][b]
text = sample["src_texts"][b]
targ = hypos[b]["targ_feature"]
pred = hypos[b]["feature"]
attn = hypos[b]["attn"]
if is_na_model:
data = plot_tts_output(
[targ.transpose(0, 1), pred.transpose(0, 1)],
[f"target (idx={idx})", "output"],
attn,
"alignment",
ret_np=True,
suptitle=text,
)
else:
eos_prob = hypos[b]["eos_prob"]
data = plot_tts_output(
[targ.transpose(0, 1), pred.transpose(0, 1), attn],
[f"target (idx={idx})", "output", "alignment"],
eos_prob,
"eos prob",
ret_np=True,
suptitle=text,
)
tb_writer.add_image(
f"inference_sample_{b}", data, num_updates, dataformats="HWC"
)
if hypos[b]["waveform"] is not None:
targ_wave = hypos[b]["targ_waveform"].detach().cpu().float()
pred_wave = hypos[b]["waveform"].detach().cpu().float()
tb_writer.add_audio(
f"inference_targ_{b}", targ_wave, num_updates, sample_rate=self.sr
)
tb_writer.add_audio(
f"inference_pred_{b}", pred_wave, num_updates, sample_rate=self.sr
)
def save_figure_to_numpy(fig):
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data
DEFAULT_V_MIN = np.log(1e-5)
def plot_tts_output(
data_2d,
title_2d,
data_1d,
title_1d,
figsize=(24, 4),
v_min=DEFAULT_V_MIN,
v_max=3,
ret_np=False,
suptitle="",
):
try:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
except ImportError:
raise ImportError("Please install Matplotlib: pip install matplotlib")
data_2d = [
x.detach().cpu().float().numpy() if isinstance(x, torch.Tensor) else x
for x in data_2d
]
fig, axes = plt.subplots(1, len(data_2d) + 1, figsize=figsize)
if suptitle:
fig.suptitle(suptitle[:400]) # capped at 400 chars
axes = [axes] if len(data_2d) == 0 else axes
for ax, x, name in zip(axes, data_2d, title_2d):
ax.set_title(name)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
im = ax.imshow(
x,
origin="lower",
aspect="auto",
vmin=max(x.min(), v_min),
vmax=min(x.max(), v_max),
)
fig.colorbar(im, cax=cax, orientation="vertical")
if isinstance(data_1d, torch.Tensor):
data_1d = data_1d.detach().cpu().numpy()
axes[-1].plot(data_1d)
axes[-1].set_title(title_1d)
plt.tight_layout()
if ret_np:
fig.canvas.draw()
data = save_figure_to_numpy(fig)
plt.close(fig)
return data
def antidiag_indices(offset, min_i=0, max_i=None, min_j=0, max_j=None):
"""
for a (3, 4) matrix with min_i=1, max_i=3, min_j=1, max_j=4, outputs
offset=2 (1, 1),
offset=3 (2, 1), (1, 2)
offset=4 (2, 2), (1, 3)
offset=5 (2, 3)
constraints:
i + j = offset
min_j <= j < max_j
min_i <= offset - j < max_i
"""
if max_i is None:
max_i = offset + 1
if max_j is None:
max_j = offset + 1
min_j = max(min_j, offset - max_i + 1, 0)
max_j = min(max_j, offset - min_i + 1, offset + 1)
j = torch.arange(min_j, max_j)
i = offset - j
return torch.stack([i, j])
def batch_dynamic_time_warping(distance, shapes=None):
"""full batched DTW without any constraints
distance: (batchsize, max_M, max_N) matrix
shapes: (batchsize,) vector specifying (M, N) for each entry
"""
# ptr: 0=left, 1=up-left, 2=up
ptr2dij = {0: (0, -1), 1: (-1, -1), 2: (-1, 0)}
bsz, m, n = distance.size()
cumdist = torch.zeros_like(distance)
backptr = torch.zeros_like(distance).type(torch.int32) - 1
# initialize
cumdist[:, 0, :] = distance[:, 0, :].cumsum(dim=-1)
cumdist[:, :, 0] = distance[:, :, 0].cumsum(dim=-1)
backptr[:, 0, :] = 0
backptr[:, :, 0] = 2
# DP with optimized anti-diagonal parallelization, O(M+N) steps
for offset in range(2, m + n - 1):
ind = antidiag_indices(offset, 1, m, 1, n)
c = torch.stack(
[
cumdist[:, ind[0], ind[1] - 1],
cumdist[:, ind[0] - 1, ind[1] - 1],
cumdist[:, ind[0] - 1, ind[1]],
],
dim=2,
)
v, b = c.min(axis=-1)
backptr[:, ind[0], ind[1]] = b.int()
cumdist[:, ind[0], ind[1]] = v + distance[:, ind[0], ind[1]]
# backtrace
pathmap = torch.zeros_like(backptr)
for b in range(bsz):
i = m - 1 if shapes is None else (shapes[b][0] - 1).item()
j = n - 1 if shapes is None else (shapes[b][1] - 1).item()
dtwpath = [(i, j)]
while (i != 0 or j != 0) and len(dtwpath) < 10000:
assert i >= 0 and j >= 0
di, dj = ptr2dij[backptr[b, i, j].item()]
i, j = i + di, j + dj
dtwpath.append((i, j))
dtwpath = dtwpath[::-1]
indices = torch.from_numpy(np.array(dtwpath))
pathmap[b, indices[:, 0], indices[:, 1]] = 1
return cumdist, backptr, pathmap
def compute_l2_dist(x1, x2):
"""compute an (m, n) L2 distance matrix from (m, d) and (n, d) matrices"""
return torch.cdist(x1.unsqueeze(0), x2.unsqueeze(0), p=2).squeeze(0).pow(2)
def compute_rms_dist(x1, x2):
l2_dist = compute_l2_dist(x1, x2)
return (l2_dist / x1.size(1)).pow(0.5)
def get_divisor(pathmap, normalize_type):
if normalize_type is None:
return 1
elif normalize_type == "len1":
return pathmap.size(0)
elif normalize_type == "len2":
return pathmap.size(1)
elif normalize_type == "path":
return pathmap.sum().item()
else:
raise ValueError(f"normalize_type {normalize_type} not supported")
def batch_compute_distortion(y1, y2, sr, feat_fn, dist_fn, normalize_type):
d, s, x1, x2 = [], [], [], []
for cur_y1, cur_y2 in zip(y1, y2):
assert cur_y1.ndim == 1 and cur_y2.ndim == 1
cur_x1 = feat_fn(cur_y1)
cur_x2 = feat_fn(cur_y2)
x1.append(cur_x1)
x2.append(cur_x2)
cur_d = dist_fn(cur_x1, cur_x2)
d.append(cur_d)
s.append(d[-1].size())
max_m = max(ss[0] for ss in s)
max_n = max(ss[1] for ss in s)
d = torch.stack(
[F.pad(dd, (0, max_n - dd.size(1), 0, max_m - dd.size(0))) for dd in d]
)
s = torch.LongTensor(s).to(d.device)
cumdists, backptrs, pathmaps = batch_dynamic_time_warping(d, s)
rets = []
itr = zip(s, x1, x2, d, cumdists, backptrs, pathmaps)
for (m, n), cur_x1, cur_x2, dist, cumdist, backptr, pathmap in itr:
cumdist = cumdist[:m, :n]
backptr = backptr[:m, :n]
pathmap = pathmap[:m, :n]
divisor = get_divisor(pathmap, normalize_type)
distortion = cumdist[-1, -1] / divisor
ret = distortion, (cur_x1, cur_x2, dist, cumdist, backptr, pathmap)
rets.append(ret)
return rets
def batch_mel_cepstral_distortion(y1, y2, sr, normalize_type="path", mfcc_fn=None):
"""
https://arxiv.org/pdf/2011.03568.pdf
The root mean squared error computed on 13-dimensional MFCC using DTW for
alignment. MFCC features are computed from an 80-channel log-mel
spectrogram using a 50ms Hann window and hop of 12.5ms.
y1: list of waveforms
y2: list of waveforms
sr: sampling rate
"""
try:
import torchaudio
except ImportError:
raise ImportError("Please install torchaudio: pip install torchaudio")
if mfcc_fn is None or mfcc_fn.sample_rate != sr:
melkwargs = {
"n_fft": int(0.05 * sr),
"win_length": int(0.05 * sr),
"hop_length": int(0.0125 * sr),
"f_min": 20,
"n_mels": 80,
"window_fn": torch.hann_window,
}
mfcc_fn = torchaudio.transforms.MFCC(
sr, n_mfcc=13, log_mels=True, melkwargs=melkwargs
).to(y1[0].device)
return batch_compute_distortion(
y1,
y2,
sr,
lambda y: mfcc_fn(y).transpose(-1, -2),
compute_rms_dist,
normalize_type,
)

View File

@@ -0,0 +1,497 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
import itertools
import json
import logging
import os
from typing import Optional
from argparse import Namespace
from omegaconf import II
import numpy as np
from fairseq import metrics, utils
from fairseq.data import (
AppendTokenDataset,
ConcatDataset,
LanguagePairDataset,
PrependTokenDataset,
StripTokenDataset,
TruncateDataset,
data_utils,
encoders,
indexed_dataset,
)
from fairseq.data.indexed_dataset import get_available_dataset_impl
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
EVAL_BLEU_ORDER = 4
logger = logging.getLogger(__name__)
def load_langpair_dataset(
data_path,
split,
src,
src_dict,
tgt,
tgt_dict,
combine,
dataset_impl,
upsample_primary,
left_pad_source,
left_pad_target,
max_source_positions,
max_target_positions,
prepend_bos=False,
load_alignments=False,
truncate_source=False,
append_source_id=False,
num_buckets=0,
shuffle=True,
pad_to_multiple=1,
prepend_bos_src=None,
):
def split_exists(split, src, tgt, lang, data_path):
filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang))
return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
src_datasets = []
tgt_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else "")
# infer langcode
if split_exists(split_k, src, tgt, src, data_path):
prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt))
elif split_exists(split_k, tgt, src, src, data_path):
prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src))
else:
if k > 0:
break
else:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, data_path)
)
src_dataset = data_utils.load_indexed_dataset(
prefix + src, src_dict, dataset_impl
)
if truncate_source:
src_dataset = AppendTokenDataset(
TruncateDataset(
StripTokenDataset(src_dataset, src_dict.eos()),
max_source_positions - 1,
),
src_dict.eos(),
)
src_datasets.append(src_dataset)
tgt_dataset = data_utils.load_indexed_dataset(
prefix + tgt, tgt_dict, dataset_impl
)
if tgt_dataset is not None:
tgt_datasets.append(tgt_dataset)
logger.info(
"{} {} {}-{} {} examples".format(
data_path, split_k, src, tgt, len(src_datasets[-1])
)
)
if not combine:
break
assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
if len(src_datasets) == 1:
src_dataset = src_datasets[0]
tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
else:
sample_ratios = [1] * len(src_datasets)
sample_ratios[0] = upsample_primary
src_dataset = ConcatDataset(src_datasets, sample_ratios)
if len(tgt_datasets) > 0:
tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
else:
tgt_dataset = None
if prepend_bos:
assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
if tgt_dataset is not None:
tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
elif prepend_bos_src is not None:
logger.info(f"prepending src bos: {prepend_bos_src}")
src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src)
eos = None
if append_source_id:
src_dataset = AppendTokenDataset(
src_dataset, src_dict.index("[{}]".format(src))
)
if tgt_dataset is not None:
tgt_dataset = AppendTokenDataset(
tgt_dataset, tgt_dict.index("[{}]".format(tgt))
)
eos = tgt_dict.index("[{}]".format(tgt))
align_dataset = None
if load_alignments:
align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt))
if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
align_dataset = data_utils.load_indexed_dataset(
align_path, None, dataset_impl
)
tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
return LanguagePairDataset(
src_dataset,
src_dataset.sizes,
src_dict,
tgt_dataset,
tgt_dataset_sizes,
tgt_dict,
left_pad_source=left_pad_source,
left_pad_target=left_pad_target,
align_dataset=align_dataset,
eos=eos,
num_buckets=num_buckets,
shuffle=shuffle,
pad_to_multiple=pad_to_multiple,
)
@dataclass
class TranslationConfig(FairseqDataclass):
data: Optional[str] = field(
default=None,
metadata={
"help": "colon separated path to data directories list, will be iterated upon during epochs "
"in round-robin manner; however, valid and test data are always in the first directory "
"to avoid the need for repeating them in all directories"
},
)
source_lang: Optional[str] = field(
default=None,
metadata={
"help": "source language",
"argparse_alias": "-s",
},
)
target_lang: Optional[str] = field(
default=None,
metadata={
"help": "target language",
"argparse_alias": "-t",
},
)
load_alignments: bool = field(
default=False, metadata={"help": "load the binarized alignments"}
)
left_pad_source: bool = field(
default=True, metadata={"help": "pad the source on the left"}
)
left_pad_target: bool = field(
default=False, metadata={"help": "pad the target on the left"}
)
max_source_positions: int = field(
default=1024, metadata={"help": "max number of tokens in the source sequence"}
)
max_target_positions: int = field(
default=1024, metadata={"help": "max number of tokens in the target sequence"}
)
upsample_primary: int = field(
default=-1, metadata={"help": "the amount of upsample primary dataset"}
)
truncate_source: bool = field(
default=False, metadata={"help": "truncate source to max-source-positions"}
)
num_batch_buckets: int = field(
default=0,
metadata={
"help": "if >0, then bucket source and target lengths into "
"N buckets and pad accordingly; this is useful on TPUs to minimize the number of compilations"
},
)
train_subset: str = II("dataset.train_subset")
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
"dataset.dataset_impl"
)
required_seq_len_multiple: int = II("dataset.required_seq_len_multiple")
# options for reporting BLEU during validation
eval_bleu: bool = field(
default=False, metadata={"help": "evaluation with BLEU scores"}
)
eval_bleu_args: Optional[str] = field(
default="{}",
metadata={
"help": 'generation args for BLUE scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string'
},
)
eval_bleu_detok: str = field(
default="space",
metadata={
"help": "detokenize before computing BLEU (e.g., 'moses'); required if using --eval-bleu; "
"use 'space' to disable detokenization; see fairseq.data.encoders for other options"
},
)
eval_bleu_detok_args: Optional[str] = field(
default="{}",
metadata={"help": "args for building the tokenizer, if needed, as JSON string"},
)
eval_tokenized_bleu: bool = field(
default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"}
)
eval_bleu_remove_bpe: Optional[str] = field(
default=None,
metadata={
"help": "remove BPE before computing BLEU",
"argparse_const": "@@ ",
},
)
eval_bleu_print_samples: bool = field(
default=False, metadata={"help": "print sample generations during validation"}
)
@register_task("translation", dataclass=TranslationConfig)
class TranslationTask(FairseqTask):
"""
Translate from one (source) language to another (target) language.
Args:
src_dict (~fairseq.data.Dictionary): dictionary for the source language
tgt_dict (~fairseq.data.Dictionary): dictionary for the target language
.. note::
The translation task is compatible with :mod:`fairseq-train`,
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
"""
cfg: TranslationConfig
def __init__(self, cfg: TranslationConfig, src_dict, tgt_dict):
super().__init__(cfg)
self.src_dict = src_dict
self.tgt_dict = tgt_dict
@classmethod
def setup_task(cls, cfg: TranslationConfig, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
args (argparse.Namespace): parsed command-line arguments
"""
paths = utils.split_paths(cfg.data)
assert len(paths) > 0
# find language pair automatically
if cfg.source_lang is None or cfg.target_lang is None:
cfg.source_lang, cfg.target_lang = data_utils.infer_language_pair(paths[0])
if cfg.source_lang is None or cfg.target_lang is None:
raise Exception(
"Could not infer language pair, please provide it explicitly"
)
# load dictionaries
src_dict = cls.load_dictionary(
os.path.join(paths[0], "dict.{}.txt".format(cfg.source_lang))
)
tgt_dict = cls.load_dictionary(
os.path.join(paths[0], "dict.{}.txt".format(cfg.target_lang))
)
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
logger.info("[{}] dictionary: {} types".format(cfg.source_lang, len(src_dict)))
logger.info("[{}] dictionary: {} types".format(cfg.target_lang, len(tgt_dict)))
return cls(cfg, src_dict, tgt_dict)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = utils.split_paths(self.cfg.data)
assert len(paths) > 0
if split != self.cfg.train_subset:
# if not training data set, use the first shard for valid and test
paths = paths[:1]
data_path = paths[(epoch - 1) % len(paths)]
# infer langcode
src, tgt = self.cfg.source_lang, self.cfg.target_lang
self.datasets[split] = load_langpair_dataset(
data_path,
split,
src,
self.src_dict,
tgt,
self.tgt_dict,
combine=combine,
dataset_impl=self.cfg.dataset_impl,
upsample_primary=self.cfg.upsample_primary,
left_pad_source=self.cfg.left_pad_source,
left_pad_target=self.cfg.left_pad_target,
max_source_positions=self.cfg.max_source_positions,
max_target_positions=self.cfg.max_target_positions,
load_alignments=self.cfg.load_alignments,
truncate_source=self.cfg.truncate_source,
num_buckets=self.cfg.num_batch_buckets,
shuffle=(split != "test"),
pad_to_multiple=self.cfg.required_seq_len_multiple,
)
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
return LanguagePairDataset(
src_tokens,
src_lengths,
self.source_dictionary,
tgt_dict=self.target_dictionary,
constraints=constraints,
)
def build_model(self, cfg, from_checkpoint=False):
model = super().build_model(cfg, from_checkpoint)
if self.cfg.eval_bleu:
detok_args = json.loads(self.cfg.eval_bleu_detok_args)
self.tokenizer = encoders.build_tokenizer(
Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args)
)
gen_args = json.loads(self.cfg.eval_bleu_args)
self.sequence_generator = self.build_generator(
[model], Namespace(**gen_args)
)
return model
def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
if self.cfg.eval_bleu:
bleu = self._inference_with_bleu(self.sequence_generator, sample, model)
logging_output["_bleu_sys_len"] = bleu.sys_len
logging_output["_bleu_ref_len"] = bleu.ref_len
# we split counts into separate entries so that they can be
# summed efficiently across workers using fast-stat-sync
assert len(bleu.counts) == EVAL_BLEU_ORDER
for i in range(EVAL_BLEU_ORDER):
logging_output["_bleu_counts_" + str(i)] = bleu.counts[i]
logging_output["_bleu_totals_" + str(i)] = bleu.totals[i]
return loss, sample_size, logging_output
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
if self.cfg.eval_bleu:
def sum_logs(key):
import torch
result = sum(log.get(key, 0) for log in logging_outputs)
if torch.is_tensor(result):
result = result.cpu()
return result
counts, totals = [], []
for i in range(EVAL_BLEU_ORDER):
counts.append(sum_logs("_bleu_counts_" + str(i)))
totals.append(sum_logs("_bleu_totals_" + str(i)))
if max(totals) > 0:
# log counts as numpy arrays -- log_scalar will sum them correctly
metrics.log_scalar("_bleu_counts", np.array(counts))
metrics.log_scalar("_bleu_totals", np.array(totals))
metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len"))
metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len"))
def compute_bleu(meters):
import inspect
try:
from sacrebleu.metrics import BLEU
comp_bleu = BLEU.compute_bleu
except ImportError:
# compatibility API for sacrebleu 1.x
import sacrebleu
comp_bleu = sacrebleu.compute_bleu
fn_sig = inspect.getfullargspec(comp_bleu)[0]
if "smooth_method" in fn_sig:
smooth = {"smooth_method": "exp"}
else:
smooth = {"smooth": "exp"}
bleu = comp_bleu(
correct=meters["_bleu_counts"].sum,
total=meters["_bleu_totals"].sum,
sys_len=int(meters["_bleu_sys_len"].sum),
ref_len=int(meters["_bleu_ref_len"].sum),
**smooth,
)
return round(bleu.score, 2)
metrics.log_derived("bleu", compute_bleu)
def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.cfg.max_source_positions, self.cfg.max_target_positions)
@property
def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary`."""
return self.src_dict
@property
def target_dictionary(self):
"""Return the target :class:`~fairseq.data.Dictionary`."""
return self.tgt_dict
def _inference_with_bleu(self, generator, sample, model):
import sacrebleu
def decode(toks, escape_unk=False):
s = self.tgt_dict.string(
toks.int().cpu(),
self.cfg.eval_bleu_remove_bpe,
# The default unknown string in fairseq is `<unk>`, but
# this is tokenized by sacrebleu as `< unk >`, inflating
# BLEU scores. Instead, we use a somewhat more verbose
# alternative that is unlikely to appear in the real
# reference, but doesn't get split into multiple tokens.
unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"),
)
if self.tokenizer:
s = self.tokenizer.decode(s)
return s
gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None)
hyps, refs = [], []
for i in range(len(gen_out)):
hyps.append(decode(gen_out[i][0]["tokens"]))
refs.append(
decode(
utils.strip_pad(sample["target"][i], self.tgt_dict.pad()),
escape_unk=True, # don't count <unk> as matches to the hypo
)
)
if self.cfg.eval_bleu_print_samples:
logger.info("example hypothesis: " + hyps[0])
logger.info("example reference: " + refs[0])
if self.cfg.eval_tokenized_bleu:
return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none")
else:
return sacrebleu.corpus_bleu(hyps, [refs])

View File

@@ -0,0 +1,132 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq import utils
from fairseq.data import LanguagePairDataset
from . import register_task
from .translation import TranslationTask, load_langpair_dataset
@register_task("translation_from_pretrained_bart")
class TranslationFromPretrainedBARTTask(TranslationTask):
"""
Translate from source language to target language with a model initialized with a multilingual pretrain.
Args:
src_dict (~fairseq.data.Dictionary): dictionary for the source language
tgt_dict (~fairseq.data.Dictionary): dictionary for the target language
.. note::
The translation task is compatible with :mod:`fairseq-train`,
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
The translation task provides the following additional command-line
arguments:
.. argparse::
:ref: fairseq.tasks.translation_parser
:prog:
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
TranslationTask.add_args(parser)
parser.add_argument('--langs', type=str, metavar='LANG',
help='comma-separated list of monolingual language, '
'for example, "en,de,fr". These should match the '
'langs from pretraining (and be in the same order). '
'You should always add all pretraining language idx '
'during finetuning.')
parser.add_argument('--prepend-bos', action='store_true',
help='prepend bos token to each sentence, which matches '
'mBART pretraining')
# fmt: on
def __init__(self, args, src_dict, tgt_dict):
super().__init__(args, src_dict, tgt_dict)
self.langs = args.langs.split(",")
for d in [src_dict, tgt_dict]:
for l in self.langs:
d.add_symbol("[{}]".format(l))
d.add_symbol("<mask>")
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
self.datasets[split] = load_langpair_dataset(
data_path,
split,
src,
self.src_dict,
tgt,
self.tgt_dict,
combine=combine,
dataset_impl=self.args.dataset_impl,
upsample_primary=self.args.upsample_primary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=getattr(self.args, "max_source_positions", 1024),
max_target_positions=getattr(self.args, "max_target_positions", 1024),
load_alignments=self.args.load_alignments,
prepend_bos=getattr(self.args, "prepend_bos", False),
append_source_id=True,
)
def build_generator(self, models, args, **unused):
if getattr(args, "score_reference", False):
from fairseq.sequence_scorer import SequenceScorer
return SequenceScorer(
self.target_dictionary,
eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)),
)
else:
from fairseq.sequence_generator import SequenceGenerator
return SequenceGenerator(
models,
self.target_dictionary,
beam_size=getattr(args, "beam", 5),
max_len_a=getattr(args, "max_len_a", 0),
max_len_b=getattr(args, "max_len_b", 200),
min_len=getattr(args, "min_len", 1),
normalize_scores=(not getattr(args, "unnormalized", False)),
len_penalty=getattr(args, "lenpen", 1),
unk_penalty=getattr(args, "unkpen", 0),
temperature=getattr(args, "temperature", 1.0),
match_source_len=getattr(args, "match_source_len", False),
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)),
)
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
src_lang_id = self.source_dictionary.index("[{}]".format(self.args.source_lang))
source_tokens = []
for s_t in src_tokens:
s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)])
source_tokens.append(s_t)
dataset = LanguagePairDataset(
source_tokens,
src_lengths,
self.source_dictionary,
tgt_dict=self.target_dictionary,
constraints=constraints,
)
return dataset

View File

@@ -0,0 +1,39 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
from fairseq.tasks.translation import TranslationConfig, TranslationTask
from . import register_task
@dataclass
class TranslationFromPretrainedXLMConfig(TranslationConfig):
pass
@register_task(
"translation_from_pretrained_xlm", dataclass=TranslationFromPretrainedXLMConfig
)
class TranslationFromPretrainedXLMTask(TranslationTask):
"""
Same as TranslationTask except use the MaskedLMDictionary class so that
we can load data that was binarized with the MaskedLMDictionary class.
This task should be used for the entire training pipeline when we want to
train an NMT model from a pretrained XLM checkpoint: binarizing NMT data,
training NMT with the pretrained XLM checkpoint, and subsequent evaluation
of that trained model.
"""
@classmethod
def load_dictionary(cls, filename):
"""Load the masked LM dictionary from the filename
Args:
filename (str): the filename
"""
return MaskedLMDictionary.load(filename)

View File

@@ -0,0 +1,195 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
import torch
from fairseq import utils
from fairseq.data import LanguagePairDataset
from fairseq.dataclass import ChoiceEnum
from fairseq.tasks import register_task
from fairseq.tasks.translation import (
TranslationConfig,
TranslationTask,
load_langpair_dataset,
)
from fairseq.utils import new_arange
NOISE_CHOICES = ChoiceEnum(["random_delete", "random_mask", "no_noise", "full_mask"])
@dataclass
class TranslationLevenshteinConfig(TranslationConfig):
noise: NOISE_CHOICES = field(
default="random_delete",
metadata={"help": "type of noise"},
)
@register_task("translation_lev", dataclass=TranslationLevenshteinConfig)
class TranslationLevenshteinTask(TranslationTask):
"""
Translation (Sequence Generation) task for Levenshtein Transformer
See `"Levenshtein Transformer" <https://arxiv.org/abs/1905.11006>`_.
"""
cfg: TranslationLevenshteinConfig
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = utils.split_paths(self.cfg.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
# infer langcode
src, tgt = self.cfg.source_lang, self.cfg.target_lang
self.datasets[split] = load_langpair_dataset(
data_path,
split,
src,
self.src_dict,
tgt,
self.tgt_dict,
combine=combine,
dataset_impl=self.cfg.dataset_impl,
upsample_primary=self.cfg.upsample_primary,
left_pad_source=self.cfg.left_pad_source,
left_pad_target=self.cfg.left_pad_target,
max_source_positions=self.cfg.max_source_positions,
max_target_positions=self.cfg.max_target_positions,
prepend_bos=True,
)
def inject_noise(self, target_tokens):
def _random_delete(target_tokens):
pad = self.tgt_dict.pad()
bos = self.tgt_dict.bos()
eos = self.tgt_dict.eos()
max_len = target_tokens.size(1)
target_mask = target_tokens.eq(pad)
target_score = target_tokens.clone().float().uniform_()
target_score.masked_fill_(
target_tokens.eq(bos) | target_tokens.eq(eos), 0.0
)
target_score.masked_fill_(target_mask, 1)
target_score, target_rank = target_score.sort(1)
target_length = target_mask.size(1) - target_mask.float().sum(
1, keepdim=True
)
# do not delete <bos> and <eos> (we assign 0 score for them)
target_cutoff = (
2
+ (
(target_length - 2)
* target_score.new_zeros(target_score.size(0), 1).uniform_()
).long()
)
target_cutoff = target_score.sort(1)[1] >= target_cutoff
prev_target_tokens = (
target_tokens.gather(1, target_rank)
.masked_fill_(target_cutoff, pad)
.gather(1, target_rank.masked_fill_(target_cutoff, max_len).sort(1)[1])
)
prev_target_tokens = prev_target_tokens[
:, : prev_target_tokens.ne(pad).sum(1).max()
]
return prev_target_tokens
def _random_mask(target_tokens):
pad = self.tgt_dict.pad()
bos = self.tgt_dict.bos()
eos = self.tgt_dict.eos()
unk = self.tgt_dict.unk()
target_masks = (
target_tokens.ne(pad) & target_tokens.ne(bos) & target_tokens.ne(eos)
)
target_score = target_tokens.clone().float().uniform_()
target_score.masked_fill_(~target_masks, 2.0)
target_length = target_masks.sum(1).float()
target_length = target_length * target_length.clone().uniform_()
target_length = target_length + 1 # make sure to mask at least one token.
_, target_rank = target_score.sort(1)
target_cutoff = new_arange(target_rank) < target_length[:, None].long()
prev_target_tokens = target_tokens.masked_fill(
target_cutoff.scatter(1, target_rank, target_cutoff), unk
)
return prev_target_tokens
def _full_mask(target_tokens):
pad = self.tgt_dict.pad()
bos = self.tgt_dict.bos()
eos = self.tgt_dict.eos()
unk = self.tgt_dict.unk()
target_mask = (
target_tokens.eq(bos) | target_tokens.eq(eos) | target_tokens.eq(pad)
)
return target_tokens.masked_fill(~target_mask, unk)
if self.cfg.noise == "random_delete":
return _random_delete(target_tokens)
elif self.cfg.noise == "random_mask":
return _random_mask(target_tokens)
elif self.cfg.noise == "full_mask":
return _full_mask(target_tokens)
elif self.cfg.noise == "no_noise":
return target_tokens
else:
raise NotImplementedError
def build_generator(self, models, args, **unused):
# add models input to match the API for SequenceGenerator
from fairseq.iterative_refinement_generator import IterativeRefinementGenerator
return IterativeRefinementGenerator(
self.target_dictionary,
eos_penalty=getattr(args, "iter_decode_eos_penalty", 0.0),
max_iter=getattr(args, "iter_decode_max_iter", 10),
beam_size=getattr(args, "iter_decode_with_beam", 1),
reranking=getattr(args, "iter_decode_with_external_reranker", False),
decoding_format=getattr(args, "decoding_format", None),
adaptive=not getattr(args, "iter_decode_force_max_iter", False),
retain_history=getattr(args, "retain_iter_history", False),
)
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
if constraints is not None:
# Though see Susanto et al. (ACL 2020): https://www.aclweb.org/anthology/2020.acl-main.325/
raise NotImplementedError(
"Constrained decoding with the translation_lev task is not supported"
)
return LanguagePairDataset(
src_tokens, src_lengths, self.source_dictionary, append_bos=True
)
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
model.train()
sample["prev_target"] = self.inject_noise(sample["target"])
loss, sample_size, logging_output = criterion(model, sample)
if ignore_grad:
loss *= 0
optimizer.backward(loss)
return loss, sample_size, logging_output
def valid_step(self, sample, model, criterion):
model.eval()
with torch.no_grad():
sample["prev_target"] = self.inject_noise(sample["target"])
loss, sample_size, logging_output = criterion(model, sample)
return loss, sample_size, logging_output

View File

@@ -0,0 +1,441 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import logging
import time
import torch
from fairseq.data import (
FairseqDataset,
LanguagePairDataset,
ListDataset,
data_utils,
iterators,
)
from fairseq.data.multilingual.multilingual_data_manager import (
MultilingualDatasetManager,
)
from fairseq.data.multilingual.sampling_method import SamplingMethod
from fairseq.tasks import LegacyFairseqTask, register_task
from fairseq.utils import FileContentsAction
###
def get_time_gap(s, e):
return (
datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)
).__str__()
###
logger = logging.getLogger(__name__)
@register_task("translation_multi_simple_epoch")
class TranslationMultiSimpleEpochTask(LegacyFairseqTask):
"""
Translate from one (source) language to another (target) language.
Args:
langs (List[str]): a list of languages that are being supported
dicts (Dict[str, fairseq.data.Dictionary]): mapping from supported languages to their dictionaries
training (bool): whether the task should be configured for training or not
.. note::
The translation task is compatible with :mod:`fairseq-train`,
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
The translation task provides the following additional command-line
arguments:
.. argparse::
:ref: fairseq.tasks.translation_parser
:prog:
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='inference source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='inference target language')
parser.add_argument('--lang-pairs', default=None, metavar='PAIRS',
help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr',
action=FileContentsAction)
parser.add_argument('--keep-inference-langtok', action='store_true',
help='keep language tokens in inference output (e.g. for analysis or debugging)')
SamplingMethod.add_arguments(parser)
MultilingualDatasetManager.add_args(parser)
# fmt: on
def __init__(self, args, langs, dicts, training):
super().__init__(args)
self.langs = langs
self.dicts = dicts
self.training = training
if training:
self.lang_pairs = args.lang_pairs
else:
self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)]
# eval_lang_pairs for multilingual translation is usually all of the
# lang_pairs. However for other multitask settings or when we want to
# optimize for certain languages we want to use a different subset. Thus
# the eval_lang_pairs class variable is provided for classes that extend
# this class.
self.eval_lang_pairs = self.lang_pairs
# model_lang_pairs will be used to build encoder-decoder model pairs in
# models.build_model(). This allows multitask type of sub-class can
# build models other than the input lang_pairs
self.model_lang_pairs = self.lang_pairs
self.source_langs = [d.split("-")[0] for d in self.lang_pairs]
self.target_langs = [d.split("-")[1] for d in self.lang_pairs]
self.check_dicts(self.dicts, self.source_langs, self.target_langs)
self.sampling_method = SamplingMethod.build_sampler(args, self)
self.data_manager = MultilingualDatasetManager.setup_data_manager(
args, self.lang_pairs, langs, dicts, self.sampling_method
)
def check_dicts(self, dicts, source_langs, target_langs):
if self.args.source_dict is not None or self.args.target_dict is not None:
# no need to check whether the source side and target side are sharing dictionaries
return
src_dict = dicts[source_langs[0]]
tgt_dict = dicts[target_langs[0]]
for src_lang in source_langs:
assert (
src_dict == dicts[src_lang]
), "Diffrent dictionary are specified for different source languages; "
"TranslationMultiSimpleEpochTask only supports one shared dictionary across all source languages"
for tgt_lang in target_langs:
assert (
tgt_dict == dicts[tgt_lang]
), "Diffrent dictionary are specified for different target languages; "
"TranslationMultiSimpleEpochTask only supports one shared dictionary across all target languages"
@classmethod
def setup_task(cls, args, **kwargs):
langs, dicts, training = MultilingualDatasetManager.prepare(
cls.load_dictionary, args, **kwargs
)
return cls(args, langs, dicts, training)
def has_sharded_data(self, split):
return self.data_manager.has_sharded_data(split)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if split in self.datasets:
dataset = self.datasets[split]
if self.has_sharded_data(split):
if self.args.virtual_epoch_size is not None:
if dataset.load_next_shard:
shard_epoch = dataset.shard_epoch
else:
# no need to load next shard so skip loading
# also this avoid always loading from beginning of the data
return
else:
shard_epoch = epoch
else:
# estimate the shard epoch from virtual data size and virtual epoch size
shard_epoch = self.data_manager.estimate_global_pass_epoch(epoch)
logger.info(f"loading data for {split} epoch={epoch}/{shard_epoch}")
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
if split in self.datasets:
del self.datasets[split]
logger.info("old dataset deleted manually")
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
self.datasets[split] = self.data_manager.load_dataset(
split,
self.training,
epoch=epoch,
combine=combine,
shard_epoch=shard_epoch,
**kwargs,
)
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
if constraints is not None:
raise NotImplementedError(
"Constrained decoding with the multilingual_translation task is not supported"
)
src_data = ListDataset(src_tokens, src_lengths)
dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary)
src_langtok_spec, tgt_langtok_spec = self.args.langtoks["main"]
if self.args.lang_tok_replacing_bos_eos:
dataset = self.data_manager.alter_dataset_langtok(
dataset,
src_eos=self.source_dictionary.eos(),
src_lang=self.args.source_lang,
tgt_eos=self.target_dictionary.eos(),
tgt_lang=self.args.target_lang,
src_langtok_spec=src_langtok_spec,
tgt_langtok_spec=tgt_langtok_spec,
)
else:
dataset.src = self.data_manager.src_dataset_tranform_func(
self.args.source_lang,
self.args.target_lang,
dataset=dataset.src,
spec=src_langtok_spec,
)
return dataset
def build_generator(
self,
models,
args,
seq_gen_cls=None,
extra_gen_cls_kwargs=None,
):
if not getattr(args, "keep_inference_langtok", False):
_, tgt_langtok_spec = self.args.langtoks["main"]
if tgt_langtok_spec:
tgt_lang_tok = self.data_manager.get_decoder_langtok(
self.args.target_lang, tgt_langtok_spec
)
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
extra_gen_cls_kwargs["symbols_to_strip_from_output"] = {tgt_lang_tok}
return super().build_generator(
models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs
)
def build_model(self, args, from_checkpoint=False):
return super().build_model(args, from_checkpoint)
def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
return loss, sample_size, logging_output
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
with torch.no_grad():
_, tgt_langtok_spec = self.args.langtoks["main"]
if not self.args.lang_tok_replacing_bos_eos:
if prefix_tokens is None and tgt_langtok_spec:
tgt_lang_tok = self.data_manager.get_decoder_langtok(
self.args.target_lang, tgt_langtok_spec
)
src_tokens = sample["net_input"]["src_tokens"]
bsz = src_tokens.size(0)
prefix_tokens = (
torch.LongTensor([[tgt_lang_tok]]).expand(bsz, 1).to(src_tokens)
)
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
constraints=constraints,
)
else:
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
bos_token=self.data_manager.get_decoder_langtok(
self.args.target_lang, tgt_langtok_spec
)
if tgt_langtok_spec
else self.target_dictionary.eos(),
)
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions)
@property
def source_dictionary(self):
return self.data_manager.get_source_dictionary(self.source_langs[0])
@property
def target_dictionary(self):
return self.data_manager.get_target_dictionary(self.target_langs[0])
def create_batch_sampler_func(
self,
max_positions,
ignore_invalid_inputs,
max_tokens,
max_sentences,
required_batch_size_multiple=1,
seed=1,
):
def construct_batch_sampler(dataset, epoch):
splits = [
s for s, _ in self.datasets.items() if self.datasets[s] == dataset
]
split = splits[0] if len(splits) > 0 else None
# NEW implementation
if epoch is not None:
# initialize the dataset with the correct starting epoch
dataset.set_epoch(epoch)
# get indices ordered by example size
start_time = time.time()
logger.info(f"start batch sampler: mem usage: {data_utils.get_mem_usage()}")
with data_utils.numpy_seed(seed):
indices = dataset.ordered_indices()
logger.info(
f"[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}"
)
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
# filter examples that are too large
if max_positions is not None:
my_time = time.time()
indices = self.filter_indices_by_size(
indices, dataset, max_positions, ignore_invalid_inputs
)
logger.info(
f"[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}"
)
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
# create mini-batches with given size constraints
my_time = time.time()
batch_sampler = dataset.batch_by_size(
indices,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
logger.info(
f"[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}"
)
logger.info(
f"[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}"
)
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
return batch_sampler
return construct_batch_sampler
# we need to override get_batch_iterator because we want to reset the epoch iterator each time
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
skip_remainder_batch=False,
grouped_shuffling=False,
update_epoch_batch_itr=False,
):
"""
Get an iterator that yields batches of data from the given dataset.
Args:
dataset (~fairseq.data.FairseqDataset): dataset to batch
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch (default: None).
max_positions (optional): max sentence length supported by the
model (default: None).
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long (default: False).
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N (default: 1).
seed (int, optional): seed for random number generator for
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
shards (default: 1).
shard_id (int, optional): which shard of the data iterator to
return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 0).
data_buffer_size (int, optional): number of batches to
preload (default: 0).
disable_iterator_cache (bool, optional): don't cache the
EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`)
(default: False).
grouped_shuffling (bool, optional): group batches with each groups
containing num_shards batches and shuffle groups. Reduces difference
between sequence lengths among workers for batches sorted by length.
update_epoch_batch_itr (bool optional): if true then donot use the cached
batch iterator for the epoch
Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
given dataset split
"""
# initialize the dataset with the correct starting epoch
assert isinstance(dataset, FairseqDataset)
if dataset in self.dataset_to_epoch_iter:
return self.dataset_to_epoch_iter[dataset]
if self.args.sampling_method == "RoundRobin":
batch_iter = super().get_batch_iterator(
dataset,
max_tokens=max_tokens,
max_sentences=max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=ignore_invalid_inputs,
required_batch_size_multiple=required_batch_size_multiple,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
data_buffer_size=data_buffer_size,
disable_iterator_cache=disable_iterator_cache,
skip_remainder_batch=skip_remainder_batch,
update_epoch_batch_itr=update_epoch_batch_itr,
)
self.dataset_to_epoch_iter[dataset] = batch_iter
return batch_iter
construct_batch_sampler = self.create_batch_sampler_func(
max_positions,
ignore_invalid_inputs,
max_tokens,
max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
seed=seed,
)
epoch_iter = iterators.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=construct_batch_sampler,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
)
return epoch_iter