mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-30 19:31:20 +00:00
Add monkey patched fairseq package to run on python 3.11 (what is needed for our use of RVC at least)
This commit is contained in:
136
modules/voice_conversion/fairseq/tasks/__init__.py
Normal file
136
modules/voice_conversion/fairseq/tasks/__init__.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""isort:skip_file"""
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import os
|
||||
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.dataclass.utils import merge_with_parent
|
||||
from hydra.core.config_store import ConfigStore
|
||||
|
||||
from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa
|
||||
|
||||
|
||||
# register dataclass
|
||||
TASK_DATACLASS_REGISTRY = {}
|
||||
TASK_REGISTRY = {}
|
||||
TASK_CLASS_NAMES = set()
|
||||
|
||||
|
||||
def setup_task(cfg: FairseqDataclass, **kwargs):
|
||||
task = None
|
||||
task_name = getattr(cfg, "task", None)
|
||||
|
||||
if isinstance(task_name, str):
|
||||
# legacy tasks
|
||||
task = TASK_REGISTRY[task_name]
|
||||
if task_name in TASK_DATACLASS_REGISTRY:
|
||||
dc = TASK_DATACLASS_REGISTRY[task_name]
|
||||
cfg = dc.from_namespace(cfg)
|
||||
else:
|
||||
task_name = getattr(cfg, "_name", None)
|
||||
|
||||
if task_name and task_name in TASK_DATACLASS_REGISTRY:
|
||||
dc = TASK_DATACLASS_REGISTRY[task_name]
|
||||
cfg = merge_with_parent(dc(), cfg)
|
||||
task = TASK_REGISTRY[task_name]
|
||||
|
||||
assert (
|
||||
task is not None
|
||||
), f"Could not infer task type from {cfg}. Available argparse tasks: {TASK_REGISTRY.keys()}. Available hydra tasks: {TASK_DATACLASS_REGISTRY.keys()}"
|
||||
|
||||
return task.setup_task(cfg, **kwargs)
|
||||
|
||||
|
||||
def register_task(name, dataclass=None):
|
||||
"""
|
||||
New tasks can be added to fairseq with the
|
||||
:func:`~fairseq.tasks.register_task` function decorator.
|
||||
|
||||
For example::
|
||||
|
||||
@register_task('classification')
|
||||
class ClassificationTask(FairseqTask):
|
||||
(...)
|
||||
|
||||
.. note::
|
||||
|
||||
All Tasks must implement the :class:`~fairseq.tasks.FairseqTask`
|
||||
interface.
|
||||
|
||||
Args:
|
||||
name (str): the name of the task
|
||||
"""
|
||||
|
||||
def register_task_cls(cls):
|
||||
if name in TASK_REGISTRY:
|
||||
raise ValueError("Cannot register duplicate task ({})".format(name))
|
||||
if not issubclass(cls, FairseqTask):
|
||||
raise ValueError(
|
||||
"Task ({}: {}) must extend FairseqTask".format(name, cls.__name__)
|
||||
)
|
||||
if cls.__name__ in TASK_CLASS_NAMES:
|
||||
raise ValueError(
|
||||
"Cannot register task with duplicate class name ({})".format(
|
||||
cls.__name__
|
||||
)
|
||||
)
|
||||
TASK_REGISTRY[name] = cls
|
||||
TASK_CLASS_NAMES.add(cls.__name__)
|
||||
|
||||
if dataclass is not None and not issubclass(dataclass, FairseqDataclass):
|
||||
raise ValueError(
|
||||
"Dataclass {} must extend FairseqDataclass".format(dataclass)
|
||||
)
|
||||
|
||||
cls.__dataclass = dataclass
|
||||
if dataclass is not None:
|
||||
TASK_DATACLASS_REGISTRY[name] = dataclass
|
||||
|
||||
cs = ConfigStore.instance()
|
||||
node = dataclass()
|
||||
node._name = name
|
||||
cs.store(name=name, group="task", node=node, provider="fairseq")
|
||||
|
||||
return cls
|
||||
|
||||
return register_task_cls
|
||||
|
||||
|
||||
def get_task(name):
|
||||
return TASK_REGISTRY[name]
|
||||
|
||||
|
||||
def import_tasks(tasks_dir, namespace):
|
||||
for file in os.listdir(tasks_dir):
|
||||
path = os.path.join(tasks_dir, file)
|
||||
if (
|
||||
not file.startswith("_")
|
||||
and not file.startswith(".")
|
||||
and (file.endswith(".py") or os.path.isdir(path))
|
||||
):
|
||||
task_name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||
importlib.import_module(namespace + "." + task_name)
|
||||
|
||||
# expose `task_parser` for sphinx
|
||||
if task_name in TASK_REGISTRY:
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
group_task = parser.add_argument_group("Task name")
|
||||
# fmt: off
|
||||
group_task.add_argument('--task', metavar=task_name,
|
||||
help='Enable this task with: ``--task=' + task_name + '``')
|
||||
# fmt: on
|
||||
group_args = parser.add_argument_group(
|
||||
"Additional command-line arguments"
|
||||
)
|
||||
TASK_REGISTRY[task_name].add_args(group_args)
|
||||
globals()[task_name + "_parser"] = parser
|
||||
|
||||
|
||||
# automatically import any Python files in the tasks/ directory
|
||||
tasks_dir = os.path.dirname(__file__)
|
||||
import_tasks(tasks_dir, "fairseq.tasks")
|
||||
343
modules/voice_conversion/fairseq/tasks/audio_finetuning.py
Normal file
343
modules/voice_conversion/fairseq/tasks/audio_finetuning.py
Normal file
@@ -0,0 +1,343 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Any
|
||||
|
||||
from fairseq.data import AddTargetDataset, Dictionary, encoders
|
||||
from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.dataclass.configs import GenerationConfig
|
||||
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
|
||||
|
||||
from . import register_task
|
||||
from .. import utils
|
||||
from ..logging import metrics
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LabelEncoder(object):
|
||||
def __init__(self, dictionary):
|
||||
self.dictionary = dictionary
|
||||
|
||||
def __call__(self, label):
|
||||
return self.dictionary.encode_line(
|
||||
label, append_eos=False, add_if_not_exist=False
|
||||
)
|
||||
|
||||
|
||||
def label_len_fn(label):
|
||||
return len(label.split(" "))
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioFinetuningConfig(AudioPretrainingConfig):
|
||||
# Options for reporting WER metrics during validation. Only applicable to
|
||||
# Seq2Seq models during fine-tuning
|
||||
eval_wer: bool = field(
|
||||
default=False, metadata={"help": "compute WER for Seq2Seq models"}
|
||||
)
|
||||
eval_wer_config: GenerationConfig = field(
|
||||
default_factory=lambda: GenerationConfig(),
|
||||
metadata={"help": "beam search config for evaluating wer during training"},
|
||||
)
|
||||
eval_wer_tokenizer: Any = field(
|
||||
default=None,
|
||||
metadata={"help": "tokenizer config for evaluating wer during training"},
|
||||
)
|
||||
eval_wer_post_process: str = field(
|
||||
default="letter",
|
||||
metadata={
|
||||
"help": "remove BPE tokens before scoring (can be sentencepiece, letter, and more)"
|
||||
},
|
||||
)
|
||||
eval_bleu: bool = field(
|
||||
default=False, metadata={"help": "evaluation with BLEU scores"}
|
||||
)
|
||||
eval_bleu_detok: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "detokenize before computing BLEU (e.g., 'moses'); "
|
||||
"required if using --eval-bleu; use 'space' to disable "
|
||||
"detokenization; see fairseq.data.encoders for other options"
|
||||
},
|
||||
)
|
||||
eval_bleu_detok_args: str = field(
|
||||
default="{}", metadata={"help": "args for building the tokenizer, if needed"}
|
||||
)
|
||||
eval_tokenized_bleu: bool = field(
|
||||
default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"}
|
||||
)
|
||||
eval_bleu_remove_bpe: Optional[str] = field(
|
||||
default=None, metadata={"help": "remove BPE before computing BLEU"}
|
||||
)
|
||||
eval_bleu_args: str = field(
|
||||
default="{}",
|
||||
metadata={
|
||||
"help": "generation args for BLUE scoring, e.g., "
|
||||
'\'{"beam": 4, "lenpen": 0.6}\''
|
||||
},
|
||||
)
|
||||
eval_bleu_print_samples: bool = field(
|
||||
default=False, metadata={"help": "print sample generations during validation"}
|
||||
)
|
||||
autoregressive: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "required for autoregressive decoders (like seq2seq models); "
|
||||
"adds 'prev_output_tokens' to input and appends eos to target"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@register_task("audio_finetuning", dataclass=AudioFinetuningConfig)
|
||||
class AudioFinetuningTask(AudioPretrainingTask):
|
||||
""" """
|
||||
|
||||
cfg: AudioFinetuningConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: AudioFinetuningConfig,
|
||||
):
|
||||
super().__init__(cfg)
|
||||
self.blank_symbol = "<s>"
|
||||
|
||||
self.state.add_factory("target_dictionary", self.load_target_dictionary)
|
||||
|
||||
def load_target_dictionary(self):
|
||||
if self.cfg.labels:
|
||||
dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt")
|
||||
return Dictionary.load(dict_path)
|
||||
return None
|
||||
|
||||
def load_dataset(
|
||||
self, split: str, task_cfg: AudioFinetuningConfig = None, **kwargs
|
||||
):
|
||||
super().load_dataset(split, task_cfg, **kwargs)
|
||||
|
||||
task_cfg = task_cfg or self.cfg
|
||||
assert task_cfg.labels is not None
|
||||
text_compression_level = getattr(
|
||||
TextCompressionLevel, str(self.cfg.text_compression_level)
|
||||
)
|
||||
data_path = self.cfg.data
|
||||
label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
|
||||
skipped_indices = getattr(self.datasets[split], "skipped_indices", set())
|
||||
text_compressor = TextCompressor(level=text_compression_level)
|
||||
with open(label_path, "r") as f:
|
||||
labels = [
|
||||
text_compressor.compress(l)
|
||||
for i, l in enumerate(f)
|
||||
if i not in skipped_indices
|
||||
]
|
||||
|
||||
assert len(labels) == len(self.datasets[split]), (
|
||||
f"labels length ({len(labels)}) and dataset length "
|
||||
f"({len(self.datasets[split])}) do not match"
|
||||
)
|
||||
|
||||
process_label = LabelEncoder(self.target_dictionary)
|
||||
|
||||
self.datasets[split] = AddTargetDataset(
|
||||
self.datasets[split],
|
||||
labels,
|
||||
pad=self.target_dictionary.pad(),
|
||||
eos=self.target_dictionary.eos(),
|
||||
batch_targets=True,
|
||||
process_label=process_label,
|
||||
label_len_fn=label_len_fn,
|
||||
add_to_input=task_cfg.get("autoregressive", False),
|
||||
text_compression_level=text_compression_level,
|
||||
)
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
"""Return the :class:`~fairseq.data.Dictionary` for the language
|
||||
model."""
|
||||
return self.state.target_dictionary
|
||||
|
||||
def valid_step(self, sample, model, criterion):
|
||||
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
|
||||
if self.cfg.eval_wer and self.cfg.autoregressive:
|
||||
metrics = self._inference_with_wer(self.sequence_generator, sample, model)
|
||||
logging_output["_num_char_errors"] = metrics["num_char_errors"]
|
||||
logging_output["_num_chars"] = metrics["num_chars"]
|
||||
logging_output["_num_word_errors"] = metrics["num_word_errors"]
|
||||
logging_output["_num_words"] = metrics["num_words"]
|
||||
if self.cfg.eval_bleu and self.cfg.autoregressive:
|
||||
metrics = self._inference_with_bleu(self.sequence_generator, sample, model)
|
||||
logging_output["_bleu_sys_len"] = metrics.sys_len
|
||||
logging_output["_bleu_ref_len"] = metrics.ref_len
|
||||
# we split counts into separate entries so that they can be
|
||||
# summed efficiently across workers using fast-stat-sync
|
||||
assert len(metrics.counts) == 4
|
||||
for i in range(4):
|
||||
logging_output[f"_bleu_counts_{i}"] = metrics.counts[i]
|
||||
logging_output[f"_bleu_totals_{i}"] = metrics.totals[i]
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False):
|
||||
model = super().build_model(model_cfg, from_checkpoint)
|
||||
|
||||
if self.cfg.eval_wer and self.cfg.autoregressive:
|
||||
self.sequence_generator = self.build_generator(
|
||||
[model],
|
||||
self.cfg.eval_wer_config,
|
||||
)
|
||||
if self.cfg.eval_wer_tokenizer:
|
||||
self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer)
|
||||
else:
|
||||
self.tokenizer = None
|
||||
if self.cfg.eval_bleu and self.cfg.autoregressive:
|
||||
assert self.cfg.eval_bleu_detok is not None, (
|
||||
"--eval-bleu-detok is required if using --eval-bleu; "
|
||||
"try --eval-bleu-detok=moses (or --eval-bleu-detok=space "
|
||||
"to disable detokenization, e.g., when using sentencepiece)"
|
||||
)
|
||||
detok_args = json.loads(self.cfg.eval_bleu_detok_args)
|
||||
self.tokenizer = encoders.build_tokenizer(
|
||||
Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args)
|
||||
)
|
||||
gen_args = json.loads(self.cfg.eval_bleu_args)
|
||||
gen_args = Namespace(**gen_args)
|
||||
self.sequence_generator = self.build_generator([model], gen_args)
|
||||
|
||||
return model
|
||||
|
||||
def _inference_with_wer(self, generator, sample, model):
|
||||
import editdistance
|
||||
|
||||
def decode(toks):
|
||||
s = self.target_dictionary.string(
|
||||
toks.int().cpu(),
|
||||
self.cfg.eval_wer_post_process,
|
||||
escape_unk=True,
|
||||
)
|
||||
if self.tokenizer:
|
||||
s = self.tokenizer.decode(s)
|
||||
return s
|
||||
|
||||
num_word_errors, num_char_errors = 0, 0
|
||||
num_chars, num_words = 0, 0
|
||||
gen_out = self.inference_step(generator, [model], sample, None)
|
||||
for i in range(len(gen_out)):
|
||||
hyp = decode(gen_out[i][0]["tokens"])
|
||||
ref = decode(
|
||||
utils.strip_pad(sample["target"][i], self.target_dictionary.pad()),
|
||||
)
|
||||
num_char_errors += editdistance.eval(hyp, ref)
|
||||
num_chars += len(ref)
|
||||
hyp_words = hyp.split()
|
||||
ref_words = ref.split()
|
||||
num_word_errors += editdistance.eval(hyp_words, ref_words)
|
||||
num_words += len(ref_words)
|
||||
|
||||
return {
|
||||
"num_char_errors": num_char_errors,
|
||||
"num_chars": num_chars,
|
||||
"num_word_errors": num_word_errors,
|
||||
"num_words": num_words,
|
||||
}
|
||||
|
||||
def _inference_with_bleu(self, generator, sample, model):
|
||||
import sacrebleu
|
||||
|
||||
def decode(toks, is_ref):
|
||||
s = self.target_dictionary.string(
|
||||
toks.int().cpu(),
|
||||
self.cfg.eval_bleu_remove_bpe,
|
||||
# The default unknown string in fairseq is `<unk>`, but
|
||||
# this is tokenized by sacrebleu as `< unk >`, inflating
|
||||
# BLEU scores. Instead, we use a somewhat more verbose
|
||||
# alternative that is unlikely to appear in the real
|
||||
# reference, but doesn't get split into multiple tokens.
|
||||
unk_string=("UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP"),
|
||||
)
|
||||
if self.tokenizer:
|
||||
s = self.tokenizer.decode(s)
|
||||
return s
|
||||
|
||||
gen_out = self.inference_step(generator, [model], sample)
|
||||
hyps, refs = [], []
|
||||
for i in range(len(gen_out)):
|
||||
hyps.append(decode(gen_out[i][0]["tokens"], is_ref=False))
|
||||
refs.append(
|
||||
decode(
|
||||
utils.strip_pad(sample["target"][i], self.target_dictionary.pad()),
|
||||
is_ref=True, # don't count <unk> as matches to the hypo
|
||||
)
|
||||
)
|
||||
if self.cfg.eval_bleu_print_samples:
|
||||
logger.info("H-{} {}".format(sample["id"][0], hyps[0]))
|
||||
logger.info("T-{} {}".format(sample["id"][0], refs[0]))
|
||||
|
||||
eval_tokenization = "none" if self.cfg.eval_tokenized_bleu else "13a"
|
||||
return sacrebleu.corpus_bleu(hyps, [refs], tokenize=eval_tokenization)
|
||||
|
||||
def reduce_metrics(self, logging_outputs, criterion):
|
||||
super().reduce_metrics(logging_outputs, criterion)
|
||||
|
||||
if self.cfg.eval_wer:
|
||||
zero = torch.scalar_tensor(0.0)
|
||||
num_char_errors = sum(
|
||||
log.get("_num_char_errors", zero) for log in logging_outputs
|
||||
)
|
||||
num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs)
|
||||
num_word_errors = sum(
|
||||
log.get("_num_word_errors", zero) for log in logging_outputs
|
||||
)
|
||||
num_words = sum(log.get("_num_words", zero) for log in logging_outputs)
|
||||
metrics.log_scalar("_num_char_errors", num_char_errors)
|
||||
metrics.log_scalar("_num_chars", num_chars)
|
||||
metrics.log_scalar("_num_word_errors", num_word_errors)
|
||||
metrics.log_scalar("_num_words", num_words)
|
||||
if num_chars > 0:
|
||||
metrics.log_derived(
|
||||
"uer",
|
||||
lambda meters: meters["_num_char_errors"].sum
|
||||
* 100.0
|
||||
/ meters["_num_chars"].sum
|
||||
if meters["_num_chars"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
if num_words > 0:
|
||||
metrics.log_derived(
|
||||
"wer",
|
||||
lambda meters: meters["_num_word_errors"].sum
|
||||
* 100.0
|
||||
/ meters["_num_words"].sum
|
||||
if meters["_num_words"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
if self.cfg.eval_bleu:
|
||||
len_keys = ["_bleu_sys_len", "_bleu_ref_len"]
|
||||
count_keys = [f"_bleu_counts_{i}" for i in range(4)]
|
||||
total_keys = [f"_bleu_totals_{i}" for i in range(4)]
|
||||
for k in len_keys + count_keys + total_keys:
|
||||
metrics.log_scalar(k, sum(log.get(k, 0) for log in logging_outputs))
|
||||
|
||||
import sacrebleu
|
||||
|
||||
metrics.log_derived(
|
||||
"bleu",
|
||||
lambda meters: sacrebleu.compute_bleu(
|
||||
correct=[meters[k].sum for k in count_keys],
|
||||
total=[meters[k].sum for k in total_keys],
|
||||
sys_len=meters["_bleu_sys_len"].sum,
|
||||
ref_len=meters["_bleu_ref_len"].sum,
|
||||
smooth_method="exp",
|
||||
).score,
|
||||
)
|
||||
205
modules/voice_conversion/fairseq/tasks/audio_pretraining.py
Normal file
205
modules/voice_conversion/fairseq/tasks/audio_pretraining.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from omegaconf import MISSING, II, OmegaConf
|
||||
|
||||
from fairseq.data import BinarizedAudioDataset, FileAudioDataset
|
||||
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
|
||||
from fairseq.data.text_compressor import TextCompressionLevel
|
||||
|
||||
from . import FairseqTask, register_task
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferredW2vConfig:
|
||||
# The following are needed to precompute mask and mask channel indices
|
||||
# before model's forward.
|
||||
mask_length: Optional[int] = II("model.mask_length")
|
||||
mask_prob: Optional[float] = II("model.mask_prob")
|
||||
mask_selection: Optional[str] = II("model.mask_selection")
|
||||
mask_other: Optional[float] = II("model.mask_other")
|
||||
no_mask_overlap: Optional[bool] = II("model.no_mask_overlap")
|
||||
mask_min_space: Optional[int] = II("model.mask_min_space")
|
||||
mask_channel_length: Optional[int] = II("model.mask_channel_length")
|
||||
mask_channel_prob: Optional[float] = II("model.mask_channel_prob")
|
||||
mask_channel_selection: Optional[str] = II("model.mask_channel_selection")
|
||||
mask_channel_other: Optional[float] = II("model.mask_channel_other")
|
||||
no_mask_channel_overlap: Optional[bool] = II("model.no_mask_channel_overlap")
|
||||
mask_channel_min_space: Optional[int] = II("model.mask_channel_min_space")
|
||||
|
||||
conv_feature_layers: Optional[str] = II("model.conv_feature_layers")
|
||||
encoder_embed_dim: Optional[int] = II("model.encoder_embed_dim")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioPretrainingConfig(FairseqDataclass):
|
||||
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
|
||||
labels: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "extension of the label file to load, used for fine-tuning"},
|
||||
)
|
||||
binarized_dataset: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "if true, loads binarized dataset (useful for very large datasets). "
|
||||
"See examples/wav2vec/scripts/binarize_manifest.sh"
|
||||
},
|
||||
)
|
||||
sample_rate: int = field(
|
||||
default=16_000,
|
||||
metadata={
|
||||
"help": "target sample rate. audio files will be up/down sampled to this rate"
|
||||
},
|
||||
)
|
||||
normalize: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "if set, normalizes input to have 0 mean and unit variance"},
|
||||
)
|
||||
enable_padding: bool = field(
|
||||
default=False, metadata={"help": "pad shorter samples instead of cropping"}
|
||||
)
|
||||
max_sample_size: Optional[int] = field(
|
||||
default=None, metadata={"help": "max sample size to crop to for batching"}
|
||||
)
|
||||
min_sample_size: Optional[int] = field(
|
||||
default=None, metadata={"help": "min sample size to skip small examples"}
|
||||
)
|
||||
num_batch_buckets: int = field(
|
||||
default=0,
|
||||
metadata={"help": "number of buckets"},
|
||||
)
|
||||
precompute_mask_indices: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "flag to compute mask indices in data preparation.",
|
||||
},
|
||||
)
|
||||
|
||||
inferred_w2v_config: Optional[InferredW2vConfig] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "wav2vec 2.0 masking arguments used to pre-compute masks (required for TPU)",
|
||||
},
|
||||
)
|
||||
|
||||
tpu: bool = II("common.tpu")
|
||||
text_compression_level: ChoiceEnum([x.name for x in TextCompressionLevel]) = field(
|
||||
default="none",
|
||||
metadata={
|
||||
"help": "compression level for texts (e.g. audio filenames, "
|
||||
"target texts): none/low/high (default: none). "
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@register_task("audio_pretraining", dataclass=AudioPretrainingConfig)
|
||||
class AudioPretrainingTask(FairseqTask):
|
||||
""" """
|
||||
|
||||
cfg: AudioPretrainingConfig
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs):
|
||||
"""Setup the task (e.g., load dictionaries).
|
||||
|
||||
Args:
|
||||
cfg (AudioPretrainingConfig): configuration of this task
|
||||
"""
|
||||
|
||||
return cls(cfg)
|
||||
|
||||
def _get_mask_precompute_kwargs(self, cfg):
|
||||
if self.cfg.precompute_mask_indices or self.cfg.tpu:
|
||||
assert (
|
||||
cfg.inferred_w2v_config is not None
|
||||
), "inferred_w2v_config must be set"
|
||||
return OmegaConf.to_container(
|
||||
cfg.inferred_w2v_config, resolve=True, enum_to_str=True
|
||||
)
|
||||
else:
|
||||
return {}
|
||||
|
||||
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
|
||||
data_path = self.cfg.data
|
||||
task_cfg = task_cfg or self.cfg
|
||||
|
||||
# upgrade old task
|
||||
if isinstance(task_cfg, Namespace):
|
||||
if not hasattr(task_cfg, "autoregressive"):
|
||||
task_cfg.autoregressive = not task_cfg.criterion == "ctc"
|
||||
|
||||
text_compression_level = getattr(
|
||||
TextCompressionLevel, str(self.cfg.text_compression_level)
|
||||
)
|
||||
if getattr(task_cfg, "binarized_dataset", False):
|
||||
self.datasets[split] = BinarizedAudioDataset(
|
||||
data_path,
|
||||
split=split,
|
||||
sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate),
|
||||
max_sample_size=self.cfg.max_sample_size,
|
||||
min_sample_size=self.cfg.min_sample_size,
|
||||
pad=task_cfg.labels is not None or task_cfg.enable_padding,
|
||||
normalize=task_cfg.normalize,
|
||||
num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu),
|
||||
compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu),
|
||||
**self._get_mask_precompute_kwargs(task_cfg),
|
||||
)
|
||||
else:
|
||||
manifest_path = os.path.join(data_path, "{}.tsv".format(split))
|
||||
|
||||
self.datasets[split] = FileAudioDataset(
|
||||
manifest_path=manifest_path,
|
||||
sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate),
|
||||
max_sample_size=self.cfg.max_sample_size,
|
||||
min_sample_size=self.cfg.min_sample_size,
|
||||
pad=task_cfg.labels is not None or task_cfg.enable_padding,
|
||||
normalize=task_cfg.normalize,
|
||||
num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu),
|
||||
compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu),
|
||||
text_compression_level=text_compression_level,
|
||||
**self._get_mask_precompute_kwargs(task_cfg),
|
||||
)
|
||||
|
||||
if self.cfg.tpu and task_cfg.inferred_w2v_config.mask_channel_prob == 0.0:
|
||||
logger.info(
|
||||
"Pretraining on TPUs may suffer convergence "
|
||||
"issues when training with `mask_channel_prob` value of "
|
||||
"0. You may want to set this to a low value close to 0."
|
||||
)
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return None
|
||||
|
||||
def max_positions(self):
|
||||
"""Maximum input length supported by the encoder."""
|
||||
return sys.maxsize, sys.maxsize
|
||||
|
||||
def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False):
|
||||
model = super().build_model(model_cfg, from_checkpoint)
|
||||
|
||||
actualized_cfg = getattr(model, "cfg", None)
|
||||
if actualized_cfg is not None:
|
||||
# if "w2v_args" in actualized_cfg:
|
||||
if hasattr(actualized_cfg, "w2v_args"):
|
||||
model_cfg.w2v_args = actualized_cfg.w2v_args
|
||||
|
||||
return model
|
||||
191
modules/voice_conversion/fairseq/tasks/cross_lingual_lm.py
Normal file
191
modules/voice_conversion/fairseq/tasks/cross_lingual_lm.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
from fairseq import tokenizer, utils
|
||||
from fairseq.data import ConcatDataset, Dictionary, TokenBlockDataset, data_utils
|
||||
from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset
|
||||
from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
|
||||
from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task("cross_lingual_lm")
|
||||
class CrossLingualLMTask(LegacyFairseqTask):
|
||||
"""
|
||||
Task for training cross-lingual language models.
|
||||
|
||||
For more details look at: https://arxiv.org/pdf/1901.07291.pdf
|
||||
|
||||
Args:
|
||||
dictionary (Dictionary): the dictionary for the input of the task
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
parser.add_argument(
|
||||
"data",
|
||||
help="colon separated path to data directories list, \
|
||||
will be iterated upon during epochs in round-robin manner",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokens-per-sample",
|
||||
default=512,
|
||||
type=int,
|
||||
help="max number of total tokens over all segments" " per sample",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--monolingual-langs",
|
||||
default="en",
|
||||
type=str,
|
||||
help="comma separated list of languages for which we"
|
||||
" want to train XLM on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shuffle",
|
||||
action="store_true",
|
||||
help="shuffle each monolingual dataset while" " training",
|
||||
)
|
||||
|
||||
def __init__(self, args, dictionary):
|
||||
super().__init__(args)
|
||||
self.dictionary = dictionary
|
||||
self.seed = args.seed
|
||||
self.distributed_world_size = args.distributed_world_size
|
||||
self.langs2id = self._lang_to_id(args.monolingual_langs)
|
||||
|
||||
def _lang_to_id(self, languages: str):
|
||||
"""
|
||||
Build a map from languages to ids. These ids are used as segment labels
|
||||
for cross-lingual LM training.
|
||||
"""
|
||||
lang2id = {}
|
||||
langs = [l.strip() for l in languages.split(",")]
|
||||
for id, lang in enumerate(langs):
|
||||
lang2id[lang] = id
|
||||
return lang2id
|
||||
|
||||
@classmethod
|
||||
def load_dictionary(cls, filename):
|
||||
return MaskedLMDictionary.load(filename)
|
||||
|
||||
@classmethod
|
||||
def build_dictionary(
|
||||
cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8
|
||||
):
|
||||
d = MaskedLMDictionary()
|
||||
for filename in filenames:
|
||||
Dictionary.add_file_to_dictionary(
|
||||
filename, d, tokenizer.tokenize_line, workers
|
||||
)
|
||||
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
|
||||
return d
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
"""Setup the task."""
|
||||
dictionary = MaskedLMDictionary.load(os.path.join(args.data, "dict.txt"))
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
return cls(args, dictionary)
|
||||
|
||||
def _load_single_lang_dataset(self, split, epoch):
|
||||
loaded_datasets = []
|
||||
|
||||
paths = utils.split_paths(self.args.data)
|
||||
assert len(paths) > 0
|
||||
data_path = paths[(epoch - 1) % len(paths)]
|
||||
|
||||
for k in itertools.count():
|
||||
split_k = split + (str(k) if k > 0 else "")
|
||||
path = os.path.join(data_path, split_k)
|
||||
|
||||
ds = data_utils.load_indexed_dataset(
|
||||
path, self.dictionary, self.args.dataset_impl
|
||||
)
|
||||
if ds is None:
|
||||
if k > 0:
|
||||
break
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
"Dataset not found: {} ({})".format(split, data_path)
|
||||
)
|
||||
|
||||
# Since we append each block with the classification_token,
|
||||
# we need to effectively create blocks of length
|
||||
# tokens_per_sample-1
|
||||
loaded_datasets.append(
|
||||
TokenBlockDataset(
|
||||
ds,
|
||||
ds.sizes,
|
||||
self.args.tokens_per_sample - 1,
|
||||
pad=self.dictionary.pad(),
|
||||
eos=self.dictionary.eos(),
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1]))
|
||||
)
|
||||
|
||||
if len(loaded_datasets) == 1:
|
||||
dataset = loaded_datasets[0]
|
||||
sizes = dataset.sizes
|
||||
else:
|
||||
dataset = ConcatDataset(loaded_datasets)
|
||||
sizes = np.concatenate([ds.sizes for ds in loaded_datasets])
|
||||
|
||||
return dataset, sizes
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
dataset_map = OrderedDict()
|
||||
|
||||
for lang in self.langs2id.keys():
|
||||
# Datasets are expected to be in "split.lang" format (Eg: train.en)
|
||||
language_split = "{}.{}".format(split, lang)
|
||||
|
||||
block_dataset, sizes = self._load_single_lang_dataset(
|
||||
split=language_split, epoch=epoch
|
||||
)
|
||||
|
||||
dataset_map[lang] = MaskedLMDataset(
|
||||
dataset=block_dataset,
|
||||
sizes=sizes,
|
||||
vocab=self.dictionary,
|
||||
pad_idx=self.dictionary.pad(),
|
||||
mask_idx=self.dictionary.mask(),
|
||||
classif_token_idx=self.dictionary.eos(),
|
||||
sep_token_idx=self.dictionary.eos(),
|
||||
shuffle=getattr(self.args, "shuffle", False),
|
||||
has_pairs=False,
|
||||
segment_id=self.langs2id[lang],
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
self.datasets[split] = MultiCorpusSampledDataset(dataset_map)
|
||||
logger.info(
|
||||
"{} {} {} examples".format(
|
||||
utils.split_paths(self.args.data)[epoch - 1],
|
||||
split,
|
||||
len(self.datasets[split]),
|
||||
)
|
||||
)
|
||||
296
modules/voice_conversion/fairseq/tasks/denoising.py
Normal file
296
modules/voice_conversion/fairseq/tasks/denoising.py
Normal file
@@ -0,0 +1,296 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from omegaconf import II, MISSING
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.data import (
|
||||
AppendTokenDataset,
|
||||
DenoisingDataset,
|
||||
Dictionary,
|
||||
IdDataset,
|
||||
NestedDictionaryDataset,
|
||||
NumelDataset,
|
||||
PadDataset,
|
||||
PrependTokenDataset,
|
||||
StripTokenDataset,
|
||||
TokenBlockDataset,
|
||||
data_utils,
|
||||
)
|
||||
from fairseq.data.encoders.utils import get_whole_word_mask
|
||||
from fairseq.data.shorten_dataset import maybe_shorten_dataset
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.tasks import FairseqTask, register_task
|
||||
|
||||
from ..data.indexed_dataset import get_available_dataset_impl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
|
||||
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
|
||||
MASK_LENGTH_CHOICES = ChoiceEnum(["subword", "word", "span-poisson"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class DenoisingConfig(FairseqDataclass):
|
||||
data: str = field(
|
||||
default=MISSING,
|
||||
metadata={"help": "path to data directory"},
|
||||
)
|
||||
bpe: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "TODO"},
|
||||
)
|
||||
tokens_per_sample: int = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help": "max number of total tokens over all segments "
|
||||
"per sample for dataset"
|
||||
},
|
||||
)
|
||||
sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
|
||||
default="complete_doc",
|
||||
metadata={
|
||||
"help": 'If omitted or "none", fills each sample with tokens-per-sample '
|
||||
'tokens. If set to "complete", splits samples only at the end '
|
||||
"of sentence, but may include multiple sentences per sample. "
|
||||
'"complete_doc" is similar but respects doc boundaries. '
|
||||
'If set to "eos", includes only one sentence per sample.'
|
||||
},
|
||||
)
|
||||
replace_length: int = field(
|
||||
default=0,
|
||||
metadata={"help": "TODO, should only allow -1, 0 and 1"},
|
||||
)
|
||||
mask: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "fraction of words/subwords that will be masked"},
|
||||
)
|
||||
mask_random: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "instead of using [MASK], use random token this often"},
|
||||
)
|
||||
insert: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "insert this percentage of additional random tokens"},
|
||||
)
|
||||
permute: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "take this proportion of subwords and permute them"},
|
||||
)
|
||||
rotate: float = field(
|
||||
default=0.5,
|
||||
metadata={"help": "rotate this proportion of inputs"},
|
||||
)
|
||||
poisson_lambda: float = field(
|
||||
default=3.0,
|
||||
metadata={"help": "randomly shuffle sentences for this proportion of inputs"},
|
||||
)
|
||||
shuffle_instance: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "shuffle this proportion of sentences in all inputs"},
|
||||
)
|
||||
mask_length: MASK_LENGTH_CHOICES = field(
|
||||
default="subword",
|
||||
metadata={"help": "mask length to choose"},
|
||||
)
|
||||
permute_sentences: int = field(
|
||||
default=-1,
|
||||
metadata={
|
||||
"help": "when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)"
|
||||
},
|
||||
)
|
||||
seed: int = II("common.seed")
|
||||
shorten_method: SHORTEN_METHOD_CHOICES = field(
|
||||
default="none",
|
||||
metadata={
|
||||
"help": "if not none, shorten sequences that exceed --tokens-per-sample"
|
||||
},
|
||||
)
|
||||
shorten_data_split_list: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "comma-separated list of dataset splits to apply shortening to, "
|
||||
'e.g., "train,valid" (default: all dataset splits)'
|
||||
},
|
||||
)
|
||||
max_source_positions: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "max number of tokens in the source sequence"},
|
||||
)
|
||||
max_target_positions: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "max number of tokens in the target sequence"},
|
||||
)
|
||||
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
|
||||
"dataset.dataset_impl"
|
||||
)
|
||||
|
||||
|
||||
@register_task("denoising", dataclass=DenoisingConfig)
|
||||
class DenoisingTask(FairseqTask):
|
||||
"""
|
||||
Denoising task for applying sequence to sequence denoising. (ie. BART)
|
||||
"""
|
||||
|
||||
cfg: DenoisingConfig
|
||||
|
||||
def __init__(self, cfg, dictionary):
|
||||
super().__init__(cfg)
|
||||
self.dictionary = dictionary
|
||||
|
||||
# add mask token
|
||||
self.mask_idx = self.dictionary.add_symbol("<mask>")
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, cfg: DenoisingConfig, **kwargs):
|
||||
"""Setup the task."""
|
||||
paths = utils.split_paths(cfg.data)
|
||||
assert len(paths) > 0
|
||||
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
if not hasattr(cfg, "shuffle_instance"):
|
||||
cfg.shuffle_instance = False
|
||||
return cls(cfg, dictionary)
|
||||
|
||||
def _load_dataset_split(self, split, epoch, combine):
|
||||
paths = utils.split_paths(self.cfg.data)
|
||||
assert len(paths) > 0
|
||||
data_path = paths[(epoch - 1) % len(paths)]
|
||||
split_path = os.path.join(data_path, split)
|
||||
|
||||
dataset = data_utils.load_indexed_dataset(
|
||||
split_path,
|
||||
self.dictionary,
|
||||
self.cfg.dataset_impl,
|
||||
combine=combine,
|
||||
)
|
||||
if dataset is None:
|
||||
raise FileNotFoundError(
|
||||
"Dataset not found: {} ({})".format(split, split_path)
|
||||
)
|
||||
|
||||
dataset = StripTokenDataset(dataset, self.dictionary.eos())
|
||||
|
||||
dataset = maybe_shorten_dataset(
|
||||
dataset,
|
||||
split,
|
||||
self.cfg.shorten_data_split_list,
|
||||
self.cfg.shorten_method,
|
||||
self.cfg.tokens_per_sample,
|
||||
self.cfg.seed,
|
||||
)
|
||||
|
||||
# create continuous blocks of tokens
|
||||
dataset = TokenBlockDataset(
|
||||
dataset,
|
||||
dataset.sizes,
|
||||
self.cfg.tokens_per_sample - 2,
|
||||
# one less for <s> and one for </s>
|
||||
pad=self.dictionary.pad(),
|
||||
eos=self.dictionary.eos(),
|
||||
break_mode=self.cfg.sample_break_mode,
|
||||
document_sep_len=0,
|
||||
)
|
||||
logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))
|
||||
|
||||
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
|
||||
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
|
||||
dataset = AppendTokenDataset(dataset, self.source_dictionary.eos())
|
||||
return dataset
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
dataset = self._load_dataset_split(split, epoch, combine)
|
||||
|
||||
mask_whole_words = (
|
||||
get_whole_word_mask(self.cfg.bpe, self.source_dictionary)
|
||||
if self.cfg.mask_length != "subword"
|
||||
else None
|
||||
)
|
||||
|
||||
self.datasets[split] = DenoisingDataset(
|
||||
dataset,
|
||||
dataset.sizes,
|
||||
self.dictionary,
|
||||
self.mask_idx,
|
||||
mask_whole_words,
|
||||
shuffle=self.cfg.shuffle_instance,
|
||||
seed=self.cfg.seed,
|
||||
mask=self.cfg.mask,
|
||||
mask_random=self.cfg.mask_random,
|
||||
insert=self.cfg.insert,
|
||||
rotate=self.cfg.rotate,
|
||||
permute_sentences=self.cfg.permute_sentences,
|
||||
bpe=self.cfg.bpe,
|
||||
replace_length=self.cfg.replace_length,
|
||||
mask_length=self.cfg.mask_length,
|
||||
poisson_lambda=self.cfg.poisson_lambda,
|
||||
)
|
||||
logger.info(
|
||||
"Split: {0}, Loaded {1} samples of denoising_dataset".format(
|
||||
split,
|
||||
len(self.datasets[split]),
|
||||
)
|
||||
)
|
||||
|
||||
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
|
||||
"""
|
||||
Generate batches for inference. We assume that the input begins with a
|
||||
bos symbol (`<s>`) and ends with an eos symbol (`</s>`).
|
||||
"""
|
||||
pad = self.source_dictionary.pad()
|
||||
eos = self.source_dictionary.eos()
|
||||
src_dataset = TokenBlockDataset(
|
||||
src_tokens,
|
||||
src_lengths,
|
||||
block_size=self.cfg.tokens_per_sample - 2, # for <s> and </s>
|
||||
pad=pad,
|
||||
eos=eos,
|
||||
break_mode=self.cfg.sample_break_mode,
|
||||
document_sep_len=0,
|
||||
)
|
||||
prev_output_tokens = PrependTokenDataset(
|
||||
StripTokenDataset(src_dataset, eos), eos
|
||||
)
|
||||
src_dataset = PadDataset(src_dataset, pad_idx=pad, left_pad=False)
|
||||
return NestedDictionaryDataset(
|
||||
{
|
||||
"id": IdDataset(),
|
||||
"net_input": {
|
||||
"src_tokens": src_dataset,
|
||||
"src_lengths": NumelDataset(src_dataset, reduce=False),
|
||||
"prev_output_tokens": PadDataset(
|
||||
prev_output_tokens, pad_idx=pad, left_pad=False
|
||||
),
|
||||
},
|
||||
"target": src_dataset,
|
||||
},
|
||||
sizes=[np.array(src_lengths)],
|
||||
)
|
||||
|
||||
def max_positions(self):
|
||||
"""Return the max sentence length allowed by the task."""
|
||||
return (self.cfg.max_source_positions, self.cfg.max_target_positions)
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
"""Return the source :class:`~fairseq.data.Dictionary`."""
|
||||
return self.dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
"""Return the target :class:`~fairseq.data.Dictionary`."""
|
||||
return self.dictionary
|
||||
693
modules/voice_conversion/fairseq/tasks/fairseq_task.py
Normal file
693
modules/voice_conversion/fairseq/tasks/fairseq_task.py
Normal file
@@ -0,0 +1,693 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
from fairseq import metrics, search, tokenizer, utils
|
||||
from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
||||
from fairseq.optim.amp_optimizer import AMPOptimizer
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StatefulContainer(object):
|
||||
def __init__(self):
|
||||
self._state = dict()
|
||||
self._factories = dict()
|
||||
|
||||
def add_factory(self, name, factory: Callable[[], Any]):
|
||||
self._factories[name] = factory
|
||||
|
||||
def merge_state_dict(self, state_dict: Dict[str, Any]):
|
||||
self._state.update(state_dict)
|
||||
|
||||
@property
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
return self._state
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name not in self._state and name in self._factories:
|
||||
self._state[name] = self._factories[name]()
|
||||
|
||||
if name in self._state:
|
||||
return self._state[name]
|
||||
|
||||
raise AttributeError(f"Task state has no factory for attribute {name}")
|
||||
|
||||
|
||||
class FairseqTask(object):
|
||||
"""
|
||||
Tasks store dictionaries and provide helpers for loading/iterating over
|
||||
Datasets, initializing the Model/Criterion and calculating the loss.
|
||||
|
||||
Tasks have limited statefulness. In particular, state that needs to be
|
||||
saved to/loaded from checkpoints needs to be stored in the `self.state`
|
||||
:class:`StatefulContainer` object. For example::
|
||||
|
||||
self.state.add_factory("dictionary", self.load_dictionary)
|
||||
print(self.state.dictionary) # calls self.load_dictionary()
|
||||
|
||||
This is necessary so that when loading checkpoints, we can properly
|
||||
recreate the task state after initializing the task instance.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def add_args(cls, parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
dc = getattr(cls, "__dataclass", None)
|
||||
if dc is not None:
|
||||
gen_parser_from_dataclass(parser, dc())
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed(criterion) -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `train_step` and `valid_step` can
|
||||
be summed across workers prior to calling `aggregate_logging_outputs`.
|
||||
Setting this to True will improves distributed training speed.
|
||||
"""
|
||||
return criterion.logging_outputs_can_be_summed()
|
||||
|
||||
def __init__(self, cfg: FairseqDataclass, **kwargs):
|
||||
self.cfg = cfg
|
||||
self.datasets = dict()
|
||||
self.dataset_to_epoch_iter = dict()
|
||||
self.state = StatefulContainer()
|
||||
|
||||
@classmethod
|
||||
def load_dictionary(cls, filename):
|
||||
"""Load the dictionary from the filename
|
||||
|
||||
Args:
|
||||
filename (str): the filename
|
||||
"""
|
||||
return Dictionary.load(filename)
|
||||
|
||||
@classmethod
|
||||
def build_dictionary(
|
||||
cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8
|
||||
):
|
||||
"""Build the dictionary
|
||||
|
||||
Args:
|
||||
filenames (list): list of filenames
|
||||
workers (int): number of concurrent workers
|
||||
threshold (int): defines the minimum word count
|
||||
nwords (int): defines the total number of words in the final dictionary,
|
||||
including special symbols
|
||||
padding_factor (int): can be used to pad the dictionary size to be a
|
||||
multiple of 8, which is important on some hardware (e.g., Nvidia
|
||||
Tensor Cores).
|
||||
"""
|
||||
d = Dictionary()
|
||||
for filename in filenames:
|
||||
Dictionary.add_file_to_dictionary(
|
||||
filename, d, tokenizer.tokenize_line, workers
|
||||
)
|
||||
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, cfg: DictConfig, **kwargs):
|
||||
"""Setup the task (e.g., load dictionaries).
|
||||
|
||||
Args:
|
||||
cfg (omegaconf.DictConfig): parsed command-line arguments
|
||||
"""
|
||||
return cls(cfg, **kwargs)
|
||||
|
||||
def has_sharded_data(self, split):
|
||||
return os.pathsep in getattr(self.cfg, "data", "")
|
||||
|
||||
def load_dataset(
|
||||
self,
|
||||
split: str,
|
||||
combine: bool = False,
|
||||
task_cfg: FairseqDataclass = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
combine (bool): combines a split segmented into pieces into one dataset
|
||||
task_cfg (FairseqDataclass): optional task configuration stored in the checkpoint that can be used
|
||||
to load datasets
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def dataset(self, split):
|
||||
"""
|
||||
Return a loaded dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
|
||||
Returns:
|
||||
a :class:`~fairseq.data.FairseqDataset` corresponding to *split*
|
||||
"""
|
||||
from fairseq.data import FairseqDataset
|
||||
|
||||
if split not in self.datasets:
|
||||
raise KeyError("Dataset not loaded: " + split)
|
||||
if not isinstance(self.datasets[split], FairseqDataset):
|
||||
raise TypeError("Datasets are expected to be of type FairseqDataset")
|
||||
return self.datasets[split]
|
||||
|
||||
def filter_indices_by_size(
|
||||
self, indices, dataset, max_positions=None, ignore_invalid_inputs=False
|
||||
):
|
||||
"""
|
||||
Filter examples that are too large
|
||||
|
||||
Args:
|
||||
indices (np.array): original array of sample indices
|
||||
dataset (~fairseq.data.FairseqDataset): dataset to batch
|
||||
max_positions (optional): max sentence length supported by the
|
||||
model (default: None).
|
||||
ignore_invalid_inputs (bool, optional): don't raise Exception for
|
||||
sentences that are too long (default: False).
|
||||
Returns:
|
||||
np.array: array of filtered sample indices
|
||||
"""
|
||||
indices, ignored = dataset.filter_indices_by_size(indices, max_positions)
|
||||
if len(ignored) > 0:
|
||||
if not ignore_invalid_inputs:
|
||||
raise Exception(
|
||||
(
|
||||
"Size of sample #{} is invalid (={}) since max_positions={}, "
|
||||
"skip this example with --skip-invalid-size-inputs-valid-test"
|
||||
).format(ignored[0], dataset.size(ignored[0]), max_positions)
|
||||
)
|
||||
logger.warning(
|
||||
(
|
||||
"{:,} samples have invalid sizes and will be skipped, "
|
||||
"max_positions={}, first few sample ids={}"
|
||||
).format(len(ignored), max_positions, ignored[:10])
|
||||
)
|
||||
return indices
|
||||
|
||||
def can_reuse_epoch_itr(self, dataset):
|
||||
# We can reuse the epoch iterator across epochs as long as the dataset
|
||||
# hasn't disabled it. We default to ``False`` here, although in practice
|
||||
# this will be ``True`` for most datasets that inherit from
|
||||
# ``FairseqDataset`` due to the base implementation there.
|
||||
return getattr(dataset, "can_reuse_epoch_itr_across_epochs", False)
|
||||
|
||||
def get_batch_iterator(
|
||||
self,
|
||||
dataset,
|
||||
max_tokens=None,
|
||||
max_sentences=None,
|
||||
max_positions=None,
|
||||
ignore_invalid_inputs=False,
|
||||
required_batch_size_multiple=1,
|
||||
seed=1,
|
||||
num_shards=1,
|
||||
shard_id=0,
|
||||
num_workers=0,
|
||||
epoch=1,
|
||||
data_buffer_size=0,
|
||||
disable_iterator_cache=False,
|
||||
skip_remainder_batch=False,
|
||||
grouped_shuffling=False,
|
||||
update_epoch_batch_itr=False,
|
||||
):
|
||||
"""
|
||||
Get an iterator that yields batches of data from the given dataset.
|
||||
|
||||
Args:
|
||||
dataset (~fairseq.data.FairseqDataset): dataset to batch
|
||||
max_tokens (int, optional): max number of tokens in each batch
|
||||
(default: None).
|
||||
max_sentences (int, optional): max number of sentences in each
|
||||
batch (default: None).
|
||||
max_positions (optional): max sentence length supported by the
|
||||
model (default: None).
|
||||
ignore_invalid_inputs (bool, optional): don't raise Exception for
|
||||
sentences that are too long (default: False).
|
||||
required_batch_size_multiple (int, optional): require batch size to
|
||||
be a multiple of N (default: 1).
|
||||
seed (int, optional): seed for random number generator for
|
||||
reproducibility (default: 1).
|
||||
num_shards (int, optional): shard the data iterator into N
|
||||
shards (default: 1).
|
||||
shard_id (int, optional): which shard of the data iterator to
|
||||
return (default: 0).
|
||||
num_workers (int, optional): how many subprocesses to use for data
|
||||
loading. 0 means the data will be loaded in the main process
|
||||
(default: 0).
|
||||
epoch (int, optional): the epoch to start the iterator from
|
||||
(default: 1).
|
||||
data_buffer_size (int, optional): number of batches to
|
||||
preload (default: 0).
|
||||
disable_iterator_cache (bool, optional): don't cache the
|
||||
EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`)
|
||||
(default: False).
|
||||
skip_remainder_batch (bool, optional): if set, discard the last
|
||||
batch in each training epoch, as the last batch is often smaller than
|
||||
local_batch_size * distributed_word_size (default: ``True``).
|
||||
grouped_shuffling (bool, optional): group batches with each groups
|
||||
containing num_shards batches and shuffle groups. Reduces difference
|
||||
between sequence lengths among workers for batches sorted by length.
|
||||
update_epoch_batch_itr (bool optional): if true then donot use the cached
|
||||
batch iterator for the epoch
|
||||
|
||||
Returns:
|
||||
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
|
||||
given dataset split
|
||||
"""
|
||||
can_reuse_epoch_itr = (
|
||||
not disable_iterator_cache
|
||||
and not update_epoch_batch_itr
|
||||
and self.can_reuse_epoch_itr(dataset)
|
||||
)
|
||||
if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter:
|
||||
logger.debug("reusing EpochBatchIterator for epoch {}".format(epoch))
|
||||
return self.dataset_to_epoch_iter[dataset]
|
||||
|
||||
assert isinstance(dataset, FairseqDataset)
|
||||
|
||||
# initialize the dataset with the correct starting epoch
|
||||
dataset.set_epoch(epoch)
|
||||
|
||||
# get indices ordered by example size
|
||||
with data_utils.numpy_seed(seed):
|
||||
indices = dataset.ordered_indices()
|
||||
|
||||
# filter examples that are too large
|
||||
if max_positions is not None:
|
||||
indices = self.filter_indices_by_size(
|
||||
indices, dataset, max_positions, ignore_invalid_inputs
|
||||
)
|
||||
|
||||
# create mini-batches with given size constraints
|
||||
batch_sampler = dataset.batch_by_size(
|
||||
indices,
|
||||
max_tokens=max_tokens,
|
||||
max_sentences=max_sentences,
|
||||
required_batch_size_multiple=required_batch_size_multiple,
|
||||
)
|
||||
|
||||
reuse_dataloader = getattr(self.cfg, "reuse_dataloader", True)
|
||||
persistent_workers = getattr(self.cfg, "persistent_workers", False)
|
||||
|
||||
# return a reusable, sharded iterator
|
||||
epoch_iter = iterators.EpochBatchIterator(
|
||||
dataset=dataset,
|
||||
collate_fn=dataset.collater,
|
||||
batch_sampler=batch_sampler,
|
||||
seed=seed,
|
||||
num_shards=num_shards,
|
||||
shard_id=shard_id,
|
||||
num_workers=num_workers,
|
||||
epoch=epoch,
|
||||
buffer_size=data_buffer_size,
|
||||
skip_remainder_batch=skip_remainder_batch,
|
||||
grouped_shuffling=grouped_shuffling,
|
||||
reuse_dataloader=reuse_dataloader,
|
||||
persistent_workers=persistent_workers,
|
||||
)
|
||||
|
||||
if can_reuse_epoch_itr:
|
||||
self.dataset_to_epoch_iter[dataset] = epoch_iter
|
||||
|
||||
return epoch_iter
|
||||
|
||||
def build_model(self, cfg: FairseqDataclass, from_checkpoint=False):
|
||||
"""
|
||||
Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
|
||||
task.
|
||||
|
||||
Args:
|
||||
cfg (FairseqDataclass): configuration object
|
||||
|
||||
Returns:
|
||||
a :class:`~fairseq.models.BaseFairseqModel` instance
|
||||
"""
|
||||
from fairseq import models, quantization_utils
|
||||
|
||||
model = models.build_model(cfg, self, from_checkpoint)
|
||||
model = quantization_utils.quantize_model_scalar(model, cfg)
|
||||
return model
|
||||
|
||||
def build_criterion(self, cfg: DictConfig):
|
||||
"""
|
||||
Build the :class:`~fairseq.criterions.FairseqCriterion` instance for
|
||||
this task.
|
||||
|
||||
Args:
|
||||
cfg (omegaconf.DictConfig): configration object
|
||||
|
||||
Returns:
|
||||
a :class:`~fairseq.criterions.FairseqCriterion` instance
|
||||
"""
|
||||
from fairseq import criterions
|
||||
|
||||
return criterions.build_criterion(cfg, self)
|
||||
|
||||
def build_generator(
|
||||
self,
|
||||
models,
|
||||
args,
|
||||
seq_gen_cls=None,
|
||||
extra_gen_cls_kwargs=None,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
):
|
||||
"""
|
||||
Build a :class:`~fairseq.SequenceGenerator` instance for this
|
||||
task.
|
||||
|
||||
Args:
|
||||
models (List[~fairseq.models.FairseqModel]): ensemble of models
|
||||
args (fairseq.dataclass.configs.GenerationConfig):
|
||||
configuration object (dataclass) for generation
|
||||
extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass
|
||||
through to SequenceGenerator
|
||||
prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]):
|
||||
If provided, this function constrains the beam search to
|
||||
allowed tokens only at each step. The provided function
|
||||
should take 2 arguments: the batch ID (`batch_id: int`)
|
||||
and a unidimensional tensor of token ids (`inputs_ids:
|
||||
torch.Tensor`). It has to return a `List[int]` with the
|
||||
allowed tokens for the next generation step conditioned
|
||||
on the previously generated tokens (`inputs_ids`) and
|
||||
the batch ID (`batch_id`). This argument is useful for
|
||||
constrained generation conditioned on the prefix, as
|
||||
described in "Autoregressive Entity Retrieval"
|
||||
(https://arxiv.org/abs/2010.00904) and
|
||||
https://github.com/facebookresearch/GENRE.
|
||||
"""
|
||||
if getattr(args, "score_reference", False):
|
||||
from fairseq.sequence_scorer import SequenceScorer
|
||||
|
||||
return SequenceScorer(
|
||||
self.target_dictionary,
|
||||
compute_alignment=getattr(args, "print_alignment", False),
|
||||
)
|
||||
|
||||
from fairseq.sequence_generator import (
|
||||
SequenceGenerator,
|
||||
SequenceGeneratorWithAlignment,
|
||||
)
|
||||
|
||||
# Choose search strategy. Defaults to Beam Search.
|
||||
sampling = getattr(args, "sampling", False)
|
||||
sampling_topk = getattr(args, "sampling_topk", -1)
|
||||
sampling_topp = getattr(args, "sampling_topp", -1.0)
|
||||
diverse_beam_groups = getattr(args, "diverse_beam_groups", -1)
|
||||
diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
|
||||
match_source_len = getattr(args, "match_source_len", False)
|
||||
diversity_rate = getattr(args, "diversity_rate", -1)
|
||||
constrained = getattr(args, "constraints", False)
|
||||
if prefix_allowed_tokens_fn is None:
|
||||
prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None)
|
||||
if (
|
||||
sum(
|
||||
int(cond)
|
||||
for cond in [
|
||||
sampling,
|
||||
diverse_beam_groups > 0,
|
||||
match_source_len,
|
||||
diversity_rate > 0,
|
||||
]
|
||||
)
|
||||
> 1
|
||||
):
|
||||
raise ValueError("Provided Search parameters are mutually exclusive.")
|
||||
assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
|
||||
assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"
|
||||
|
||||
if sampling:
|
||||
search_strategy = search.Sampling(
|
||||
self.target_dictionary, sampling_topk, sampling_topp
|
||||
)
|
||||
elif diverse_beam_groups > 0:
|
||||
search_strategy = search.DiverseBeamSearch(
|
||||
self.target_dictionary, diverse_beam_groups, diverse_beam_strength
|
||||
)
|
||||
elif match_source_len:
|
||||
# this is useful for tagging applications where the output
|
||||
# length should match the input length, so we hardcode the
|
||||
# length constraints for simplicity
|
||||
search_strategy = search.LengthConstrainedBeamSearch(
|
||||
self.target_dictionary,
|
||||
min_len_a=1,
|
||||
min_len_b=0,
|
||||
max_len_a=1,
|
||||
max_len_b=0,
|
||||
)
|
||||
elif diversity_rate > -1:
|
||||
search_strategy = search.DiverseSiblingsSearch(
|
||||
self.target_dictionary, diversity_rate
|
||||
)
|
||||
elif constrained:
|
||||
search_strategy = search.LexicallyConstrainedBeamSearch(
|
||||
self.target_dictionary, args.constraints
|
||||
)
|
||||
elif prefix_allowed_tokens_fn:
|
||||
search_strategy = search.PrefixConstrainedBeamSearch(
|
||||
self.target_dictionary, prefix_allowed_tokens_fn
|
||||
)
|
||||
else:
|
||||
search_strategy = search.BeamSearch(self.target_dictionary)
|
||||
|
||||
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
|
||||
if seq_gen_cls is None:
|
||||
if getattr(args, "print_alignment", False):
|
||||
seq_gen_cls = SequenceGeneratorWithAlignment
|
||||
extra_gen_cls_kwargs["print_alignment"] = args.print_alignment
|
||||
else:
|
||||
seq_gen_cls = SequenceGenerator
|
||||
|
||||
return seq_gen_cls(
|
||||
models,
|
||||
self.target_dictionary,
|
||||
beam_size=getattr(args, "beam", 5),
|
||||
max_len_a=getattr(args, "max_len_a", 0),
|
||||
max_len_b=getattr(args, "max_len_b", 200),
|
||||
min_len=getattr(args, "min_len", 1),
|
||||
normalize_scores=(not getattr(args, "unnormalized", False)),
|
||||
len_penalty=getattr(args, "lenpen", 1),
|
||||
unk_penalty=getattr(args, "unkpen", 0),
|
||||
temperature=getattr(args, "temperature", 1.0),
|
||||
match_source_len=getattr(args, "match_source_len", False),
|
||||
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
|
||||
search_strategy=search_strategy,
|
||||
**extra_gen_cls_kwargs,
|
||||
)
|
||||
|
||||
def train_step(
|
||||
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
|
||||
):
|
||||
"""
|
||||
Do forward and backward, and return the loss as computed by *criterion*
|
||||
for the given *model* and *sample*.
|
||||
|
||||
Args:
|
||||
sample (dict): the mini-batch. The format is defined by the
|
||||
:class:`~fairseq.data.FairseqDataset`.
|
||||
model (~fairseq.models.BaseFairseqModel): the model
|
||||
criterion (~fairseq.criterions.FairseqCriterion): the criterion
|
||||
optimizer (~fairseq.optim.FairseqOptimizer): the optimizer
|
||||
update_num (int): the current update
|
||||
ignore_grad (bool): multiply loss by 0 if this is set to True
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- the loss
|
||||
- the sample size, which is used as the denominator for the
|
||||
gradient
|
||||
- logging outputs to display while training
|
||||
"""
|
||||
model.train()
|
||||
model.set_num_updates(update_num)
|
||||
with torch.autograd.profiler.record_function("forward"):
|
||||
with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))):
|
||||
loss, sample_size, logging_output = criterion(model, sample)
|
||||
if ignore_grad:
|
||||
loss *= 0
|
||||
with torch.autograd.profiler.record_function("backward"):
|
||||
optimizer.backward(loss)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def valid_step(self, sample, model, criterion):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
loss, sample_size, logging_output = criterion(model, sample)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def optimizer_step(self, optimizer, model, update_num):
|
||||
optimizer.step()
|
||||
|
||||
def build_dataset_for_inference(
|
||||
self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs
|
||||
) -> torch.utils.data.Dataset:
|
||||
raise NotImplementedError
|
||||
|
||||
def inference_step(
|
||||
self, generator, models, sample, prefix_tokens=None, constraints=None
|
||||
):
|
||||
with torch.no_grad():
|
||||
return generator.generate(
|
||||
models, sample, prefix_tokens=prefix_tokens, constraints=constraints
|
||||
)
|
||||
|
||||
def begin_epoch(self, epoch, model):
|
||||
"""Hook function called before the start of each epoch."""
|
||||
pass
|
||||
|
||||
def begin_valid_epoch(self, epoch, model):
|
||||
"""Hook function called before the start of each validation epoch."""
|
||||
pass
|
||||
|
||||
def aggregate_logging_outputs(self, logging_outputs, criterion):
|
||||
"""[deprecated] Aggregate logging outputs from data parallel training."""
|
||||
utils.deprecation_warning(
|
||||
"The aggregate_logging_outputs API is deprecated. "
|
||||
"Please use the reduce_metrics API instead."
|
||||
)
|
||||
with metrics.aggregate() as agg:
|
||||
self.reduce_metrics(logging_outputs, criterion)
|
||||
return agg.get_smoothed_values()
|
||||
|
||||
def reduce_metrics(self, logging_outputs, criterion):
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
# backward compatibility for tasks that override aggregate_logging_outputs
|
||||
base_func = FairseqTask.aggregate_logging_outputs
|
||||
self_func = getattr(self, "aggregate_logging_outputs").__func__
|
||||
if self_func is not base_func:
|
||||
utils.deprecation_warning(
|
||||
"Tasks should implement the reduce_metrics API. "
|
||||
"Falling back to deprecated aggregate_logging_outputs API."
|
||||
)
|
||||
agg_logging_outputs = self.aggregate_logging_outputs(
|
||||
logging_outputs, criterion
|
||||
)
|
||||
for k, v in agg_logging_outputs.items():
|
||||
metrics.log_scalar(k, v)
|
||||
return
|
||||
|
||||
if not any("ntokens" in log for log in logging_outputs):
|
||||
warnings.warn(
|
||||
"ntokens not found in Criterion logging outputs, cannot log wpb or wps"
|
||||
)
|
||||
else:
|
||||
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("wpb", ntokens, priority=180, round=1)
|
||||
metrics.log_speed("wps", ntokens, priority=90, round=1)
|
||||
|
||||
if not any("nsentences" in log for log in logging_outputs):
|
||||
warnings.warn(
|
||||
"nsentences not found in Criterion logging outputs, cannot log bsz"
|
||||
)
|
||||
else:
|
||||
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("bsz", nsentences, priority=190, round=1)
|
||||
|
||||
criterion.__class__.reduce_metrics(logging_outputs)
|
||||
|
||||
def state_dict(self):
|
||||
if self.state is not None:
|
||||
return self.state.state_dict
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]):
|
||||
if self.state is not None:
|
||||
self.state.merge_state_dict(state_dict)
|
||||
|
||||
def max_positions(self):
|
||||
"""Return the max input length allowed by the task."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
"""Return the source :class:`~fairseq.data.Dictionary` (if applicable
|
||||
for this task)."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
"""Return the target :class:`~fairseq.data.Dictionary` (if applicable
|
||||
for this task)."""
|
||||
raise NotImplementedError
|
||||
|
||||
def build_tokenizer(self, args):
|
||||
"""Build the pre-tokenizer for this task."""
|
||||
return encoders.build_tokenizer(args)
|
||||
|
||||
def build_bpe(self, args):
|
||||
"""Build the tokenizer for this task."""
|
||||
return encoders.build_bpe(args)
|
||||
|
||||
def get_interactive_tokens_and_lengths(self, lines, encode_fn):
|
||||
tokens = [
|
||||
self.source_dictionary.encode_line(
|
||||
encode_fn(src_str), add_if_not_exist=False
|
||||
).long()
|
||||
for src_str in lines
|
||||
]
|
||||
lengths = [t.numel() for t in tokens]
|
||||
return tokens, lengths
|
||||
|
||||
|
||||
class LegacyFairseqTask(FairseqTask):
|
||||
def __init__(self, args: Namespace):
|
||||
super().__init__(None)
|
||||
self.args = args
|
||||
self.datasets = {}
|
||||
self.dataset_to_epoch_iter = {}
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args: Namespace, **kwargs):
|
||||
"""Setup the task (e.g., load dictionaries).
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
"""
|
||||
return cls(args, **kwargs)
|
||||
|
||||
def has_sharded_data(self, split):
|
||||
return os.pathsep in getattr(self.args, "data", "")
|
||||
|
||||
def build_model(self, args: Namespace, from_checkpoint=False):
|
||||
"""
|
||||
Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
|
||||
task.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
|
||||
Returns:
|
||||
a :class:`~fairseq.models.BaseFairseqModel` instance
|
||||
"""
|
||||
from fairseq import models, quantization_utils
|
||||
|
||||
model = models.build_model(args, self, from_checkpoint)
|
||||
model = quantization_utils.quantize_model_scalar(model, args)
|
||||
return model
|
||||
|
||||
def build_criterion(self, args: Namespace):
|
||||
"""
|
||||
Build the :class:`~fairseq.criterions.FairseqCriterion` instance for
|
||||
this task.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
|
||||
Returns:
|
||||
a :class:`~fairseq.criterions.FairseqCriterion` instance
|
||||
"""
|
||||
from fairseq import criterions
|
||||
|
||||
return criterions.build_criterion(args, self)
|
||||
55
modules/voice_conversion/fairseq/tasks/frm_text_to_speech.py
Normal file
55
modules/voice_conversion/fairseq/tasks/frm_text_to_speech.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from fairseq.data.audio.frm_text_to_speech_dataset import FrmTextToSpeechDatasetCreator
|
||||
from fairseq.tasks import register_task
|
||||
from fairseq.tasks.text_to_speech import TextToSpeechTask
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task("frm_text_to_speech")
|
||||
class FrmTextToSpeechTask(TextToSpeechTask):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
TextToSpeechTask.add_args(parser)
|
||||
parser.add_argument("--do_chunk", action="store_true", help="train on chunks")
|
||||
parser.add_argument("--chunk_bound", default=-1, type=int)
|
||||
parser.add_argument("--chunk_init", default=50, type=int)
|
||||
parser.add_argument("--chunk_incr", default=5, type=int)
|
||||
parser.add_argument("--add_eos", action="store_true")
|
||||
parser.add_argument("--dedup", action="store_true")
|
||||
parser.add_argument("--ref_fpu", default=-1, type=float)
|
||||
|
||||
def load_dataset(self, split, **unused_kwargs):
|
||||
is_train_split = split.startswith("train")
|
||||
pre_tokenizer = self.build_tokenizer(self.args)
|
||||
bpe_tokenizer = self.build_bpe(self.args)
|
||||
self.datasets[split] = FrmTextToSpeechDatasetCreator.from_tsv(
|
||||
self.args.data,
|
||||
self.data_cfg,
|
||||
split,
|
||||
self.src_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
is_train_split=is_train_split,
|
||||
n_frames_per_step=self.args.n_frames_per_step,
|
||||
speaker_to_id=self.speaker_to_id,
|
||||
do_chunk=self.args.do_chunk,
|
||||
chunk_bound=self.args.chunk_bound,
|
||||
chunk_init=self.args.chunk_init,
|
||||
chunk_incr=self.args.chunk_incr,
|
||||
add_eos=self.args.add_eos,
|
||||
dedup=self.args.dedup,
|
||||
ref_fpu=self.args.ref_fpu,
|
||||
)
|
||||
191
modules/voice_conversion/fairseq/tasks/hubert_pretraining.py
Normal file
191
modules/voice_conversion/fairseq/tasks/hubert_pretraining.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from fairseq.data import Dictionary, HubertDataset
|
||||
from fairseq.dataclass.configs import FairseqDataclass
|
||||
from fairseq.tasks import register_task
|
||||
from fairseq.tasks.fairseq_task import FairseqTask
|
||||
from omegaconf import MISSING
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LabelEncoder(object):
|
||||
def __init__(self, dictionary: Dictionary) -> None:
|
||||
self.dictionary = dictionary
|
||||
|
||||
def __call__(self, label: str) -> List[str]:
|
||||
return self.dictionary.encode_line(
|
||||
label,
|
||||
append_eos=False,
|
||||
add_if_not_exist=False,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HubertPretrainingConfig(FairseqDataclass):
|
||||
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
|
||||
fine_tuning: bool = field(
|
||||
default=False, metadata={"help": "set to true if fine-tuning Hubert"}
|
||||
)
|
||||
labels: List[str] = field(
|
||||
default_factory=lambda: ["ltr"],
|
||||
metadata={
|
||||
"help": (
|
||||
"extension of the label files to load, frame-level labels for"
|
||||
" pre-training, and sequence-level label for fine-tuning"
|
||||
)
|
||||
},
|
||||
)
|
||||
label_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "if set, looks for labels in this directory instead",
|
||||
},
|
||||
)
|
||||
label_rate: float = field(
|
||||
default=-1.0,
|
||||
metadata={"help": "label frame rate. -1.0 for sequence label"},
|
||||
)
|
||||
sample_rate: int = field(
|
||||
default=16_000,
|
||||
metadata={
|
||||
"help": "target sample rate. audio files will be up/down "
|
||||
"sampled to this rate"
|
||||
},
|
||||
)
|
||||
normalize: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "if set, normalizes input to have 0 mean and unit variance"},
|
||||
)
|
||||
enable_padding: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "pad shorter samples instead of cropping"},
|
||||
)
|
||||
max_keep_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "exclude sample longer than this"},
|
||||
)
|
||||
max_sample_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "max sample size to crop to for batching"},
|
||||
)
|
||||
min_sample_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "min sample size to crop to for batching"},
|
||||
)
|
||||
single_target: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "if set, AddTargetDatasets outputs same keys " "as AddTargetDataset"
|
||||
},
|
||||
)
|
||||
random_crop: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "always crop from the beginning if false"},
|
||||
)
|
||||
pad_audio: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "pad audio to the longest one in the batch if true"},
|
||||
)
|
||||
|
||||
|
||||
@register_task("hubert_pretraining", dataclass=HubertPretrainingConfig)
|
||||
class HubertPretrainingTask(FairseqTask):
|
||||
|
||||
cfg: HubertPretrainingConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: HubertPretrainingConfig,
|
||||
) -> None:
|
||||
super().__init__(cfg)
|
||||
|
||||
logger.info(f"current directory is {os.getcwd()}")
|
||||
logger.info(f"HubertPretrainingTask Config {cfg}")
|
||||
|
||||
self.cfg = cfg
|
||||
self.fine_tuning = cfg.fine_tuning
|
||||
|
||||
if cfg.fine_tuning:
|
||||
self.state.add_factory("target_dictionary", self.load_dictionaries)
|
||||
else:
|
||||
self.state.add_factory("dictionaries", self.load_dictionaries)
|
||||
|
||||
self.blank_symbol = "<s>"
|
||||
|
||||
@property
|
||||
def source_dictionary(self) -> Optional[Dictionary]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def target_dictionary(self) -> Optional[Dictionary]:
|
||||
return self.state.target_dictionary
|
||||
|
||||
@property
|
||||
def dictionaries(self) -> List[Dictionary]:
|
||||
return self.state.dictionaries
|
||||
|
||||
@classmethod
|
||||
def setup_task(
|
||||
cls, cfg: HubertPretrainingConfig, **kwargs
|
||||
) -> "HubertPretrainingTask":
|
||||
return cls(cfg)
|
||||
|
||||
def load_dictionaries(self):
|
||||
label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir
|
||||
dictionaries = [
|
||||
Dictionary.load(f"{label_dir}/dict.{label}.txt")
|
||||
for label in self.cfg.labels
|
||||
]
|
||||
return dictionaries[0] if self.cfg.fine_tuning else dictionaries
|
||||
|
||||
def get_label_dir(self) -> str:
|
||||
if self.cfg.label_dir is None:
|
||||
return self.cfg.data
|
||||
return self.cfg.label_dir
|
||||
|
||||
def load_dataset(self, split: str, **kwargs) -> None:
|
||||
manifest = f"{self.cfg.data}/{split}.tsv"
|
||||
dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries
|
||||
pad_list = [dict.pad() for dict in dicts]
|
||||
eos_list = [dict.eos() for dict in dicts]
|
||||
procs = [LabelEncoder(dict) for dict in dicts]
|
||||
paths = [f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels]
|
||||
|
||||
# hubert v1: pad_audio=True, random_crop=False;
|
||||
self.datasets[split] = HubertDataset(
|
||||
manifest,
|
||||
sample_rate=self.cfg.sample_rate,
|
||||
label_paths=paths,
|
||||
label_rates=self.cfg.label_rate,
|
||||
pad_list=pad_list,
|
||||
eos_list=eos_list,
|
||||
label_processors=procs,
|
||||
max_keep_sample_size=self.cfg.max_keep_size,
|
||||
min_keep_sample_size=self.cfg.min_sample_size,
|
||||
max_sample_size=self.cfg.max_sample_size,
|
||||
pad_audio=self.cfg.pad_audio,
|
||||
normalize=self.cfg.normalize,
|
||||
store_labels=False,
|
||||
random_crop=self.cfg.random_crop,
|
||||
single_target=self.cfg.single_target,
|
||||
)
|
||||
|
||||
def max_positions(self) -> Tuple[int, int]:
|
||||
return (sys.maxsize, sys.maxsize)
|
||||
|
||||
def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array:
|
||||
return indices
|
||||
383
modules/voice_conversion/fairseq/tasks/language_modeling.py
Normal file
383
modules/voice_conversion/fairseq/tasks/language_modeling.py
Normal file
@@ -0,0 +1,383 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq import utils
|
||||
from fairseq.data import (
|
||||
AppendTokenDataset,
|
||||
Dictionary,
|
||||
IdDataset,
|
||||
LMContextWindowDataset,
|
||||
MonolingualDataset,
|
||||
NestedDictionaryDataset,
|
||||
NumelDataset,
|
||||
PadDataset,
|
||||
PrependTokenDataset,
|
||||
StripTokenDataset,
|
||||
TokenBlockDataset,
|
||||
TruncatedDictionary,
|
||||
data_utils,
|
||||
)
|
||||
from fairseq.data.indexed_dataset import get_available_dataset_impl
|
||||
from fairseq.data.shorten_dataset import maybe_shorten_dataset
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
from omegaconf import II
|
||||
|
||||
|
||||
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
|
||||
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageModelingConfig(FairseqDataclass):
|
||||
data: Optional[str] = field(
|
||||
default=None, metadata={"help": "path to data directory"}
|
||||
)
|
||||
sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
|
||||
default="none",
|
||||
metadata={
|
||||
"help": 'If omitted or "none", fills each sample with tokens-per-sample '
|
||||
'tokens. If set to "complete", splits samples only at the end '
|
||||
"of sentence, but may include multiple sentences per sample. "
|
||||
'"complete_doc" is similar but respects doc boundaries. '
|
||||
'If set to "eos", includes only one sentence per sample.'
|
||||
},
|
||||
)
|
||||
tokens_per_sample: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "max number of tokens per sample for LM dataset"},
|
||||
)
|
||||
output_dictionary_size: int = field(
|
||||
default=-1, metadata={"help": "limit the size of output dictionary"}
|
||||
)
|
||||
self_target: bool = field(default=False, metadata={"help": "include self target"})
|
||||
future_target: bool = field(
|
||||
default=False, metadata={"help": "include future target"}
|
||||
)
|
||||
past_target: bool = field(default=False, metadata={"help": "include past target"})
|
||||
add_bos_token: bool = field(
|
||||
default=False, metadata={"help": "prepend beginning of sentence token (<s>)"}
|
||||
)
|
||||
max_target_positions: Optional[int] = field(
|
||||
default=None, metadata={"help": "max number of tokens in the target sequence"}
|
||||
)
|
||||
shorten_method: SHORTEN_METHOD_CHOICES = field(
|
||||
default="none",
|
||||
metadata={
|
||||
"help": "if not none, shorten sequences that exceed --tokens-per-sample"
|
||||
},
|
||||
)
|
||||
shorten_data_split_list: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "comma-separated list of dataset splits to apply shortening to, "
|
||||
'e.g., "train,valid" (default: all dataset splits)'
|
||||
},
|
||||
)
|
||||
pad_to_fixed_length: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "pad to fixed length"},
|
||||
)
|
||||
pad_to_fixed_bsz: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "boolean to pad to fixed batch size"},
|
||||
)
|
||||
|
||||
# TODO common vars below add to parent
|
||||
seed: int = II("common.seed")
|
||||
batch_size: Optional[int] = II("dataset.batch_size")
|
||||
batch_size_valid: Optional[int] = II("dataset.batch_size_valid")
|
||||
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
|
||||
"dataset.dataset_impl"
|
||||
)
|
||||
data_buffer_size: int = II("dataset.data_buffer_size")
|
||||
tpu: bool = II("common.tpu")
|
||||
use_plasma_view: bool = II("common.use_plasma_view")
|
||||
plasma_path: str = II("common.plasma_path")
|
||||
|
||||
|
||||
@register_task("language_modeling", dataclass=LanguageModelingConfig)
|
||||
class LanguageModelingTask(LegacyFairseqTask):
|
||||
"""
|
||||
Train a language model.
|
||||
|
||||
Args:
|
||||
dictionary (~fairseq.data.Dictionary): the dictionary for the input of
|
||||
the language model
|
||||
output_dictionary (~fairseq.data.Dictionary): the dictionary for the
|
||||
output of the language model. In most cases it will be the same as
|
||||
*dictionary*, but could possibly be a more limited version of the
|
||||
dictionary (if ``--output-dictionary-size`` is used).
|
||||
targets (List[str]): list of the target types that the language model
|
||||
should predict. Can be one of "self", "future", and "past".
|
||||
Defaults to "future".
|
||||
|
||||
.. note::
|
||||
|
||||
The language modeling task is compatible with :mod:`fairseq-train`,
|
||||
:mod:`fairseq-generate`, :mod:`fairseq-interactive` and
|
||||
:mod:`fairseq-eval-lm`.
|
||||
|
||||
The language modeling task provides the following additional command-line
|
||||
arguments:
|
||||
|
||||
.. argparse::
|
||||
:ref: fairseq.tasks.language_modeling_parser
|
||||
:prog:
|
||||
"""
|
||||
|
||||
def __init__(self, args, dictionary, output_dictionary=None, targets=None):
|
||||
super().__init__(args)
|
||||
self.dictionary = dictionary
|
||||
self.output_dictionary = output_dictionary or dictionary
|
||||
|
||||
if targets is None:
|
||||
targets = ["future"]
|
||||
self.targets = targets
|
||||
|
||||
@classmethod
|
||||
def setup_dictionary(cls, args, **kwargs):
|
||||
dictionary = None
|
||||
output_dictionary = None
|
||||
if args.data:
|
||||
paths = utils.split_paths(args.data)
|
||||
assert len(paths) > 0
|
||||
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
output_dictionary = dictionary
|
||||
if args.output_dictionary_size >= 0:
|
||||
output_dictionary = TruncatedDictionary(
|
||||
dictionary, args.output_dictionary_size
|
||||
)
|
||||
return (dictionary, output_dictionary)
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
"""Setup the task (e.g., load dictionaries).
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
"""
|
||||
dictionary, output_dictionary = cls.setup_dictionary(args, **kwargs)
|
||||
|
||||
# upgrade old checkpoints
|
||||
if getattr(args, "exclude_self_target", False):
|
||||
args.self_target = False
|
||||
|
||||
targets = []
|
||||
if getattr(args, "self_target", False):
|
||||
targets.append("self")
|
||||
if getattr(args, "future_target", False):
|
||||
targets.append("future")
|
||||
if getattr(args, "past_target", False):
|
||||
targets.append("past")
|
||||
if len(targets) == 0:
|
||||
# standard language modeling
|
||||
targets = ["future"]
|
||||
|
||||
return cls(args, dictionary, output_dictionary, targets=targets)
|
||||
|
||||
def build_model(self, args, from_checkpoint=False):
|
||||
model = super().build_model(args, from_checkpoint)
|
||||
for target in self.targets:
|
||||
if target not in model.supported_targets:
|
||||
raise ValueError(
|
||||
"Unsupported language modeling target: {}".format(target)
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def load_dataset(
|
||||
self, split: str, epoch=1, combine=False, **kwargs
|
||||
) -> MonolingualDataset:
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, valid1, test)
|
||||
"""
|
||||
paths = utils.split_paths(self.args.data)
|
||||
assert len(paths) > 0
|
||||
|
||||
data_path = paths[(epoch - 1) % len(paths)]
|
||||
split_path = os.path.join(data_path, split)
|
||||
|
||||
# each process has its own copy of the raw data (likely to be an np.memmap)
|
||||
dataset = data_utils.load_indexed_dataset(
|
||||
split_path, self.dictionary, self.args.dataset_impl, combine=combine
|
||||
)
|
||||
if dataset is None:
|
||||
raise FileNotFoundError(f"Dataset not found: {split} ({split_path})")
|
||||
|
||||
dataset = maybe_shorten_dataset(
|
||||
dataset,
|
||||
split,
|
||||
self.args.shorten_data_split_list,
|
||||
self.args.shorten_method,
|
||||
self.args.tokens_per_sample,
|
||||
self.args.seed,
|
||||
)
|
||||
dataset = TokenBlockDataset(
|
||||
dataset,
|
||||
dataset.sizes,
|
||||
self.args.tokens_per_sample,
|
||||
pad=self.dictionary.pad(),
|
||||
eos=self.dictionary.eos(),
|
||||
break_mode=self.args.sample_break_mode,
|
||||
include_targets=True,
|
||||
use_plasma_view=self.args.use_plasma_view,
|
||||
split_path=split_path,
|
||||
plasma_path=self.args.plasma_path,
|
||||
)
|
||||
|
||||
add_eos_for_other_targets = (
|
||||
self.args.sample_break_mode is not None
|
||||
and self.args.sample_break_mode != "none"
|
||||
)
|
||||
fixed_pad_length = None
|
||||
if self.args.pad_to_fixed_length:
|
||||
fixed_pad_length = self.args.tokens_per_sample
|
||||
|
||||
pad_to_bsz = None
|
||||
if self.args.pad_to_fixed_bsz:
|
||||
pad_to_bsz = (
|
||||
self.args.batch_size_valid if "valid" in split else self.args.batch_size
|
||||
)
|
||||
|
||||
self.datasets[split] = MonolingualDataset(
|
||||
dataset=dataset,
|
||||
sizes=dataset.sizes,
|
||||
src_vocab=self.dictionary,
|
||||
tgt_vocab=self.output_dictionary,
|
||||
add_eos_for_other_targets=add_eos_for_other_targets,
|
||||
shuffle=True,
|
||||
targets=self.targets,
|
||||
add_bos_token=self.args.add_bos_token,
|
||||
fixed_pad_length=fixed_pad_length,
|
||||
pad_to_bsz=pad_to_bsz,
|
||||
)
|
||||
|
||||
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
|
||||
"""
|
||||
Generate batches for inference. We prepend an eos token to src_tokens
|
||||
(or bos if `--add-bos-token` is set) and we append a <pad> to target.
|
||||
This is convenient both for generation with a prefix and LM scoring.
|
||||
"""
|
||||
dataset = StripTokenDataset(
|
||||
TokenBlockDataset(
|
||||
src_tokens,
|
||||
src_lengths,
|
||||
block_size=None, # ignored for "eos" break mode
|
||||
pad=self.source_dictionary.pad(),
|
||||
eos=self.source_dictionary.eos(),
|
||||
break_mode="eos",
|
||||
),
|
||||
# remove eos from (end of) target sequence
|
||||
self.source_dictionary.eos(),
|
||||
)
|
||||
src_dataset = PrependTokenDataset(
|
||||
dataset,
|
||||
token=(
|
||||
self.source_dictionary.bos()
|
||||
if getattr(self.args, "add_bos_token", False)
|
||||
else self.source_dictionary.eos()
|
||||
),
|
||||
)
|
||||
tgt_dataset = AppendTokenDataset(dataset, token=self.source_dictionary.pad())
|
||||
return NestedDictionaryDataset(
|
||||
{
|
||||
"id": IdDataset(),
|
||||
"net_input": {
|
||||
"src_tokens": PadDataset(
|
||||
src_dataset,
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
left_pad=False,
|
||||
),
|
||||
"src_lengths": NumelDataset(src_dataset, reduce=False),
|
||||
},
|
||||
"target": PadDataset(
|
||||
tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False
|
||||
),
|
||||
},
|
||||
sizes=[np.array(src_lengths)],
|
||||
)
|
||||
|
||||
def inference_step(
|
||||
self, generator, models, sample, prefix_tokens=None, constraints=None
|
||||
):
|
||||
with torch.no_grad():
|
||||
# Generation will always be conditioned on bos_token
|
||||
if getattr(self.args, "add_bos_token", False):
|
||||
bos_token = self.source_dictionary.bos()
|
||||
else:
|
||||
bos_token = self.source_dictionary.eos()
|
||||
|
||||
if constraints is not None:
|
||||
raise NotImplementedError(
|
||||
"Constrained decoding with the language_modeling task is not supported"
|
||||
)
|
||||
|
||||
# SequenceGenerator doesn't use src_tokens directly, we need to
|
||||
# pass the `prefix_tokens` argument instead
|
||||
if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():
|
||||
prefix_tokens = sample["net_input"]["src_tokens"]
|
||||
if prefix_tokens[:, 0].eq(bos_token).all():
|
||||
prefix_tokens = prefix_tokens[:, 1:]
|
||||
|
||||
return generator.generate(
|
||||
models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token
|
||||
)
|
||||
|
||||
def eval_lm_dataloader(
|
||||
self,
|
||||
dataset,
|
||||
max_tokens: Optional[int] = 36000,
|
||||
batch_size: Optional[int] = None,
|
||||
max_positions: Optional[int] = None,
|
||||
num_shards: int = 1,
|
||||
shard_id: int = 0,
|
||||
num_workers: int = 1,
|
||||
data_buffer_size: int = 10,
|
||||
# ensures that every evaluated token has access to a context of at least
|
||||
# this size, if possible
|
||||
context_window: int = 0,
|
||||
):
|
||||
if context_window > 0:
|
||||
dataset = LMContextWindowDataset(
|
||||
dataset=dataset,
|
||||
tokens_per_sample=self.args.tokens_per_sample,
|
||||
context_window=context_window,
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
)
|
||||
return self.get_batch_iterator(
|
||||
dataset=dataset,
|
||||
max_tokens=max_tokens,
|
||||
max_sentences=batch_size,
|
||||
max_positions=max_positions,
|
||||
ignore_invalid_inputs=True,
|
||||
num_shards=num_shards,
|
||||
shard_id=shard_id,
|
||||
num_workers=num_workers,
|
||||
data_buffer_size=data_buffer_size,
|
||||
).next_epoch_itr(shuffle=False)
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
"""Return the :class:`~fairseq.data.Dictionary` for the language
|
||||
model."""
|
||||
return self.dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
"""Return the :class:`~fairseq.data.Dictionary` for the language
|
||||
model."""
|
||||
return self.output_dictionary
|
||||
152
modules/voice_conversion/fairseq/tasks/legacy_masked_lm.py
Normal file
152
modules/voice_conversion/fairseq/tasks/legacy_masked_lm.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from fairseq import tokenizer, utils
|
||||
from fairseq.data import ConcatDataset, Dictionary, data_utils, indexed_dataset
|
||||
from fairseq.data.legacy.block_pair_dataset import BlockPairDataset
|
||||
from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset
|
||||
from fairseq.data.legacy.masked_lm_dictionary import BertDictionary
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task("legacy_masked_lm")
|
||||
class LegacyMaskedLMTask(LegacyFairseqTask):
|
||||
"""
|
||||
Task for training Masked LM (BERT) model.
|
||||
Args:
|
||||
dictionary (Dictionary): the dictionary for the input of the task
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
parser.add_argument(
|
||||
"data",
|
||||
help="colon separated path to data directories list, \
|
||||
will be iterated upon during epochs in round-robin manner",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokens-per-sample",
|
||||
default=512,
|
||||
type=int,
|
||||
help="max number of total tokens over all segments"
|
||||
" per sample for BERT dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--break-mode", default="doc", type=str, help="mode for breaking sentence"
|
||||
)
|
||||
parser.add_argument("--shuffle-dataset", action="store_true", default=False)
|
||||
|
||||
def __init__(self, args, dictionary):
|
||||
super().__init__(args)
|
||||
self.dictionary = dictionary
|
||||
self.seed = args.seed
|
||||
|
||||
@classmethod
|
||||
def load_dictionary(cls, filename):
|
||||
return BertDictionary.load(filename)
|
||||
|
||||
@classmethod
|
||||
def build_dictionary(
|
||||
cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8
|
||||
):
|
||||
d = BertDictionary()
|
||||
for filename in filenames:
|
||||
Dictionary.add_file_to_dictionary(
|
||||
filename, d, tokenizer.tokenize_line, workers
|
||||
)
|
||||
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
|
||||
return d
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
"""Setup the task."""
|
||||
paths = utils.split_paths(args.data)
|
||||
assert len(paths) > 0
|
||||
dictionary = BertDictionary.load(os.path.join(paths[0], "dict.txt"))
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
|
||||
return cls(args, dictionary)
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
loaded_datasets = []
|
||||
|
||||
paths = utils.split_paths(self.args.data)
|
||||
assert len(paths) > 0
|
||||
data_path = paths[(epoch - 1) % len(paths)]
|
||||
logger.info("data_path", data_path)
|
||||
|
||||
for k in itertools.count():
|
||||
split_k = split + (str(k) if k > 0 else "")
|
||||
path = os.path.join(data_path, split_k)
|
||||
ds = indexed_dataset.make_dataset(
|
||||
path,
|
||||
impl=self.args.dataset_impl,
|
||||
fix_lua_indexing=True,
|
||||
dictionary=self.dictionary,
|
||||
)
|
||||
|
||||
if ds is None:
|
||||
if k > 0:
|
||||
break
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
"Dataset not found: {} ({})".format(split, data_path)
|
||||
)
|
||||
|
||||
with data_utils.numpy_seed(self.seed + k):
|
||||
loaded_datasets.append(
|
||||
BlockPairDataset(
|
||||
ds,
|
||||
self.dictionary,
|
||||
ds.sizes,
|
||||
self.args.tokens_per_sample,
|
||||
break_mode=self.args.break_mode,
|
||||
doc_break_size=1,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1]))
|
||||
)
|
||||
|
||||
if not combine:
|
||||
break
|
||||
|
||||
if len(loaded_datasets) == 1:
|
||||
dataset = loaded_datasets[0]
|
||||
sizes = dataset.sizes
|
||||
else:
|
||||
dataset = ConcatDataset(loaded_datasets)
|
||||
sizes = np.concatenate([ds.sizes for ds in loaded_datasets])
|
||||
|
||||
self.datasets[split] = MaskedLMDataset(
|
||||
dataset=dataset,
|
||||
sizes=sizes,
|
||||
vocab=self.dictionary,
|
||||
pad_idx=self.dictionary.pad(),
|
||||
mask_idx=self.dictionary.mask(),
|
||||
classif_token_idx=self.dictionary.cls(),
|
||||
sep_token_idx=self.dictionary.sep(),
|
||||
shuffle=self.args.shuffle_dataset,
|
||||
seed=self.seed,
|
||||
)
|
||||
270
modules/voice_conversion/fairseq/tasks/masked_lm.py
Normal file
270
modules/voice_conversion/fairseq/tasks/masked_lm.py
Normal file
@@ -0,0 +1,270 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
from omegaconf import II, MISSING, OmegaConf
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.data import (
|
||||
Dictionary,
|
||||
IdDataset,
|
||||
MaskTokensDataset,
|
||||
NestedDictionaryDataset,
|
||||
NumelDataset,
|
||||
NumSamplesDataset,
|
||||
PrependTokenDataset,
|
||||
RightPadDataset,
|
||||
SortDataset,
|
||||
TokenBlockDataset,
|
||||
data_utils,
|
||||
)
|
||||
from fairseq.data.encoders.utils import get_whole_word_mask
|
||||
from fairseq.data.shorten_dataset import maybe_shorten_dataset
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.tasks import FairseqTask, register_task
|
||||
|
||||
from .language_modeling import SAMPLE_BREAK_MODE_CHOICES, SHORTEN_METHOD_CHOICES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskedLMConfig(FairseqDataclass):
|
||||
data: str = field(
|
||||
default=MISSING,
|
||||
metadata={
|
||||
"help": "colon separated path to data directories list, \
|
||||
will be iterated upon during epochs in round-robin manner"
|
||||
},
|
||||
)
|
||||
sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
|
||||
default="none",
|
||||
metadata={
|
||||
"help": 'If omitted or "none", fills each sample with tokens-per-sample '
|
||||
'tokens. If set to "complete", splits samples only at the end '
|
||||
"of sentence, but may include multiple sentences per sample. "
|
||||
'"complete_doc" is similar but respects doc boundaries. '
|
||||
'If set to "eos", includes only one sentence per sample.'
|
||||
},
|
||||
)
|
||||
tokens_per_sample: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "max number of tokens per sample for LM dataset"},
|
||||
)
|
||||
mask_prob: float = field(
|
||||
default=0.15,
|
||||
metadata={"help": "probability of replacing a token with mask"},
|
||||
)
|
||||
leave_unmasked_prob: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "probability that a masked token is unmasked"},
|
||||
)
|
||||
random_token_prob: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "probability of replacing a token with a random token"},
|
||||
)
|
||||
freq_weighted_replacement: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "sample random replacement words based on word frequencies"},
|
||||
)
|
||||
mask_whole_words: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "mask whole words; you may also want to set --bpe"},
|
||||
)
|
||||
mask_multiple_length: int = field(
|
||||
default=1,
|
||||
metadata={"help": "repeat the mask indices multiple times"},
|
||||
)
|
||||
mask_stdev: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "stdev of the mask length"},
|
||||
)
|
||||
shorten_method: SHORTEN_METHOD_CHOICES = field(
|
||||
default="none",
|
||||
metadata={
|
||||
"help": "if not none, shorten sequences that exceed --tokens-per-sample"
|
||||
},
|
||||
)
|
||||
shorten_data_split_list: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "comma-separated list of dataset splits to apply shortening to, "
|
||||
'e.g., "train,valid" (default: all dataset splits)'
|
||||
},
|
||||
)
|
||||
seed: int = II("common.seed")
|
||||
|
||||
include_target_tokens: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "include target tokens in model input. this is used for data2vec"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@register_task("masked_lm", dataclass=MaskedLMConfig)
|
||||
class MaskedLMTask(FairseqTask):
|
||||
|
||||
cfg: MaskedLMConfig
|
||||
|
||||
"""Task for training masked language models (e.g., BERT, RoBERTa)."""
|
||||
|
||||
def __init__(self, cfg: MaskedLMConfig, dictionary):
|
||||
super().__init__(cfg)
|
||||
self.dictionary = dictionary
|
||||
|
||||
# add mask token
|
||||
self.mask_idx = dictionary.add_symbol("<mask>")
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, cfg: MaskedLMConfig, **kwargs):
|
||||
paths = utils.split_paths(cfg.data)
|
||||
assert len(paths) > 0
|
||||
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
return cls(cfg, dictionary)
|
||||
|
||||
def _load_dataset_split(self, split, epoch, combine):
|
||||
paths = utils.split_paths(self.cfg.data)
|
||||
assert len(paths) > 0
|
||||
data_path = paths[(epoch - 1) % len(paths)]
|
||||
split_path = os.path.join(data_path, split)
|
||||
|
||||
dataset = data_utils.load_indexed_dataset(
|
||||
split_path,
|
||||
self.source_dictionary,
|
||||
combine=combine,
|
||||
)
|
||||
if dataset is None:
|
||||
raise FileNotFoundError(
|
||||
"Dataset not found: {} ({})".format(split, split_path)
|
||||
)
|
||||
|
||||
dataset = maybe_shorten_dataset(
|
||||
dataset,
|
||||
split,
|
||||
self.cfg.shorten_data_split_list,
|
||||
self.cfg.shorten_method,
|
||||
self.cfg.tokens_per_sample,
|
||||
self.cfg.seed,
|
||||
)
|
||||
|
||||
# create continuous blocks of tokens
|
||||
dataset = TokenBlockDataset(
|
||||
dataset,
|
||||
dataset.sizes,
|
||||
self.cfg.tokens_per_sample - 1, # one less for <s>
|
||||
pad=self.source_dictionary.pad(),
|
||||
eos=self.source_dictionary.eos(),
|
||||
break_mode=self.cfg.sample_break_mode,
|
||||
)
|
||||
logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))
|
||||
|
||||
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
|
||||
return PrependTokenDataset(dataset, self.source_dictionary.bos())
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
dataset = self._load_dataset_split(split, epoch, combine)
|
||||
|
||||
# create masked input and targets
|
||||
mask_whole_words = (
|
||||
get_whole_word_mask(self.args, self.source_dictionary)
|
||||
if self.cfg.mask_whole_words
|
||||
else None
|
||||
)
|
||||
|
||||
src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
|
||||
dataset,
|
||||
self.source_dictionary,
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
mask_idx=self.mask_idx,
|
||||
seed=self.cfg.seed,
|
||||
mask_prob=self.cfg.mask_prob,
|
||||
leave_unmasked_prob=self.cfg.leave_unmasked_prob,
|
||||
random_token_prob=self.cfg.random_token_prob,
|
||||
freq_weighted_replacement=self.cfg.freq_weighted_replacement,
|
||||
mask_whole_words=mask_whole_words,
|
||||
mask_multiple_length=self.cfg.mask_multiple_length,
|
||||
mask_stdev=self.cfg.mask_stdev,
|
||||
)
|
||||
|
||||
with data_utils.numpy_seed(self.cfg.seed):
|
||||
shuffle = np.random.permutation(len(src_dataset))
|
||||
|
||||
target_dataset = RightPadDataset(
|
||||
tgt_dataset,
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
)
|
||||
|
||||
input_dict = {
|
||||
"src_tokens": RightPadDataset(
|
||||
src_dataset,
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
),
|
||||
"src_lengths": NumelDataset(src_dataset, reduce=False),
|
||||
}
|
||||
if self.cfg.include_target_tokens:
|
||||
input_dict["target_tokens"] = target_dataset
|
||||
|
||||
self.datasets[split] = SortDataset(
|
||||
NestedDictionaryDataset(
|
||||
{
|
||||
"id": IdDataset(),
|
||||
"net_input": input_dict,
|
||||
"target": target_dataset,
|
||||
"nsentences": NumSamplesDataset(),
|
||||
"ntokens": NumelDataset(src_dataset, reduce=True),
|
||||
},
|
||||
sizes=[src_dataset.sizes],
|
||||
),
|
||||
sort_order=[
|
||||
shuffle,
|
||||
src_dataset.sizes,
|
||||
],
|
||||
)
|
||||
|
||||
def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True):
|
||||
src_dataset = RightPadDataset(
|
||||
TokenBlockDataset(
|
||||
src_tokens,
|
||||
src_lengths,
|
||||
self.cfg.tokens_per_sample - 1, # one less for <s>
|
||||
pad=self.source_dictionary.pad(),
|
||||
eos=self.source_dictionary.eos(),
|
||||
break_mode="eos",
|
||||
),
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
)
|
||||
src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos())
|
||||
src_dataset = NestedDictionaryDataset(
|
||||
{
|
||||
"id": IdDataset(),
|
||||
"net_input": {
|
||||
"src_tokens": src_dataset,
|
||||
"src_lengths": NumelDataset(src_dataset, reduce=False),
|
||||
},
|
||||
},
|
||||
sizes=src_lengths,
|
||||
)
|
||||
if sort:
|
||||
src_dataset = SortDataset(src_dataset, sort_order=[src_lengths])
|
||||
return src_dataset
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.dictionary
|
||||
268
modules/voice_conversion/fairseq/tasks/multilingual_denoising.py
Normal file
268
modules/voice_conversion/fairseq/tasks/multilingual_denoising.py
Normal file
@@ -0,0 +1,268 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from omegaconf import II
|
||||
|
||||
from fairseq.data import (
|
||||
AppendTokenDataset,
|
||||
ConcatDataset,
|
||||
DenoisingDataset,
|
||||
Dictionary,
|
||||
PrependTokenDataset,
|
||||
ResamplingDataset,
|
||||
SortDataset,
|
||||
TokenBlockDataset,
|
||||
data_utils,
|
||||
)
|
||||
from fairseq.data.encoders.utils import get_whole_word_mask
|
||||
from fairseq.tasks import register_task
|
||||
|
||||
from .denoising import DenoisingConfig, DenoisingTask
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultilingualDenoisingConfig(DenoisingConfig):
|
||||
multilang_sampling_alpha: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "smoothing alpha for sample ratios across multiple datasets"},
|
||||
)
|
||||
add_lang_token: bool = field(
|
||||
default=False,
|
||||
metadata={"help": ""},
|
||||
)
|
||||
langs: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "language ids we are considering"},
|
||||
)
|
||||
no_whole_word_mask_langs: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "languages without spacing between words don't support whole word masking"
|
||||
},
|
||||
)
|
||||
train_subset: str = II("common.train_subset")
|
||||
valid_subset: str = II("common.valid_subset")
|
||||
|
||||
|
||||
@register_task("multilingual_denoising", dataclass=MultilingualDenoisingConfig)
|
||||
class MultilingualDenoisingTask(DenoisingTask):
|
||||
|
||||
cfg: MultilingualDenoisingConfig
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, cfg: MultilingualDenoisingConfig, **kwargs):
|
||||
"""Setup the task."""
|
||||
paths = cfg.data.split(":")
|
||||
assert len(paths) > 0
|
||||
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
|
||||
|
||||
data_path = paths[0]
|
||||
if cfg.langs is None:
|
||||
languages = sorted(
|
||||
[
|
||||
name
|
||||
for name in os.listdir(data_path)
|
||||
if os.path.isdir(os.path.join(data_path, name))
|
||||
]
|
||||
)
|
||||
else:
|
||||
languages = cfg.langs.split(",")
|
||||
|
||||
if cfg.add_lang_token:
|
||||
for lang in languages:
|
||||
dictionary.add_symbol("[{}]".format(lang))
|
||||
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
if not hasattr(cfg, "shuffle_instance"):
|
||||
cfg.shuffle_instance = False
|
||||
return cls(cfg, dictionary)
|
||||
|
||||
def __init__(self, cfg: MultilingualDenoisingConfig, dictionary):
|
||||
super().__init__(cfg, dictionary)
|
||||
self.dictionary = dictionary
|
||||
|
||||
# add mask token
|
||||
self.mask_idx = self.dictionary.add_symbol("<mask>")
|
||||
self.cfg = cfg
|
||||
|
||||
def _get_sample_prob(self, dataset_lens):
|
||||
"""
|
||||
Get smoothed sampling probability by languages. This helps low resource
|
||||
languages by upsampling them.
|
||||
"""
|
||||
prob = dataset_lens / dataset_lens.sum()
|
||||
smoothed_prob = prob**self.cfg.multilang_sampling_alpha
|
||||
smoothed_prob = smoothed_prob / smoothed_prob.sum()
|
||||
return smoothed_prob
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
paths = self.cfg.data.split(":")
|
||||
assert len(paths) > 0
|
||||
data_path = paths[(epoch - 1) % len(paths)]
|
||||
split_path = os.path.join(data_path, split)
|
||||
|
||||
if self.cfg.langs is None:
|
||||
languages = sorted(
|
||||
[
|
||||
name
|
||||
for name in os.listdir(data_path)
|
||||
if os.path.isdir(os.path.join(data_path, name))
|
||||
]
|
||||
)
|
||||
else:
|
||||
languages = self.cfg.langs.split(",")
|
||||
for name in languages:
|
||||
p = os.path.join(data_path, name)
|
||||
assert os.path.exists(p), "data not found: {}".format(p)
|
||||
|
||||
logger.info("Training on {0} languages: {1}".format(len(languages), languages))
|
||||
logger.info(
|
||||
"Language to id mapping: ", {lang: id for id, lang in enumerate(languages)}
|
||||
)
|
||||
|
||||
mask_whole_words = get_whole_word_mask(self.cfg.bpe, self.dictionary)
|
||||
language_without_segmentations = self.cfg.no_whole_word_mask_langs.split(",")
|
||||
lang_datasets = []
|
||||
for language in languages:
|
||||
split_path = os.path.join(data_path, language, split)
|
||||
|
||||
dataset = data_utils.load_indexed_dataset(
|
||||
split_path,
|
||||
self.source_dictionary,
|
||||
self.cfg.dataset_impl,
|
||||
combine=combine,
|
||||
)
|
||||
if dataset is None:
|
||||
raise FileNotFoundError(
|
||||
"Dataset not found: {} ({})".format(split, split_path)
|
||||
)
|
||||
|
||||
end_token = (
|
||||
self.source_dictionary.index("[{}]".format(language))
|
||||
if self.cfg.add_lang_token
|
||||
else self.source_dictionary.eos()
|
||||
)
|
||||
|
||||
# create continuous blocks of tokens
|
||||
dataset = TokenBlockDataset(
|
||||
dataset,
|
||||
dataset.sizes,
|
||||
self.cfg.tokens_per_sample - 2, # one less for <s>
|
||||
pad=self.source_dictionary.pad(),
|
||||
eos=end_token,
|
||||
break_mode=self.cfg.sample_break_mode,
|
||||
)
|
||||
logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))
|
||||
|
||||
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
|
||||
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
|
||||
dataset = AppendTokenDataset(dataset, end_token)
|
||||
|
||||
lang_mask_whole_words = (
|
||||
mask_whole_words
|
||||
if language not in language_without_segmentations
|
||||
else None
|
||||
)
|
||||
lang_dataset = DenoisingDataset(
|
||||
dataset,
|
||||
dataset.sizes,
|
||||
self.dictionary,
|
||||
self.mask_idx,
|
||||
lang_mask_whole_words,
|
||||
shuffle=self.cfg.shuffle_instance,
|
||||
seed=self.cfg.seed,
|
||||
mask=self.cfg.mask,
|
||||
mask_random=self.cfg.mask_random,
|
||||
insert=self.cfg.insert,
|
||||
rotate=self.cfg.rotate,
|
||||
permute_sentences=self.cfg.permute_sentences,
|
||||
bpe=self.cfg.bpe,
|
||||
replace_length=self.cfg.replace_length,
|
||||
mask_length=self.cfg.mask_length,
|
||||
poisson_lambda=self.cfg.poisson_lambda,
|
||||
eos=None
|
||||
if not self.cfg.add_lang_token
|
||||
else self.source_dictionary.index("[{}]".format(language)),
|
||||
)
|
||||
lang_datasets.append(lang_dataset)
|
||||
|
||||
dataset_lengths = np.array(
|
||||
[len(d) for d in lang_datasets],
|
||||
dtype=float,
|
||||
)
|
||||
logger.info(
|
||||
"loaded total {} blocks for all languages".format(
|
||||
int(dataset_lengths.sum()),
|
||||
)
|
||||
)
|
||||
if split == self.cfg.train_subset:
|
||||
# For train subset, additionally up or down sample languages.
|
||||
sample_probs = self._get_sample_prob(dataset_lengths)
|
||||
logger.info(
|
||||
"Sample probability by language: {}".format(
|
||||
{
|
||||
lang: "{0:.4f}".format(sample_probs[id])
|
||||
for id, lang in enumerate(languages)
|
||||
}
|
||||
)
|
||||
)
|
||||
size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths
|
||||
logger.info(
|
||||
"Up/Down Sampling ratio by language: {}".format(
|
||||
{
|
||||
lang: "{0:.2f}".format(size_ratio[id])
|
||||
for id, lang in enumerate(languages)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
resampled_lang_datasets = [
|
||||
ResamplingDataset(
|
||||
lang_datasets[i],
|
||||
size_ratio=size_ratio[i],
|
||||
seed=self.cfg.seed,
|
||||
epoch=epoch,
|
||||
replace=size_ratio[i] >= 1.0,
|
||||
)
|
||||
for i, d in enumerate(lang_datasets)
|
||||
]
|
||||
dataset = ConcatDataset(
|
||||
resampled_lang_datasets,
|
||||
)
|
||||
else:
|
||||
dataset = ConcatDataset(lang_datasets)
|
||||
lang_splits = [split]
|
||||
for lang_id, lang_dataset in enumerate(lang_datasets):
|
||||
split_name = split + "_" + languages[lang_id]
|
||||
lang_splits.append(split_name)
|
||||
self.datasets[split_name] = lang_dataset
|
||||
|
||||
if split in self.cfg.valid_subset:
|
||||
self.cfg.valid_subset = self.cfg.valid_subset.replace(
|
||||
split, ",".join(lang_splits)
|
||||
)
|
||||
|
||||
with data_utils.numpy_seed(self.cfg.seed + epoch):
|
||||
shuffle = np.random.permutation(len(dataset))
|
||||
|
||||
self.datasets[split] = SortDataset(
|
||||
dataset,
|
||||
sort_order=[
|
||||
shuffle,
|
||||
dataset.sizes,
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,627 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import II
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.data import (
|
||||
AppendTokenDataset,
|
||||
ConcatDataset,
|
||||
Dictionary,
|
||||
IdDataset,
|
||||
LMContextWindowDataset,
|
||||
MonolingualDataset,
|
||||
NestedDictionaryDataset,
|
||||
NumelDataset,
|
||||
PadDataset,
|
||||
PrependTokenDataset,
|
||||
ResamplingDataset,
|
||||
SortDataset,
|
||||
StripTokenDataset,
|
||||
TokenBlockDataset,
|
||||
TruncatedDictionary,
|
||||
data_utils,
|
||||
)
|
||||
from fairseq.data.indexed_dataset import get_available_dataset_impl
|
||||
from fairseq.data.shorten_dataset import maybe_shorten_dataset
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
|
||||
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def lang_token(lang):
|
||||
return f"<{lang}>"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultilingualLanguageModelingConfig(FairseqDataclass):
|
||||
# TODO common var add to parent
|
||||
data: Optional[str] = field(
|
||||
default=None, metadata={"help": "path to data directory"}
|
||||
)
|
||||
sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
|
||||
default="none",
|
||||
metadata={
|
||||
"help": 'If omitted or "none", fills each sample with tokens-per-sample '
|
||||
'tokens. If set to "complete", splits samples only at the end '
|
||||
"of sentence, but may include multiple sentences per sample. "
|
||||
'"complete_doc" is similar but respects doc boundaries. '
|
||||
'If set to "eos", includes only one sentence per sample.'
|
||||
},
|
||||
)
|
||||
tokens_per_sample: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "max number of tokens per sample for LM dataset"},
|
||||
)
|
||||
output_dictionary_size: int = field(
|
||||
default=-1, metadata={"help": "limit the size of output dictionary"}
|
||||
)
|
||||
self_target: bool = field(default=False, metadata={"help": "include self target"})
|
||||
future_target: bool = field(
|
||||
default=False, metadata={"help": "include future target"}
|
||||
)
|
||||
past_target: bool = field(default=False, metadata={"help": "include past target"})
|
||||
add_bos_token: bool = field(
|
||||
default=False, metadata={"help": "prepend lang id token <dialect>"}
|
||||
)
|
||||
max_source_positions: Optional[int] = field(
|
||||
default=None, metadata={"help": "max number of tokens in the source sequence"}
|
||||
)
|
||||
max_target_positions: Optional[int] = field(
|
||||
default=None, metadata={"help": "max number of tokens in the target sequence"}
|
||||
)
|
||||
pad_to_fixed_length: Optional[bool] = field(
|
||||
default=False, metadata={"help": "pad to fixed length"}
|
||||
)
|
||||
pad_to_fixed_bsz: Optional[bool] = field(
|
||||
default=False, metadata={"help": "boolean to pad to fixed batch size"}
|
||||
)
|
||||
|
||||
multilang_sampling_alpha: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={
|
||||
"help": "smoothing alpha for sample rations across multiple datasets"
|
||||
},
|
||||
)
|
||||
|
||||
shorten_method: SHORTEN_METHOD_CHOICES = field(
|
||||
default="none",
|
||||
metadata={
|
||||
"help": "if not none, shorten sequences that exceed --tokens-per-sample"
|
||||
},
|
||||
)
|
||||
shorten_data_split_list: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "comma-separated list of dataset splits to apply shortening to, "
|
||||
'e.g., "train,valid" (default: all dataset splits)'
|
||||
},
|
||||
)
|
||||
|
||||
langs: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "comma-separated list of languages (default: all directories in data path)"
|
||||
},
|
||||
)
|
||||
baseline_model_langs: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "comma-separated list of languages in the baseline model (default: none)"
|
||||
},
|
||||
)
|
||||
# TODO: legacy parameter kept for compatibility
|
||||
baseline_model: str = field(
|
||||
default="",
|
||||
metadata={"help": "path to the baseline model (default: none)"},
|
||||
)
|
||||
|
||||
lang_to_offline_shard_ratio: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "absolute path of tsv file location to indicate lang to offline shard ratio.",
|
||||
},
|
||||
)
|
||||
# TODO common vars below add to parent
|
||||
seed: int = II("common.seed")
|
||||
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
|
||||
"dataset.dataset_impl"
|
||||
)
|
||||
data_buffer_size: int = II("dataset.data_buffer_size")
|
||||
tpu: bool = II("common.tpu")
|
||||
batch_size: Optional[int] = II("dataset.batch_size")
|
||||
batch_size_valid: Optional[int] = II("dataset.batch_size_valid")
|
||||
train_subset: str = II("common.train_subset")
|
||||
valid_subset: str = II("common.valid_subset")
|
||||
|
||||
|
||||
@register_task(
|
||||
"multilingual_language_modeling", dataclass=MultilingualLanguageModelingConfig
|
||||
)
|
||||
class MultilingualLanguageModelingTask(LegacyFairseqTask):
|
||||
"""
|
||||
Train a language model.
|
||||
|
||||
Args:
|
||||
dictionary (~fairseq.data.Dictionary): the dictionary for the input of
|
||||
the language model
|
||||
output_dictionary (~fairseq.data.Dictionary): the dictionary for the
|
||||
output of the language model. In most cases it will be the same as
|
||||
*dictionary*, but could possibly be a more limited version of the
|
||||
dictionary (if ``--output-dictionary-size`` is used).
|
||||
targets (List[str]): list of the target types that the language model
|
||||
should predict. Can be one of "self", "future", and "past".
|
||||
Defaults to "future".
|
||||
|
||||
.. note::
|
||||
|
||||
The language modeling task is compatible with :mod:`fairseq-train`,
|
||||
:mod:`fairseq-generate`, :mod:`fairseq-interactive` and
|
||||
:mod:`fairseq-eval-lm`.
|
||||
|
||||
The language modeling task provides the following additional command-line
|
||||
arguments:
|
||||
|
||||
.. argparse::
|
||||
:ref: fairseq.tasks.language_modeling_parser
|
||||
:prog:
|
||||
"""
|
||||
|
||||
def __init__(self, args, dictionary, output_dictionary=None, targets=None):
|
||||
super().__init__(args)
|
||||
self.dictionary = dictionary
|
||||
self.output_dictionary = output_dictionary or dictionary
|
||||
|
||||
if targets is None:
|
||||
targets = ["future"]
|
||||
self.targets = targets
|
||||
|
||||
@staticmethod
|
||||
def _get_langs(args, epoch=1):
|
||||
paths = utils.split_paths(args.data)
|
||||
assert len(paths) > 0
|
||||
data_path = paths[(epoch - 1) % len(paths)]
|
||||
|
||||
languages = sorted(
|
||||
name
|
||||
for name in os.listdir(data_path)
|
||||
if os.path.isdir(os.path.join(data_path, name))
|
||||
)
|
||||
if args.langs:
|
||||
keep_langs = set(args.langs.split(","))
|
||||
languages = [lang for lang in languages if lang in keep_langs]
|
||||
assert len(languages) == len(keep_langs)
|
||||
|
||||
return languages, data_path
|
||||
|
||||
@classmethod
|
||||
def setup_dictionary(cls, args, **kwargs):
|
||||
dictionary = None
|
||||
output_dictionary = None
|
||||
if args.data:
|
||||
paths = utils.split_paths(args.data)
|
||||
assert len(paths) > 0
|
||||
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
|
||||
if args.add_bos_token:
|
||||
languages, _ = cls._get_langs(args)
|
||||
logger.info("----------------")
|
||||
for lang in languages:
|
||||
dictionary.add_symbol(lang_token(lang))
|
||||
logger.info(f"add language token: {lang_token(lang)}")
|
||||
logger.info("----------------")
|
||||
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
output_dictionary = dictionary
|
||||
if args.output_dictionary_size >= 0:
|
||||
output_dictionary = TruncatedDictionary(
|
||||
dictionary, args.output_dictionary_size
|
||||
)
|
||||
return (dictionary, output_dictionary)
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
"""Setup the task (e.g., load dictionaries).
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
"""
|
||||
dictionary, output_dictionary = cls.setup_dictionary(args, **kwargs)
|
||||
|
||||
# upgrade old checkpoints
|
||||
if hasattr(args, "exclude_self_target"):
|
||||
args.self_target = not args.exclude_self_target
|
||||
|
||||
targets = []
|
||||
if getattr(args, "self_target", False):
|
||||
targets.append("self")
|
||||
if getattr(args, "future_target", False):
|
||||
targets.append("future")
|
||||
if getattr(args, "past_target", False):
|
||||
targets.append("past")
|
||||
if len(targets) == 0:
|
||||
# standard language modeling
|
||||
targets = ["future"]
|
||||
|
||||
return cls(args, dictionary, output_dictionary, targets=targets)
|
||||
|
||||
def build_model(self, args, from_checkpoint=False):
|
||||
model = super().build_model(args, from_checkpoint)
|
||||
for target in self.targets:
|
||||
if target not in model.supported_targets:
|
||||
raise ValueError(
|
||||
f"Unsupported language modeling target: {target} not in {model.supported_targets}"
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def _get_sample_prob(self, dataset_lens):
|
||||
"""
|
||||
Get smoothed sampling porbability by languages. This helps low resource
|
||||
languages by upsampling them.
|
||||
"""
|
||||
prob = dataset_lens / dataset_lens.sum()
|
||||
smoothed_prob = prob**self.args.multilang_sampling_alpha
|
||||
smoothed_prob = smoothed_prob / smoothed_prob.sum()
|
||||
return smoothed_prob
|
||||
|
||||
def load_dataset(self, split: str, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
languages, data_path = MultilingualLanguageModelingTask._get_langs(
|
||||
self.args, epoch
|
||||
)
|
||||
lang_to_offline_shard_ratio = None
|
||||
if self.args.lang_to_offline_shard_ratio != "":
|
||||
lang_to_offline_shard_ratio = {}
|
||||
assert os.path.exists(
|
||||
self.args.lang_to_offline_shard_ratio
|
||||
), "provided offline shard ratio file doesn't exist: {0}".format(
|
||||
self.args.lang_to_offline_shard_ratio
|
||||
)
|
||||
with open(self.args.lang_to_offline_shard_ratio) as fin:
|
||||
for line in fin:
|
||||
lang, ratio = line.strip().split("\t")
|
||||
ratio = float(ratio)
|
||||
lang_to_offline_shard_ratio[lang] = ratio
|
||||
|
||||
logger.info(
|
||||
"Found offline sharded ratio: %s",
|
||||
lang_to_offline_shard_ratio,
|
||||
)
|
||||
|
||||
if split == self.args.train_subset:
|
||||
logger.info(
|
||||
"Training on {0} languages: {1}".format(len(languages), languages)
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Evaluating on {0} languages: {1}".format(len(languages), languages)
|
||||
)
|
||||
|
||||
tokens_per_sample = self.args.tokens_per_sample - int(self.args.add_bos_token)
|
||||
|
||||
fixed_pad_length = None
|
||||
if self.args.pad_to_fixed_length:
|
||||
fixed_pad_length = self.args.tokens_per_sample
|
||||
|
||||
pad_to_bsz = None
|
||||
if self.args.pad_to_fixed_bsz:
|
||||
pad_to_bsz = (
|
||||
self.args.batch_size_valid if "valid" in split else self.args.batch_size
|
||||
)
|
||||
|
||||
lang_datasets = []
|
||||
for lang_id, language in enumerate(languages):
|
||||
split_path = os.path.join(data_path, language, split)
|
||||
dataset = data_utils.load_indexed_dataset(
|
||||
split_path, self.dictionary, self.args.dataset_impl, combine=combine
|
||||
)
|
||||
# print('len(dataset) =', len(dataset))
|
||||
if dataset is None:
|
||||
raise FileNotFoundError(
|
||||
"Dataset not found: {} ({})".format(split, split_path)
|
||||
)
|
||||
|
||||
dataset = maybe_shorten_dataset(
|
||||
dataset,
|
||||
split,
|
||||
self.args.shorten_data_split_list,
|
||||
self.args.shorten_method,
|
||||
tokens_per_sample,
|
||||
self.args.seed,
|
||||
)
|
||||
|
||||
dataset = TokenBlockDataset(
|
||||
dataset,
|
||||
dataset.sizes,
|
||||
tokens_per_sample,
|
||||
pad=self.dictionary.pad(),
|
||||
eos=self.dictionary.eos(),
|
||||
break_mode=self.args.sample_break_mode,
|
||||
include_targets=True,
|
||||
)
|
||||
|
||||
add_eos_for_other_targets = (
|
||||
self.args.sample_break_mode is not None
|
||||
and self.args.sample_break_mode != "none"
|
||||
)
|
||||
src_lang_idx, tgt_lang_idx = None, None
|
||||
if self.args.add_bos_token:
|
||||
src_lang_idx = self.dictionary.index(lang_token(language))
|
||||
tgt_lang_idx = self.output_dictionary.index(lang_token(language))
|
||||
|
||||
lang_datasets.append(
|
||||
MonolingualDataset(
|
||||
dataset=dataset,
|
||||
sizes=dataset.sizes,
|
||||
src_vocab=self.dictionary,
|
||||
tgt_vocab=self.output_dictionary,
|
||||
add_eos_for_other_targets=add_eos_for_other_targets,
|
||||
shuffle=True,
|
||||
targets=self.targets,
|
||||
fixed_pad_length=fixed_pad_length,
|
||||
pad_to_bsz=pad_to_bsz,
|
||||
add_bos_token=self.args.add_bos_token,
|
||||
src_lang_idx=src_lang_idx,
|
||||
tgt_lang_idx=tgt_lang_idx,
|
||||
)
|
||||
)
|
||||
|
||||
dataset_lengths = np.array(
|
||||
[len(d) for d in lang_datasets],
|
||||
dtype=float,
|
||||
)
|
||||
logger.info(
|
||||
"loaded total {} blocks for all languages".format(
|
||||
dataset_lengths.sum(),
|
||||
)
|
||||
)
|
||||
if split == self.args.train_subset:
|
||||
dataset_lengths_ratio_multiplier = np.ones(len(dataset_lengths))
|
||||
if lang_to_offline_shard_ratio is not None:
|
||||
dataset_lengths_ratio_multiplier = []
|
||||
for lang in languages:
|
||||
assert (
|
||||
lang in lang_to_offline_shard_ratio
|
||||
), "Lang: {0} missing in offline shard ratio file: {1}".format(
|
||||
lang,
|
||||
self.args.lang_to_offline_shard_ratio,
|
||||
)
|
||||
dataset_lengths_ratio_multiplier.append(
|
||||
lang_to_offline_shard_ratio[lang]
|
||||
)
|
||||
dataset_lengths_ratio_multiplier = np.array(
|
||||
dataset_lengths_ratio_multiplier
|
||||
)
|
||||
true_dataset_lengths = (
|
||||
dataset_lengths * dataset_lengths_ratio_multiplier
|
||||
)
|
||||
else:
|
||||
true_dataset_lengths = dataset_lengths
|
||||
# For train subset, additionally up or down sample languages.
|
||||
sample_probs = self._get_sample_prob(true_dataset_lengths)
|
||||
|
||||
logger.info(
|
||||
"Sample probability by language: %s",
|
||||
{
|
||||
lang: "{0:.4f}".format(sample_probs[id])
|
||||
for id, lang in enumerate(languages)
|
||||
},
|
||||
)
|
||||
size_ratio = (sample_probs * true_dataset_lengths.sum()) / dataset_lengths
|
||||
# TODO: add an option for shrinking all size ratios to below 1
|
||||
# if self.args.multilang_sampling_alpha != 1:
|
||||
# size_ratio /= size_ratio.max()
|
||||
|
||||
# Fix numeric errors in size ratio computation
|
||||
# 0.999999999999999999 -> 1
|
||||
# 1.000000000000000002 -> 1
|
||||
for i in range(len(size_ratio)):
|
||||
size_ratio[i] = round(size_ratio[i], 8)
|
||||
|
||||
logger.info(
|
||||
"Up/Down Sampling ratio by language: %s",
|
||||
{
|
||||
lang: "{0:.2f}".format(size_ratio[id])
|
||||
for id, lang in enumerate(languages)
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
"Actual dataset size by language: %s",
|
||||
{
|
||||
lang: "{0:.2f}".format(len(lang_datasets[id]))
|
||||
for id, lang in enumerate(languages)
|
||||
},
|
||||
)
|
||||
resampled_lang_datasets = [
|
||||
ResamplingDataset(
|
||||
lang_datasets[i],
|
||||
size_ratio=size_ratio[i],
|
||||
seed=self.args.seed,
|
||||
epoch=epoch,
|
||||
replace=size_ratio[i] > 1.0,
|
||||
)
|
||||
for i, d in enumerate(lang_datasets)
|
||||
]
|
||||
logger.info(
|
||||
"Resampled dataset size by language: %s",
|
||||
{
|
||||
lang: "{0:.2f}".format(len(resampled_lang_datasets[id]))
|
||||
for id, lang in enumerate(languages)
|
||||
},
|
||||
)
|
||||
dataset = ConcatDataset(resampled_lang_datasets)
|
||||
else:
|
||||
dataset = ConcatDataset(lang_datasets)
|
||||
lang_splits = [split]
|
||||
for lang_id, lang_dataset in enumerate(lang_datasets):
|
||||
split_name = split + "_" + languages[lang_id]
|
||||
lang_splits.append(split_name)
|
||||
self.datasets[split_name] = lang_dataset
|
||||
|
||||
# [TODO]: This is hacky for now to print validation ppl for each
|
||||
# language individually. Maybe need task API changes to allow it
|
||||
# in more generic ways.
|
||||
if split in self.args.valid_subset:
|
||||
self.args.valid_subset = self.args.valid_subset.replace(
|
||||
split, ",".join(lang_splits)
|
||||
)
|
||||
|
||||
with data_utils.numpy_seed(self.args.seed + epoch):
|
||||
shuffle = np.random.permutation(len(dataset))
|
||||
|
||||
self.datasets[split] = SortDataset(
|
||||
dataset,
|
||||
sort_order=[
|
||||
shuffle,
|
||||
dataset.sizes,
|
||||
],
|
||||
)
|
||||
|
||||
def build_dataset_for_inference(
|
||||
self, src_tokens, src_lengths, language="en_XX", **kwargs
|
||||
):
|
||||
"""
|
||||
Generate batches for inference. We prepend an eos token to src_tokens
|
||||
(or bos if `--add-bos-token` is set) and we append a <pad> to target.
|
||||
This is convenient both for generation with a prefix and LM scoring.
|
||||
"""
|
||||
dataset = StripTokenDataset(
|
||||
TokenBlockDataset(
|
||||
src_tokens,
|
||||
src_lengths,
|
||||
block_size=None, # ignored for "eos" break mode
|
||||
pad=self.source_dictionary.pad(),
|
||||
eos=self.source_dictionary.eos(),
|
||||
break_mode="eos",
|
||||
),
|
||||
# remove eos from (end of) target sequence
|
||||
self.source_dictionary.eos(),
|
||||
)
|
||||
|
||||
src_lang_idx = self.dictionary.index(lang_token(language))
|
||||
src_dataset = PrependTokenDataset(
|
||||
dataset,
|
||||
token=(
|
||||
(src_lang_idx or self.source_dictionary.bos())
|
||||
if getattr(self.args, "add_bos_token", False)
|
||||
else self.source_dictionary.eos()
|
||||
),
|
||||
)
|
||||
|
||||
max_seq_len = max(src_lengths) + 1
|
||||
tgt_dataset = AppendTokenDataset(dataset, token=self.source_dictionary.pad())
|
||||
return NestedDictionaryDataset(
|
||||
{
|
||||
"id": IdDataset(),
|
||||
"net_input": {
|
||||
"src_tokens": PadDataset(
|
||||
src_dataset,
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
left_pad=False,
|
||||
pad_length=max_seq_len,
|
||||
),
|
||||
"src_lengths": NumelDataset(src_dataset, reduce=False),
|
||||
},
|
||||
"target": PadDataset(
|
||||
tgt_dataset,
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
left_pad=False,
|
||||
pad_length=max_seq_len,
|
||||
),
|
||||
},
|
||||
sizes=[np.array(src_lengths)],
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference_step(
|
||||
self,
|
||||
generator,
|
||||
models,
|
||||
sample,
|
||||
language="en_XX",
|
||||
prefix_tokens=None,
|
||||
constraints=None,
|
||||
):
|
||||
# Generation will always be conditioned on bos_token
|
||||
if getattr(self.args, "add_bos_token", False):
|
||||
src_lang_idx = self.dictionary.index(lang_token(language))
|
||||
bos_token = src_lang_idx or self.source_dictionary.bos()
|
||||
else:
|
||||
bos_token = self.source_dictionary.eos()
|
||||
|
||||
if constraints is not None:
|
||||
raise NotImplementedError(
|
||||
"Constrained decoding with the language_modeling task is not supported"
|
||||
)
|
||||
|
||||
# SequenceGenerator doesn't use src_tokens directly, we need to
|
||||
# pass the `prefix_tokens` argument instead
|
||||
if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():
|
||||
prefix_tokens = sample["net_input"]["src_tokens"]
|
||||
if prefix_tokens[:, 0].eq(bos_token).all():
|
||||
prefix_tokens = prefix_tokens[:, 1:]
|
||||
|
||||
return generator.generate(
|
||||
models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token
|
||||
)
|
||||
|
||||
def eval_lm_dataloader(
|
||||
self,
|
||||
dataset,
|
||||
max_tokens: Optional[int] = 36000,
|
||||
batch_size: Optional[int] = None,
|
||||
max_positions: Optional[int] = None,
|
||||
num_shards: int = 1,
|
||||
shard_id: int = 0,
|
||||
num_workers: int = 1,
|
||||
data_buffer_size: int = 10,
|
||||
# ensures that every evaluated token has access to a context of at least
|
||||
# this size, if possible
|
||||
context_window: int = 0,
|
||||
):
|
||||
if context_window > 0:
|
||||
dataset = LMContextWindowDataset(
|
||||
dataset=dataset,
|
||||
tokens_per_sample=self.args.tokens_per_sample,
|
||||
context_window=context_window,
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
)
|
||||
return self.get_batch_iterator(
|
||||
dataset=dataset,
|
||||
max_tokens=max_tokens,
|
||||
max_sentences=batch_size,
|
||||
max_positions=max_positions,
|
||||
ignore_invalid_inputs=True,
|
||||
num_shards=num_shards,
|
||||
shard_id=shard_id,
|
||||
num_workers=num_workers,
|
||||
data_buffer_size=data_buffer_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
"""Return the :class:`~fairseq.data.Dictionary` for the language
|
||||
model."""
|
||||
return self.dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
"""Return the :class:`~fairseq.data.Dictionary` for the language
|
||||
model."""
|
||||
return self.output_dictionary
|
||||
338
modules/voice_conversion/fairseq/tasks/multilingual_masked_lm.py
Normal file
338
modules/voice_conversion/fairseq/tasks/multilingual_masked_lm.py
Normal file
@@ -0,0 +1,338 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.data import (
|
||||
ConcatDataset,
|
||||
Dictionary,
|
||||
IdDataset,
|
||||
MaskTokensDataset,
|
||||
NestedDictionaryDataset,
|
||||
NumelDataset,
|
||||
NumSamplesDataset,
|
||||
PadDataset,
|
||||
PrependTokenDataset,
|
||||
RawLabelDataset,
|
||||
ResamplingDataset,
|
||||
SortDataset,
|
||||
TokenBlockDataset,
|
||||
data_utils,
|
||||
encoders,
|
||||
)
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task("multilingual_masked_lm")
|
||||
class MultiLingualMaskedLMTask(LegacyFairseqTask):
|
||||
"""Task for training masked language models (e.g., BERT, RoBERTa)."""
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
parser.add_argument(
|
||||
"data",
|
||||
help="colon separated path to data directories list, \
|
||||
will be iterated upon during epochs in round-robin manner",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample-break-mode",
|
||||
default="complete",
|
||||
choices=["none", "complete", "complete_doc", "eos"],
|
||||
help='If omitted or "none", fills each sample with tokens-per-sample '
|
||||
'tokens. If set to "complete", splits samples only at the end '
|
||||
"of sentence, but may include multiple sentences per sample. "
|
||||
'"complete_doc" is similar but respects doc boundaries. '
|
||||
'If set to "eos", includes only one sentence per sample.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokens-per-sample",
|
||||
default=512,
|
||||
type=int,
|
||||
help="max number of total tokens over all segments "
|
||||
"per sample for BERT dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask-prob",
|
||||
default=0.15,
|
||||
type=float,
|
||||
help="probability of replacing a token with mask",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--leave-unmasked-prob",
|
||||
default=0.1,
|
||||
type=float,
|
||||
help="probability that a masked token is unmasked",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-token-prob",
|
||||
default=0.1,
|
||||
type=float,
|
||||
help="probability of replacing a token with a random token",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freq-weighted-replacement",
|
||||
action="store_true",
|
||||
help="sample random replacement words based on word frequencies",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask-whole-words",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="mask whole words; you may also want to set --bpe",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multilang-sampling-alpha",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="smoothing alpha for sample rations across multiple datasets",
|
||||
)
|
||||
|
||||
def __init__(self, args, dictionary):
|
||||
super().__init__(args)
|
||||
self.dictionary = dictionary
|
||||
self.seed = args.seed
|
||||
|
||||
# add mask token
|
||||
self.mask_idx = dictionary.add_symbol("<mask>")
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
paths = utils.split_paths(args.data)
|
||||
assert len(paths) > 0
|
||||
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
return cls(args, dictionary)
|
||||
|
||||
def _get_whole_word_mask(self):
|
||||
# create masked input and targets
|
||||
if self.args.mask_whole_words:
|
||||
bpe = encoders.build_bpe(self.args)
|
||||
if bpe is not None:
|
||||
|
||||
def is_beginning_of_word(i):
|
||||
if i < self.source_dictionary.nspecial:
|
||||
# special elements are always considered beginnings
|
||||
return True
|
||||
tok = self.source_dictionary[i]
|
||||
if tok.startswith("madeupword"):
|
||||
return True
|
||||
try:
|
||||
return bpe.is_beginning_of_word(tok)
|
||||
except ValueError:
|
||||
return True
|
||||
|
||||
mask_whole_words = torch.ByteTensor(
|
||||
list(map(is_beginning_of_word, range(len(self.source_dictionary))))
|
||||
)
|
||||
else:
|
||||
mask_whole_words = None
|
||||
return mask_whole_words
|
||||
|
||||
def _get_sample_prob(self, dataset_lens):
|
||||
"""
|
||||
Get smoothed sampling porbability by languages. This helps low resource
|
||||
languages by upsampling them.
|
||||
"""
|
||||
prob = dataset_lens / dataset_lens.sum()
|
||||
smoothed_prob = prob**self.args.multilang_sampling_alpha
|
||||
smoothed_prob = smoothed_prob / smoothed_prob.sum()
|
||||
return smoothed_prob
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
paths = utils.split_paths(self.args.data)
|
||||
assert len(paths) > 0
|
||||
data_path = paths[(epoch - 1) % len(paths)]
|
||||
|
||||
languages = sorted(
|
||||
name
|
||||
for name in os.listdir(data_path)
|
||||
if os.path.isdir(os.path.join(data_path, name))
|
||||
)
|
||||
|
||||
logger.info("Training on {0} languages: {1}".format(len(languages), languages))
|
||||
logger.info(
|
||||
"Language to id mapping: ", {lang: id for id, lang in enumerate(languages)}
|
||||
)
|
||||
|
||||
mask_whole_words = self._get_whole_word_mask()
|
||||
lang_datasets = []
|
||||
for lang_id, language in enumerate(languages):
|
||||
split_path = os.path.join(data_path, language, split)
|
||||
|
||||
dataset = data_utils.load_indexed_dataset(
|
||||
split_path,
|
||||
self.source_dictionary,
|
||||
self.args.dataset_impl,
|
||||
combine=combine,
|
||||
)
|
||||
if dataset is None:
|
||||
raise FileNotFoundError(
|
||||
"Dataset not found: {} ({})".format(split, split_path)
|
||||
)
|
||||
|
||||
# create continuous blocks of tokens
|
||||
dataset = TokenBlockDataset(
|
||||
dataset,
|
||||
dataset.sizes,
|
||||
self.args.tokens_per_sample - 1, # one less for <s>
|
||||
pad=self.source_dictionary.pad(),
|
||||
eos=self.source_dictionary.eos(),
|
||||
break_mode=self.args.sample_break_mode,
|
||||
)
|
||||
logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))
|
||||
|
||||
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
|
||||
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
|
||||
|
||||
src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
|
||||
dataset,
|
||||
self.source_dictionary,
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
mask_idx=self.mask_idx,
|
||||
seed=self.args.seed,
|
||||
mask_prob=self.args.mask_prob,
|
||||
leave_unmasked_prob=self.args.leave_unmasked_prob,
|
||||
random_token_prob=self.args.random_token_prob,
|
||||
freq_weighted_replacement=self.args.freq_weighted_replacement,
|
||||
mask_whole_words=mask_whole_words,
|
||||
)
|
||||
|
||||
lang_dataset = NestedDictionaryDataset(
|
||||
{
|
||||
"net_input": {
|
||||
"src_tokens": PadDataset(
|
||||
src_dataset,
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
left_pad=False,
|
||||
),
|
||||
"src_lengths": NumelDataset(src_dataset, reduce=False),
|
||||
},
|
||||
"target": PadDataset(
|
||||
tgt_dataset,
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
left_pad=False,
|
||||
),
|
||||
"nsentences": NumSamplesDataset(),
|
||||
"ntokens": NumelDataset(src_dataset, reduce=True),
|
||||
"lang_id": RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]),
|
||||
},
|
||||
sizes=[src_dataset.sizes],
|
||||
)
|
||||
lang_datasets.append(lang_dataset)
|
||||
|
||||
dataset_lengths = np.array(
|
||||
[len(d) for d in lang_datasets],
|
||||
dtype=float,
|
||||
)
|
||||
logger.info(
|
||||
"loaded total {} blocks for all languages".format(
|
||||
dataset_lengths.sum(),
|
||||
)
|
||||
)
|
||||
if split == self.args.train_subset:
|
||||
# For train subset, additionally up or down sample languages.
|
||||
sample_probs = self._get_sample_prob(dataset_lengths)
|
||||
logger.info(
|
||||
"Sample probability by language: ",
|
||||
{
|
||||
lang: "{0:.4f}".format(sample_probs[id])
|
||||
for id, lang in enumerate(languages)
|
||||
},
|
||||
)
|
||||
size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths
|
||||
logger.info(
|
||||
"Up/Down Sampling ratio by language: ",
|
||||
{
|
||||
lang: "{0:.2f}".format(size_ratio[id])
|
||||
for id, lang in enumerate(languages)
|
||||
},
|
||||
)
|
||||
|
||||
resampled_lang_datasets = [
|
||||
ResamplingDataset(
|
||||
lang_datasets[i],
|
||||
size_ratio=size_ratio[i],
|
||||
seed=self.args.seed,
|
||||
epoch=epoch,
|
||||
replace=size_ratio[i] >= 1.0,
|
||||
)
|
||||
for i, d in enumerate(lang_datasets)
|
||||
]
|
||||
dataset = ConcatDataset(resampled_lang_datasets)
|
||||
else:
|
||||
dataset = ConcatDataset(lang_datasets)
|
||||
lang_splits = [split]
|
||||
for lang_id, lang_dataset in enumerate(lang_datasets):
|
||||
split_name = split + "_" + languages[lang_id]
|
||||
lang_splits.append(split_name)
|
||||
self.datasets[split_name] = lang_dataset
|
||||
|
||||
# [TODO]: This is hacky for now to print validation ppl for each
|
||||
# language individually. Maybe need task API changes to allow it
|
||||
# in more generic ways.
|
||||
if split in self.args.valid_subset:
|
||||
self.args.valid_subset = self.args.valid_subset.replace(
|
||||
split, ",".join(lang_splits)
|
||||
)
|
||||
|
||||
with data_utils.numpy_seed(self.args.seed + epoch):
|
||||
shuffle = np.random.permutation(len(dataset))
|
||||
|
||||
self.datasets[split] = SortDataset(
|
||||
dataset,
|
||||
sort_order=[
|
||||
shuffle,
|
||||
dataset.sizes,
|
||||
],
|
||||
)
|
||||
|
||||
def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True):
|
||||
src_dataset = PadDataset(
|
||||
TokenBlockDataset(
|
||||
src_tokens,
|
||||
src_lengths,
|
||||
self.args.tokens_per_sample - 1, # one less for <s>
|
||||
pad=self.source_dictionary.pad(),
|
||||
eos=self.source_dictionary.eos(),
|
||||
break_mode="eos",
|
||||
),
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
left_pad=False,
|
||||
)
|
||||
src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos())
|
||||
src_dataset = NestedDictionaryDataset(
|
||||
{
|
||||
"id": IdDataset(),
|
||||
"net_input": {
|
||||
"src_tokens": src_dataset,
|
||||
"src_lengths": NumelDataset(src_dataset, reduce=False),
|
||||
},
|
||||
},
|
||||
sizes=src_lengths,
|
||||
)
|
||||
if sort:
|
||||
src_dataset = SortDataset(src_dataset, sort_order=[src_lengths])
|
||||
return src_dataset
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.dictionary
|
||||
@@ -0,0 +1,462 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from argparse import ArgumentError
|
||||
|
||||
import torch
|
||||
from fairseq import metrics, options, utils
|
||||
from fairseq.data import (
|
||||
Dictionary,
|
||||
LanguagePairDataset,
|
||||
RoundRobinZipDatasets,
|
||||
TransformEosLangPairDataset,
|
||||
)
|
||||
from fairseq.models import FairseqMultiModel
|
||||
from fairseq.tasks.translation import load_langpair_dataset
|
||||
|
||||
from . import LegacyFairseqTask, register_task
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _lang_token(lang: str):
|
||||
return "__{}__".format(lang)
|
||||
|
||||
|
||||
def _lang_token_index(dic: Dictionary, lang: str):
|
||||
"""Return language token index."""
|
||||
idx = dic.index(_lang_token(lang))
|
||||
assert idx != dic.unk_index, "cannot find language token for lang {}".format(lang)
|
||||
return idx
|
||||
|
||||
|
||||
@register_task("multilingual_translation")
|
||||
class MultilingualTranslationTask(LegacyFairseqTask):
|
||||
"""A task for training multiple translation models simultaneously.
|
||||
|
||||
We iterate round-robin over batches from multiple language pairs, ordered
|
||||
according to the `--lang-pairs` argument.
|
||||
|
||||
The training loop is roughly:
|
||||
|
||||
for i in range(len(epoch)):
|
||||
for lang_pair in args.lang_pairs:
|
||||
batch = next_batch_for_lang_pair(lang_pair)
|
||||
loss = criterion(model_for_lang_pair(lang_pair), batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
In practice, `next_batch_for_lang_pair` is abstracted in a FairseqDataset
|
||||
(e.g., `RoundRobinZipDatasets`) and `model_for_lang_pair` is a model that
|
||||
implements the `FairseqMultiModel` interface.
|
||||
|
||||
During inference it is required to specify a single `--source-lang` and
|
||||
`--target-lang`, which indicates the inference langauge direction.
|
||||
`--lang-pairs`, `--encoder-langtok`, `--decoder-langtok` have to be set to
|
||||
the same value as training.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
parser.add_argument('data', metavar='DIR', help='path to data directory')
|
||||
parser.add_argument('--lang-pairs', default=None, metavar='PAIRS',
|
||||
help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr')
|
||||
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
|
||||
help='source language (only needed for inference)')
|
||||
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
|
||||
help='target language (only needed for inference)')
|
||||
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
|
||||
help='pad the source on the left (default: True)')
|
||||
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
|
||||
help='pad the target on the left (default: False)')
|
||||
try:
|
||||
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
|
||||
help='max number of tokens in the source sequence')
|
||||
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
|
||||
help='max number of tokens in the target sequence')
|
||||
except ArgumentError:
|
||||
# this might have already been defined. Once we transition this to hydra it should be fine to add it here.
|
||||
pass
|
||||
parser.add_argument('--upsample-primary', default=1, type=int,
|
||||
help='amount to upsample primary dataset')
|
||||
parser.add_argument('--encoder-langtok', default=None, type=str, choices=['src', 'tgt'],
|
||||
metavar='SRCTGT',
|
||||
help='replace beginning-of-sentence in source sentence with source or target '
|
||||
'language token. (src/tgt)')
|
||||
parser.add_argument('--decoder-langtok', action='store_true',
|
||||
help='replace beginning-of-sentence in target sentence with target language token')
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, args, dicts, training):
|
||||
super().__init__(args)
|
||||
self.dicts = dicts
|
||||
self.training = training
|
||||
if training:
|
||||
self.lang_pairs = args.lang_pairs
|
||||
else:
|
||||
self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)]
|
||||
# eval_lang_pairs for multilingual translation is usually all of the
|
||||
# lang_pairs. However for other multitask settings or when we want to
|
||||
# optimize for certain languages we want to use a different subset. Thus
|
||||
# the eval_lang_pairs class variable is provided for classes that extend
|
||||
# this class.
|
||||
self.eval_lang_pairs = self.lang_pairs
|
||||
# model_lang_pairs will be used to build encoder-decoder model pairs in
|
||||
# models.build_model(). This allows multitask type of sub-class can
|
||||
# build models other than the input lang_pairs
|
||||
self.model_lang_pairs = self.lang_pairs
|
||||
self.langs = list(dicts.keys())
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
dicts, training = cls.prepare(args, **kwargs)
|
||||
return cls(args, dicts, training)
|
||||
|
||||
@classmethod
|
||||
def update_args(cls, args):
|
||||
args.left_pad_source = utils.eval_bool(args.left_pad_source)
|
||||
args.left_pad_target = utils.eval_bool(args.left_pad_target)
|
||||
|
||||
if args.lang_pairs is None:
|
||||
raise ValueError(
|
||||
"--lang-pairs is required. List all the language pairs in the training objective."
|
||||
)
|
||||
if isinstance(args.lang_pairs, str):
|
||||
args.lang_pairs = args.lang_pairs.split(",")
|
||||
|
||||
@classmethod
|
||||
def prepare(cls, args, **kargs):
|
||||
cls.update_args(args)
|
||||
sorted_langs = sorted(
|
||||
list({x for lang_pair in args.lang_pairs for x in lang_pair.split("-")})
|
||||
)
|
||||
if args.source_lang is not None or args.target_lang is not None:
|
||||
training = False
|
||||
else:
|
||||
training = True
|
||||
|
||||
# load dictionaries
|
||||
dicts = OrderedDict()
|
||||
for lang in sorted_langs:
|
||||
paths = utils.split_paths(args.data)
|
||||
assert len(paths) > 0
|
||||
dicts[lang] = cls.load_dictionary(
|
||||
os.path.join(paths[0], "dict.{}.txt".format(lang))
|
||||
)
|
||||
if len(dicts) > 0:
|
||||
assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
|
||||
assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
|
||||
assert dicts[lang].unk() == dicts[sorted_langs[0]].unk()
|
||||
if args.encoder_langtok is not None or args.decoder_langtok:
|
||||
for lang_to_add in sorted_langs:
|
||||
dicts[lang].add_symbol(_lang_token(lang_to_add))
|
||||
logger.info("[{}] dictionary: {} types".format(lang, len(dicts[lang])))
|
||||
return dicts, training
|
||||
|
||||
def get_encoder_langtok(self, src_lang, tgt_lang):
|
||||
if self.args.encoder_langtok is None:
|
||||
return self.dicts[src_lang].eos()
|
||||
if self.args.encoder_langtok == "src":
|
||||
return _lang_token_index(self.dicts[src_lang], src_lang)
|
||||
else:
|
||||
return _lang_token_index(self.dicts[src_lang], tgt_lang)
|
||||
|
||||
def get_decoder_langtok(self, tgt_lang):
|
||||
if not self.args.decoder_langtok:
|
||||
return self.dicts[tgt_lang].eos()
|
||||
return _lang_token_index(self.dicts[tgt_lang], tgt_lang)
|
||||
|
||||
def alter_dataset_langtok(
|
||||
self,
|
||||
lang_pair_dataset,
|
||||
src_eos=None,
|
||||
src_lang=None,
|
||||
tgt_eos=None,
|
||||
tgt_lang=None,
|
||||
):
|
||||
if self.args.encoder_langtok is None and not self.args.decoder_langtok:
|
||||
return lang_pair_dataset
|
||||
|
||||
new_src_eos = None
|
||||
if (
|
||||
self.args.encoder_langtok is not None
|
||||
and src_eos is not None
|
||||
and src_lang is not None
|
||||
and tgt_lang is not None
|
||||
):
|
||||
new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang)
|
||||
else:
|
||||
src_eos = None
|
||||
|
||||
new_tgt_bos = None
|
||||
if self.args.decoder_langtok and tgt_eos is not None and tgt_lang is not None:
|
||||
new_tgt_bos = self.get_decoder_langtok(tgt_lang)
|
||||
else:
|
||||
tgt_eos = None
|
||||
|
||||
return TransformEosLangPairDataset(
|
||||
lang_pair_dataset,
|
||||
src_eos=src_eos,
|
||||
new_src_eos=new_src_eos,
|
||||
tgt_bos=tgt_eos,
|
||||
new_tgt_bos=new_tgt_bos,
|
||||
)
|
||||
|
||||
def load_dataset(self, split, epoch=1, **kwargs):
|
||||
"""Load a dataset split."""
|
||||
paths = utils.split_paths(self.args.data)
|
||||
assert len(paths) > 0
|
||||
data_path = paths[(epoch - 1) % len(paths)]
|
||||
|
||||
def language_pair_dataset(lang_pair):
|
||||
src, tgt = lang_pair.split("-")
|
||||
langpair_dataset = load_langpair_dataset(
|
||||
data_path,
|
||||
split,
|
||||
src,
|
||||
self.dicts[src],
|
||||
tgt,
|
||||
self.dicts[tgt],
|
||||
combine=True,
|
||||
dataset_impl=self.args.dataset_impl,
|
||||
upsample_primary=self.args.upsample_primary,
|
||||
left_pad_source=self.args.left_pad_source,
|
||||
left_pad_target=self.args.left_pad_target,
|
||||
max_source_positions=self.args.max_source_positions,
|
||||
max_target_positions=self.args.max_target_positions,
|
||||
)
|
||||
return self.alter_dataset_langtok(
|
||||
langpair_dataset,
|
||||
src_eos=self.dicts[src].eos(),
|
||||
src_lang=src,
|
||||
tgt_eos=self.dicts[tgt].eos(),
|
||||
tgt_lang=tgt,
|
||||
)
|
||||
|
||||
self.datasets[split] = RoundRobinZipDatasets(
|
||||
OrderedDict(
|
||||
[
|
||||
(lang_pair, language_pair_dataset(lang_pair))
|
||||
for lang_pair in self.lang_pairs
|
||||
]
|
||||
),
|
||||
eval_key=None
|
||||
if self.training
|
||||
else "%s-%s" % (self.args.source_lang, self.args.target_lang),
|
||||
)
|
||||
|
||||
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
|
||||
if constraints is not None:
|
||||
raise NotImplementedError(
|
||||
"Constrained decoding with the multilingual_translation task is not supported"
|
||||
)
|
||||
|
||||
lang_pair = "%s-%s" % (self.args.source_lang, self.args.target_lang)
|
||||
return RoundRobinZipDatasets(
|
||||
OrderedDict(
|
||||
[
|
||||
(
|
||||
lang_pair,
|
||||
self.alter_dataset_langtok(
|
||||
LanguagePairDataset(
|
||||
src_tokens, src_lengths, self.source_dictionary
|
||||
),
|
||||
src_eos=self.source_dictionary.eos(),
|
||||
src_lang=self.args.source_lang,
|
||||
tgt_eos=self.target_dictionary.eos(),
|
||||
tgt_lang=self.args.target_lang,
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
eval_key=lang_pair,
|
||||
)
|
||||
|
||||
def build_model(self, args, from_checkpoint=False):
|
||||
def check_args():
|
||||
messages = []
|
||||
if (
|
||||
len(set(self.args.lang_pairs).symmetric_difference(args.lang_pairs))
|
||||
!= 0
|
||||
):
|
||||
messages.append(
|
||||
"--lang-pairs should include all the language pairs {}.".format(
|
||||
args.lang_pairs
|
||||
)
|
||||
)
|
||||
if self.args.encoder_langtok != args.encoder_langtok:
|
||||
messages.append(
|
||||
"--encoder-langtok should be {}.".format(args.encoder_langtok)
|
||||
)
|
||||
if self.args.decoder_langtok != args.decoder_langtok:
|
||||
messages.append(
|
||||
"--decoder-langtok should {} be set.".format(
|
||||
"" if args.decoder_langtok else "not"
|
||||
)
|
||||
)
|
||||
|
||||
if len(messages) > 0:
|
||||
raise ValueError(" ".join(messages))
|
||||
|
||||
# Update args -> the fact that the constructor here
|
||||
# changes the args object doesn't mean you get the same one here
|
||||
self.update_args(args)
|
||||
|
||||
# Check if task args are consistant with model args
|
||||
check_args()
|
||||
|
||||
from fairseq import models
|
||||
|
||||
model = models.build_model(args, self, from_checkpoint)
|
||||
if not isinstance(model, FairseqMultiModel):
|
||||
raise ValueError(
|
||||
"MultilingualTranslationTask requires a FairseqMultiModel architecture"
|
||||
)
|
||||
return model
|
||||
|
||||
def _per_lang_pair_train_loss(
|
||||
self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad
|
||||
):
|
||||
loss, sample_size, logging_output = criterion(
|
||||
model.models[lang_pair], sample[lang_pair]
|
||||
)
|
||||
if ignore_grad:
|
||||
loss *= 0
|
||||
optimizer.backward(loss)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def train_step(
|
||||
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
|
||||
):
|
||||
model.train()
|
||||
from collections import defaultdict
|
||||
|
||||
agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, defaultdict(float)
|
||||
curr_lang_pairs = [
|
||||
lang_pair
|
||||
for lang_pair in self.model_lang_pairs
|
||||
if sample[lang_pair] is not None and len(sample[lang_pair]) != 0
|
||||
]
|
||||
|
||||
for idx, lang_pair in enumerate(curr_lang_pairs):
|
||||
|
||||
def maybe_no_sync():
|
||||
if (
|
||||
self.args.distributed_world_size > 1
|
||||
and hasattr(model, "no_sync")
|
||||
and idx < len(curr_lang_pairs) - 1
|
||||
):
|
||||
return model.no_sync()
|
||||
else:
|
||||
return contextlib.ExitStack() # dummy contextmanager
|
||||
|
||||
with maybe_no_sync():
|
||||
loss, sample_size, logging_output = self._per_lang_pair_train_loss(
|
||||
lang_pair,
|
||||
model,
|
||||
update_num,
|
||||
criterion,
|
||||
sample,
|
||||
optimizer,
|
||||
ignore_grad,
|
||||
)
|
||||
agg_loss += loss.detach().item()
|
||||
# TODO make summing of the sample sizes configurable
|
||||
agg_sample_size += sample_size
|
||||
for k in logging_output:
|
||||
agg_logging_output[k] += logging_output[k]
|
||||
agg_logging_output[f"{lang_pair}:{k}"] += logging_output[k]
|
||||
return agg_loss, agg_sample_size, agg_logging_output
|
||||
|
||||
def _per_lang_pair_valid_loss(self, lang_pair, model, criterion, sample):
|
||||
return criterion(model.models[lang_pair], sample[lang_pair])
|
||||
|
||||
def valid_step(self, sample, model, criterion):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
from collections import defaultdict
|
||||
|
||||
agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, defaultdict(float)
|
||||
for lang_pair in self.eval_lang_pairs:
|
||||
if (
|
||||
lang_pair not in sample
|
||||
or sample[lang_pair] is None
|
||||
or len(sample[lang_pair]) == 0
|
||||
):
|
||||
continue
|
||||
loss, sample_size, logging_output = self._per_lang_pair_valid_loss(
|
||||
lang_pair, model, criterion, sample
|
||||
)
|
||||
agg_loss += loss.data.item()
|
||||
# TODO make summing of the sample sizes configurable
|
||||
agg_sample_size += sample_size
|
||||
for k in logging_output:
|
||||
agg_logging_output[k] += logging_output[k]
|
||||
agg_logging_output[f"{lang_pair}:{k}"] += logging_output[k]
|
||||
return agg_loss, agg_sample_size, agg_logging_output
|
||||
|
||||
def inference_step(
|
||||
self, generator, models, sample, prefix_tokens=None, constraints=None
|
||||
):
|
||||
with torch.no_grad():
|
||||
if self.args.decoder_langtok:
|
||||
bos_token = _lang_token_index(
|
||||
self.target_dictionary, self.args.target_lang
|
||||
)
|
||||
else:
|
||||
bos_token = self.target_dictionary.eos()
|
||||
return generator.generate(
|
||||
models,
|
||||
sample,
|
||||
prefix_tokens=prefix_tokens,
|
||||
constraints=constraints,
|
||||
bos_token=bos_token,
|
||||
)
|
||||
|
||||
def reduce_metrics(self, logging_outputs, criterion):
|
||||
with metrics.aggregate():
|
||||
# pass 'sample_size', 'nsentences', 'ntokens' stats to fairseq_task
|
||||
super().reduce_metrics(logging_outputs, criterion)
|
||||
for k in ["sample_size", "nsentences", "ntokens"]:
|
||||
metrics.log_scalar(k, sum(l[k] for l in logging_outputs))
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
if self.training:
|
||||
return next(iter(self.dicts.values()))
|
||||
else:
|
||||
return self.dicts[self.args.source_lang]
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
if self.training:
|
||||
return next(iter(self.dicts.values()))
|
||||
else:
|
||||
return self.dicts[self.args.target_lang]
|
||||
|
||||
def max_positions(self):
|
||||
"""Return the max sentence length allowed by the task."""
|
||||
if len(self.datasets.values()) == 0:
|
||||
return {
|
||||
"%s-%s"
|
||||
% (self.args.source_lang, self.args.target_lang): (
|
||||
self.args.max_source_positions,
|
||||
self.args.max_target_positions,
|
||||
)
|
||||
}
|
||||
return OrderedDict(
|
||||
[
|
||||
(key, (self.args.max_source_positions, self.args.max_target_positions))
|
||||
for split in self.datasets.keys()
|
||||
for key in self.datasets[split].datasets.keys()
|
||||
]
|
||||
)
|
||||
477
modules/voice_conversion/fairseq/tasks/nlu_finetuning.py
Normal file
477
modules/voice_conversion/fairseq/tasks/nlu_finetuning.py
Normal file
@@ -0,0 +1,477 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Any
|
||||
|
||||
from fairseq.data import AddTargetDataset, Dictionary, encoders
|
||||
from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.dataclass.configs import GenerationConfig
|
||||
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
|
||||
|
||||
from . import register_task
|
||||
from .. import utils
|
||||
from ..logging import metrics
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LabelEncoder(object):
|
||||
def __init__(self, dictionary):
|
||||
self.dictionary = dictionary
|
||||
|
||||
def __call__(self, label):
|
||||
return self.dictionary.encode_line(
|
||||
label, append_eos=False, add_if_not_exist=False
|
||||
)
|
||||
|
||||
|
||||
def label_len_fn(label):
|
||||
return len(label.split(" "))
|
||||
|
||||
|
||||
@dataclass
|
||||
class NLUFinetuningConfig(AudioPretrainingConfig):
|
||||
# Options for reporting WER metrics during validation. Only applicable to
|
||||
# Seq2Seq models during fine-tuning
|
||||
eval_wer: bool = field(
|
||||
default=False, metadata={"help": "compute WER for Seq2Seq models"}
|
||||
)
|
||||
eval_wer_parse: bool = field(
|
||||
default=False, metadata={"help": "compute WER for Seq2Seq models"}
|
||||
)
|
||||
eval_wer_config: GenerationConfig = field(
|
||||
default_factory=lambda: GenerationConfig(),
|
||||
metadata={"help": "beam search config for evaluating wer during training"},
|
||||
)
|
||||
eval_wer_tokenizer: Any = field(
|
||||
default=None,
|
||||
metadata={"help": "tokenizer config for evaluating wer during training"},
|
||||
)
|
||||
eval_wer_post_process: str = field(
|
||||
default="letter",
|
||||
metadata={
|
||||
"help": "remove BPE tokens before scoring (can be sentencepiece, letter, and more)"
|
||||
},
|
||||
)
|
||||
eval_bleu: bool = field(
|
||||
default=False, metadata={"help": "evaluation with BLEU scores"}
|
||||
)
|
||||
eval_bleu_detok: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "detokenize before computing BLEU (e.g., 'moses'); "
|
||||
"required if using --eval-bleu; use 'space' to disable "
|
||||
"detokenization; see fairseq.data.encoders for other options"
|
||||
},
|
||||
)
|
||||
eval_bleu_detok_args: str = field(
|
||||
default="{}", metadata={"help": "args for building the tokenizer, if needed"}
|
||||
)
|
||||
eval_tokenized_bleu: bool = field(
|
||||
default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"}
|
||||
)
|
||||
eval_bleu_remove_bpe: Optional[str] = field(
|
||||
default=None, metadata={"help": "remove BPE before computing BLEU"}
|
||||
)
|
||||
eval_bleu_args: str = field(
|
||||
default="{}",
|
||||
metadata={
|
||||
"help": "generation args for BLUE scoring, e.g., "
|
||||
'\'{"beam": 4, "lenpen": 0.6}\''
|
||||
},
|
||||
)
|
||||
eval_bleu_print_samples: bool = field(
|
||||
default=False, metadata={"help": "print sample generations during validation"}
|
||||
)
|
||||
autoregressive: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "required for autoregressive decoders (like seq2seq models); "
|
||||
"adds 'prev_output_tokens' to input and appends eos to target"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@register_task("nlu_finetuning", dataclass=NLUFinetuningConfig)
|
||||
class NLUFinetuningTask(AudioPretrainingTask):
|
||||
""" """
|
||||
|
||||
cfg: NLUFinetuningConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: NLUFinetuningConfig,
|
||||
):
|
||||
super().__init__(cfg)
|
||||
self.blank_symbol = "<s>"
|
||||
|
||||
self.state.add_factory("target_dictionary", self.load_target_dictionary)
|
||||
|
||||
def load_target_dictionary(self):
|
||||
if self.cfg.labels:
|
||||
dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt")
|
||||
return Dictionary.load(dict_path)
|
||||
return None
|
||||
|
||||
def load_dataset(self, split: str, task_cfg: NLUFinetuningConfig = None, **kwargs):
|
||||
super().load_dataset(split, task_cfg, **kwargs)
|
||||
|
||||
task_cfg = task_cfg or self.cfg
|
||||
assert task_cfg.labels is not None
|
||||
text_compression_level = getattr(
|
||||
TextCompressionLevel, str(self.cfg.text_compression_level)
|
||||
)
|
||||
data_path = self.cfg.data
|
||||
label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
|
||||
skipped_indices = getattr(self.datasets[split], "skipped_indices", set())
|
||||
text_compressor = TextCompressor(level=text_compression_level)
|
||||
with open(label_path, "r") as f:
|
||||
labels = [
|
||||
text_compressor.compress(l)
|
||||
for i, l in enumerate(f)
|
||||
if i not in skipped_indices
|
||||
]
|
||||
|
||||
assert len(labels) == len(self.datasets[split]), (
|
||||
f"labels length ({len(labels)}) and dataset length "
|
||||
f"({len(self.datasets[split])}) do not match"
|
||||
)
|
||||
|
||||
process_label = LabelEncoder(self.target_dictionary)
|
||||
|
||||
self.datasets[split] = AddTargetDataset(
|
||||
self.datasets[split],
|
||||
labels,
|
||||
pad=self.target_dictionary.pad(),
|
||||
eos=self.target_dictionary.eos(),
|
||||
batch_targets=True,
|
||||
process_label=process_label,
|
||||
label_len_fn=label_len_fn,
|
||||
add_to_input=task_cfg.get("autoregressive", False),
|
||||
text_compression_level=text_compression_level,
|
||||
)
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
"""Return the :class:`~fairseq.data.Dictionary` for the language
|
||||
model."""
|
||||
return self.state.target_dictionary
|
||||
|
||||
def valid_step(self, sample, model, criterion):
|
||||
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
|
||||
if self.cfg.eval_wer_parse and self.cfg.autoregressive:
|
||||
metrics = self._inference_with_wer_parse(
|
||||
self.sequence_generator, sample, model
|
||||
)
|
||||
logging_output["_num_char_errors"] = metrics["num_char_errors"]
|
||||
logging_output["_num_chars"] = metrics["num_chars"]
|
||||
logging_output["_num_word_errors"] = metrics["num_word_errors"]
|
||||
logging_output["_num_words"] = metrics["num_words"]
|
||||
logging_output["_num_em_errors"] = metrics["num_em_errors"]
|
||||
logging_output["_num_ems"] = metrics["num_ems"]
|
||||
logging_output["_num_tree_errors"] = metrics["num_tree_errors"]
|
||||
logging_output["_num_trees"] = metrics["num_trees"]
|
||||
if self.cfg.eval_wer and self.cfg.autoregressive:
|
||||
metrics = self._inference_with_wer(self.sequence_generator, sample, model)
|
||||
logging_output["_num_char_errors"] = metrics["num_char_errors"]
|
||||
logging_output["_num_chars"] = metrics["num_chars"]
|
||||
logging_output["_num_word_errors"] = metrics["num_word_errors"]
|
||||
logging_output["_num_words"] = metrics["num_words"]
|
||||
if self.cfg.eval_bleu and self.cfg.autoregressive:
|
||||
metrics = self._inference_with_bleu(self.sequence_generator, sample, model)
|
||||
logging_output["_bleu_sys_len"] = metrics.sys_len
|
||||
logging_output["_bleu_ref_len"] = metrics.ref_len
|
||||
# we split counts into separate entries so that they can be
|
||||
# summed efficiently across workers using fast-stat-sync
|
||||
assert len(metrics.counts) == 4
|
||||
for i in range(4):
|
||||
logging_output[f"_bleu_counts_{i}"] = metrics.counts[i]
|
||||
logging_output[f"_bleu_totals_{i}"] = metrics.totals[i]
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def build_model(self, model_cfg: FairseqDataclass):
|
||||
model = super().build_model(model_cfg)
|
||||
|
||||
if (self.cfg.eval_wer or self.cfg.eval_wer_parse) and self.cfg.autoregressive:
|
||||
self.sequence_generator = self.build_generator(
|
||||
[model],
|
||||
self.cfg.eval_wer_config,
|
||||
)
|
||||
if self.cfg.eval_wer_tokenizer:
|
||||
self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer)
|
||||
else:
|
||||
self.tokenizer = None
|
||||
if self.cfg.eval_bleu and self.cfg.autoregressive:
|
||||
assert self.cfg.eval_bleu_detok is not None, (
|
||||
"--eval-bleu-detok is required if using --eval-bleu; "
|
||||
"try --eval-bleu-detok=moses (or --eval-bleu-detok=space "
|
||||
"to disable detokenization, e.g., when using sentencepiece)"
|
||||
)
|
||||
detok_args = json.loads(self.cfg.eval_bleu_detok_args)
|
||||
self.tokenizer = encoders.build_tokenizer(
|
||||
Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args)
|
||||
)
|
||||
gen_args = json.loads(self.cfg.eval_bleu_args)
|
||||
gen_args = Namespace(**gen_args)
|
||||
self.sequence_generator = self.build_generator([model], gen_args)
|
||||
|
||||
return model
|
||||
|
||||
def _inference_with_wer_parse(self, generator, sample, model):
|
||||
import editdistance
|
||||
|
||||
def decode(toks):
|
||||
s = self.target_dictionary.string(
|
||||
toks.int().cpu(),
|
||||
self.cfg.eval_wer_post_process,
|
||||
escape_unk=True,
|
||||
)
|
||||
if self.tokenizer:
|
||||
s = self.tokenizer.decode(s)
|
||||
return s
|
||||
|
||||
def decode_to_list(toks):
|
||||
def token_string(i):
|
||||
if i == self.target_dictionary.unk():
|
||||
return self.target_dictionary.unk_string(False)
|
||||
else:
|
||||
return self.target_dictionary[i]
|
||||
|
||||
return [token_string(i) for i in toks]
|
||||
|
||||
def is_ont_token(token):
|
||||
return "[" in token or "]" in token
|
||||
|
||||
def post_process(l):
|
||||
o = []
|
||||
for w in l:
|
||||
if w == self.target_dictionary.eos_word or w == "|":
|
||||
continue
|
||||
if w == "_":
|
||||
o.append(" ")
|
||||
else:
|
||||
o.append(w)
|
||||
if is_ont_token(w):
|
||||
o.append(" ")
|
||||
return o
|
||||
|
||||
num_word_errors, num_char_errors = 0, 0
|
||||
num_chars, num_words = 0, 0
|
||||
num_em_errors, num_ems = 0, 0
|
||||
num_tree_errors, num_trees = 0, 0
|
||||
gen_out = self.inference_step(generator, [model], sample, None)
|
||||
for i in range(len(gen_out)):
|
||||
hyp_tokens = gen_out[i][0]["tokens"]
|
||||
# hyp = decode(hyp_tokens)
|
||||
ref_tokens = utils.strip_pad(
|
||||
sample["target"][i], self.target_dictionary.pad()
|
||||
)
|
||||
# ref = decode(ref_tokens)
|
||||
hyp_list = decode_to_list(hyp_tokens)
|
||||
ref_list = decode_to_list(ref_tokens)
|
||||
|
||||
hyp_list = post_process(hyp_list)
|
||||
ref_list = post_process(ref_list)
|
||||
|
||||
hyp = "".join(hyp_list).strip()
|
||||
ref = "".join(ref_list).strip()
|
||||
num_chars += len(ref)
|
||||
num_char_errors += editdistance.eval(hyp, ref)
|
||||
hyp_words = hyp.split()
|
||||
ref_words = ref.split()
|
||||
hyp_tree = [word for word in hyp_list if ("[" in word or "]" in word)]
|
||||
ref_tree = [word for word in ref_list if ("[" in word or "]" in word)]
|
||||
# num_word_errors += editdistance.eval(hyp_words, ref_words)
|
||||
hyp_before = decode(hyp_tokens).split()
|
||||
ref_before = decode(ref_tokens).split()
|
||||
|
||||
num_word_errors += editdistance.eval(hyp_before, ref_before)
|
||||
num_words += len(ref_before)
|
||||
if hyp != ref:
|
||||
num_em_errors += 1
|
||||
if hyp_tree != ref_tree:
|
||||
num_tree_errors += 1
|
||||
num_ems += 1
|
||||
num_trees += 1
|
||||
|
||||
return {
|
||||
"num_char_errors": num_char_errors,
|
||||
"num_chars": num_chars,
|
||||
"num_word_errors": num_word_errors,
|
||||
"num_words": num_words,
|
||||
"num_ems": num_ems,
|
||||
"num_em_errors": num_em_errors,
|
||||
"num_trees": num_trees,
|
||||
"num_tree_errors": num_tree_errors,
|
||||
}
|
||||
|
||||
def _inference_with_wer(self, generator, sample, model):
|
||||
import editdistance
|
||||
|
||||
def decode(toks):
|
||||
s = self.target_dictionary.string(
|
||||
toks.int().cpu(),
|
||||
self.cfg.eval_wer_post_process,
|
||||
escape_unk=True,
|
||||
)
|
||||
if self.tokenizer:
|
||||
s = self.tokenizer.decode(s)
|
||||
return s
|
||||
|
||||
num_word_errors, num_char_errors = 0, 0
|
||||
num_chars, num_words = 0, 0
|
||||
gen_out = self.inference_step(generator, [model], sample, None)
|
||||
for i in range(len(gen_out)):
|
||||
hyp = decode(gen_out[i][0]["tokens"])
|
||||
ref = decode(
|
||||
utils.strip_pad(sample["target"][i], self.target_dictionary.pad()),
|
||||
)
|
||||
num_char_errors += editdistance.eval(hyp, ref)
|
||||
num_chars += len(ref)
|
||||
hyp_words = hyp.split()
|
||||
ref_words = ref.split()
|
||||
num_word_errors += editdistance.eval(hyp_words, ref_words)
|
||||
num_words += len(ref_words)
|
||||
|
||||
return {
|
||||
"num_char_errors": num_char_errors,
|
||||
"num_chars": num_chars,
|
||||
"num_word_errors": num_word_errors,
|
||||
"num_words": num_words,
|
||||
}
|
||||
|
||||
def _inference_with_bleu(self, generator, sample, model):
|
||||
import sacrebleu
|
||||
|
||||
def decode(toks, is_ref):
|
||||
s = self.target_dictionary.string(
|
||||
toks.int().cpu(),
|
||||
self.cfg.eval_bleu_remove_bpe,
|
||||
# The default unknown string in fairseq is `<unk>`, but
|
||||
# this is tokenized by sacrebleu as `< unk >`, inflating
|
||||
# BLEU scores. Instead, we use a somewhat more verbose
|
||||
# alternative that is unlikely to appear in the real
|
||||
# reference, but doesn't get split into multiple tokens.
|
||||
unk_string=("UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP"),
|
||||
)
|
||||
if self.tokenizer:
|
||||
s = self.tokenizer.decode(s)
|
||||
return s
|
||||
|
||||
gen_out = self.inference_step(generator, [model], sample)
|
||||
hyps, refs = [], []
|
||||
for i in range(len(gen_out)):
|
||||
hyps.append(decode(gen_out[i][0]["tokens"], is_ref=False))
|
||||
refs.append(
|
||||
decode(
|
||||
utils.strip_pad(sample["target"][i], self.target_dictionary.pad()),
|
||||
is_ref=True, # don't count <unk> as matches to the hypo
|
||||
)
|
||||
)
|
||||
if self.cfg.eval_bleu_print_samples:
|
||||
logger.info("H-{} {}".format(sample["id"][0], hyps[0]))
|
||||
logger.info("T-{} {}".format(sample["id"][0], refs[0]))
|
||||
|
||||
eval_tokenization = "none" if self.cfg.eval_tokenized_bleu else "13a"
|
||||
return sacrebleu.corpus_bleu(hyps, [refs], tokenize=eval_tokenization)
|
||||
|
||||
def reduce_metrics(self, logging_outputs, criterion):
|
||||
super().reduce_metrics(logging_outputs, criterion)
|
||||
|
||||
if self.cfg.eval_wer or self.cfg.eval_wer_parse:
|
||||
zero = torch.scalar_tensor(0.0)
|
||||
num_char_errors = sum(
|
||||
log.get("_num_char_errors", zero) for log in logging_outputs
|
||||
)
|
||||
num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs)
|
||||
num_word_errors = sum(
|
||||
log.get("_num_word_errors", zero) for log in logging_outputs
|
||||
)
|
||||
num_words = sum(log.get("_num_words", zero) for log in logging_outputs)
|
||||
metrics.log_scalar("_num_char_errors", num_char_errors)
|
||||
metrics.log_scalar("_num_chars", num_chars)
|
||||
metrics.log_scalar("_num_word_errors", num_word_errors)
|
||||
metrics.log_scalar("_num_words", num_words)
|
||||
if num_chars > 0:
|
||||
metrics.log_derived(
|
||||
"uer",
|
||||
lambda meters: meters["_num_char_errors"].sum
|
||||
* 100.0
|
||||
/ meters["_num_chars"].sum
|
||||
if meters["_num_chars"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
if num_words > 0:
|
||||
metrics.log_derived(
|
||||
"wer",
|
||||
lambda meters: meters["_num_word_errors"].sum
|
||||
* 100.0
|
||||
/ meters["_num_words"].sum
|
||||
if meters["_num_words"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
if self.cfg.eval_wer_parse:
|
||||
num_em_errors = sum(
|
||||
log.get("_num_em_errors", zero) for log in logging_outputs
|
||||
)
|
||||
num_ems = sum(log.get("_num_ems", zero) for log in logging_outputs)
|
||||
metrics.log_scalar("_num_em_errors", num_em_errors)
|
||||
metrics.log_scalar("_num_ems", num_ems)
|
||||
num_tree_errors = sum(
|
||||
log.get("_num_tree_errors", zero) for log in logging_outputs
|
||||
)
|
||||
num_trees = sum(log.get("_num_trees", zero) for log in logging_outputs)
|
||||
metrics.log_scalar("_num_tree_errors", num_tree_errors)
|
||||
metrics.log_scalar("_num_trees", num_trees)
|
||||
|
||||
if num_ems > 0:
|
||||
metrics.log_derived(
|
||||
"em_error",
|
||||
lambda meters: meters["_num_em_errors"].sum
|
||||
* 100.0
|
||||
/ meters["_num_ems"].sum
|
||||
if meters["_num_ems"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
if num_trees > 0:
|
||||
metrics.log_derived(
|
||||
"tree_error",
|
||||
lambda meters: meters["_num_tree_errors"].sum
|
||||
* 100.0
|
||||
/ meters["_num_trees"].sum
|
||||
if meters["_num_trees"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
|
||||
if self.cfg.eval_bleu:
|
||||
len_keys = ["_bleu_sys_len", "_bleu_ref_len"]
|
||||
count_keys = [f"_bleu_counts_{i}" for i in range(4)]
|
||||
total_keys = [f"_bleu_totals_{i}" for i in range(4)]
|
||||
for k in len_keys + count_keys + total_keys:
|
||||
metrics.log_scalar(k, sum(log.get(k, 0) for log in logging_outputs))
|
||||
|
||||
import sacrebleu
|
||||
|
||||
metrics.log_derived(
|
||||
"bleu",
|
||||
lambda meters: sacrebleu.compute_bleu(
|
||||
correct=[meters[k].sum for k in count_keys],
|
||||
total=[meters[k].sum for k in total_keys],
|
||||
sys_len=meters["_bleu_sys_len"].sum,
|
||||
ref_len=meters["_bleu_ref_len"].sum,
|
||||
smooth_method="exp",
|
||||
).score,
|
||||
)
|
||||
682
modules/voice_conversion/fairseq/tasks/online_backtranslation.py
Normal file
682
modules/voice_conversion/fairseq/tasks/online_backtranslation.py
Normal file
@@ -0,0 +1,682 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from collections import OrderedDict, defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, Sequence, Tuple
|
||||
from argparse import ArgumentError
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import fairseq
|
||||
from fairseq import metrics, options, utils
|
||||
from fairseq.data import (
|
||||
FairseqDataset,
|
||||
LanguagePairDataset,
|
||||
NoisingDataset,
|
||||
PrependTokenDataset,
|
||||
RoundRobinZipDatasets,
|
||||
TransformEosLangPairDataset,
|
||||
data_utils,
|
||||
encoders,
|
||||
)
|
||||
from fairseq.sequence_generator import SequenceGenerator
|
||||
from fairseq.tasks import register_task
|
||||
from fairseq.tasks.translation import TranslationTask, load_langpair_dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PiecewiseLinearFn:
|
||||
"""Piecewise linear function. Can be configured with a string."""
|
||||
|
||||
def __init__(self, pieces: Sequence[Tuple[int, float]]):
|
||||
assert pieces == sorted(
|
||||
pieces
|
||||
), f"PiecewiseLinearFn configuration should be sorted, received: {pieces}"
|
||||
|
||||
self.pieces = pieces
|
||||
|
||||
def __call__(self, x: int) -> float:
|
||||
for i, (x_a, y_a) in enumerate(self.pieces[:-1]):
|
||||
x_b, y_b = self.pieces[i + 1]
|
||||
if x_a <= x <= x_b:
|
||||
return y_a + (x - x_a) * (y_b - y_a) / (x_b - x_a)
|
||||
|
||||
return self.pieces[-1][1]
|
||||
|
||||
@staticmethod
|
||||
def from_string(configuration: str) -> "PiecewiseLinearFn":
|
||||
"""
|
||||
Parse the configuration of lambda coefficient (for scheduling).
|
||||
x = "3" # lambda will be a constant equal to x
|
||||
x = "0:1,1000:0" # lambda will start from 1 and linearly decrease
|
||||
# to 0 during the first 1000 iterations
|
||||
x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000
|
||||
# iterations, then will linearly increase to 1 until iteration 2000
|
||||
"""
|
||||
if isinstance(configuration, float):
|
||||
return PiecewiseLinearFn([(0, configuration)])
|
||||
|
||||
try:
|
||||
parts = configuration.split(",")
|
||||
if len(parts) == 1:
|
||||
v = float(configuration)
|
||||
return PiecewiseLinearFn([(0, v)])
|
||||
|
||||
split = [s.split(":") for s in parts]
|
||||
pieces = [(int(t), float(v)) for t, v in split]
|
||||
return PiecewiseLinearFn(pieces)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Invalid PiecewiseLinearFn configuration: {configuration!r}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def one() -> "PiecewiseLinearFn":
|
||||
return PiecewiseLinearFn([(0, 1.0)])
|
||||
|
||||
|
||||
@register_task("online_backtranslation")
|
||||
class OnlineBackTranslationTask(TranslationTask):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
# Generic translation args
|
||||
parser.add_argument('data', help='colon separated path to data directories list, \
|
||||
will be iterated upon during epochs in round-robin manner; \
|
||||
however, valid and test data are always in the first directory to \
|
||||
avoid the need for repeating them in all directories')
|
||||
parser.add_argument('--mono-langs', metavar='MONO_LANGS',
|
||||
help='monolingual languages for training')
|
||||
parser.add_argument('--valid-lang-pairs', default=None, metavar='VALID_LANG_PAIRS',
|
||||
help='language pairs for validation')
|
||||
parser.add_argument('--load-alignments', action='store_true',
|
||||
help='load the binarized alignments')
|
||||
parser.add_argument('--left-pad-source', default='False', type=str, metavar='BOOL',
|
||||
help='pad the source on the left')
|
||||
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
|
||||
help='pad the target on the left')
|
||||
parser.add_argument('--upsample-primary', default=1, type=int,
|
||||
help='amount to upsample primary dataset')
|
||||
try:
|
||||
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
|
||||
help='max number of tokens in the source sequence')
|
||||
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
|
||||
help='max number of tokens in the target sequence')
|
||||
except ArgumentError:
|
||||
# this might have already been defined. Once we transition this to hydra it should be fine to add it here.
|
||||
pass
|
||||
parser.add_argument('--truncate-source', action='store_true', default=False,
|
||||
help='truncate source to max-source-positions')
|
||||
parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N',
|
||||
help='if >0, then bucket source and target lengths into N '
|
||||
'buckets and pad accordingly; this is useful on TPUs '
|
||||
'to minimize the number of compilations')
|
||||
|
||||
# Denoising args
|
||||
parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N',
|
||||
help='maximum word shuffle distance for denoising autoencoding data generation')
|
||||
parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N',
|
||||
help='word dropout probability for denoising autoencoding data generation')
|
||||
parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N',
|
||||
help='word blanking probability for denoising autoencoding data generation')
|
||||
|
||||
# Backtranslation args
|
||||
parser.add_argument('--lambda-bt', default="1.0", type=str, metavar='N',
|
||||
help='back-translation weight')
|
||||
parser.add_argument('--lambda-dae', default="1.0", type=str, metavar='N',
|
||||
help='denoising auto-encoder weight')
|
||||
|
||||
# Evaluation args
|
||||
parser.add_argument('--generate-one-by-one', action='store_true',
|
||||
help='generate one sentence at a time for backtranslation')
|
||||
|
||||
parser.add_argument('--eval-bleu', action='store_true',
|
||||
help='evaluation with BLEU scores')
|
||||
parser.add_argument('--eval-bleu-detok', type=str, default="space",
|
||||
help='detokenize before computing BLEU (e.g., "moses"); '
|
||||
'required if using --eval-bleu; use "space" to '
|
||||
'disable detokenization; see fairseq.data.encoders '
|
||||
'for other options')
|
||||
parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON',
|
||||
help='args for building the tokenizer, if needed')
|
||||
parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False,
|
||||
help='compute tokenized BLEU instead of sacrebleu')
|
||||
parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None,
|
||||
help='remove BPE before computing BLEU')
|
||||
parser.add_argument('--eval-bleu-args', type=str, metavar='JSON',
|
||||
help='generation args for BLUE scoring, '
|
||||
'e.g., \'{"beam": 4, "lenpen": 0.6}\'')
|
||||
parser.add_argument('--eval-bleu-print-samples', action='store_true',
|
||||
help='print sample generations during validation')
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, args, common_dict, mono_langs, valid_lang_pairs):
|
||||
super().__init__(args, common_dict, common_dict)
|
||||
self.common_dict = common_dict
|
||||
self.mono_langs = mono_langs
|
||||
self.valid_lang_pairs = valid_lang_pairs
|
||||
|
||||
self.SHOW_SAMPLES_INTERVAL = 1000
|
||||
# Start by showing samples
|
||||
self._show_samples_ctr = self.SHOW_SAMPLES_INTERVAL
|
||||
self.SHOW_SAMPLES_NUMBER = 5
|
||||
self.lambda_bt = PiecewiseLinearFn.from_string(args.lambda_bt)
|
||||
self.lambda_dae = PiecewiseLinearFn.from_string(args.lambda_dae)
|
||||
|
||||
self.args = args
|
||||
self.data = utils.split_paths(self.args.data)
|
||||
if len(self.data) == 1:
|
||||
shards = list(Path(self.data[0]).glob("shard*"))
|
||||
if len(shards) > 0:
|
||||
# keep this as strings, since it can also be a manifold path
|
||||
old_data = self.data
|
||||
self.data = [str(shard) for shard in shards]
|
||||
logging.warning(f"Expanded data directory {old_data} to {self.data}")
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
"""Setup the task (e.g., load dictionaries).
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
"""
|
||||
args.left_pad_source = options.eval_bool(args.left_pad_source)
|
||||
args.left_pad_target = options.eval_bool(args.left_pad_target)
|
||||
|
||||
paths = utils.split_paths(args.data)
|
||||
assert len(paths) > 0
|
||||
assert args.mono_langs is not None
|
||||
|
||||
mono_langs = args.mono_langs.split(",")
|
||||
valid_lang_pairs = args.valid_lang_pairs.split(",")
|
||||
|
||||
# load dictionary
|
||||
dict_path = os.path.join(paths[0], "dict.txt")
|
||||
common_dict = cls.load_dictionary(dict_path)
|
||||
|
||||
return cls(args, common_dict, mono_langs, valid_lang_pairs)
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs) -> FairseqDataset:
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
if split == "train":
|
||||
data_path = self.data[(epoch - 1) % len(self.data)]
|
||||
dataset = self.load_train_dataset(data_path)
|
||||
else:
|
||||
# valid/test should always be the same.
|
||||
dataset = self.load_translation_dataset(split, self.data[0])
|
||||
|
||||
self.datasets[split] = dataset
|
||||
return dataset
|
||||
|
||||
def load_train_dataset(self, data_path: str) -> FairseqDataset:
|
||||
"""The training dataset is made of backtranslation dataset and denoising dataset."""
|
||||
data = []
|
||||
for lang in self.mono_langs:
|
||||
train_path = os.path.join(data_path, lang, "train")
|
||||
# TODO: could we do the BT using denoise sample ?
|
||||
# this would half the data loading work
|
||||
data.append((f"{lang}-BT", self.load_bt_dataset(train_path, lang)))
|
||||
data.append(
|
||||
(f"{lang}-DENOISE", self.load_denoise_dataset(train_path, lang))
|
||||
)
|
||||
|
||||
return RoundRobinZipDatasets(OrderedDict(data))
|
||||
|
||||
def _langpair_dataset(
|
||||
self, src: FairseqDataset, tgt: FairseqDataset
|
||||
) -> LanguagePairDataset:
|
||||
return LanguagePairDataset(
|
||||
src,
|
||||
src.sizes,
|
||||
self.dictionary,
|
||||
tgt=tgt,
|
||||
tgt_sizes=tgt.sizes,
|
||||
tgt_dict=self.dictionary,
|
||||
left_pad_source=self.args.left_pad_source,
|
||||
left_pad_target=self.args.left_pad_target,
|
||||
# TODO: should we shuffle ? we are already sorting batch by sizes so ?
|
||||
# shuffle=True,
|
||||
)
|
||||
|
||||
def _prepend_lang_bos_to_target(
|
||||
self, dataset: LanguagePairDataset, lang: str
|
||||
) -> LanguagePairDataset:
|
||||
bos = _lang_token_index(self.dictionary, lang)
|
||||
return TransformEosLangPairDataset(
|
||||
dataset,
|
||||
src_eos=self.dictionary.eos(),
|
||||
new_src_eos=self.dictionary.eos(),
|
||||
tgt_bos=self.dictionary.eos(),
|
||||
new_tgt_bos=bos,
|
||||
)
|
||||
|
||||
def load_bt_dataset(self, data_path: str, lang: str) -> FairseqDataset:
|
||||
"""The BT dataset is generated with (tgt, tgt) pairs.
|
||||
The actual translation to a (generated_src, tgt) pair
|
||||
is done on the fly during training.
|
||||
"""
|
||||
mono_dataset = data_utils.load_indexed_dataset(
|
||||
data_path, self.common_dict, self.args.dataset_impl
|
||||
)
|
||||
assert mono_dataset is not None, f"No dataset found for {lang}"
|
||||
|
||||
mono_dataset_src = PrependTokenDataset(
|
||||
mono_dataset, _lang_token_index(self.dictionary, lang)
|
||||
)
|
||||
|
||||
mono_dataset_bt = self._langpair_dataset(mono_dataset_src, mono_dataset)
|
||||
logger.info(
|
||||
f"mono_lang = {lang} "
|
||||
f"lang token index = {_lang_token_index(self.dictionary, lang)} "
|
||||
f"lang token = {_lang_token(lang)}"
|
||||
)
|
||||
|
||||
mono_dataset_bt = self._prepend_lang_bos_to_target(mono_dataset_bt, lang)
|
||||
return mono_dataset_bt
|
||||
|
||||
def load_denoise_dataset(self, data_path: str, lang: str) -> FairseqDataset:
|
||||
"""Classic denoising dataset"""
|
||||
dataset = data_utils.load_indexed_dataset(
|
||||
data_path, self.common_dict, self.args.dataset_impl
|
||||
)
|
||||
noisy_dataset = NoisingDataset(
|
||||
dataset,
|
||||
self.dictionary,
|
||||
seed=1,
|
||||
max_word_shuffle_distance=self.args.max_word_shuffle_distance,
|
||||
word_dropout_prob=self.args.word_dropout_prob,
|
||||
word_blanking_prob=self.args.word_blanking_prob,
|
||||
)
|
||||
noisy_dataset = PrependTokenDataset(
|
||||
noisy_dataset, _lang_token_index(self.dictionary, lang)
|
||||
)
|
||||
|
||||
clean_dataset = data_utils.load_indexed_dataset(
|
||||
data_path, self.common_dict, self.args.dataset_impl
|
||||
)
|
||||
denoising_dataset = self._langpair_dataset(noisy_dataset, clean_dataset)
|
||||
denoising_dataset = self._prepend_lang_bos_to_target(denoising_dataset, lang)
|
||||
return denoising_dataset
|
||||
|
||||
def load_translation_dataset(
|
||||
self, split: str, data_path: str, combine: bool = False
|
||||
):
|
||||
# only judging with one language pair for the moment,
|
||||
# since ConcatDataset doesn't work as expected
|
||||
assert len(self.valid_lang_pairs) == 1, "For now..."
|
||||
valid_lang_pair = self.valid_lang_pairs[0]
|
||||
src, tgt = valid_lang_pair.split("-")
|
||||
|
||||
# use the same function than TranslationTask
|
||||
src_tgt_dt = load_langpair_dataset(
|
||||
data_path,
|
||||
split,
|
||||
src,
|
||||
self.common_dict,
|
||||
tgt,
|
||||
self.common_dict,
|
||||
combine=combine,
|
||||
dataset_impl=self.args.dataset_impl,
|
||||
upsample_primary=self.args.upsample_primary,
|
||||
left_pad_source=self.args.left_pad_source,
|
||||
left_pad_target=self.args.left_pad_target,
|
||||
max_source_positions=self.args.max_source_positions,
|
||||
max_target_positions=self.args.max_target_positions,
|
||||
load_alignments=self.args.load_alignments,
|
||||
truncate_source=self.args.truncate_source,
|
||||
num_buckets=self.args.num_batch_buckets,
|
||||
shuffle=(split != "test"),
|
||||
prepend_bos_src=_lang_token_index(self.dictionary, src),
|
||||
)
|
||||
|
||||
src_tgt_eos_dt = self._prepend_lang_bos_to_target(src_tgt_dt, tgt)
|
||||
src_tgt_eos_dt.args = self.args
|
||||
return src_tgt_eos_dt
|
||||
|
||||
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def build_model(self, args, from_checkpoint=False):
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
model = super().build_model(args, from_checkpoint)
|
||||
|
||||
add_secial_tokens_to_dict_and_model(self.common_dict, model, self.mono_langs)
|
||||
|
||||
self.sequence_generators = {}
|
||||
for mono_lang in self.mono_langs:
|
||||
self.sequence_generators[mono_lang] = SequenceGenerator(
|
||||
[model],
|
||||
tgt_dict=self.dictionary,
|
||||
beam_size=1,
|
||||
max_len_a=1.3,
|
||||
max_len_b=5,
|
||||
min_len=5,
|
||||
# keep 1 to be able to prepend bos
|
||||
max_len=model.max_decoder_positions() - 1,
|
||||
)
|
||||
|
||||
if getattr(args, "eval_bleu", False):
|
||||
assert getattr(args, "eval_bleu_detok", None) is not None, (
|
||||
"--eval-bleu-detok is required if using --eval-bleu; "
|
||||
"try --eval-bleu-detok=moses (or --eval-bleu-detok=space "
|
||||
"to disable detokenization, e.g., when using sentencepiece)"
|
||||
)
|
||||
detok_args = json.loads(getattr(args, "eval_bleu_detok_args", "{}") or "{}")
|
||||
self.tokenizer = encoders.build_tokenizer(
|
||||
Namespace(
|
||||
tokenizer=getattr(args, "eval_bleu_detok", None), **detok_args
|
||||
)
|
||||
)
|
||||
|
||||
gen_args = json.loads(getattr(args, "eval_bleu_args", "{}") or "{}")
|
||||
self.bleu_sequence_generator = self.build_generator(
|
||||
[model], Namespace(**gen_args)
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def max_positions(self):
|
||||
"""Return the max sentence length allowed by the task."""
|
||||
return (self.args.max_source_positions, self.args.max_target_positions)
|
||||
|
||||
@property
|
||||
def dictionary(self):
|
||||
"""Return the source :class:`~fairseq.data.Dictionary`."""
|
||||
return self.common_dict
|
||||
|
||||
def display_samples_once_in_a_while(self, smp, mono_lang, other_lang):
|
||||
self._show_samples_ctr += 1
|
||||
if self._show_samples_ctr < self.SHOW_SAMPLES_INTERVAL:
|
||||
return
|
||||
self._show_samples_ctr = 0
|
||||
|
||||
ln = smp["net_input"]["src_tokens"].shape[0]
|
||||
|
||||
logger.info(
|
||||
f"(r:{self.args.distributed_rank}) : "
|
||||
f"{other_lang} ---> {mono_lang} "
|
||||
f"({other_lang} was generated by back-translation.) {ln} samples"
|
||||
)
|
||||
|
||||
for i in range(min(ln, self.SHOW_SAMPLES_NUMBER)):
|
||||
src_tokens = smp["net_input"]["src_tokens"][i]
|
||||
tgt_tokens = smp["target"][i]
|
||||
|
||||
src_str = self.dictionary.string(src_tokens, "sentencepiece")
|
||||
tgt_str = self.dictionary.string(tgt_tokens, "sentencepiece")
|
||||
logger.info(
|
||||
f"\n{i}\t\t[{other_lang} generated] {src_str}\n"
|
||||
f"\t\t[{mono_lang} original ] {tgt_str}\n"
|
||||
f"\t\t[ src tokens] {src_tokens}\n"
|
||||
)
|
||||
|
||||
def backtranslate_sample(self, smp, orig_lang, other_lang) -> None:
|
||||
"""
|
||||
* WARNING: smp is modified in place.
|
||||
* At the start of this function, `smp` has the same input and target:
|
||||
|--------------------------------------------------------|
|
||||
| smp['net_input']['src_tokens'] | smp['target'] |
|
||||
| (from data) __en__ hello world | __en__ hello world |
|
||||
|--------------------------------------------------------|
|
||||
|
||||
* We call generator.generate(smp, bos_token = token("ro")),
|
||||
and copy the result as input
|
||||
* At the end, `smp` has the translation to other language.
|
||||
|--------------------------------------------------------|
|
||||
| smp['net_input']['src_tokens'] | smp['target'] |
|
||||
| (generated) __ro__ salut lume | __en__ hello world |
|
||||
|--------------------------------------------------------|
|
||||
|
||||
"""
|
||||
bos_token = _lang_token_index(self.dictionary, other_lang)
|
||||
generated = self.sequence_generators[orig_lang].generate(
|
||||
models=[], sample=smp, bos_token=bos_token
|
||||
)
|
||||
|
||||
max_lngth = max([gn[0]["tokens"].size(0) for gn in generated])
|
||||
net_input = smp["net_input"]
|
||||
n_src_tokens = torch.empty(
|
||||
size=(len(generated), max_lngth + 1), dtype=net_input["src_tokens"].dtype
|
||||
)
|
||||
n_src_lengths = torch.empty(
|
||||
len(generated), dtype=net_input["src_lengths"].dtype
|
||||
)
|
||||
|
||||
for i, gn in enumerate(generated):
|
||||
tokens = gn[0]["tokens"]
|
||||
tokens_size = tokens.size(0)
|
||||
padding_needed = max_lngth - tokens_size
|
||||
tokens = torch.cat([tokens.new([bos_token]), tokens])
|
||||
tokens = F.pad(tokens, (0, padding_needed), value=self.dictionary.pad())
|
||||
n_src_tokens[i] = tokens
|
||||
n_src_lengths[i] = tokens_size + 1
|
||||
|
||||
device = net_input["src_tokens"].device
|
||||
# This seems to be important
|
||||
del net_input["src_tokens"]
|
||||
del net_input["src_lengths"]
|
||||
net_input["src_tokens"] = n_src_tokens.to(device)
|
||||
net_input["src_lengths"] = n_src_lengths.to(device)
|
||||
|
||||
def generate(self, smp, model):
|
||||
model.eval()
|
||||
orig_lang = (
|
||||
self.dictionary[smp["net_input"]["src_tokens"][0][0]]
|
||||
.replace(" ", "")
|
||||
.replace("_", "")
|
||||
)
|
||||
bos_token = smp["net_input"]["prev_output_tokens"][0][0]
|
||||
with torch.no_grad():
|
||||
generated = self.sequence_generators[orig_lang].generate(
|
||||
models=[model], sample=smp, bos_token=bos_token
|
||||
)
|
||||
return generated
|
||||
|
||||
def get_other_lang(self, lang):
|
||||
# TODO: allow more complex mapping
|
||||
if lang != self.mono_langs[0]:
|
||||
return self.mono_langs[0]
|
||||
if len(self.mono_langs) == 2:
|
||||
return self.mono_langs[1]
|
||||
return self.mono_langs[np.random.randint(1, len(self.mono_langs))]
|
||||
|
||||
def train_step(
|
||||
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
|
||||
):
|
||||
|
||||
model.train()
|
||||
model.set_num_updates(update_num)
|
||||
|
||||
agg_loss, agg_sample_size = 0.0, 0.0
|
||||
agg_logging_output: Dict[str, float] = defaultdict(float)
|
||||
|
||||
dataset_keys = self.datasets["train"].datasets.keys()
|
||||
|
||||
weights = {
|
||||
"BT": self.lambda_bt(update_num),
|
||||
"DENOISE": self.lambda_dae(update_num),
|
||||
}
|
||||
log_keys = {"BT": "bt_", "DENOISE": "dae_"}
|
||||
|
||||
for dataset_key in dataset_keys:
|
||||
smp = sample[dataset_key]
|
||||
mono_lang, task_subtype = dataset_key.split("-")
|
||||
if weights[task_subtype] == 0:
|
||||
continue
|
||||
|
||||
if task_subtype == "BT":
|
||||
with torch.autograd.profiler.record_function("backtranslation"):
|
||||
model.eval()
|
||||
# TODO: Could we translate to several language at once ?
|
||||
# this would allow to share encoder_out and maximize GPU usage.
|
||||
other_lang = self.get_other_lang(mono_lang)
|
||||
self.backtranslate_sample(smp, mono_lang, other_lang)
|
||||
self.display_samples_once_in_a_while(smp, mono_lang, other_lang)
|
||||
model.train()
|
||||
|
||||
# Like in FairseqTask.train_step
|
||||
with torch.autograd.profiler.record_function("forward"):
|
||||
loss, sample_size, logging_output = criterion(model, smp)
|
||||
loss *= weights[task_subtype]
|
||||
if ignore_grad:
|
||||
loss *= 0
|
||||
with torch.autograd.profiler.record_function("backward"):
|
||||
optimizer.backward(loss)
|
||||
|
||||
agg_loss += loss.item()
|
||||
agg_sample_size += sample_size
|
||||
for k in logging_output:
|
||||
agg_logging_output[log_keys[task_subtype] + k] += logging_output[k]
|
||||
agg_logging_output[k] += logging_output[k]
|
||||
|
||||
return agg_loss, agg_sample_size, agg_logging_output
|
||||
|
||||
def get_bos_token_from_sample(self, sample):
|
||||
net_input = sample["net_input"]
|
||||
source_lang_token_id = torch.unique(net_input["src_tokens"][:, 0]).item()
|
||||
source_lang_token = self.dictionary[source_lang_token_id].replace("_", "")
|
||||
target_lang_token_id = _lang_token_index(
|
||||
self.dictionary, self.get_other_lang(source_lang_token)
|
||||
)
|
||||
|
||||
return target_lang_token_id
|
||||
|
||||
def reduce_metrics(self, logging_outputs, criterion):
|
||||
super().reduce_metrics(logging_outputs, criterion)
|
||||
bt_sample_size = sum(x.get("bt_sample_size", 0) for x in logging_outputs)
|
||||
if bt_sample_size:
|
||||
bt_loss_sum = sum(x.get("bt_loss", 0) for x in logging_outputs)
|
||||
bt_loss_sum *= 1 / bt_sample_size / math.log(2)
|
||||
metrics.log_scalar("bt_loss", bt_loss_sum, bt_sample_size, round=3)
|
||||
|
||||
bt_nll_loss_sum = sum(x.get("bt_nll_loss", 0) for x in logging_outputs)
|
||||
bt_ntokens = sum(x.get("bt_ntokens", 0) for x in logging_outputs)
|
||||
bt_nll_loss_sum *= 1 / bt_ntokens / math.log(2)
|
||||
metrics.log_scalar("bt_nll_loss", bt_nll_loss_sum, bt_ntokens, round=3)
|
||||
metrics.log_derived(
|
||||
"bt_ppl", lambda meters: utils.get_perplexity(meters["bt_nll_loss"].avg)
|
||||
)
|
||||
|
||||
dae_sample_size = sum(x.get("dae_sample_size", 0) for x in logging_outputs)
|
||||
if dae_sample_size:
|
||||
dae_loss_sum = sum(x.get("dae_loss", 0) for x in logging_outputs)
|
||||
dae_loss_sum *= 1 / dae_sample_size / math.log(2)
|
||||
metrics.log_scalar("dae_loss", dae_loss_sum, dae_sample_size, round=3)
|
||||
|
||||
dae_nll_loss_sum = sum(x.get("dae_nll_loss", 0) for x in logging_outputs)
|
||||
dae_ntokens = sum(x.get("dae_ntokens", 0) for x in logging_outputs)
|
||||
dae_nll_loss_sum *= 1 / dae_ntokens / math.log(2)
|
||||
metrics.log_scalar("dae_nll_loss", dae_nll_loss_sum, dae_ntokens, round=3)
|
||||
metrics.log_derived(
|
||||
"dae_ppl",
|
||||
lambda meters: utils.get_perplexity(meters["dae_nll_loss"].avg),
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def extend_embedding(
|
||||
emb: nn.Module, new_vocab_size: int, copy_from_token_id: int
|
||||
) -> None:
|
||||
old_emb_data = emb.weight.data
|
||||
(old_vocab_size, dim) = old_emb_data.shape
|
||||
assert new_vocab_size >= old_vocab_size
|
||||
|
||||
if new_vocab_size > old_vocab_size:
|
||||
emb.weight.data = torch.zeros((new_vocab_size, dim))
|
||||
emb.weight.data[:old_vocab_size, :] = old_emb_data
|
||||
# initialize new embeddings
|
||||
emb.weight.data[old_vocab_size:, :] = old_emb_data[copy_from_token_id]
|
||||
if hasattr(emb, "num_embeddings"):
|
||||
emb.num_embeddings = new_vocab_size
|
||||
if hasattr(emb, "out_features"):
|
||||
emb.out_features = new_vocab_size
|
||||
|
||||
if getattr(emb, "bias", None) is None:
|
||||
return
|
||||
|
||||
# Fix the bias.
|
||||
# Bias shape can be different from the previous vocab size
|
||||
# if the weight matrix was shared and alread extended but not the bias.
|
||||
(old_vocab_size,) = emb.bias.shape
|
||||
assert new_vocab_size >= old_vocab_size
|
||||
if new_vocab_size > old_vocab_size:
|
||||
old_bias = emb.bias.data
|
||||
new_bias = torch.zeros(
|
||||
(new_vocab_size,), dtype=old_bias.dtype, device=old_bias.device
|
||||
)
|
||||
new_bias[:old_vocab_size] = old_bias
|
||||
emb.bias.data = new_bias
|
||||
|
||||
|
||||
def add_secial_tokens_to_dict_and_model(
|
||||
dictionary: "fairseq.data.Dictionary",
|
||||
model: nn.Module,
|
||||
mono_langs: Sequence[str],
|
||||
) -> None:
|
||||
embs = model.encoder.embed_tokens
|
||||
vocab_size, embedding_dim = embs.weight.shape
|
||||
|
||||
# The model may or may not have a '<mask>' embedding yet
|
||||
assert (
|
||||
len(dictionary) <= vocab_size <= len(dictionary) + 1
|
||||
), f"Dictionary len ({len(dictionary)}) doesn't match embs shape ({embs.weight.shape})"
|
||||
# TODO: we should reuse the pretrained model dict which already has <mask>
|
||||
dictionary.add_symbol("<mask>")
|
||||
|
||||
for lang in mono_langs:
|
||||
lang_token = _lang_token(lang)
|
||||
dictionary.add_symbol(lang_token)
|
||||
logger.info(
|
||||
f"dictionary: {len(dictionary)} -> {vocab_size} tokens "
|
||||
f"after adding {len(mono_langs)} lang tokens."
|
||||
)
|
||||
|
||||
if len(dictionary) <= vocab_size:
|
||||
return
|
||||
|
||||
extend_embedding(embs, len(dictionary), dictionary.bos())
|
||||
dec_embs = model.decoder.embed_tokens
|
||||
extend_embedding(dec_embs, len(dictionary), dictionary.bos())
|
||||
lm_head = model.decoder.output_projection
|
||||
extend_embedding(lm_head, len(dictionary), dictionary.bos())
|
||||
assert lm_head.weight.shape == (len(dictionary), embedding_dim)
|
||||
|
||||
|
||||
def _lang_token(lang: str) -> str:
|
||||
return f"__{lang}__"
|
||||
|
||||
|
||||
def _lang_token_index(dictionary, lang: str) -> int:
|
||||
return dictionary.index(_lang_token(lang))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assert_weights_have_changed(model: nn.Module):
|
||||
def checksum(model: nn.Module) -> float:
|
||||
return sum(p.sum().item() for p in model.parameters())
|
||||
|
||||
initial_checksum = checksum(model)
|
||||
yield model
|
||||
final_checksum = checksum(model)
|
||||
logger.info(
|
||||
f"initial_checksum={initial_checksum} -> final_checksum={final_checksum}"
|
||||
)
|
||||
assert initial_checksum != final_checksum, "Model hasn't changed !"
|
||||
@@ -0,0 +1,485 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.data import (
|
||||
BacktranslationDataset,
|
||||
IndexedCachedDataset,
|
||||
IndexedDataset,
|
||||
IndexedRawTextDataset,
|
||||
LanguagePairDataset,
|
||||
NoisingDataset,
|
||||
RoundRobinZipDatasets,
|
||||
data_utils,
|
||||
indexed_dataset,
|
||||
)
|
||||
from fairseq.models import FairseqMultiModel
|
||||
from fairseq.sequence_generator import SequenceGenerator
|
||||
|
||||
from . import register_task
|
||||
from .multilingual_translation import MultilingualTranslationTask
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_bt_dataset_key(lang_pair):
|
||||
return "bt:" + lang_pair
|
||||
|
||||
|
||||
def _get_denoising_dataset_key(lang_pair):
|
||||
return "denoising:" + lang_pair
|
||||
|
||||
|
||||
# ported from UnsupervisedMT
|
||||
def parse_lambda_config(x):
|
||||
"""
|
||||
Parse the configuration of lambda coefficient (for scheduling).
|
||||
x = "3" # lambda will be a constant equal to x
|
||||
x = "0:1,1000:0" # lambda will start from 1 and linearly decrease
|
||||
# to 0 during the first 1000 iterations
|
||||
x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000
|
||||
# iterations, then will linearly increase to 1 until iteration 2000
|
||||
"""
|
||||
split = x.split(",")
|
||||
if len(split) == 1:
|
||||
return float(x), None
|
||||
else:
|
||||
split = [s.split(os.pathsep) for s in split]
|
||||
assert all(len(s) == 2 for s in split)
|
||||
assert all(k.isdigit() for k, _ in split)
|
||||
assert all(
|
||||
int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1)
|
||||
)
|
||||
return float(split[0][1]), [(int(k), float(v)) for k, v in split]
|
||||
|
||||
|
||||
@register_task("semisupervised_translation")
|
||||
class SemisupervisedTranslationTask(MultilingualTranslationTask):
|
||||
"""A task for training multiple translation models simultaneously.
|
||||
|
||||
We iterate round-robin over batches from multiple language pairs, ordered
|
||||
according to the `--lang-pairs` argument.
|
||||
|
||||
The training loop is roughly:
|
||||
|
||||
for i in range(len(epoch)):
|
||||
for lang_pair in args.lang_pairs:
|
||||
batch = next_batch_for_lang_pair(lang_pair)
|
||||
loss = criterion(model_for_lang_pair(lang_pair), batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
In practice, `next_batch_for_lang_pair` is abstracted in a FairseqDataset
|
||||
(e.g., `RoundRobinZipDatasets`) and `model_for_lang_pair` is a model that
|
||||
implements the `FairseqMultiModel` interface.
|
||||
|
||||
During inference it is required to specify a single `--source-lang` and
|
||||
`--target-lang`, instead of `--lang-pairs`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
MultilingualTranslationTask.add_args(parser)
|
||||
parser.add_argument('--lambda-parallel-config', default="1.0", type=str, metavar='CONFIG',
|
||||
help='cross-entropy reconstruction coefficient (parallel data). '
|
||||
'use fixed weight during training if set to floating point number. '
|
||||
'use piecewise linear function over number of updates to schedule the '
|
||||
'weight with the format: w0:step0,w1:step1,...')
|
||||
parser.add_argument('--lambda-denoising-config', default="0.0", type=str, metavar='CONFIG',
|
||||
help='Cross-entropy reconstruction coefficient (denoising autoencoding)'
|
||||
'use fixed weight during training if set to floating point number. '
|
||||
'use piecewise linear function over number of updates to schedule the '
|
||||
'weight with the format: w0:step0,w1:step1,...')
|
||||
parser.add_argument('--lambda-otf-bt-config', default="0.0", type=str, metavar='CONFIG',
|
||||
help='cross-entropy reconstruction coefficient (on-the-fly back-translation parallel data)'
|
||||
'use fixed weight during training if set to floating point number. '
|
||||
'use piecewise linear function over number of updates to schedule the '
|
||||
'weight with the format: w0:step0,w1:step1,...')
|
||||
parser.add_argument('--bt-max-len-a', default=1.1, type=float, metavar='N',
|
||||
help='generate back-translated sequences of maximum length ax + b, where x is the '
|
||||
'source length')
|
||||
parser.add_argument('--bt-max-len-b', default=10.0, type=float, metavar='N',
|
||||
help='generate back-translated sequences of maximum length ax + b, where x is the '
|
||||
'source length')
|
||||
parser.add_argument('--bt-beam-size', default=1, type=int, metavar='N',
|
||||
help='beam size used in beam search of online back-translation')
|
||||
parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N',
|
||||
help='maximum word shuffle distance for denoising autoencoding data generation')
|
||||
parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N',
|
||||
help='word dropout probability for denoising autoencoding data generation')
|
||||
parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N',
|
||||
help='word blanking probability for denoising autoencoding data generation')
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, args, dicts, training):
|
||||
super().__init__(args, dicts, training)
|
||||
self.lambda_parallel, self.lambda_parallel_steps = parse_lambda_config(
|
||||
args.lambda_parallel_config
|
||||
)
|
||||
self.lambda_otf_bt, self.lambda_otf_bt_steps = parse_lambda_config(
|
||||
args.lambda_otf_bt_config
|
||||
)
|
||||
self.lambda_denoising, self.lambda_denoising_steps = parse_lambda_config(
|
||||
args.lambda_denoising_config
|
||||
)
|
||||
if self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None:
|
||||
denoising_lang_pairs = [
|
||||
"%s-%s" % (tgt, tgt)
|
||||
for tgt in {lang_pair.split("-")[1] for lang_pair in args.lang_pairs}
|
||||
]
|
||||
self.model_lang_pairs = self.model_lang_pairs + denoising_lang_pairs
|
||||
self.backtranslate_datasets = {}
|
||||
self.backtranslators = {}
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
dicts, training = MultilingualTranslationTask.prepare(args, **kwargs)
|
||||
return cls(args, dicts, training)
|
||||
|
||||
def load_dataset(self, split, epoch=1, **kwargs):
|
||||
"""Load a dataset split."""
|
||||
paths = utils.split_paths(self.args.data)
|
||||
assert len(paths) > 0
|
||||
data_path = paths[(epoch - 1) % len(paths)]
|
||||
|
||||
def split_exists(split, src, tgt, lang):
|
||||
if src is not None:
|
||||
filename = os.path.join(
|
||||
data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)
|
||||
)
|
||||
else:
|
||||
filename = os.path.join(
|
||||
data_path, "{}.{}-None.{}".format(split, src, tgt)
|
||||
)
|
||||
return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)
|
||||
|
||||
def load_indexed_dataset(path, dictionary):
|
||||
return data_utils.load_indexed_dataset(
|
||||
path, dictionary, self.args.dataset_impl
|
||||
)
|
||||
|
||||
# load parallel datasets
|
||||
src_datasets, tgt_datasets = {}, {}
|
||||
if (
|
||||
self.lambda_parallel > 0.0
|
||||
or self.lambda_parallel_steps is not None
|
||||
or not split.startswith("train")
|
||||
):
|
||||
for lang_pair in self.lang_pairs:
|
||||
src, tgt = lang_pair.split("-")
|
||||
if split_exists(split, src, tgt, src):
|
||||
prefix = os.path.join(
|
||||
data_path, "{}.{}-{}.".format(split, src, tgt)
|
||||
)
|
||||
elif split_exists(split, tgt, src, src):
|
||||
prefix = os.path.join(
|
||||
data_path, "{}.{}-{}.".format(split, tgt, src)
|
||||
)
|
||||
else:
|
||||
continue
|
||||
src_datasets[lang_pair] = load_indexed_dataset(
|
||||
prefix + src, self.dicts[src]
|
||||
)
|
||||
tgt_datasets[lang_pair] = load_indexed_dataset(
|
||||
prefix + tgt, self.dicts[tgt]
|
||||
)
|
||||
logger.info(
|
||||
"parallel-{} {} {} examples".format(
|
||||
data_path, split, len(src_datasets[lang_pair])
|
||||
)
|
||||
)
|
||||
if len(src_datasets) == 0:
|
||||
raise FileNotFoundError(
|
||||
"Dataset not found: {} ({})".format(split, data_path)
|
||||
)
|
||||
|
||||
# back translation datasets
|
||||
backtranslate_datasets = {}
|
||||
if (
|
||||
self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None
|
||||
) and split.startswith("train"):
|
||||
for lang_pair in self.lang_pairs:
|
||||
src, tgt = lang_pair.split("-")
|
||||
if not split_exists(split, tgt, None, tgt):
|
||||
raise FileNotFoundError(
|
||||
"Dataset not found: backtranslation {} ({})".format(
|
||||
split, data_path
|
||||
)
|
||||
)
|
||||
filename = os.path.join(
|
||||
data_path, "{}.{}-None.{}".format(split, tgt, tgt)
|
||||
)
|
||||
dataset = load_indexed_dataset(filename, self.dicts[tgt])
|
||||
lang_pair_dataset_tgt = LanguagePairDataset(
|
||||
dataset,
|
||||
dataset.sizes,
|
||||
self.dicts[tgt],
|
||||
left_pad_source=self.args.left_pad_source,
|
||||
left_pad_target=self.args.left_pad_target,
|
||||
)
|
||||
lang_pair_dataset = LanguagePairDataset(
|
||||
dataset,
|
||||
dataset.sizes,
|
||||
src_dict=self.dicts[src],
|
||||
tgt=dataset,
|
||||
tgt_sizes=dataset.sizes,
|
||||
tgt_dict=self.dicts[tgt],
|
||||
left_pad_source=self.args.left_pad_source,
|
||||
left_pad_target=self.args.left_pad_target,
|
||||
)
|
||||
backtranslate_datasets[lang_pair] = BacktranslationDataset(
|
||||
tgt_dataset=self.alter_dataset_langtok(
|
||||
lang_pair_dataset_tgt,
|
||||
src_eos=self.dicts[tgt].eos(),
|
||||
src_lang=tgt,
|
||||
tgt_lang=src,
|
||||
),
|
||||
backtranslation_fn=self.backtranslators[lang_pair],
|
||||
src_dict=self.dicts[src],
|
||||
tgt_dict=self.dicts[tgt],
|
||||
output_collater=self.alter_dataset_langtok(
|
||||
lang_pair_dataset=lang_pair_dataset,
|
||||
src_eos=self.dicts[src].eos(),
|
||||
src_lang=src,
|
||||
tgt_eos=self.dicts[tgt].eos(),
|
||||
tgt_lang=tgt,
|
||||
).collater,
|
||||
)
|
||||
logger.info(
|
||||
"backtranslate-{}: {} {} {} examples".format(
|
||||
tgt,
|
||||
data_path,
|
||||
split,
|
||||
len(backtranslate_datasets[lang_pair]),
|
||||
)
|
||||
)
|
||||
self.backtranslate_datasets[lang_pair] = backtranslate_datasets[
|
||||
lang_pair
|
||||
]
|
||||
|
||||
# denoising autoencoder
|
||||
noising_datasets = {}
|
||||
if (
|
||||
self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None
|
||||
) and split.startswith("train"):
|
||||
for lang_pair in self.lang_pairs:
|
||||
_, tgt = lang_pair.split("-")
|
||||
if not split_exists(split, tgt, None, tgt):
|
||||
continue
|
||||
filename = os.path.join(
|
||||
data_path, "{}.{}-None.{}".format(split, tgt, tgt)
|
||||
)
|
||||
tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt])
|
||||
tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt])
|
||||
noising_dataset = NoisingDataset(
|
||||
tgt_dataset1,
|
||||
self.dicts[tgt],
|
||||
seed=1,
|
||||
max_word_shuffle_distance=self.args.max_word_shuffle_distance,
|
||||
word_dropout_prob=self.args.word_dropout_prob,
|
||||
word_blanking_prob=self.args.word_blanking_prob,
|
||||
)
|
||||
noising_datasets[lang_pair] = self.alter_dataset_langtok(
|
||||
LanguagePairDataset(
|
||||
noising_dataset,
|
||||
tgt_dataset1.sizes,
|
||||
self.dicts[tgt],
|
||||
tgt_dataset2,
|
||||
tgt_dataset2.sizes,
|
||||
self.dicts[tgt],
|
||||
left_pad_source=self.args.left_pad_source,
|
||||
left_pad_target=self.args.left_pad_target,
|
||||
),
|
||||
src_eos=self.dicts[tgt].eos(),
|
||||
src_lang=tgt,
|
||||
tgt_eos=self.dicts[tgt].eos(),
|
||||
tgt_lang=tgt,
|
||||
)
|
||||
logger.info(
|
||||
"denoising-{}: {} {} {} examples".format(
|
||||
tgt,
|
||||
data_path,
|
||||
split,
|
||||
len(noising_datasets[lang_pair]),
|
||||
)
|
||||
)
|
||||
|
||||
def language_pair_dataset(lang_pair):
|
||||
src, tgt = lang_pair.split("-")
|
||||
src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
|
||||
return self.alter_dataset_langtok(
|
||||
LanguagePairDataset(
|
||||
src_dataset,
|
||||
src_dataset.sizes,
|
||||
self.dicts[src],
|
||||
tgt_dataset,
|
||||
tgt_dataset.sizes,
|
||||
self.dicts[tgt],
|
||||
left_pad_source=self.args.left_pad_source,
|
||||
left_pad_target=self.args.left_pad_target,
|
||||
),
|
||||
self.dicts[src].eos(),
|
||||
src,
|
||||
self.dicts[tgt].eos(),
|
||||
tgt,
|
||||
)
|
||||
|
||||
self.datasets[split] = RoundRobinZipDatasets(
|
||||
OrderedDict(
|
||||
[
|
||||
(lang_pair, language_pair_dataset(lang_pair))
|
||||
for lang_pair in src_datasets.keys()
|
||||
]
|
||||
+ [
|
||||
(_get_bt_dataset_key(lang_pair), dataset)
|
||||
for lang_pair, dataset in backtranslate_datasets.items()
|
||||
]
|
||||
+ [
|
||||
(_get_denoising_dataset_key(lang_pair), dataset)
|
||||
for lang_pair, dataset in noising_datasets.items()
|
||||
]
|
||||
),
|
||||
eval_key=None
|
||||
if self.training
|
||||
else "%s-%s" % (self.args.source_lang, self.args.target_lang),
|
||||
)
|
||||
|
||||
def build_model(self, args, from_checkpoint=False):
|
||||
from fairseq import models
|
||||
|
||||
model = models.build_model(args, self, from_checkpoint)
|
||||
if not isinstance(model, FairseqMultiModel):
|
||||
raise ValueError(
|
||||
"SemisupervisedTranslationTask requires a FairseqMultiModel architecture"
|
||||
)
|
||||
|
||||
# create SequenceGenerator for each model that has backtranslation dependency on it
|
||||
self.sequence_generators = {}
|
||||
if (
|
||||
self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None
|
||||
) and self.training:
|
||||
for lang_pair in self.lang_pairs:
|
||||
src, tgt = lang_pair.split("-")
|
||||
key = "{}-{}".format(tgt, src)
|
||||
self.sequence_generators[key] = SequenceGenerator(
|
||||
[model.models[key]],
|
||||
tgt_dict=self.dicts[src],
|
||||
beam_size=args.bt_beam_size,
|
||||
max_len_a=args.bt_max_len_a,
|
||||
max_len_b=args.bt_max_len_b,
|
||||
)
|
||||
decoder_lang_tok_idx = self.get_decoder_langtok(src)
|
||||
|
||||
def backtranslate_fn(
|
||||
sample,
|
||||
model=model.models[key],
|
||||
bos_token=decoder_lang_tok_idx,
|
||||
sequence_generator=self.sequence_generators[key],
|
||||
):
|
||||
return sequence_generator.generate(
|
||||
[model],
|
||||
sample,
|
||||
bos_token=bos_token,
|
||||
)
|
||||
|
||||
self.backtranslators[lang_pair] = backtranslate_fn
|
||||
|
||||
return model
|
||||
|
||||
def train_step(
|
||||
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
|
||||
):
|
||||
model.train()
|
||||
|
||||
if update_num > 0:
|
||||
self.update_step(update_num)
|
||||
|
||||
agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, {}
|
||||
|
||||
def forward_backward(model, samples, logging_output_key, weight):
|
||||
nonlocal agg_loss, agg_sample_size, agg_logging_output
|
||||
if samples is None or len(samples) == 0:
|
||||
return
|
||||
loss, sample_size, logging_output = criterion(model, samples)
|
||||
if ignore_grad:
|
||||
loss *= 0
|
||||
else:
|
||||
loss *= weight
|
||||
optimizer.backward(loss)
|
||||
agg_loss += loss.detach().item()
|
||||
# TODO make summing of the sample sizes configurable
|
||||
agg_sample_size += sample_size
|
||||
for k in logging_output:
|
||||
agg_logging_output[k] += logging_output[k]
|
||||
agg_logging_output[logging_output_key] += logging_output[k]
|
||||
|
||||
if self.lambda_parallel > 0.0:
|
||||
for lang_pair in self.lang_pairs:
|
||||
forward_backward(
|
||||
model.models[lang_pair],
|
||||
sample[lang_pair],
|
||||
lang_pair,
|
||||
self.lambda_parallel,
|
||||
)
|
||||
|
||||
if self.lambda_otf_bt > 0.0:
|
||||
for lang_pair in self.lang_pairs:
|
||||
sample_key = _get_bt_dataset_key(lang_pair)
|
||||
forward_backward(
|
||||
model.models[lang_pair],
|
||||
sample[sample_key],
|
||||
sample_key,
|
||||
self.lambda_otf_bt,
|
||||
)
|
||||
|
||||
if self.lambda_denoising > 0.0:
|
||||
for lang_pair in self.lang_pairs:
|
||||
_, tgt = lang_pair.split("-")
|
||||
sample_key = _get_denoising_dataset_key(lang_pair)
|
||||
forward_backward(
|
||||
model.models["{0}-{0}".format(tgt)],
|
||||
sample[sample_key],
|
||||
sample_key,
|
||||
self.lambda_denoising,
|
||||
)
|
||||
|
||||
return agg_loss, agg_sample_size, agg_logging_output
|
||||
|
||||
def update_step(self, num_updates):
|
||||
def lambda_step_func(config, n_iter):
|
||||
"""
|
||||
Update a lambda value according to its schedule configuration.
|
||||
"""
|
||||
ranges = [
|
||||
i
|
||||
for i in range(len(config) - 1)
|
||||
if config[i][0] <= n_iter < config[i + 1][0]
|
||||
]
|
||||
if len(ranges) == 0:
|
||||
assert n_iter >= config[-1][0]
|
||||
return config[-1][1]
|
||||
assert len(ranges) == 1
|
||||
i = ranges[0]
|
||||
x_a, y_a = config[i]
|
||||
x_b, y_b = config[i + 1]
|
||||
return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a)
|
||||
|
||||
if self.lambda_parallel_steps is not None:
|
||||
self.lambda_parallel = lambda_step_func(
|
||||
self.lambda_parallel_steps, num_updates
|
||||
)
|
||||
if self.lambda_denoising_steps is not None:
|
||||
self.lambda_denoising = lambda_step_func(
|
||||
self.lambda_denoising_steps, num_updates
|
||||
)
|
||||
if self.lambda_otf_bt_steps is not None:
|
||||
self.lambda_otf_bt = lambda_step_func(self.lambda_otf_bt_steps, num_updates)
|
||||
286
modules/voice_conversion/fairseq/tasks/sentence_prediction.py
Normal file
286
modules/voice_conversion/fairseq/tasks/sentence_prediction.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import contextlib
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from omegaconf import MISSING, II, open_dict, OmegaConf
|
||||
|
||||
import numpy as np
|
||||
from fairseq.data import (
|
||||
ConcatSentencesDataset,
|
||||
Dictionary,
|
||||
IdDataset,
|
||||
NestedDictionaryDataset,
|
||||
NumelDataset,
|
||||
NumSamplesDataset,
|
||||
OffsetTokensDataset,
|
||||
PrependTokenDataset,
|
||||
RawLabelDataset,
|
||||
RightPadDataset,
|
||||
RollDataset,
|
||||
SortDataset,
|
||||
StripTokenDataset,
|
||||
data_utils,
|
||||
)
|
||||
from fairseq.data.shorten_dataset import maybe_shorten_dataset
|
||||
from fairseq.tasks import FairseqDataclass, FairseqTask, register_task
|
||||
from fairseq.dataclass import ChoiceEnum
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class SentencePredictionConfig(FairseqDataclass):
|
||||
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
|
||||
num_classes: int = field(
|
||||
default=-1,
|
||||
metadata={"help": "number of classes or regression targets"},
|
||||
)
|
||||
init_token: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "add token at the beginning of each batch item"},
|
||||
)
|
||||
separator_token: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "add separator token between inputs"},
|
||||
)
|
||||
no_shuffle: bool = field(
|
||||
default=False,
|
||||
)
|
||||
shorten_method: SHORTEN_METHOD_CHOICES = field(
|
||||
default="none",
|
||||
metadata={
|
||||
"help": "if not none, shorten sequences that exceed tokens_per_sample"
|
||||
},
|
||||
)
|
||||
shorten_data_split_list: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "comma-separated list of dataset splits to apply shortening to, "
|
||||
'e.g., "train,valid" (default: all dataset splits)'
|
||||
},
|
||||
)
|
||||
add_prev_output_tokens: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "add prev_output_tokens to sample, used for encoder-decoder arch"
|
||||
},
|
||||
)
|
||||
max_positions: int = field(
|
||||
default=512,
|
||||
metadata={"help": "max tokens per example"},
|
||||
)
|
||||
|
||||
regression_target: bool = II("criterion.regression_target")
|
||||
classification_head_name: str = II("criterion.classification_head_name")
|
||||
seed: int = II("common.seed")
|
||||
|
||||
|
||||
@register_task("sentence_prediction", dataclass=SentencePredictionConfig)
|
||||
class SentencePredictionTask(FairseqTask):
|
||||
"""
|
||||
Sentence (or sentence pair) prediction (classification or regression) task.
|
||||
|
||||
Args:
|
||||
dictionary (Dictionary): the dictionary for the input of the task
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, data_dictionary, label_dictionary):
|
||||
super().__init__(cfg)
|
||||
self.dictionary = data_dictionary
|
||||
self._label_dictionary = label_dictionary
|
||||
|
||||
@classmethod
|
||||
def load_dictionary(cls, filename):
|
||||
"""Load the dictionary from the filename
|
||||
|
||||
Args:
|
||||
filename (str): the filename
|
||||
"""
|
||||
dictionary = Dictionary.load(filename)
|
||||
dictionary.add_symbol("<mask>")
|
||||
return dictionary
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, cfg, **kwargs):
|
||||
assert cfg.num_classes > 0, "Must set task.num_classes"
|
||||
|
||||
# load data dictionary
|
||||
data_dict = cls.load_dictionary(
|
||||
os.path.join(cfg.data, "input0", "dict.txt"),
|
||||
)
|
||||
logger.info("[input] dictionary: {} types".format(len(data_dict)))
|
||||
|
||||
# load label dictionary
|
||||
if not cfg.regression_target:
|
||||
label_dict = cls.load_dictionary(
|
||||
os.path.join(cfg.data, "label", "dict.txt"),
|
||||
)
|
||||
logger.info("[label] dictionary: {} types".format(len(label_dict)))
|
||||
else:
|
||||
label_dict = data_dict
|
||||
return cls(cfg, data_dict, label_dict)
|
||||
|
||||
def load_dataset(self, split, combine=False, **kwargs):
|
||||
"""Load a given dataset split (e.g., train, valid, test)."""
|
||||
|
||||
def get_path(key, split):
|
||||
return os.path.join(self.cfg.data, key, split)
|
||||
|
||||
def make_dataset(key, dictionary):
|
||||
split_path = get_path(key, split)
|
||||
|
||||
try:
|
||||
dataset = data_utils.load_indexed_dataset(
|
||||
split_path,
|
||||
dictionary,
|
||||
combine=combine,
|
||||
)
|
||||
except Exception as e:
|
||||
if "StorageException: [404] Path not found" in str(e):
|
||||
logger.warning(f"dataset {e} not found")
|
||||
dataset = None
|
||||
else:
|
||||
raise e
|
||||
return dataset
|
||||
|
||||
input0 = make_dataset("input0", self.source_dictionary)
|
||||
assert input0 is not None, "could not find dataset: {}".format(
|
||||
get_path("input0", split)
|
||||
)
|
||||
input1 = make_dataset("input1", self.source_dictionary)
|
||||
|
||||
if self.cfg.init_token is not None:
|
||||
input0 = PrependTokenDataset(input0, self.cfg.init_token)
|
||||
|
||||
if input1 is None:
|
||||
src_tokens = input0
|
||||
else:
|
||||
if self.cfg.separator_token is not None:
|
||||
input1 = PrependTokenDataset(input1, self.cfg.separator_token)
|
||||
|
||||
src_tokens = ConcatSentencesDataset(input0, input1)
|
||||
|
||||
with data_utils.numpy_seed(self.cfg.seed):
|
||||
shuffle = np.random.permutation(len(src_tokens))
|
||||
|
||||
src_tokens = maybe_shorten_dataset(
|
||||
src_tokens,
|
||||
split,
|
||||
self.cfg.shorten_data_split_list,
|
||||
self.cfg.shorten_method,
|
||||
self.max_positions(),
|
||||
self.cfg.seed,
|
||||
)
|
||||
|
||||
dataset = {
|
||||
"id": IdDataset(),
|
||||
"net_input": {
|
||||
"src_tokens": RightPadDataset(
|
||||
src_tokens,
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
),
|
||||
"src_lengths": NumelDataset(src_tokens, reduce=False),
|
||||
},
|
||||
"nsentences": NumSamplesDataset(),
|
||||
"ntokens": NumelDataset(src_tokens, reduce=True),
|
||||
}
|
||||
|
||||
if self.cfg.add_prev_output_tokens:
|
||||
prev_tokens_dataset = RightPadDataset(
|
||||
RollDataset(src_tokens, 1),
|
||||
pad_idx=self.dictionary.pad(),
|
||||
)
|
||||
dataset["net_input"].update(
|
||||
prev_output_tokens=prev_tokens_dataset,
|
||||
)
|
||||
|
||||
if not self.cfg.regression_target:
|
||||
label_dataset = make_dataset("label", self.label_dictionary)
|
||||
if label_dataset is not None:
|
||||
dataset.update(
|
||||
target=OffsetTokensDataset(
|
||||
StripTokenDataset(
|
||||
label_dataset,
|
||||
id_to_strip=self.label_dictionary.eos(),
|
||||
),
|
||||
offset=-self.label_dictionary.nspecial,
|
||||
)
|
||||
)
|
||||
else:
|
||||
label_path = "{0}.label".format(get_path("label", split))
|
||||
if os.path.exists(label_path):
|
||||
|
||||
def parse_regression_target(i, line):
|
||||
values = line.split()
|
||||
assert (
|
||||
len(values) == self.cfg.num_classes
|
||||
), f'expected num_classes={self.cfg.num_classes} regression target values on line {i}, found: "{line}"'
|
||||
return [float(x) for x in values]
|
||||
|
||||
with open(label_path) as h:
|
||||
dataset.update(
|
||||
target=RawLabelDataset(
|
||||
[
|
||||
parse_regression_target(i, line.strip())
|
||||
for i, line in enumerate(h.readlines())
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
nested_dataset = NestedDictionaryDataset(
|
||||
dataset,
|
||||
sizes=[src_tokens.sizes],
|
||||
)
|
||||
|
||||
if self.cfg.no_shuffle:
|
||||
dataset = nested_dataset
|
||||
else:
|
||||
dataset = SortDataset(
|
||||
nested_dataset,
|
||||
# shuffle
|
||||
sort_order=[shuffle],
|
||||
)
|
||||
|
||||
logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset)))
|
||||
|
||||
self.datasets[split] = dataset
|
||||
return self.datasets[split]
|
||||
|
||||
def build_model(self, cfg, from_checkpoint=False):
|
||||
from fairseq import models
|
||||
|
||||
with open_dict(cfg) if OmegaConf.is_config(cfg) else contextlib.ExitStack():
|
||||
cfg.max_positions = self.cfg.max_positions
|
||||
|
||||
model = models.build_model(cfg, self, from_checkpoint)
|
||||
|
||||
model.register_classification_head(
|
||||
self.cfg.classification_head_name,
|
||||
num_classes=self.cfg.num_classes,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def max_positions(self):
|
||||
return self.cfg.max_positions
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
@property
|
||||
def label_dictionary(self):
|
||||
return self._label_dictionary
|
||||
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import contextlib
|
||||
from omegaconf import open_dict, OmegaConf
|
||||
|
||||
from fairseq.tasks import register_task
|
||||
from fairseq.tasks.sentence_prediction import (
|
||||
SentencePredictionTask,
|
||||
SentencePredictionConfig,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task("sentence_prediction_adapters", dataclass=SentencePredictionConfig)
|
||||
class SentencePredictionAdapterTask(SentencePredictionTask):
|
||||
def build_model(self, cfg):
|
||||
from fairseq import models
|
||||
|
||||
with open_dict(cfg) if OmegaConf.is_config(cfg) else contextlib.ExitStack():
|
||||
cfg.max_positions = self.cfg.max_positions
|
||||
|
||||
model = models.build_model(cfg, self)
|
||||
|
||||
model.register_classification_head(
|
||||
self.cfg.classification_head_name,
|
||||
num_classes=self.cfg.num_classes,
|
||||
)
|
||||
|
||||
logger.info("Freezing Embedding Parameters")
|
||||
for parameter in model.encoder.sentence_encoder.embed_positions.parameters():
|
||||
parameter.requires_grad = False
|
||||
for (
|
||||
parameter
|
||||
) in model.encoder.sentence_encoder.layernorm_embedding.parameters():
|
||||
parameter.requires_grad = False
|
||||
for parameter in model.encoder.sentence_encoder.embed_tokens.parameters():
|
||||
parameter.requires_grad = False
|
||||
|
||||
logger.info("Freezing Adapters")
|
||||
for k, v in model.encoder.sentence_encoder.layers._modules.items():
|
||||
logger.info("Freezing Adapters in Layer " + str(k))
|
||||
if hasattr(v, "adapter_layer_norm"):
|
||||
logger.info("Freezing Adapter LN")
|
||||
for parameter in v.adapter_layer_norm.parameters():
|
||||
parameter.requires_grad = False
|
||||
for parameter in v.adapter_modules.parameters():
|
||||
parameter.requires_grad = False
|
||||
|
||||
return model
|
||||
219
modules/voice_conversion/fairseq/tasks/sentence_ranking.py
Normal file
219
modules/voice_conversion/fairseq/tasks/sentence_ranking.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from fairseq import utils
|
||||
from fairseq.data import (
|
||||
ConcatSentencesDataset,
|
||||
Dictionary,
|
||||
IdDataset,
|
||||
NestedDictionaryDataset,
|
||||
NumelDataset,
|
||||
NumSamplesDataset,
|
||||
PrependTokenDataset,
|
||||
RawLabelDataset,
|
||||
RightPadDataset,
|
||||
SortDataset,
|
||||
TruncateDataset,
|
||||
data_utils,
|
||||
)
|
||||
from fairseq.data.shorten_dataset import maybe_shorten_dataset
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task("sentence_ranking")
|
||||
class SentenceRankingTask(LegacyFairseqTask):
|
||||
"""
|
||||
Ranking task on multiple sentences.
|
||||
|
||||
Args:
|
||||
dictionary (Dictionary): the dictionary for the input of the task
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
parser.add_argument("data", metavar="FILE", help="file prefix for data")
|
||||
parser.add_argument(
|
||||
"--num-classes", type=int, help="number of sentences to be ranked"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init-token",
|
||||
type=int,
|
||||
help="add token at the beginning of each batch item",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--separator-token", type=int, help="add separator token between inputs"
|
||||
)
|
||||
parser.add_argument("--no-shuffle", action="store_true")
|
||||
parser.add_argument(
|
||||
"--shorten-method",
|
||||
default="none",
|
||||
choices=["none", "truncate", "random_crop"],
|
||||
help="if not none, shorten sequences that exceed --tokens-per-sample",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shorten-data-split-list",
|
||||
default="",
|
||||
help="comma-separated list of dataset splits to apply shortening to, "
|
||||
'e.g., "train,valid" (default: all dataset splits)',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-option-length", type=int, help="max length for each option"
|
||||
)
|
||||
|
||||
def __init__(self, args, dictionary):
|
||||
super().__init__(args)
|
||||
self.dictionary = dictionary
|
||||
|
||||
@classmethod
|
||||
def load_dictionary(cls, args, filename, source=True):
|
||||
"""Load the dictionary from the filename
|
||||
|
||||
Args:
|
||||
filename (str): the filename
|
||||
"""
|
||||
dictionary = Dictionary.load(filename)
|
||||
dictionary.add_symbol("<mask>")
|
||||
return dictionary
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
assert (
|
||||
args.criterion == "sentence_ranking"
|
||||
), "Must set --criterion=sentence_ranking"
|
||||
|
||||
# load data dictionary
|
||||
data_dict = cls.load_dictionary(
|
||||
args,
|
||||
os.path.join(args.data, "input0", "dict.txt"),
|
||||
source=True,
|
||||
)
|
||||
logger.info("[input] dictionary: {} types".format(len(data_dict)))
|
||||
return SentenceRankingTask(args, data_dict)
|
||||
|
||||
def load_dataset(self, split, combine=False, **kwargs):
|
||||
"""Load a given dataset split (e.g., train, valid, test)."""
|
||||
|
||||
def get_path(type, split):
|
||||
return os.path.join(self.args.data, type, split)
|
||||
|
||||
def make_dataset(type, dictionary):
|
||||
split_path = get_path(type, split)
|
||||
|
||||
dataset = data_utils.load_indexed_dataset(
|
||||
split_path,
|
||||
self.source_dictionary,
|
||||
self.args.dataset_impl,
|
||||
combine=combine,
|
||||
)
|
||||
return dataset
|
||||
|
||||
input0 = make_dataset("input0", self.source_dictionary)
|
||||
input_options = [
|
||||
make_dataset("input{idx}".format(idx=idx + 1), self.source_dictionary)
|
||||
for idx in range(self.args.num_classes)
|
||||
]
|
||||
|
||||
if self.args.separator_token is not None:
|
||||
input0 = PrependTokenDataset(input0, self.args.separator_token)
|
||||
|
||||
src_tokens = []
|
||||
for input_option in input_options:
|
||||
if self.args.init_token is not None:
|
||||
input_option = PrependTokenDataset(input_option, self.args.init_token)
|
||||
if self.args.max_option_length is not None:
|
||||
input_option = TruncateDataset(
|
||||
input_option, self.args.max_option_length
|
||||
)
|
||||
src_token = ConcatSentencesDataset(input_option, input0)
|
||||
src_token = maybe_shorten_dataset(
|
||||
src_token,
|
||||
split,
|
||||
self.args.shorten_data_split_list,
|
||||
self.args.shorten_method,
|
||||
self.args.max_positions,
|
||||
self.args.seed,
|
||||
)
|
||||
src_tokens.append(src_token)
|
||||
|
||||
with data_utils.numpy_seed(self.args.seed):
|
||||
shuffle = np.random.permutation(len(src_tokens[0]))
|
||||
|
||||
dataset = {
|
||||
"id": IdDataset(),
|
||||
"nsentences": NumSamplesDataset(),
|
||||
"ntokens": NumelDataset(src_tokens[0], reduce=True),
|
||||
}
|
||||
|
||||
for src_token_idx in range(len(src_tokens)):
|
||||
dataset.update(
|
||||
{
|
||||
"net_input{idx}".format(idx=src_token_idx + 1): {
|
||||
"src_tokens": RightPadDataset(
|
||||
src_tokens[src_token_idx],
|
||||
pad_idx=self.source_dictionary.pad(),
|
||||
),
|
||||
"src_lengths": NumelDataset(
|
||||
src_tokens[src_token_idx], reduce=False
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
label_path = "{}.label".format(get_path("label", split))
|
||||
if os.path.exists(label_path):
|
||||
with open(label_path) as h:
|
||||
dataset.update(
|
||||
target=RawLabelDataset([int(x.strip()) for x in h.readlines()])
|
||||
)
|
||||
|
||||
nested_dataset = NestedDictionaryDataset(
|
||||
dataset,
|
||||
sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])],
|
||||
)
|
||||
|
||||
if self.args.no_shuffle:
|
||||
dataset = nested_dataset
|
||||
else:
|
||||
dataset = SortDataset(
|
||||
nested_dataset,
|
||||
# shuffle
|
||||
sort_order=[shuffle],
|
||||
)
|
||||
|
||||
logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset)))
|
||||
|
||||
self.datasets[split] = dataset
|
||||
return self.datasets[split]
|
||||
|
||||
def build_model(self, args, from_checkpoint=False):
|
||||
from fairseq import models
|
||||
|
||||
model = models.build_model(args, self, from_checkpoint)
|
||||
|
||||
model.register_classification_head(
|
||||
getattr(args, "ranking_head_name", "sentence_classification_head"),
|
||||
num_classes=1,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def max_positions(self):
|
||||
return self.args.max_positions
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.dictionary
|
||||
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from fairseq.tasks import register_task
|
||||
from fairseq.tasks.speech_to_text import SpeechToTextTask
|
||||
from fairseq.tasks.translation import TranslationTask, TranslationConfig
|
||||
|
||||
try:
|
||||
import examples.simultaneous_translation # noqa
|
||||
|
||||
import_successful = True
|
||||
except BaseException:
|
||||
import_successful = False
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_import(flag):
|
||||
if not flag:
|
||||
raise ImportError(
|
||||
"'examples.simultaneous_translation' is not correctly imported. "
|
||||
"Please considering `pip install -e $FAIRSEQ_DIR`."
|
||||
)
|
||||
|
||||
|
||||
@register_task("simul_speech_to_text")
|
||||
class SimulSpeechToTextTask(SpeechToTextTask):
|
||||
def __init__(self, args, tgt_dict):
|
||||
check_import(import_successful)
|
||||
super().__init__(args, tgt_dict)
|
||||
|
||||
|
||||
@register_task("simul_text_to_text", dataclass=TranslationConfig)
|
||||
class SimulTextToTextTask(TranslationTask):
|
||||
def __init__(self, cfg, src_dict, tgt_dict):
|
||||
check_import(import_successful)
|
||||
super().__init__(cfg, src_dict, tgt_dict)
|
||||
243
modules/voice_conversion/fairseq/tasks/span_masked_lm.py
Normal file
243
modules/voice_conversion/fairseq/tasks/span_masked_lm.py
Normal file
@@ -0,0 +1,243 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from omegaconf import II, MISSING
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.data import (
|
||||
AppendTokenDataset,
|
||||
Dictionary,
|
||||
IdDataset,
|
||||
NestedDictionaryDataset,
|
||||
NumelDataset,
|
||||
PadDataset,
|
||||
PrependTokenDataset,
|
||||
StripTokenDataset,
|
||||
TokenBlockDataset,
|
||||
data_utils,
|
||||
)
|
||||
from fairseq.data.shorten_dataset import maybe_shorten_dataset
|
||||
from fairseq.data.span_mask_tokens_dataset import SpanMaskedTokensDataset
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.tasks import FairseqTask, register_task
|
||||
|
||||
from ..data.indexed_dataset import get_available_dataset_impl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
|
||||
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpanMaskedLMConfig(FairseqDataclass):
|
||||
shuffle: bool = field(
|
||||
default=False,
|
||||
)
|
||||
noise_density: float = field(
|
||||
default=0.15,
|
||||
metadata={"help": "What fraction of the tokens to select as noise"},
|
||||
)
|
||||
mean_noise_span_length: float = field(
|
||||
default=3,
|
||||
metadata={"help": "Mean noise span length, must be >= 1"},
|
||||
)
|
||||
data: str = field(
|
||||
default=MISSING,
|
||||
metadata={
|
||||
"help": "colon separated path to data directories list, "
|
||||
"will be iterated upon during epochs in round-robin manner"
|
||||
},
|
||||
)
|
||||
sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
|
||||
default="none",
|
||||
metadata={
|
||||
"help": 'If omitted or "none", fills each sample with tokens-per-sample '
|
||||
'tokens. If set to "complete", splits samples only at the end '
|
||||
"of sentence, but may include multiple sentences per sample. "
|
||||
'"complete_doc" is similar but respects doc boundaries. '
|
||||
'If set to "eos", includes only one sentence per sample.'
|
||||
},
|
||||
)
|
||||
tokens_per_sample: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "max number of tokens per sample for LM dataset"},
|
||||
)
|
||||
shorten_method: SHORTEN_METHOD_CHOICES = field(
|
||||
default="none",
|
||||
metadata={
|
||||
"help": "if not none, shorten sequences that exceed --tokens-per-sample"
|
||||
},
|
||||
)
|
||||
shorten_data_split_list: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "comma-separated list of dataset splits to apply shortening to, "
|
||||
'e.g., "train,valid" (default: all dataset splits)'
|
||||
},
|
||||
)
|
||||
seed: int = II("common.seed")
|
||||
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
|
||||
"dataset.dataset_impl"
|
||||
)
|
||||
max_source_positions: int = field(
|
||||
default=1024, metadata={"help": "max number of tokens in the source sequence"}
|
||||
)
|
||||
max_target_positions: int = field(
|
||||
default=1024, metadata={"help": "max number of tokens in the target sequence"}
|
||||
)
|
||||
include_target_tokens: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "include target tokens in model input. this is used for data2vec"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@register_task("span_masked_lm", dataclass=SpanMaskedLMConfig)
|
||||
class SpanMaskedLMTask(FairseqTask):
|
||||
"""
|
||||
Span masked language modeling task. (ie. T5)
|
||||
"""
|
||||
|
||||
cfg: SpanMaskedLMConfig
|
||||
|
||||
def __init__(self, cfg, dictionary):
|
||||
super().__init__(cfg)
|
||||
self.dictionary = dictionary
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, cfg: SpanMaskedLMConfig, **kwargs):
|
||||
"""Setup the task."""
|
||||
paths = utils.split_paths(cfg.data)
|
||||
assert len(paths) > 0
|
||||
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
if not hasattr(cfg, "shuffle"):
|
||||
cfg.shuffle = False
|
||||
return cls(cfg, dictionary)
|
||||
|
||||
def _load_dataset_split(self, split, epoch, combine):
|
||||
paths = utils.split_paths(self.cfg.data)
|
||||
assert len(paths) > 0
|
||||
data_path = paths[(epoch - 1) % len(paths)]
|
||||
split_path = os.path.join(data_path, split)
|
||||
|
||||
dataset = data_utils.load_indexed_dataset(
|
||||
split_path,
|
||||
self.dictionary,
|
||||
self.cfg.dataset_impl,
|
||||
combine=combine,
|
||||
)
|
||||
if dataset is None:
|
||||
raise FileNotFoundError(
|
||||
"Dataset not found: {} ({})".format(split, split_path)
|
||||
)
|
||||
|
||||
dataset = StripTokenDataset(dataset, self.dictionary.eos())
|
||||
|
||||
dataset = maybe_shorten_dataset(
|
||||
dataset,
|
||||
split,
|
||||
self.cfg.shorten_data_split_list,
|
||||
self.cfg.shorten_method,
|
||||
self.cfg.tokens_per_sample,
|
||||
self.cfg.seed,
|
||||
)
|
||||
|
||||
# create continuous blocks of tokens
|
||||
dataset = TokenBlockDataset(
|
||||
dataset,
|
||||
dataset.sizes,
|
||||
self.cfg.tokens_per_sample - 2, # one less for <s> and one for </s>
|
||||
pad=self.dictionary.pad(),
|
||||
eos=self.dictionary.eos(),
|
||||
break_mode=self.cfg.sample_break_mode,
|
||||
document_sep_len=0,
|
||||
)
|
||||
logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))
|
||||
|
||||
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
|
||||
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
|
||||
dataset = AppendTokenDataset(dataset, self.source_dictionary.eos())
|
||||
return dataset
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
dataset = self._load_dataset_split(split, epoch, combine)
|
||||
|
||||
self.datasets[split] = SpanMaskedTokensDataset(
|
||||
dataset,
|
||||
self.dictionary,
|
||||
noise_density=self.cfg.noise_density,
|
||||
mean_noise_span_length=self.cfg.mean_noise_span_length,
|
||||
shuffle=self.cfg.shuffle,
|
||||
seed=self.cfg.seed,
|
||||
)
|
||||
logger.info(
|
||||
"Split: {0}, Loaded {1} samples of span_masked_tokens_dataset".format(
|
||||
split,
|
||||
len(self.datasets[split]),
|
||||
)
|
||||
)
|
||||
|
||||
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
|
||||
"""
|
||||
Generate batches for inference. We assume that the input begins with a
|
||||
bos symbol (`<s>`) and ends with an eos symbol (`</s>`).
|
||||
"""
|
||||
pad = self.source_dictionary.pad()
|
||||
eos = self.source_dictionary.eos()
|
||||
src_dataset = TokenBlockDataset(
|
||||
src_tokens,
|
||||
src_lengths,
|
||||
block_size=self.cfg.tokens_per_sample - 2, # for <s> and </s>
|
||||
pad=pad,
|
||||
eos=eos,
|
||||
break_mode=self.cfg.sample_break_mode,
|
||||
document_sep_len=0,
|
||||
)
|
||||
prev_output_tokens = PrependTokenDataset(
|
||||
StripTokenDataset(src_dataset, eos), eos
|
||||
)
|
||||
src_dataset = PadDataset(src_dataset, pad_idx=pad, left_pad=False)
|
||||
return NestedDictionaryDataset(
|
||||
{
|
||||
"id": IdDataset(),
|
||||
"net_input": {
|
||||
"src_tokens": src_dataset,
|
||||
"src_lengths": NumelDataset(src_dataset, reduce=False),
|
||||
"prev_output_tokens": PadDataset(
|
||||
prev_output_tokens, pad_idx=pad, left_pad=False
|
||||
),
|
||||
},
|
||||
"target": src_dataset,
|
||||
},
|
||||
sizes=[np.array(src_lengths)],
|
||||
)
|
||||
|
||||
def max_positions(self):
|
||||
"""Return the max sentence length allowed by the task."""
|
||||
return (self.cfg.max_source_positions, self.cfg.max_target_positions)
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
"""Return the source :class:`~fairseq.data.Dictionary`."""
|
||||
return self.dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
"""Return the target :class:`~fairseq.data.Dictionary`."""
|
||||
return self.dictionary
|
||||
597
modules/voice_conversion/fairseq/tasks/speech_to_speech.py
Normal file
597
modules/voice_conversion/fairseq/tasks/speech_to_speech.py
Normal file
@@ -0,0 +1,597 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.data.audio.data_cfg import MultitaskConfig, S2SDataConfig
|
||||
from fairseq.data.audio.speech_to_speech_dataset import SpeechToSpeechDatasetCreator
|
||||
from fairseq.data.audio.speech_to_text_dataset import (
|
||||
SpeechToTextDataset,
|
||||
TextTargetMultitaskData,
|
||||
)
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
from fairseq.tasks.speech_to_text import DummyMultiTask
|
||||
from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StackUnitSequenceGenerator(nn.Module):
|
||||
def __init__(self, tgt_dict, vocab_size):
|
||||
super().__init__()
|
||||
self.pad = tgt_dict.pad()
|
||||
self.eos = tgt_dict.eos()
|
||||
self.unk = tgt_dict.unk()
|
||||
self.offset = len(tgt_dict) - vocab_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
def pack_units(self, input: torch.Tensor, n_frames_per_step) -> torch.Tensor:
|
||||
if n_frames_per_step <= 1:
|
||||
return input
|
||||
|
||||
bsz, _, n = input.shape
|
||||
assert n == n_frames_per_step
|
||||
|
||||
scale = [
|
||||
pow(self.vocab_size, n_frames_per_step - 1 - i)
|
||||
for i in range(n_frames_per_step)
|
||||
]
|
||||
scale = torch.LongTensor(scale).squeeze(0).to(input.device)
|
||||
mask = input >= self.offset
|
||||
res = ((input - self.offset) * scale * mask).sum(dim=2) + self.offset
|
||||
return res
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, models, sample, **kwargs):
|
||||
# currently only support viterbi search for stacked units
|
||||
model = models[0]
|
||||
model.eval()
|
||||
|
||||
max_len = model.max_decoder_positions()
|
||||
# TODO: incorporate max_len_a and max_len_b
|
||||
|
||||
src_tokens = sample["net_input"]["src_tokens"]
|
||||
src_lengths = sample["net_input"]["src_lengths"]
|
||||
bsz, src_len, _ = src_tokens.size()
|
||||
n_frames_per_step = model.decoder.n_frames_per_step
|
||||
|
||||
# initialize
|
||||
encoder_out = model.forward_encoder(
|
||||
src_tokens, src_lengths, speaker=sample["speaker"]
|
||||
)
|
||||
incremental_state = {}
|
||||
pred_out, attn, scores = [], [], []
|
||||
finished = src_tokens.new_zeros((bsz,)).bool()
|
||||
|
||||
prev_output_tokens = src_lengths.new_zeros((bsz, 1)).long().fill_(self.eos)
|
||||
for _ in range(max_len):
|
||||
cur_out, cur_extra = model.forward_decoder(
|
||||
prev_output_tokens,
|
||||
encoder_out=encoder_out,
|
||||
incremental_state=incremental_state,
|
||||
)
|
||||
|
||||
lprobs = model.get_normalized_probs([cur_out], log_probs=True)
|
||||
# never select pad, unk
|
||||
lprobs[:, :, self.pad] = -math.inf
|
||||
lprobs[:, :, self.unk] = -math.inf
|
||||
|
||||
cur_pred_lprob, cur_pred_out = torch.max(lprobs, dim=2)
|
||||
scores.append(cur_pred_lprob)
|
||||
pred_out.append(cur_pred_out)
|
||||
|
||||
prev_output_tokens = torch.cat(
|
||||
(
|
||||
prev_output_tokens,
|
||||
self.pack_units(
|
||||
cur_pred_out.view(bsz, 1, n_frames_per_step), n_frames_per_step
|
||||
),
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
attn.append(cur_extra["attn"][0])
|
||||
|
||||
cur_finished = torch.any(cur_pred_out.squeeze(1) == self.eos, dim=1)
|
||||
finished = finished | cur_finished
|
||||
if finished.sum().item() == bsz:
|
||||
break
|
||||
|
||||
pred_out = torch.cat(pred_out, dim=1).view(bsz, -1)
|
||||
attn = torch.cat(attn, dim=2)
|
||||
alignment = attn.max(dim=1)[1]
|
||||
attn = attn.repeat_interleave(n_frames_per_step, dim=2)
|
||||
alignment = alignment.repeat_interleave(n_frames_per_step, dim=1)
|
||||
scores = torch.cat(scores, dim=1)
|
||||
eos_idx = (pred_out == self.eos).nonzero(as_tuple=True)
|
||||
out_lens = src_lengths.new_zeros((bsz,)).long().fill_(max_len)
|
||||
for b, l in zip(eos_idx[0], eos_idx[1]):
|
||||
out_lens[b] = min(l, out_lens[b])
|
||||
|
||||
hypos = [
|
||||
[
|
||||
{
|
||||
"tokens": pred_out[b, :out_len],
|
||||
"attn": attn[b, :, :out_len],
|
||||
"alignment": alignment[b, :out_len],
|
||||
"positional_scores": scores[b, :out_len],
|
||||
"score": utils.item(scores[b, :out_len].sum().data),
|
||||
}
|
||||
]
|
||||
for b, out_len in zip(range(bsz), out_lens)
|
||||
]
|
||||
|
||||
return hypos
|
||||
|
||||
|
||||
@register_task("speech_to_speech")
|
||||
class SpeechToSpeechTask(LegacyFairseqTask):
|
||||
@classmethod
|
||||
def add_args(cls, parser):
|
||||
parser.add_argument("data", help="manifest root path")
|
||||
parser.add_argument(
|
||||
"--config-yaml",
|
||||
type=str,
|
||||
default="config.yaml",
|
||||
help="Configuration YAML filename (under manifest root)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multitask-config-yaml",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Configuration YAML filename for the multitasks (under manifest root)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-source-positions",
|
||||
default=6000,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="max number of tokens in the source sequence",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-target-positions",
|
||||
default=1024,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="max number of tokens in the target sequence",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-is-code",
|
||||
action="store_true",
|
||||
help="set if target is discrete unit instead of spectrogram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-code-size", type=int, default=None, help="# discrete units"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n-frames-per-step",
|
||||
type=int,
|
||||
default=1,
|
||||
help="# stacked frames, use 0 for reduced discrete unit sequence",
|
||||
)
|
||||
parser.add_argument("--eval-inference", action="store_true")
|
||||
parser.add_argument(
|
||||
"--eval-args",
|
||||
type=str,
|
||||
default="{}",
|
||||
help='generation args for speech-to-unit model , e.g., \'{"beam": 5, "max_len_a": 1}\', as JSON string',
|
||||
)
|
||||
parser.add_argument("--eos-prob-threshold", type=float, default=0.5)
|
||||
parser.add_argument(
|
||||
"--mcd-normalize-type",
|
||||
type=str,
|
||||
default="targ",
|
||||
choices=["targ", "pred", "path"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocoder",
|
||||
type=str,
|
||||
default="griffin_lim",
|
||||
choices=["griffin_lim", "hifigan", "code_hifigan"],
|
||||
)
|
||||
parser.add_argument("--spec-bwd-max-iter", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--infer-target-lang",
|
||||
type=str,
|
||||
default="",
|
||||
help="target language for inference",
|
||||
)
|
||||
|
||||
def __init__(self, args, tgt_dict, infer_tgt_lang_id=None):
|
||||
super().__init__(args)
|
||||
self.tgt_dict = tgt_dict
|
||||
self.data_cfg = S2SDataConfig(Path(args.data) / args.config_yaml)
|
||||
|
||||
self.multitask_tasks = {}
|
||||
self.tgt_dict_mt = None
|
||||
self.eos_token_mt = None
|
||||
if getattr(args, "multitask_config_yaml", None) is not None:
|
||||
multitask_cfg = MultitaskConfig(
|
||||
Path(args.data) / args.multitask_config_yaml
|
||||
)
|
||||
first_pass_task_idx = multitask_cfg.first_pass_decoder_task_index
|
||||
for i, (task_name, task_config) in enumerate(
|
||||
multitask_cfg.get_all_tasks().items()
|
||||
):
|
||||
task_obj = DummyMultiTask(
|
||||
task_config,
|
||||
task_config.tgt_dict,
|
||||
first_pass=i == first_pass_task_idx,
|
||||
)
|
||||
self.multitask_tasks[task_name] = task_obj
|
||||
if task_obj.is_first_pass_decoder:
|
||||
self.tgt_dict_mt = task_obj.target_dictionary
|
||||
if task_config.prepend_bos_and_append_tgt_lang_tag:
|
||||
self.eos_token_mt = task_config.eos_token
|
||||
assert not isinstance(self.eos_token_mt, List)
|
||||
|
||||
if not self.eos_token_mt:
|
||||
raise Warning(
|
||||
"Please provide eos_token in --multitask-config-yaml to replace eos in sequence generator"
|
||||
)
|
||||
|
||||
self._infer_tgt_lang_id = infer_tgt_lang_id
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
data_cfg = data_cfg = S2SDataConfig(Path(args.data) / args.config_yaml)
|
||||
tgt_dict = None
|
||||
infer_tgt_lang_id = None
|
||||
if args.target_is_code:
|
||||
if data_cfg.prepend_tgt_lang_tag_as_bos:
|
||||
# dictionary with language tags
|
||||
dict_path = Path(args.data) / data_cfg.vocab_filename
|
||||
if not dict_path.is_file():
|
||||
raise FileNotFoundError(
|
||||
f"Dict has to be provided when setting prepend_tgt_lang_tag_as_bos: true, but dict not found: {dict_path}"
|
||||
)
|
||||
tgt_dict = Dictionary.load(dict_path.as_posix())
|
||||
|
||||
# target langauge for inference
|
||||
if args.infer_target_lang != "":
|
||||
tgt_lang_tag = SpeechToTextDataset.LANG_TAG_TEMPLATE.format(
|
||||
args.infer_target_lang
|
||||
)
|
||||
infer_tgt_lang_id = tgt_dict.index(tgt_lang_tag)
|
||||
assert infer_tgt_lang_id != tgt_dict.unk()
|
||||
else:
|
||||
assert args.target_code_size is not None
|
||||
|
||||
tgt_dict = Dictionary()
|
||||
for i in range(args.target_code_size):
|
||||
tgt_dict.add_symbol(str(i))
|
||||
logger.info(f"dictionary size: " f"{len(tgt_dict):,}")
|
||||
|
||||
if getattr(args, "train_subset", None) is not None:
|
||||
if not all(s.startswith("train") for s in args.train_subset.split(",")):
|
||||
raise ValueError('Train splits should be named like "train*".')
|
||||
|
||||
assert args.n_frames_per_step >= 1
|
||||
assert (
|
||||
not args.eval_inference
|
||||
or (args.target_is_code and args.vocoder == "code_hifigan")
|
||||
or (not args.target_is_code and args.vocoder != "code_hifigan")
|
||||
)
|
||||
|
||||
return cls(args, tgt_dict, infer_tgt_lang_id=infer_tgt_lang_id)
|
||||
|
||||
def build_criterion(self, args):
|
||||
from fairseq import criterions
|
||||
|
||||
if len(self.multitask_tasks) > 0:
|
||||
if self.args.target_is_code and not args._name.startswith("speech_to_unit"):
|
||||
raise ValueError(
|
||||
"set --criterion speech_to_unit for speech-to-unit loss with multitask"
|
||||
)
|
||||
elif not self.args.target_is_code and not args._name.startswith(
|
||||
"speech_to_spectrogram"
|
||||
):
|
||||
raise ValueError(
|
||||
"set --criterion speech_to_spectrogram for speech-to-spectrogram loss with multitask"
|
||||
)
|
||||
|
||||
return criterions.build_criterion(args, self)
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
self.datasets[split] = SpeechToSpeechDatasetCreator.from_tsv(
|
||||
root=self.args.data,
|
||||
data_cfg=self.data_cfg,
|
||||
splits=split,
|
||||
is_train_split=split.startswith("train"),
|
||||
epoch=epoch,
|
||||
seed=self.args.seed,
|
||||
target_is_code=self.args.target_is_code,
|
||||
tgt_dict=self.target_dictionary,
|
||||
n_frames_per_step=self.args.n_frames_per_step,
|
||||
multitask=self.multitask_tasks,
|
||||
)
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.tgt_dict
|
||||
|
||||
@property
|
||||
def target_dictionary_mt(self):
|
||||
return self.tgt_dict_mt
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return None
|
||||
|
||||
def max_positions(self):
|
||||
return self.args.max_source_positions, self.args.max_target_positions
|
||||
|
||||
def build_model(self, args, from_checkpoint=False):
|
||||
args.input_feat_per_channel = self.data_cfg.input_feat_per_channel
|
||||
args.input_channels = self.data_cfg.input_transformed_channels
|
||||
args.target_speaker_embed = self.data_cfg.target_speaker_embed is not None
|
||||
args.n_frames_per_step = self.args.n_frames_per_step
|
||||
|
||||
model = super().build_model(args, from_checkpoint)
|
||||
|
||||
if len(self.multitask_tasks) > 0:
|
||||
from fairseq.models.speech_to_speech.s2s_transformer import (
|
||||
S2STransformerMultitaskModelBase,
|
||||
)
|
||||
|
||||
assert isinstance(model, S2STransformerMultitaskModelBase)
|
||||
|
||||
if self.args.eval_inference:
|
||||
self.eval_gen_args = json.loads(self.args.eval_args)
|
||||
self.generator = self.build_generator(
|
||||
[model], Namespace(**self.eval_gen_args)
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def build_generator_dual_decoder(
|
||||
self,
|
||||
models,
|
||||
args,
|
||||
extra_gen_cls_kwargs=None,
|
||||
):
|
||||
from examples.speech_to_speech.unity.sequence_generator_multi_decoder import (
|
||||
MultiDecoderSequenceGenerator,
|
||||
)
|
||||
|
||||
return MultiDecoderSequenceGenerator(
|
||||
models,
|
||||
self.target_dictionary,
|
||||
self.target_dictionary_mt,
|
||||
beam_size=max(1, getattr(args, "beam", 1)),
|
||||
beam_size_mt=max(1, getattr(args, "beam_mt", 1)),
|
||||
max_len_a=getattr(args, "max_len_a", 0),
|
||||
max_len_b=getattr(args, "max_len_b", 200),
|
||||
max_len_a_mt=getattr(args, "max_len_a_mt", 0),
|
||||
max_len_b_mt=getattr(args, "max_len_b_mt", 200),
|
||||
min_len=getattr(args, "min_len", 1),
|
||||
normalize_scores=(not getattr(args, "unnormalized", False)),
|
||||
len_penalty=getattr(args, "lenpen", 1),
|
||||
unk_penalty=getattr(args, "unkpen", 0),
|
||||
temperature=getattr(args, "temperature", 1.0),
|
||||
match_source_len=getattr(args, "match_source_len", False),
|
||||
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
|
||||
**extra_gen_cls_kwargs,
|
||||
)
|
||||
|
||||
def build_generator(
|
||||
self,
|
||||
models,
|
||||
args,
|
||||
seq_gen_cls=None,
|
||||
extra_gen_cls_kwargs=None,
|
||||
):
|
||||
|
||||
if not self.args.target_is_code or self.args.eval_inference:
|
||||
from fairseq.models.text_to_speech.vocoder import get_vocoder
|
||||
|
||||
self.vocoder = get_vocoder(self.args, self.data_cfg)
|
||||
self.vocoder = (
|
||||
self.vocoder.cuda()
|
||||
if torch.cuda.is_available() and not self.args.cpu
|
||||
else self.vocoder.cpu()
|
||||
)
|
||||
|
||||
has_dual_decoder = getattr(models[0], "mt_task_name", None) is not None
|
||||
|
||||
if self.args.target_is_code:
|
||||
if self.args.n_frames_per_step == 1:
|
||||
if has_dual_decoder:
|
||||
seq_generator = self.build_generator_dual_decoder(
|
||||
models,
|
||||
args,
|
||||
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
|
||||
)
|
||||
else:
|
||||
seq_generator = super().build_generator(
|
||||
models,
|
||||
args,
|
||||
seq_gen_cls=None,
|
||||
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
getattr(args, "beam", 1) == 1 and getattr(args, "nbest", 1) == 1
|
||||
), "only support viterbi search for stacked units"
|
||||
seq_generator = StackUnitSequenceGenerator(
|
||||
self.tgt_dict,
|
||||
self.args.target_code_size,
|
||||
)
|
||||
else:
|
||||
if has_dual_decoder:
|
||||
if getattr(args, "teacher_forcing", False):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
from fairseq.speech_generator import MultiDecoderSpeechGenerator
|
||||
|
||||
generator = MultiDecoderSpeechGenerator
|
||||
|
||||
lang_token_ids_aux = {
|
||||
i
|
||||
for s, i in self.tgt_dict_mt.indices.items()
|
||||
if TextTargetMultitaskData.is_lang_tag(s)
|
||||
}
|
||||
|
||||
if extra_gen_cls_kwargs is None:
|
||||
extra_gen_cls_kwargs = {}
|
||||
extra_gen_cls_kwargs[
|
||||
"symbols_to_strip_from_output"
|
||||
] = lang_token_ids_aux
|
||||
|
||||
eos_id_mt = (
|
||||
self.tgt_dict_mt.index(self.eos_token_mt)
|
||||
if self.eos_token_mt
|
||||
else None
|
||||
)
|
||||
assert eos_id_mt != self.tgt_dict_mt.unk()
|
||||
extra_gen_cls_kwargs["eos_mt"] = eos_id_mt
|
||||
|
||||
seq_generator = generator(
|
||||
models,
|
||||
args,
|
||||
self.vocoder,
|
||||
self.data_cfg,
|
||||
self.target_dictionary_mt,
|
||||
max_iter=self.args.max_target_positions,
|
||||
eos_prob_threshold=self.args.eos_prob_threshold,
|
||||
**extra_gen_cls_kwargs,
|
||||
)
|
||||
else:
|
||||
if getattr(args, "teacher_forcing", False):
|
||||
from fairseq.speech_generator import (
|
||||
TeacherForcingAutoRegressiveSpeechGenerator,
|
||||
)
|
||||
|
||||
generator = TeacherForcingAutoRegressiveSpeechGenerator
|
||||
logger.info("Teacher forcing mode for generation")
|
||||
else:
|
||||
from fairseq.speech_generator import AutoRegressiveSpeechGenerator
|
||||
|
||||
generator = AutoRegressiveSpeechGenerator
|
||||
|
||||
seq_generator = generator(
|
||||
models[0],
|
||||
self.vocoder,
|
||||
self.data_cfg,
|
||||
max_iter=self.args.max_target_positions,
|
||||
eos_prob_threshold=self.args.eos_prob_threshold,
|
||||
)
|
||||
|
||||
return seq_generator
|
||||
|
||||
def train_step(
|
||||
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
|
||||
):
|
||||
for task_name, task_obj in self.multitask_tasks.items():
|
||||
criterion.set_multitask_loss_weight(
|
||||
task_name, task_obj.args.get_loss_weight(update_num)
|
||||
)
|
||||
if task_name in model.multitask_decoders:
|
||||
model.multitask_decoders[task_name].train()
|
||||
|
||||
loss, sample_size, logging_output = super().train_step(
|
||||
sample, model, criterion, optimizer, update_num, ignore_grad
|
||||
)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def valid_step(self, sample, model, criterion):
|
||||
for task_name in self.multitask_tasks.keys():
|
||||
if task_name in model.multitask_decoders:
|
||||
model.multitask_decoders[task_name].eval()
|
||||
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
|
||||
|
||||
if self.args.eval_inference:
|
||||
hypos, inference_losses = self.valid_step_with_inference(
|
||||
sample, model, self.generator
|
||||
)
|
||||
for k, v in inference_losses.items():
|
||||
assert k not in logging_output
|
||||
logging_output[k] = v
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def valid_step_with_inference(self, sample, model, generator):
|
||||
if self.args.target_is_code:
|
||||
hypos = generator.generate([model], sample)
|
||||
tgt_lens = (
|
||||
sample["target_lengths"] - 1
|
||||
) * self.args.n_frames_per_step # strip <eos>
|
||||
for b, (f, l) in enumerate(zip(sample["target"], tgt_lens)):
|
||||
hypos[b][0]["targ_waveform"] = self.vocoder(
|
||||
{"code": f[:l] - 4}, # remove <bos>, <pad>, <eos>, <unk>
|
||||
dur_prediction=self.eval_gen_args.get("dur_prediction", False),
|
||||
)
|
||||
if len(hypos[b][0]["tokens"]) > 0:
|
||||
hypos[b][0]["waveform"] = self.vocoder(
|
||||
{"code": hypos[b][0]["tokens"] - 4},
|
||||
dur_prediction=self.eval_gen_args.get("dur_prediction", False),
|
||||
)
|
||||
else:
|
||||
hypos[b][0]["waveform"] = torch.flip(
|
||||
hypos[b][0]["targ_waveform"], dims=[0]
|
||||
)
|
||||
else:
|
||||
hypos = [
|
||||
[hypo] for hypo in generator.generate(model, sample, has_targ=True)
|
||||
]
|
||||
|
||||
losses = {
|
||||
"mcd_loss": 0.0,
|
||||
"targ_frames": 0.0,
|
||||
"pred_frames": 0.0,
|
||||
"path_frames": 0.0,
|
||||
"nins": 0.0,
|
||||
"ndel": 0.0,
|
||||
}
|
||||
rets = batch_mel_cepstral_distortion(
|
||||
[hypo[0]["targ_waveform"] for hypo in hypos],
|
||||
[hypo[0]["waveform"] for hypo in hypos],
|
||||
self.data_cfg.output_sample_rate,
|
||||
normalize_type=None,
|
||||
)
|
||||
for d, extra in rets:
|
||||
pathmap = extra[-1]
|
||||
losses["mcd_loss"] += d.item()
|
||||
losses["targ_frames"] += pathmap.size(0)
|
||||
losses["pred_frames"] += pathmap.size(1)
|
||||
losses["path_frames"] += pathmap.sum().item()
|
||||
losses["nins"] += (pathmap.sum(dim=1) - 1).sum().item()
|
||||
losses["ndel"] += (pathmap.sum(dim=0) - 1).sum().item()
|
||||
losses["norm_frames"] = losses[
|
||||
f"{getattr(self.args, 'mcd_normalize_type', 'targ')}_frames"
|
||||
]
|
||||
|
||||
return hypos, losses
|
||||
|
||||
def inference_step(
|
||||
self, generator, models, sample, prefix_tokens=None, constraints=None
|
||||
):
|
||||
with torch.no_grad():
|
||||
if self._infer_tgt_lang_id is not None:
|
||||
return generator.generate(
|
||||
models,
|
||||
sample,
|
||||
prefix_tokens=prefix_tokens,
|
||||
constraints=constraints,
|
||||
bos_token=self._infer_tgt_lang_id,
|
||||
)
|
||||
else:
|
||||
return super().inference_step(
|
||||
generator,
|
||||
models,
|
||||
sample,
|
||||
prefix_tokens=prefix_tokens,
|
||||
constraints=constraints,
|
||||
)
|
||||
350
modules/voice_conversion/fairseq/tasks/speech_to_text.py
Normal file
350
modules/voice_conversion/fairseq/tasks/speech_to_text.py
Normal file
@@ -0,0 +1,350 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from fairseq.data import Dictionary, encoders
|
||||
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
||||
from fairseq.data.audio.data_cfg import MultitaskConfig
|
||||
from fairseq.data.audio.speech_to_text_dataset import (
|
||||
S2TDataConfig,
|
||||
SpeechToTextDataset,
|
||||
SpeechToTextDatasetCreator,
|
||||
TextTargetMultitaskData,
|
||||
)
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task("speech_to_text")
|
||||
class SpeechToTextTask(LegacyFairseqTask):
|
||||
@classmethod
|
||||
def add_args(cls, parser):
|
||||
parser.add_argument("data", help="manifest root path")
|
||||
parser.add_argument(
|
||||
"--config-yaml",
|
||||
type=str,
|
||||
default="config.yaml",
|
||||
help="Configuration YAML filename (under manifest root)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multitask-config-yaml",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Configuration YAML filename for the multitasks (under manifest root)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-source-positions",
|
||||
default=6000,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="max number of tokens in the source sequence",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-target-positions",
|
||||
default=1024,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="max number of tokens in the target sequence",
|
||||
)
|
||||
|
||||
def __init__(self, args, tgt_dict):
|
||||
super().__init__(args)
|
||||
self.tgt_dict = tgt_dict
|
||||
self.data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml)
|
||||
self.speaker_to_id = self._get_speaker_to_id()
|
||||
if (
|
||||
self.data_cfg.prepend_tgt_lang_tag
|
||||
and self.data_cfg.prepend_bos_and_append_tgt_lang_tag
|
||||
):
|
||||
raise ValueError(
|
||||
"Please set only one of the two options to avoid adding target token multiple times"
|
||||
)
|
||||
|
||||
self.multitask_tasks = {}
|
||||
self.tgt_dict_mt = None
|
||||
self.eos_token_mt = None
|
||||
if getattr(args, "multitask_config_yaml", None) is not None:
|
||||
multitask_cfg = MultitaskConfig(
|
||||
Path(args.data) / args.multitask_config_yaml
|
||||
)
|
||||
first_pass_task_idx = multitask_cfg.first_pass_decoder_task_index
|
||||
for i, (task_name, task_config) in enumerate(
|
||||
multitask_cfg.get_all_tasks().items()
|
||||
):
|
||||
task_obj = DummyMultiTask(
|
||||
task_config,
|
||||
task_config.tgt_dict,
|
||||
first_pass=i == first_pass_task_idx,
|
||||
)
|
||||
self.multitask_tasks[task_name] = task_obj
|
||||
if task_obj.is_first_pass_decoder:
|
||||
self.tgt_dict_mt = task_obj.target_dictionary
|
||||
if task_config.prepend_bos_and_append_tgt_lang_tag:
|
||||
self.eos_token_mt = task_config.eos_token
|
||||
assert not isinstance(self.eos_token_mt, List)
|
||||
|
||||
if not self.eos_token_mt:
|
||||
raise Warning(
|
||||
"Please provide eos_token in --multitask-config-yaml to replace eos in sequence generator"
|
||||
)
|
||||
|
||||
def _get_speaker_to_id(self):
|
||||
speaker_to_id = None
|
||||
speaker_set_filename = self.data_cfg.config.get("speaker_set_filename")
|
||||
if speaker_set_filename is not None:
|
||||
speaker_set_path = Path(self.args.data) / speaker_set_filename
|
||||
with open(speaker_set_path) as f:
|
||||
speaker_to_id = {r.strip(): i for i, r in enumerate(f)}
|
||||
return speaker_to_id
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml)
|
||||
dict_path = Path(args.data) / data_cfg.vocab_filename
|
||||
if not dict_path.is_file():
|
||||
raise FileNotFoundError(f"Dict not found: {dict_path.as_posix()}")
|
||||
tgt_dict = Dictionary.load(dict_path.as_posix())
|
||||
logger.info(
|
||||
f"dictionary size ({data_cfg.vocab_filename}): " f"{len(tgt_dict):,}"
|
||||
)
|
||||
|
||||
if getattr(args, "train_subset", None) is not None:
|
||||
if not all(s.startswith("train") for s in args.train_subset.split(",")):
|
||||
raise ValueError('Train splits should be named like "train*".')
|
||||
return cls(args, tgt_dict)
|
||||
|
||||
def build_criterion(self, args):
|
||||
from fairseq import criterions
|
||||
|
||||
if self.data_cfg.prepend_tgt_lang_tag and args.ignore_prefix_size != 1:
|
||||
raise ValueError(
|
||||
'Please set "--ignore-prefix-size 1" since '
|
||||
"target language ID token is prepended as BOS."
|
||||
)
|
||||
return criterions.build_criterion(args, self)
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
is_train_split = split.startswith("train")
|
||||
pre_tokenizer = self.build_tokenizer(self.args)
|
||||
bpe_tokenizer = self.build_bpe(self.args)
|
||||
self.datasets[split] = SpeechToTextDatasetCreator.from_tsv(
|
||||
root=self.args.data,
|
||||
cfg=self.data_cfg,
|
||||
splits=split,
|
||||
tgt_dict=self.tgt_dict,
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
is_train_split=is_train_split,
|
||||
epoch=epoch,
|
||||
seed=self.args.seed,
|
||||
speaker_to_id=self.speaker_to_id,
|
||||
multitask=self.multitask_tasks,
|
||||
)
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.tgt_dict
|
||||
|
||||
@property
|
||||
def target_dictionary_mt(self):
|
||||
return self.tgt_dict_mt
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return None
|
||||
|
||||
def max_positions(self):
|
||||
return self.args.max_source_positions, self.args.max_target_positions
|
||||
|
||||
def build_model(self, args, from_checkpoint=False):
|
||||
args.input_feat_per_channel = self.data_cfg.input_feat_per_channel
|
||||
args.input_channels = self.data_cfg.input_channels
|
||||
args.speaker_to_id = self.speaker_to_id
|
||||
return super(SpeechToTextTask, self).build_model(args, from_checkpoint)
|
||||
|
||||
def build_generator_dual_decoder(
|
||||
self,
|
||||
models,
|
||||
args,
|
||||
extra_gen_cls_kwargs,
|
||||
):
|
||||
from examples.speech_to_speech.unity.sequence_generator_multi_decoder import (
|
||||
MultiDecoderSequenceGenerator,
|
||||
)
|
||||
|
||||
lang_token_ids_aux = {
|
||||
i
|
||||
for s, i in self.tgt_dict_mt.indices.items()
|
||||
if TextTargetMultitaskData.is_lang_tag(s)
|
||||
}
|
||||
|
||||
extra_gen_cls_kwargs["symbols_to_strip_from_output"].update(lang_token_ids_aux)
|
||||
|
||||
eos_id_mt = (
|
||||
self.tgt_dict_mt.index(self.eos_token_mt) if self.eos_token_mt else None
|
||||
)
|
||||
assert eos_id_mt != self.tgt_dict_mt.unk()
|
||||
extra_gen_cls_kwargs["eos_mt"] = eos_id_mt
|
||||
|
||||
return MultiDecoderSequenceGenerator(
|
||||
models,
|
||||
self.target_dictionary,
|
||||
self.target_dictionary_mt,
|
||||
beam_size=max(1, getattr(args, "beam", 1)),
|
||||
beam_size_mt=max(1, getattr(args, "beam_mt", 1)),
|
||||
max_len_a=getattr(args, "max_len_a", 0),
|
||||
max_len_b=getattr(args, "max_len_b", 200),
|
||||
max_len_a_mt=getattr(args, "max_len_a_mt", 0),
|
||||
max_len_b_mt=getattr(args, "max_len_b_mt", 0),
|
||||
min_len=getattr(args, "min_len", 1),
|
||||
normalize_scores=(not getattr(args, "unnormalized", False)),
|
||||
len_penalty=getattr(args, "lenpen", 1),
|
||||
len_penalty_mt=getattr(args, "lenpen_mt", 1),
|
||||
unk_penalty=getattr(args, "unkpen", 0),
|
||||
temperature=getattr(args, "temperature", 1.0),
|
||||
match_source_len=getattr(args, "match_source_len", False),
|
||||
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
|
||||
**extra_gen_cls_kwargs,
|
||||
)
|
||||
|
||||
def build_generator(
|
||||
self,
|
||||
models,
|
||||
args,
|
||||
seq_gen_cls=None,
|
||||
extra_gen_cls_kwargs=None,
|
||||
):
|
||||
if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1:
|
||||
raise ValueError(
|
||||
'Please set "--prefix-size 1" since '
|
||||
"target language ID token is prepended as BOS."
|
||||
)
|
||||
lang_token_ids = {
|
||||
i
|
||||
for s, i in self.tgt_dict.indices.items()
|
||||
if SpeechToTextDataset.is_lang_tag(s)
|
||||
}
|
||||
|
||||
if extra_gen_cls_kwargs is None:
|
||||
extra_gen_cls_kwargs = {}
|
||||
extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids
|
||||
|
||||
eos_token = (
|
||||
args.eos_token
|
||||
if "eos_token" in args and args.eos_token is not None
|
||||
else self.data_cfg.config.get("eos_token", None)
|
||||
)
|
||||
|
||||
if self.data_cfg.prepend_bos_and_append_tgt_lang_tag and not eos_token:
|
||||
raise Warning(
|
||||
"Please provide --eos_token to replace eos in sequence generator"
|
||||
)
|
||||
|
||||
eos_id = self.tgt_dict.index(eos_token) if eos_token else None
|
||||
extra_gen_cls_kwargs["eos"] = eos_id
|
||||
|
||||
has_dual_decoder = getattr(models[0], "mt_task_name", None) is not None
|
||||
|
||||
if has_dual_decoder:
|
||||
return self.build_generator_dual_decoder(
|
||||
models,
|
||||
args,
|
||||
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
|
||||
)
|
||||
else:
|
||||
return super().build_generator(
|
||||
models,
|
||||
args,
|
||||
seq_gen_cls=None,
|
||||
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
|
||||
)
|
||||
|
||||
def train_step(
|
||||
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
|
||||
):
|
||||
for task_name, task_obj in self.multitask_tasks.items():
|
||||
criterion.set_multitask_loss_weight(
|
||||
task_name, task_obj.args.get_loss_weight(update_num)
|
||||
)
|
||||
if task_name in model.multitask_decoders:
|
||||
model.multitask_decoders[task_name].train()
|
||||
|
||||
loss, sample_size, logging_output = super().train_step(
|
||||
sample, model, criterion, optimizer, update_num, ignore_grad
|
||||
)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def valid_step(self, sample, model, criterion):
|
||||
for task_name, task_obj in self.multitask_tasks.items():
|
||||
if task_name in model.multitask_decoders:
|
||||
model.multitask_decoders[task_name].eval()
|
||||
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def build_tokenizer(self, args):
|
||||
logger.info(f"pre-tokenizer: {self.data_cfg.pre_tokenizer}")
|
||||
return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer))
|
||||
|
||||
def build_bpe(self, args):
|
||||
logger.info(f"tokenizer: {self.data_cfg.bpe_tokenizer}")
|
||||
return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer))
|
||||
|
||||
def get_interactive_tokens_and_lengths(self, lines, encode_fn):
|
||||
n_frames = [get_features_or_waveform(p).shape[0] for p in lines]
|
||||
return lines, n_frames
|
||||
|
||||
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
|
||||
return SpeechToTextDataset(
|
||||
"interactive", False, self.data_cfg, src_tokens, src_lengths
|
||||
)
|
||||
|
||||
|
||||
class DummyMultiTask(LegacyFairseqTask):
|
||||
def __init__(self, args, tgt_dict, first_pass=False):
|
||||
super().__init__(args)
|
||||
self.tgt_dict = tgt_dict
|
||||
self.first_pass = first_pass
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.tgt_dict
|
||||
|
||||
@property
|
||||
def is_first_pass_decoder(self):
|
||||
return self.first_pass
|
||||
|
||||
def inference_step(
|
||||
self, generator, models, sample, prefix_tokens=None, constraints=None
|
||||
):
|
||||
if self.args.decoder_type == "ctc":
|
||||
model = models[0] # only support single model
|
||||
encoder_out = model(**sample)
|
||||
if hasattr(model, "get_logits"):
|
||||
emissions = model.get_logits(
|
||||
encoder_out
|
||||
) # no need to normalize emissions
|
||||
else:
|
||||
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
|
||||
return generator.decode(
|
||||
emissions.transpose(0, 1).float().cpu().contiguous()
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("only ctc decoder is supported at the moment")
|
||||
|
||||
def build_generator(
|
||||
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None
|
||||
):
|
||||
if self.args.decoder_type == "ctc":
|
||||
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
|
||||
|
||||
return W2lViterbiDecoder(args, self.tgt_dict)
|
||||
else:
|
||||
raise NotImplementedError("only ctc decoder is supported at the moment")
|
||||
224
modules/voice_conversion/fairseq/tasks/speech_ulm_task.py
Normal file
224
modules/voice_conversion/fairseq/tasks/speech_ulm_task.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import torch
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.data.codedataset import ExpressiveCodeDataConfig, CodeDataset
|
||||
from fairseq.dataclass.configs import FairseqDataclass
|
||||
from fairseq.tasks import register_task
|
||||
from fairseq.tasks.fairseq_task import FairseqTask
|
||||
from omegaconf import MISSING, DictConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnitDictionary(Dictionary):
|
||||
"""
|
||||
A fixed-sized Dictionary that operates on integer-valued tokens
|
||||
wth a trivial (identity) token <-> id mapping.
|
||||
Special symbols (bos, eos, ...) have ids above n_units.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*, # begin keyword-only arguments
|
||||
n_units,
|
||||
bos="<s>",
|
||||
pad="<pad>",
|
||||
eos="</s>",
|
||||
unk="<unk>",
|
||||
extra_special_symbols=None,
|
||||
clip=False,
|
||||
):
|
||||
self.n_units = n_units
|
||||
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
|
||||
self.clip = clip
|
||||
|
||||
self.symbols = []
|
||||
self.count = []
|
||||
self.indices = {}
|
||||
for i in range(n_units):
|
||||
self.add_symbol(str(i))
|
||||
|
||||
self.bos_index = self.add_symbol(bos)
|
||||
self.pad_index = self.add_symbol(pad)
|
||||
self.eos_index = self.add_symbol(eos)
|
||||
self.unk_index = self.add_symbol(unk)
|
||||
|
||||
if extra_special_symbols:
|
||||
for s in extra_special_symbols:
|
||||
self.add_symbol(s)
|
||||
self.nspecial = len(self.symbols)
|
||||
|
||||
def encode_line(self, line, append_eos=True, prepend_bos=False) -> torch.IntTensor:
|
||||
words = [int(x) for x in line.split()]
|
||||
if self.clip:
|
||||
words = [min(self.n_units - 1, word) for word in words]
|
||||
if prepend_bos:
|
||||
words = [self.bos_index] + words
|
||||
if append_eos:
|
||||
words.append(self.eos_index)
|
||||
ids = torch.IntTensor(words)
|
||||
return ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechUnitModelingConfig(FairseqDataclass):
|
||||
data: str = field(default=MISSING, metadata={"help": "Path to data config.json"})
|
||||
max_token_duration: int = field(
|
||||
default=20, metadata={"help": "all token durations are capped to this value"}
|
||||
)
|
||||
tokens_per_sample: int = field(
|
||||
default=1024, metadata={"help": "tokens in a sample"}
|
||||
)
|
||||
max_target_positions: int = field(
|
||||
default=1024, metadata={"help": "max target positions"}
|
||||
)
|
||||
|
||||
# duration modeling
|
||||
ignore_duration_input: bool = field(
|
||||
default=False, metadata={"help": "whether token durations should be zeroed out"}
|
||||
)
|
||||
discrete_duration: bool = field(
|
||||
default=False, metadata={"help": "treat duration as discrete variable"}
|
||||
)
|
||||
# F0 modeling
|
||||
ignore_f0_input: bool = field(
|
||||
default=False, metadata={"help": "whether F0 should be zeroed out"}
|
||||
)
|
||||
discrete_f0: bool = field(
|
||||
default=False, metadata={"help": "load quantized f0. get bin from config"}
|
||||
)
|
||||
log_f0: bool = field(
|
||||
default=False, metadata={"help": "whether f0 should be modeled in log space"}
|
||||
)
|
||||
normalize_f0_mean: bool = field(
|
||||
default=False, metadata={"help": "whether normalize f0 by speaker mean"}
|
||||
)
|
||||
normalize_f0_std: bool = field(
|
||||
default=False, metadata={"help": "whether normalize f0 by speaker stddev"}
|
||||
)
|
||||
interpolate_f0: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether interpolate f0 for non-voiced segments"},
|
||||
)
|
||||
|
||||
# input/output streams
|
||||
stream_shifts: str = field(
|
||||
default="0,0",
|
||||
metadata={
|
||||
"help": (
|
||||
"comma-separated integer list denoting right-shift for "
|
||||
"duration and pitch streams"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@register_task("speech_unit_modeling", dataclass=SpeechUnitModelingConfig)
|
||||
class SpeechUnitLanguageModelingTask(FairseqTask):
|
||||
def __init__(self, cfg: SpeechUnitModelingConfig) -> None:
|
||||
super().__init__(cfg)
|
||||
assert not self.cfg.normalize_f0_std or self.cfg.normalize_f0_mean
|
||||
|
||||
self.data_config = ExpressiveCodeDataConfig(cfg.data)
|
||||
self._source_dictionary = self._target_dictionary = UnitDictionary(
|
||||
n_units=self.data_config.n_units
|
||||
)
|
||||
self._source_duration_dictionary = self._target_duration_dictionary = (
|
||||
UnitDictionary(n_units=self.cfg.max_token_duration + 1, clip=True)
|
||||
if self.cfg.discrete_duration
|
||||
else None
|
||||
)
|
||||
self._source_f0_dictionary = self._target_f0_dictionary = (
|
||||
UnitDictionary(n_units=self.data_config.f0_vq_n_units)
|
||||
if self.cfg.discrete_f0
|
||||
else None
|
||||
)
|
||||
|
||||
self._channel_names = ["token", "duration", "f0"]
|
||||
self._channel_sizes = [
|
||||
len(self.target_dictionary),
|
||||
len(self.target_duration_dictionary) if self.cfg.discrete_duration else 1,
|
||||
len(self.target_f0_dictionary) if self.cfg.discrete_f0 else 1,
|
||||
]
|
||||
|
||||
@property
|
||||
def source_dictionary(self) -> Optional[Dictionary]:
|
||||
return self._source_dictionary
|
||||
|
||||
@property
|
||||
def source_duration_dictionary(self) -> Optional[Dictionary]:
|
||||
return self._source_duration_dictionary
|
||||
|
||||
@property
|
||||
def source_f0_dictionary(self) -> Optional[Dictionary]:
|
||||
return self._source_f0_dictionary
|
||||
|
||||
@property
|
||||
def channel_names(self) -> List[str]:
|
||||
return self._channel_names
|
||||
|
||||
@property
|
||||
def channel_sizes(self) -> List[int]:
|
||||
return self._channel_sizes
|
||||
|
||||
@property
|
||||
def dictionary(self) -> Optional[Dictionary]:
|
||||
return self._source_dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self) -> Optional[Dictionary]:
|
||||
return self._target_dictionary
|
||||
|
||||
@property
|
||||
def target_duration_dictionary(self) -> Optional[Dictionary]:
|
||||
return self._target_duration_dictionary
|
||||
|
||||
@property
|
||||
def target_f0_dictionary(self) -> Optional[Dictionary]:
|
||||
return self._target_f0_dictionary
|
||||
|
||||
@property
|
||||
def dictionaries(self) -> List[Dictionary]:
|
||||
return [self._dictionaries[l] for l in self.cfg.labels]
|
||||
|
||||
@classmethod
|
||||
def setup_task(
|
||||
cls, cfg: SpeechUnitModelingConfig, **kwargs
|
||||
) -> "SpeechUnitLanguageModelingTask":
|
||||
return cls(cfg)
|
||||
|
||||
def load_dataset(self, split: str, **kwargs) -> None:
|
||||
self.datasets[split] = CodeDataset(
|
||||
manifest=self.data_config.manifests[split],
|
||||
dictionary=self.source_dictionary,
|
||||
dur_dictionary=self.source_duration_dictionary,
|
||||
f0_dictionary=self.source_f0_dictionary,
|
||||
config=self.data_config,
|
||||
discrete_dur=self.cfg.discrete_duration,
|
||||
discrete_f0=self.cfg.discrete_f0,
|
||||
log_f0=self.cfg.log_f0,
|
||||
normalize_f0_mean=self.cfg.normalize_f0_mean,
|
||||
normalize_f0_std=self.cfg.normalize_f0_std,
|
||||
interpolate_f0=self.cfg.interpolate_f0,
|
||||
shifts=self.cfg.stream_shifts,
|
||||
)
|
||||
|
||||
def max_positions(self) -> Tuple[int, int]:
|
||||
return (sys.maxsize, sys.maxsize)
|
||||
|
||||
def build_criterion(self, cfg: DictConfig):
|
||||
import fairseq.criterions
|
||||
|
||||
return fairseq.criterions.build_criterion(cfg, self)
|
||||
501
modules/voice_conversion/fairseq/tasks/text_to_speech.py
Normal file
501
modules/voice_conversion/fairseq/tasks/text_to_speech.py
Normal file
@@ -0,0 +1,501 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import os.path as op
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDatasetCreator
|
||||
from fairseq.tasks import register_task
|
||||
from fairseq.tasks.speech_to_text import SpeechToTextTask
|
||||
from fairseq.speech_generator import (
|
||||
AutoRegressiveSpeechGenerator,
|
||||
NonAutoregressiveSpeechGenerator,
|
||||
TeacherForcingAutoRegressiveSpeechGenerator,
|
||||
)
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
except ImportError:
|
||||
logger.info("Please install tensorboardX: pip install tensorboardX")
|
||||
SummaryWriter = None
|
||||
|
||||
|
||||
@register_task("text_to_speech")
|
||||
class TextToSpeechTask(SpeechToTextTask):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
parser.add_argument("data", help="manifest root path")
|
||||
parser.add_argument(
|
||||
"--config-yaml",
|
||||
type=str,
|
||||
default="config.yaml",
|
||||
help="Configuration YAML filename (under manifest root)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-source-positions",
|
||||
default=1024,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="max number of tokens in the source sequence",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-target-positions",
|
||||
default=1200,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="max number of tokens in the target sequence",
|
||||
)
|
||||
parser.add_argument("--n-frames-per-step", type=int, default=1)
|
||||
parser.add_argument("--eos-prob-threshold", type=float, default=0.5)
|
||||
parser.add_argument("--eval-inference", action="store_true")
|
||||
parser.add_argument("--eval-tb-nsample", type=int, default=8)
|
||||
parser.add_argument("--vocoder", type=str, default="griffin_lim")
|
||||
parser.add_argument("--spec-bwd-max-iter", type=int, default=8)
|
||||
|
||||
def __init__(self, args, src_dict):
|
||||
super().__init__(args, src_dict)
|
||||
self.src_dict = src_dict
|
||||
self.sr = self.data_cfg.config.get("features").get("sample_rate")
|
||||
|
||||
self.tensorboard_writer = None
|
||||
self.tensorboard_dir = ""
|
||||
if args.tensorboard_logdir and SummaryWriter is not None:
|
||||
self.tensorboard_dir = os.path.join(args.tensorboard_logdir, "valid_extra")
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
is_train_split = split.startswith("train")
|
||||
pre_tokenizer = self.build_tokenizer(self.args)
|
||||
bpe_tokenizer = self.build_bpe(self.args)
|
||||
self.datasets[split] = TextToSpeechDatasetCreator.from_tsv(
|
||||
self.args.data,
|
||||
self.data_cfg,
|
||||
split,
|
||||
self.src_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
is_train_split=is_train_split,
|
||||
epoch=epoch,
|
||||
seed=self.args.seed,
|
||||
n_frames_per_step=self.args.n_frames_per_step,
|
||||
speaker_to_id=self.speaker_to_id,
|
||||
)
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return self.src_dict
|
||||
|
||||
def get_speaker_embeddings_path(self):
|
||||
speaker_emb_path = None
|
||||
if self.data_cfg.config.get("speaker_emb_filename") is not None:
|
||||
speaker_emb_path = op.join(
|
||||
self.args.data, self.data_cfg.config.get("speaker_emb_filename")
|
||||
)
|
||||
return speaker_emb_path
|
||||
|
||||
@classmethod
|
||||
def get_speaker_embeddings(cls, args):
|
||||
embed_speaker = None
|
||||
if args.speaker_to_id is not None:
|
||||
if args.speaker_emb_path is None:
|
||||
embed_speaker = torch.nn.Embedding(
|
||||
len(args.speaker_to_id), args.speaker_embed_dim
|
||||
)
|
||||
else:
|
||||
speaker_emb_mat = np.load(args.speaker_emb_path)
|
||||
assert speaker_emb_mat.shape[1] == args.speaker_embed_dim
|
||||
embed_speaker = torch.nn.Embedding.from_pretrained(
|
||||
torch.from_numpy(speaker_emb_mat),
|
||||
freeze=True,
|
||||
)
|
||||
logger.info(
|
||||
f"load speaker embeddings from {args.speaker_emb_path}. "
|
||||
f"train embedding? {embed_speaker.weight.requires_grad}\n"
|
||||
f"embeddings:\n{speaker_emb_mat}"
|
||||
)
|
||||
return embed_speaker
|
||||
|
||||
def build_model(self, cfg, from_checkpoint=False):
|
||||
cfg.pitch_min = self.data_cfg.config["features"].get("pitch_min", None)
|
||||
cfg.pitch_max = self.data_cfg.config["features"].get("pitch_max", None)
|
||||
cfg.energy_min = self.data_cfg.config["features"].get("energy_min", None)
|
||||
cfg.energy_max = self.data_cfg.config["features"].get("energy_max", None)
|
||||
cfg.speaker_emb_path = self.get_speaker_embeddings_path()
|
||||
model = super().build_model(cfg, from_checkpoint)
|
||||
self.generator = None
|
||||
if getattr(cfg, "eval_inference", False):
|
||||
self.generator = self.build_generator([model], cfg)
|
||||
return model
|
||||
|
||||
def build_generator(self, models, cfg, vocoder=None, **unused):
|
||||
if vocoder is None:
|
||||
vocoder = self.build_default_vocoder()
|
||||
model = models[0]
|
||||
if getattr(model, "NON_AUTOREGRESSIVE", False):
|
||||
return NonAutoregressiveSpeechGenerator(model, vocoder, self.data_cfg)
|
||||
else:
|
||||
generator = AutoRegressiveSpeechGenerator
|
||||
if getattr(cfg, "teacher_forcing", False):
|
||||
generator = TeacherForcingAutoRegressiveSpeechGenerator
|
||||
logger.info("Teacher forcing mode for generation")
|
||||
return generator(
|
||||
model,
|
||||
vocoder,
|
||||
self.data_cfg,
|
||||
max_iter=self.args.max_target_positions,
|
||||
eos_prob_threshold=self.args.eos_prob_threshold,
|
||||
)
|
||||
|
||||
def build_default_vocoder(self):
|
||||
from fairseq.models.text_to_speech.vocoder import get_vocoder
|
||||
|
||||
vocoder = get_vocoder(self.args, self.data_cfg)
|
||||
if torch.cuda.is_available() and not self.args.cpu:
|
||||
vocoder = vocoder.cuda()
|
||||
else:
|
||||
vocoder = vocoder.cpu()
|
||||
return vocoder
|
||||
|
||||
def valid_step(self, sample, model, criterion):
|
||||
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
|
||||
|
||||
if getattr(self.args, "eval_inference", False):
|
||||
hypos, inference_losses = self.valid_step_with_inference(
|
||||
sample, model, self.generator
|
||||
)
|
||||
for k, v in inference_losses.items():
|
||||
assert k not in logging_output
|
||||
logging_output[k] = v
|
||||
|
||||
picked_id = 0
|
||||
if self.tensorboard_dir and (sample["id"] == picked_id).any():
|
||||
self.log_tensorboard(
|
||||
sample,
|
||||
hypos[: self.args.eval_tb_nsample],
|
||||
model._num_updates,
|
||||
is_na_model=getattr(model, "NON_AUTOREGRESSIVE", False),
|
||||
)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def valid_step_with_inference(self, sample, model, generator):
|
||||
hypos = generator.generate(model, sample, has_targ=True)
|
||||
|
||||
losses = {
|
||||
"mcd_loss": 0.0,
|
||||
"targ_frames": 0.0,
|
||||
"pred_frames": 0.0,
|
||||
"nins": 0.0,
|
||||
"ndel": 0.0,
|
||||
}
|
||||
rets = batch_mel_cepstral_distortion(
|
||||
[hypo["targ_waveform"] for hypo in hypos],
|
||||
[hypo["waveform"] for hypo in hypos],
|
||||
self.sr,
|
||||
normalize_type=None,
|
||||
)
|
||||
for d, extra in rets:
|
||||
pathmap = extra[-1]
|
||||
losses["mcd_loss"] += d.item()
|
||||
losses["targ_frames"] += pathmap.size(0)
|
||||
losses["pred_frames"] += pathmap.size(1)
|
||||
losses["nins"] += (pathmap.sum(dim=1) - 1).sum().item()
|
||||
losses["ndel"] += (pathmap.sum(dim=0) - 1).sum().item()
|
||||
|
||||
return hypos, losses
|
||||
|
||||
def log_tensorboard(self, sample, hypos, num_updates, is_na_model=False):
|
||||
if self.tensorboard_writer is None:
|
||||
self.tensorboard_writer = SummaryWriter(self.tensorboard_dir)
|
||||
tb_writer = self.tensorboard_writer
|
||||
for b in range(len(hypos)):
|
||||
idx = sample["id"][b]
|
||||
text = sample["src_texts"][b]
|
||||
targ = hypos[b]["targ_feature"]
|
||||
pred = hypos[b]["feature"]
|
||||
attn = hypos[b]["attn"]
|
||||
|
||||
if is_na_model:
|
||||
data = plot_tts_output(
|
||||
[targ.transpose(0, 1), pred.transpose(0, 1)],
|
||||
[f"target (idx={idx})", "output"],
|
||||
attn,
|
||||
"alignment",
|
||||
ret_np=True,
|
||||
suptitle=text,
|
||||
)
|
||||
else:
|
||||
eos_prob = hypos[b]["eos_prob"]
|
||||
data = plot_tts_output(
|
||||
[targ.transpose(0, 1), pred.transpose(0, 1), attn],
|
||||
[f"target (idx={idx})", "output", "alignment"],
|
||||
eos_prob,
|
||||
"eos prob",
|
||||
ret_np=True,
|
||||
suptitle=text,
|
||||
)
|
||||
|
||||
tb_writer.add_image(
|
||||
f"inference_sample_{b}", data, num_updates, dataformats="HWC"
|
||||
)
|
||||
|
||||
if hypos[b]["waveform"] is not None:
|
||||
targ_wave = hypos[b]["targ_waveform"].detach().cpu().float()
|
||||
pred_wave = hypos[b]["waveform"].detach().cpu().float()
|
||||
tb_writer.add_audio(
|
||||
f"inference_targ_{b}", targ_wave, num_updates, sample_rate=self.sr
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
f"inference_pred_{b}", pred_wave, num_updates, sample_rate=self.sr
|
||||
)
|
||||
|
||||
|
||||
def save_figure_to_numpy(fig):
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
return data
|
||||
|
||||
|
||||
DEFAULT_V_MIN = np.log(1e-5)
|
||||
|
||||
|
||||
def plot_tts_output(
|
||||
data_2d,
|
||||
title_2d,
|
||||
data_1d,
|
||||
title_1d,
|
||||
figsize=(24, 4),
|
||||
v_min=DEFAULT_V_MIN,
|
||||
v_max=3,
|
||||
ret_np=False,
|
||||
suptitle="",
|
||||
):
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
except ImportError:
|
||||
raise ImportError("Please install Matplotlib: pip install matplotlib")
|
||||
|
||||
data_2d = [
|
||||
x.detach().cpu().float().numpy() if isinstance(x, torch.Tensor) else x
|
||||
for x in data_2d
|
||||
]
|
||||
fig, axes = plt.subplots(1, len(data_2d) + 1, figsize=figsize)
|
||||
if suptitle:
|
||||
fig.suptitle(suptitle[:400]) # capped at 400 chars
|
||||
axes = [axes] if len(data_2d) == 0 else axes
|
||||
for ax, x, name in zip(axes, data_2d, title_2d):
|
||||
ax.set_title(name)
|
||||
divider = make_axes_locatable(ax)
|
||||
cax = divider.append_axes("right", size="5%", pad=0.05)
|
||||
im = ax.imshow(
|
||||
x,
|
||||
origin="lower",
|
||||
aspect="auto",
|
||||
vmin=max(x.min(), v_min),
|
||||
vmax=min(x.max(), v_max),
|
||||
)
|
||||
fig.colorbar(im, cax=cax, orientation="vertical")
|
||||
|
||||
if isinstance(data_1d, torch.Tensor):
|
||||
data_1d = data_1d.detach().cpu().numpy()
|
||||
axes[-1].plot(data_1d)
|
||||
axes[-1].set_title(title_1d)
|
||||
plt.tight_layout()
|
||||
|
||||
if ret_np:
|
||||
fig.canvas.draw()
|
||||
data = save_figure_to_numpy(fig)
|
||||
plt.close(fig)
|
||||
return data
|
||||
|
||||
|
||||
def antidiag_indices(offset, min_i=0, max_i=None, min_j=0, max_j=None):
|
||||
"""
|
||||
for a (3, 4) matrix with min_i=1, max_i=3, min_j=1, max_j=4, outputs
|
||||
|
||||
offset=2 (1, 1),
|
||||
offset=3 (2, 1), (1, 2)
|
||||
offset=4 (2, 2), (1, 3)
|
||||
offset=5 (2, 3)
|
||||
|
||||
constraints:
|
||||
i + j = offset
|
||||
min_j <= j < max_j
|
||||
min_i <= offset - j < max_i
|
||||
"""
|
||||
if max_i is None:
|
||||
max_i = offset + 1
|
||||
if max_j is None:
|
||||
max_j = offset + 1
|
||||
min_j = max(min_j, offset - max_i + 1, 0)
|
||||
max_j = min(max_j, offset - min_i + 1, offset + 1)
|
||||
j = torch.arange(min_j, max_j)
|
||||
i = offset - j
|
||||
return torch.stack([i, j])
|
||||
|
||||
|
||||
def batch_dynamic_time_warping(distance, shapes=None):
|
||||
"""full batched DTW without any constraints
|
||||
|
||||
distance: (batchsize, max_M, max_N) matrix
|
||||
shapes: (batchsize,) vector specifying (M, N) for each entry
|
||||
"""
|
||||
# ptr: 0=left, 1=up-left, 2=up
|
||||
ptr2dij = {0: (0, -1), 1: (-1, -1), 2: (-1, 0)}
|
||||
|
||||
bsz, m, n = distance.size()
|
||||
cumdist = torch.zeros_like(distance)
|
||||
backptr = torch.zeros_like(distance).type(torch.int32) - 1
|
||||
|
||||
# initialize
|
||||
cumdist[:, 0, :] = distance[:, 0, :].cumsum(dim=-1)
|
||||
cumdist[:, :, 0] = distance[:, :, 0].cumsum(dim=-1)
|
||||
backptr[:, 0, :] = 0
|
||||
backptr[:, :, 0] = 2
|
||||
|
||||
# DP with optimized anti-diagonal parallelization, O(M+N) steps
|
||||
for offset in range(2, m + n - 1):
|
||||
ind = antidiag_indices(offset, 1, m, 1, n)
|
||||
c = torch.stack(
|
||||
[
|
||||
cumdist[:, ind[0], ind[1] - 1],
|
||||
cumdist[:, ind[0] - 1, ind[1] - 1],
|
||||
cumdist[:, ind[0] - 1, ind[1]],
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
v, b = c.min(axis=-1)
|
||||
backptr[:, ind[0], ind[1]] = b.int()
|
||||
cumdist[:, ind[0], ind[1]] = v + distance[:, ind[0], ind[1]]
|
||||
|
||||
# backtrace
|
||||
pathmap = torch.zeros_like(backptr)
|
||||
for b in range(bsz):
|
||||
i = m - 1 if shapes is None else (shapes[b][0] - 1).item()
|
||||
j = n - 1 if shapes is None else (shapes[b][1] - 1).item()
|
||||
dtwpath = [(i, j)]
|
||||
while (i != 0 or j != 0) and len(dtwpath) < 10000:
|
||||
assert i >= 0 and j >= 0
|
||||
di, dj = ptr2dij[backptr[b, i, j].item()]
|
||||
i, j = i + di, j + dj
|
||||
dtwpath.append((i, j))
|
||||
dtwpath = dtwpath[::-1]
|
||||
indices = torch.from_numpy(np.array(dtwpath))
|
||||
pathmap[b, indices[:, 0], indices[:, 1]] = 1
|
||||
|
||||
return cumdist, backptr, pathmap
|
||||
|
||||
|
||||
def compute_l2_dist(x1, x2):
|
||||
"""compute an (m, n) L2 distance matrix from (m, d) and (n, d) matrices"""
|
||||
return torch.cdist(x1.unsqueeze(0), x2.unsqueeze(0), p=2).squeeze(0).pow(2)
|
||||
|
||||
|
||||
def compute_rms_dist(x1, x2):
|
||||
l2_dist = compute_l2_dist(x1, x2)
|
||||
return (l2_dist / x1.size(1)).pow(0.5)
|
||||
|
||||
|
||||
def get_divisor(pathmap, normalize_type):
|
||||
if normalize_type is None:
|
||||
return 1
|
||||
elif normalize_type == "len1":
|
||||
return pathmap.size(0)
|
||||
elif normalize_type == "len2":
|
||||
return pathmap.size(1)
|
||||
elif normalize_type == "path":
|
||||
return pathmap.sum().item()
|
||||
else:
|
||||
raise ValueError(f"normalize_type {normalize_type} not supported")
|
||||
|
||||
|
||||
def batch_compute_distortion(y1, y2, sr, feat_fn, dist_fn, normalize_type):
|
||||
d, s, x1, x2 = [], [], [], []
|
||||
for cur_y1, cur_y2 in zip(y1, y2):
|
||||
assert cur_y1.ndim == 1 and cur_y2.ndim == 1
|
||||
cur_x1 = feat_fn(cur_y1)
|
||||
cur_x2 = feat_fn(cur_y2)
|
||||
x1.append(cur_x1)
|
||||
x2.append(cur_x2)
|
||||
|
||||
cur_d = dist_fn(cur_x1, cur_x2)
|
||||
d.append(cur_d)
|
||||
s.append(d[-1].size())
|
||||
max_m = max(ss[0] for ss in s)
|
||||
max_n = max(ss[1] for ss in s)
|
||||
d = torch.stack(
|
||||
[F.pad(dd, (0, max_n - dd.size(1), 0, max_m - dd.size(0))) for dd in d]
|
||||
)
|
||||
s = torch.LongTensor(s).to(d.device)
|
||||
cumdists, backptrs, pathmaps = batch_dynamic_time_warping(d, s)
|
||||
|
||||
rets = []
|
||||
itr = zip(s, x1, x2, d, cumdists, backptrs, pathmaps)
|
||||
for (m, n), cur_x1, cur_x2, dist, cumdist, backptr, pathmap in itr:
|
||||
cumdist = cumdist[:m, :n]
|
||||
backptr = backptr[:m, :n]
|
||||
pathmap = pathmap[:m, :n]
|
||||
divisor = get_divisor(pathmap, normalize_type)
|
||||
|
||||
distortion = cumdist[-1, -1] / divisor
|
||||
ret = distortion, (cur_x1, cur_x2, dist, cumdist, backptr, pathmap)
|
||||
rets.append(ret)
|
||||
return rets
|
||||
|
||||
|
||||
def batch_mel_cepstral_distortion(y1, y2, sr, normalize_type="path", mfcc_fn=None):
|
||||
"""
|
||||
https://arxiv.org/pdf/2011.03568.pdf
|
||||
|
||||
The root mean squared error computed on 13-dimensional MFCC using DTW for
|
||||
alignment. MFCC features are computed from an 80-channel log-mel
|
||||
spectrogram using a 50ms Hann window and hop of 12.5ms.
|
||||
|
||||
y1: list of waveforms
|
||||
y2: list of waveforms
|
||||
sr: sampling rate
|
||||
"""
|
||||
|
||||
try:
|
||||
import torchaudio
|
||||
except ImportError:
|
||||
raise ImportError("Please install torchaudio: pip install torchaudio")
|
||||
|
||||
if mfcc_fn is None or mfcc_fn.sample_rate != sr:
|
||||
melkwargs = {
|
||||
"n_fft": int(0.05 * sr),
|
||||
"win_length": int(0.05 * sr),
|
||||
"hop_length": int(0.0125 * sr),
|
||||
"f_min": 20,
|
||||
"n_mels": 80,
|
||||
"window_fn": torch.hann_window,
|
||||
}
|
||||
mfcc_fn = torchaudio.transforms.MFCC(
|
||||
sr, n_mfcc=13, log_mels=True, melkwargs=melkwargs
|
||||
).to(y1[0].device)
|
||||
return batch_compute_distortion(
|
||||
y1,
|
||||
y2,
|
||||
sr,
|
||||
lambda y: mfcc_fn(y).transpose(-1, -2),
|
||||
compute_rms_dist,
|
||||
normalize_type,
|
||||
)
|
||||
497
modules/voice_conversion/fairseq/tasks/translation.py
Normal file
497
modules/voice_conversion/fairseq/tasks/translation.py
Normal file
@@ -0,0 +1,497 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
from argparse import Namespace
|
||||
from omegaconf import II
|
||||
|
||||
import numpy as np
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.data import (
|
||||
AppendTokenDataset,
|
||||
ConcatDataset,
|
||||
LanguagePairDataset,
|
||||
PrependTokenDataset,
|
||||
StripTokenDataset,
|
||||
TruncateDataset,
|
||||
data_utils,
|
||||
encoders,
|
||||
indexed_dataset,
|
||||
)
|
||||
from fairseq.data.indexed_dataset import get_available_dataset_impl
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.tasks import FairseqTask, register_task
|
||||
|
||||
|
||||
EVAL_BLEU_ORDER = 4
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_langpair_dataset(
|
||||
data_path,
|
||||
split,
|
||||
src,
|
||||
src_dict,
|
||||
tgt,
|
||||
tgt_dict,
|
||||
combine,
|
||||
dataset_impl,
|
||||
upsample_primary,
|
||||
left_pad_source,
|
||||
left_pad_target,
|
||||
max_source_positions,
|
||||
max_target_positions,
|
||||
prepend_bos=False,
|
||||
load_alignments=False,
|
||||
truncate_source=False,
|
||||
append_source_id=False,
|
||||
num_buckets=0,
|
||||
shuffle=True,
|
||||
pad_to_multiple=1,
|
||||
prepend_bos_src=None,
|
||||
):
|
||||
def split_exists(split, src, tgt, lang, data_path):
|
||||
filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang))
|
||||
return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
|
||||
|
||||
src_datasets = []
|
||||
tgt_datasets = []
|
||||
|
||||
for k in itertools.count():
|
||||
split_k = split + (str(k) if k > 0 else "")
|
||||
|
||||
# infer langcode
|
||||
if split_exists(split_k, src, tgt, src, data_path):
|
||||
prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt))
|
||||
elif split_exists(split_k, tgt, src, src, data_path):
|
||||
prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src))
|
||||
else:
|
||||
if k > 0:
|
||||
break
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
"Dataset not found: {} ({})".format(split, data_path)
|
||||
)
|
||||
|
||||
src_dataset = data_utils.load_indexed_dataset(
|
||||
prefix + src, src_dict, dataset_impl
|
||||
)
|
||||
if truncate_source:
|
||||
src_dataset = AppendTokenDataset(
|
||||
TruncateDataset(
|
||||
StripTokenDataset(src_dataset, src_dict.eos()),
|
||||
max_source_positions - 1,
|
||||
),
|
||||
src_dict.eos(),
|
||||
)
|
||||
src_datasets.append(src_dataset)
|
||||
|
||||
tgt_dataset = data_utils.load_indexed_dataset(
|
||||
prefix + tgt, tgt_dict, dataset_impl
|
||||
)
|
||||
if tgt_dataset is not None:
|
||||
tgt_datasets.append(tgt_dataset)
|
||||
|
||||
logger.info(
|
||||
"{} {} {}-{} {} examples".format(
|
||||
data_path, split_k, src, tgt, len(src_datasets[-1])
|
||||
)
|
||||
)
|
||||
|
||||
if not combine:
|
||||
break
|
||||
|
||||
assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
|
||||
|
||||
if len(src_datasets) == 1:
|
||||
src_dataset = src_datasets[0]
|
||||
tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
|
||||
else:
|
||||
sample_ratios = [1] * len(src_datasets)
|
||||
sample_ratios[0] = upsample_primary
|
||||
src_dataset = ConcatDataset(src_datasets, sample_ratios)
|
||||
if len(tgt_datasets) > 0:
|
||||
tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
|
||||
else:
|
||||
tgt_dataset = None
|
||||
|
||||
if prepend_bos:
|
||||
assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
|
||||
src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
|
||||
if tgt_dataset is not None:
|
||||
tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
|
||||
elif prepend_bos_src is not None:
|
||||
logger.info(f"prepending src bos: {prepend_bos_src}")
|
||||
src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src)
|
||||
|
||||
eos = None
|
||||
if append_source_id:
|
||||
src_dataset = AppendTokenDataset(
|
||||
src_dataset, src_dict.index("[{}]".format(src))
|
||||
)
|
||||
if tgt_dataset is not None:
|
||||
tgt_dataset = AppendTokenDataset(
|
||||
tgt_dataset, tgt_dict.index("[{}]".format(tgt))
|
||||
)
|
||||
eos = tgt_dict.index("[{}]".format(tgt))
|
||||
|
||||
align_dataset = None
|
||||
if load_alignments:
|
||||
align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt))
|
||||
if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
|
||||
align_dataset = data_utils.load_indexed_dataset(
|
||||
align_path, None, dataset_impl
|
||||
)
|
||||
|
||||
tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
|
||||
return LanguagePairDataset(
|
||||
src_dataset,
|
||||
src_dataset.sizes,
|
||||
src_dict,
|
||||
tgt_dataset,
|
||||
tgt_dataset_sizes,
|
||||
tgt_dict,
|
||||
left_pad_source=left_pad_source,
|
||||
left_pad_target=left_pad_target,
|
||||
align_dataset=align_dataset,
|
||||
eos=eos,
|
||||
num_buckets=num_buckets,
|
||||
shuffle=shuffle,
|
||||
pad_to_multiple=pad_to_multiple,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranslationConfig(FairseqDataclass):
|
||||
data: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "colon separated path to data directories list, will be iterated upon during epochs "
|
||||
"in round-robin manner; however, valid and test data are always in the first directory "
|
||||
"to avoid the need for repeating them in all directories"
|
||||
},
|
||||
)
|
||||
source_lang: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "source language",
|
||||
"argparse_alias": "-s",
|
||||
},
|
||||
)
|
||||
target_lang: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "target language",
|
||||
"argparse_alias": "-t",
|
||||
},
|
||||
)
|
||||
load_alignments: bool = field(
|
||||
default=False, metadata={"help": "load the binarized alignments"}
|
||||
)
|
||||
left_pad_source: bool = field(
|
||||
default=True, metadata={"help": "pad the source on the left"}
|
||||
)
|
||||
left_pad_target: bool = field(
|
||||
default=False, metadata={"help": "pad the target on the left"}
|
||||
)
|
||||
max_source_positions: int = field(
|
||||
default=1024, metadata={"help": "max number of tokens in the source sequence"}
|
||||
)
|
||||
max_target_positions: int = field(
|
||||
default=1024, metadata={"help": "max number of tokens in the target sequence"}
|
||||
)
|
||||
upsample_primary: int = field(
|
||||
default=-1, metadata={"help": "the amount of upsample primary dataset"}
|
||||
)
|
||||
truncate_source: bool = field(
|
||||
default=False, metadata={"help": "truncate source to max-source-positions"}
|
||||
)
|
||||
num_batch_buckets: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "if >0, then bucket source and target lengths into "
|
||||
"N buckets and pad accordingly; this is useful on TPUs to minimize the number of compilations"
|
||||
},
|
||||
)
|
||||
train_subset: str = II("dataset.train_subset")
|
||||
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
|
||||
"dataset.dataset_impl"
|
||||
)
|
||||
required_seq_len_multiple: int = II("dataset.required_seq_len_multiple")
|
||||
|
||||
# options for reporting BLEU during validation
|
||||
eval_bleu: bool = field(
|
||||
default=False, metadata={"help": "evaluation with BLEU scores"}
|
||||
)
|
||||
eval_bleu_args: Optional[str] = field(
|
||||
default="{}",
|
||||
metadata={
|
||||
"help": 'generation args for BLUE scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string'
|
||||
},
|
||||
)
|
||||
eval_bleu_detok: str = field(
|
||||
default="space",
|
||||
metadata={
|
||||
"help": "detokenize before computing BLEU (e.g., 'moses'); required if using --eval-bleu; "
|
||||
"use 'space' to disable detokenization; see fairseq.data.encoders for other options"
|
||||
},
|
||||
)
|
||||
eval_bleu_detok_args: Optional[str] = field(
|
||||
default="{}",
|
||||
metadata={"help": "args for building the tokenizer, if needed, as JSON string"},
|
||||
)
|
||||
eval_tokenized_bleu: bool = field(
|
||||
default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"}
|
||||
)
|
||||
eval_bleu_remove_bpe: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "remove BPE before computing BLEU",
|
||||
"argparse_const": "@@ ",
|
||||
},
|
||||
)
|
||||
eval_bleu_print_samples: bool = field(
|
||||
default=False, metadata={"help": "print sample generations during validation"}
|
||||
)
|
||||
|
||||
|
||||
@register_task("translation", dataclass=TranslationConfig)
|
||||
class TranslationTask(FairseqTask):
|
||||
"""
|
||||
Translate from one (source) language to another (target) language.
|
||||
|
||||
Args:
|
||||
src_dict (~fairseq.data.Dictionary): dictionary for the source language
|
||||
tgt_dict (~fairseq.data.Dictionary): dictionary for the target language
|
||||
|
||||
.. note::
|
||||
|
||||
The translation task is compatible with :mod:`fairseq-train`,
|
||||
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
|
||||
"""
|
||||
|
||||
cfg: TranslationConfig
|
||||
|
||||
def __init__(self, cfg: TranslationConfig, src_dict, tgt_dict):
|
||||
super().__init__(cfg)
|
||||
self.src_dict = src_dict
|
||||
self.tgt_dict = tgt_dict
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, cfg: TranslationConfig, **kwargs):
|
||||
"""Setup the task (e.g., load dictionaries).
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
"""
|
||||
|
||||
paths = utils.split_paths(cfg.data)
|
||||
assert len(paths) > 0
|
||||
# find language pair automatically
|
||||
if cfg.source_lang is None or cfg.target_lang is None:
|
||||
cfg.source_lang, cfg.target_lang = data_utils.infer_language_pair(paths[0])
|
||||
if cfg.source_lang is None or cfg.target_lang is None:
|
||||
raise Exception(
|
||||
"Could not infer language pair, please provide it explicitly"
|
||||
)
|
||||
|
||||
# load dictionaries
|
||||
src_dict = cls.load_dictionary(
|
||||
os.path.join(paths[0], "dict.{}.txt".format(cfg.source_lang))
|
||||
)
|
||||
tgt_dict = cls.load_dictionary(
|
||||
os.path.join(paths[0], "dict.{}.txt".format(cfg.target_lang))
|
||||
)
|
||||
assert src_dict.pad() == tgt_dict.pad()
|
||||
assert src_dict.eos() == tgt_dict.eos()
|
||||
assert src_dict.unk() == tgt_dict.unk()
|
||||
logger.info("[{}] dictionary: {} types".format(cfg.source_lang, len(src_dict)))
|
||||
logger.info("[{}] dictionary: {} types".format(cfg.target_lang, len(tgt_dict)))
|
||||
|
||||
return cls(cfg, src_dict, tgt_dict)
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
paths = utils.split_paths(self.cfg.data)
|
||||
assert len(paths) > 0
|
||||
if split != self.cfg.train_subset:
|
||||
# if not training data set, use the first shard for valid and test
|
||||
paths = paths[:1]
|
||||
data_path = paths[(epoch - 1) % len(paths)]
|
||||
|
||||
# infer langcode
|
||||
src, tgt = self.cfg.source_lang, self.cfg.target_lang
|
||||
|
||||
self.datasets[split] = load_langpair_dataset(
|
||||
data_path,
|
||||
split,
|
||||
src,
|
||||
self.src_dict,
|
||||
tgt,
|
||||
self.tgt_dict,
|
||||
combine=combine,
|
||||
dataset_impl=self.cfg.dataset_impl,
|
||||
upsample_primary=self.cfg.upsample_primary,
|
||||
left_pad_source=self.cfg.left_pad_source,
|
||||
left_pad_target=self.cfg.left_pad_target,
|
||||
max_source_positions=self.cfg.max_source_positions,
|
||||
max_target_positions=self.cfg.max_target_positions,
|
||||
load_alignments=self.cfg.load_alignments,
|
||||
truncate_source=self.cfg.truncate_source,
|
||||
num_buckets=self.cfg.num_batch_buckets,
|
||||
shuffle=(split != "test"),
|
||||
pad_to_multiple=self.cfg.required_seq_len_multiple,
|
||||
)
|
||||
|
||||
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
|
||||
return LanguagePairDataset(
|
||||
src_tokens,
|
||||
src_lengths,
|
||||
self.source_dictionary,
|
||||
tgt_dict=self.target_dictionary,
|
||||
constraints=constraints,
|
||||
)
|
||||
|
||||
def build_model(self, cfg, from_checkpoint=False):
|
||||
model = super().build_model(cfg, from_checkpoint)
|
||||
if self.cfg.eval_bleu:
|
||||
detok_args = json.loads(self.cfg.eval_bleu_detok_args)
|
||||
self.tokenizer = encoders.build_tokenizer(
|
||||
Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args)
|
||||
)
|
||||
|
||||
gen_args = json.loads(self.cfg.eval_bleu_args)
|
||||
self.sequence_generator = self.build_generator(
|
||||
[model], Namespace(**gen_args)
|
||||
)
|
||||
return model
|
||||
|
||||
def valid_step(self, sample, model, criterion):
|
||||
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
|
||||
if self.cfg.eval_bleu:
|
||||
bleu = self._inference_with_bleu(self.sequence_generator, sample, model)
|
||||
logging_output["_bleu_sys_len"] = bleu.sys_len
|
||||
logging_output["_bleu_ref_len"] = bleu.ref_len
|
||||
# we split counts into separate entries so that they can be
|
||||
# summed efficiently across workers using fast-stat-sync
|
||||
assert len(bleu.counts) == EVAL_BLEU_ORDER
|
||||
for i in range(EVAL_BLEU_ORDER):
|
||||
logging_output["_bleu_counts_" + str(i)] = bleu.counts[i]
|
||||
logging_output["_bleu_totals_" + str(i)] = bleu.totals[i]
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def reduce_metrics(self, logging_outputs, criterion):
|
||||
super().reduce_metrics(logging_outputs, criterion)
|
||||
if self.cfg.eval_bleu:
|
||||
|
||||
def sum_logs(key):
|
||||
import torch
|
||||
|
||||
result = sum(log.get(key, 0) for log in logging_outputs)
|
||||
if torch.is_tensor(result):
|
||||
result = result.cpu()
|
||||
return result
|
||||
|
||||
counts, totals = [], []
|
||||
for i in range(EVAL_BLEU_ORDER):
|
||||
counts.append(sum_logs("_bleu_counts_" + str(i)))
|
||||
totals.append(sum_logs("_bleu_totals_" + str(i)))
|
||||
|
||||
if max(totals) > 0:
|
||||
# log counts as numpy arrays -- log_scalar will sum them correctly
|
||||
metrics.log_scalar("_bleu_counts", np.array(counts))
|
||||
metrics.log_scalar("_bleu_totals", np.array(totals))
|
||||
metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len"))
|
||||
metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len"))
|
||||
|
||||
def compute_bleu(meters):
|
||||
import inspect
|
||||
|
||||
try:
|
||||
from sacrebleu.metrics import BLEU
|
||||
|
||||
comp_bleu = BLEU.compute_bleu
|
||||
except ImportError:
|
||||
# compatibility API for sacrebleu 1.x
|
||||
import sacrebleu
|
||||
|
||||
comp_bleu = sacrebleu.compute_bleu
|
||||
|
||||
fn_sig = inspect.getfullargspec(comp_bleu)[0]
|
||||
if "smooth_method" in fn_sig:
|
||||
smooth = {"smooth_method": "exp"}
|
||||
else:
|
||||
smooth = {"smooth": "exp"}
|
||||
bleu = comp_bleu(
|
||||
correct=meters["_bleu_counts"].sum,
|
||||
total=meters["_bleu_totals"].sum,
|
||||
sys_len=int(meters["_bleu_sys_len"].sum),
|
||||
ref_len=int(meters["_bleu_ref_len"].sum),
|
||||
**smooth,
|
||||
)
|
||||
return round(bleu.score, 2)
|
||||
|
||||
metrics.log_derived("bleu", compute_bleu)
|
||||
|
||||
def max_positions(self):
|
||||
"""Return the max sentence length allowed by the task."""
|
||||
return (self.cfg.max_source_positions, self.cfg.max_target_positions)
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
"""Return the source :class:`~fairseq.data.Dictionary`."""
|
||||
return self.src_dict
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
"""Return the target :class:`~fairseq.data.Dictionary`."""
|
||||
return self.tgt_dict
|
||||
|
||||
def _inference_with_bleu(self, generator, sample, model):
|
||||
import sacrebleu
|
||||
|
||||
def decode(toks, escape_unk=False):
|
||||
s = self.tgt_dict.string(
|
||||
toks.int().cpu(),
|
||||
self.cfg.eval_bleu_remove_bpe,
|
||||
# The default unknown string in fairseq is `<unk>`, but
|
||||
# this is tokenized by sacrebleu as `< unk >`, inflating
|
||||
# BLEU scores. Instead, we use a somewhat more verbose
|
||||
# alternative that is unlikely to appear in the real
|
||||
# reference, but doesn't get split into multiple tokens.
|
||||
unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"),
|
||||
)
|
||||
if self.tokenizer:
|
||||
s = self.tokenizer.decode(s)
|
||||
return s
|
||||
|
||||
gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None)
|
||||
hyps, refs = [], []
|
||||
for i in range(len(gen_out)):
|
||||
hyps.append(decode(gen_out[i][0]["tokens"]))
|
||||
refs.append(
|
||||
decode(
|
||||
utils.strip_pad(sample["target"][i], self.tgt_dict.pad()),
|
||||
escape_unk=True, # don't count <unk> as matches to the hypo
|
||||
)
|
||||
)
|
||||
if self.cfg.eval_bleu_print_samples:
|
||||
logger.info("example hypothesis: " + hyps[0])
|
||||
logger.info("example reference: " + refs[0])
|
||||
if self.cfg.eval_tokenized_bleu:
|
||||
return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none")
|
||||
else:
|
||||
return sacrebleu.corpus_bleu(hyps, [refs])
|
||||
@@ -0,0 +1,132 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from fairseq import utils
|
||||
from fairseq.data import LanguagePairDataset
|
||||
|
||||
from . import register_task
|
||||
from .translation import TranslationTask, load_langpair_dataset
|
||||
|
||||
|
||||
@register_task("translation_from_pretrained_bart")
|
||||
class TranslationFromPretrainedBARTTask(TranslationTask):
|
||||
"""
|
||||
Translate from source language to target language with a model initialized with a multilingual pretrain.
|
||||
|
||||
Args:
|
||||
src_dict (~fairseq.data.Dictionary): dictionary for the source language
|
||||
tgt_dict (~fairseq.data.Dictionary): dictionary for the target language
|
||||
|
||||
.. note::
|
||||
|
||||
The translation task is compatible with :mod:`fairseq-train`,
|
||||
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
|
||||
|
||||
The translation task provides the following additional command-line
|
||||
arguments:
|
||||
|
||||
.. argparse::
|
||||
:ref: fairseq.tasks.translation_parser
|
||||
:prog:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
TranslationTask.add_args(parser)
|
||||
parser.add_argument('--langs', type=str, metavar='LANG',
|
||||
help='comma-separated list of monolingual language, '
|
||||
'for example, "en,de,fr". These should match the '
|
||||
'langs from pretraining (and be in the same order). '
|
||||
'You should always add all pretraining language idx '
|
||||
'during finetuning.')
|
||||
parser.add_argument('--prepend-bos', action='store_true',
|
||||
help='prepend bos token to each sentence, which matches '
|
||||
'mBART pretraining')
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, args, src_dict, tgt_dict):
|
||||
super().__init__(args, src_dict, tgt_dict)
|
||||
self.langs = args.langs.split(",")
|
||||
for d in [src_dict, tgt_dict]:
|
||||
for l in self.langs:
|
||||
d.add_symbol("[{}]".format(l))
|
||||
d.add_symbol("<mask>")
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
paths = utils.split_paths(self.args.data)
|
||||
assert len(paths) > 0
|
||||
data_path = paths[(epoch - 1) % len(paths)]
|
||||
|
||||
# infer langcode
|
||||
src, tgt = self.args.source_lang, self.args.target_lang
|
||||
|
||||
self.datasets[split] = load_langpair_dataset(
|
||||
data_path,
|
||||
split,
|
||||
src,
|
||||
self.src_dict,
|
||||
tgt,
|
||||
self.tgt_dict,
|
||||
combine=combine,
|
||||
dataset_impl=self.args.dataset_impl,
|
||||
upsample_primary=self.args.upsample_primary,
|
||||
left_pad_source=self.args.left_pad_source,
|
||||
left_pad_target=self.args.left_pad_target,
|
||||
max_source_positions=getattr(self.args, "max_source_positions", 1024),
|
||||
max_target_positions=getattr(self.args, "max_target_positions", 1024),
|
||||
load_alignments=self.args.load_alignments,
|
||||
prepend_bos=getattr(self.args, "prepend_bos", False),
|
||||
append_source_id=True,
|
||||
)
|
||||
|
||||
def build_generator(self, models, args, **unused):
|
||||
if getattr(args, "score_reference", False):
|
||||
from fairseq.sequence_scorer import SequenceScorer
|
||||
|
||||
return SequenceScorer(
|
||||
self.target_dictionary,
|
||||
eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)),
|
||||
)
|
||||
else:
|
||||
from fairseq.sequence_generator import SequenceGenerator
|
||||
|
||||
return SequenceGenerator(
|
||||
models,
|
||||
self.target_dictionary,
|
||||
beam_size=getattr(args, "beam", 5),
|
||||
max_len_a=getattr(args, "max_len_a", 0),
|
||||
max_len_b=getattr(args, "max_len_b", 200),
|
||||
min_len=getattr(args, "min_len", 1),
|
||||
normalize_scores=(not getattr(args, "unnormalized", False)),
|
||||
len_penalty=getattr(args, "lenpen", 1),
|
||||
unk_penalty=getattr(args, "unkpen", 0),
|
||||
temperature=getattr(args, "temperature", 1.0),
|
||||
match_source_len=getattr(args, "match_source_len", False),
|
||||
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
|
||||
eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)),
|
||||
)
|
||||
|
||||
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
|
||||
src_lang_id = self.source_dictionary.index("[{}]".format(self.args.source_lang))
|
||||
source_tokens = []
|
||||
for s_t in src_tokens:
|
||||
s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)])
|
||||
source_tokens.append(s_t)
|
||||
dataset = LanguagePairDataset(
|
||||
source_tokens,
|
||||
src_lengths,
|
||||
self.source_dictionary,
|
||||
tgt_dict=self.target_dictionary,
|
||||
constraints=constraints,
|
||||
)
|
||||
return dataset
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
|
||||
from fairseq.tasks.translation import TranslationConfig, TranslationTask
|
||||
|
||||
from . import register_task
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranslationFromPretrainedXLMConfig(TranslationConfig):
|
||||
pass
|
||||
|
||||
|
||||
@register_task(
|
||||
"translation_from_pretrained_xlm", dataclass=TranslationFromPretrainedXLMConfig
|
||||
)
|
||||
class TranslationFromPretrainedXLMTask(TranslationTask):
|
||||
"""
|
||||
Same as TranslationTask except use the MaskedLMDictionary class so that
|
||||
we can load data that was binarized with the MaskedLMDictionary class.
|
||||
|
||||
This task should be used for the entire training pipeline when we want to
|
||||
train an NMT model from a pretrained XLM checkpoint: binarizing NMT data,
|
||||
training NMT with the pretrained XLM checkpoint, and subsequent evaluation
|
||||
of that trained model.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def load_dictionary(cls, filename):
|
||||
"""Load the masked LM dictionary from the filename
|
||||
|
||||
Args:
|
||||
filename (str): the filename
|
||||
"""
|
||||
return MaskedLMDictionary.load(filename)
|
||||
195
modules/voice_conversion/fairseq/tasks/translation_lev.py
Normal file
195
modules/voice_conversion/fairseq/tasks/translation_lev.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import torch
|
||||
from fairseq import utils
|
||||
from fairseq.data import LanguagePairDataset
|
||||
from fairseq.dataclass import ChoiceEnum
|
||||
from fairseq.tasks import register_task
|
||||
from fairseq.tasks.translation import (
|
||||
TranslationConfig,
|
||||
TranslationTask,
|
||||
load_langpair_dataset,
|
||||
)
|
||||
from fairseq.utils import new_arange
|
||||
|
||||
|
||||
NOISE_CHOICES = ChoiceEnum(["random_delete", "random_mask", "no_noise", "full_mask"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranslationLevenshteinConfig(TranslationConfig):
|
||||
noise: NOISE_CHOICES = field(
|
||||
default="random_delete",
|
||||
metadata={"help": "type of noise"},
|
||||
)
|
||||
|
||||
|
||||
@register_task("translation_lev", dataclass=TranslationLevenshteinConfig)
|
||||
class TranslationLevenshteinTask(TranslationTask):
|
||||
"""
|
||||
Translation (Sequence Generation) task for Levenshtein Transformer
|
||||
See `"Levenshtein Transformer" <https://arxiv.org/abs/1905.11006>`_.
|
||||
"""
|
||||
|
||||
cfg: TranslationLevenshteinConfig
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
paths = utils.split_paths(self.cfg.data)
|
||||
assert len(paths) > 0
|
||||
data_path = paths[(epoch - 1) % len(paths)]
|
||||
|
||||
# infer langcode
|
||||
src, tgt = self.cfg.source_lang, self.cfg.target_lang
|
||||
|
||||
self.datasets[split] = load_langpair_dataset(
|
||||
data_path,
|
||||
split,
|
||||
src,
|
||||
self.src_dict,
|
||||
tgt,
|
||||
self.tgt_dict,
|
||||
combine=combine,
|
||||
dataset_impl=self.cfg.dataset_impl,
|
||||
upsample_primary=self.cfg.upsample_primary,
|
||||
left_pad_source=self.cfg.left_pad_source,
|
||||
left_pad_target=self.cfg.left_pad_target,
|
||||
max_source_positions=self.cfg.max_source_positions,
|
||||
max_target_positions=self.cfg.max_target_positions,
|
||||
prepend_bos=True,
|
||||
)
|
||||
|
||||
def inject_noise(self, target_tokens):
|
||||
def _random_delete(target_tokens):
|
||||
pad = self.tgt_dict.pad()
|
||||
bos = self.tgt_dict.bos()
|
||||
eos = self.tgt_dict.eos()
|
||||
|
||||
max_len = target_tokens.size(1)
|
||||
target_mask = target_tokens.eq(pad)
|
||||
target_score = target_tokens.clone().float().uniform_()
|
||||
target_score.masked_fill_(
|
||||
target_tokens.eq(bos) | target_tokens.eq(eos), 0.0
|
||||
)
|
||||
target_score.masked_fill_(target_mask, 1)
|
||||
target_score, target_rank = target_score.sort(1)
|
||||
target_length = target_mask.size(1) - target_mask.float().sum(
|
||||
1, keepdim=True
|
||||
)
|
||||
|
||||
# do not delete <bos> and <eos> (we assign 0 score for them)
|
||||
target_cutoff = (
|
||||
2
|
||||
+ (
|
||||
(target_length - 2)
|
||||
* target_score.new_zeros(target_score.size(0), 1).uniform_()
|
||||
).long()
|
||||
)
|
||||
target_cutoff = target_score.sort(1)[1] >= target_cutoff
|
||||
|
||||
prev_target_tokens = (
|
||||
target_tokens.gather(1, target_rank)
|
||||
.masked_fill_(target_cutoff, pad)
|
||||
.gather(1, target_rank.masked_fill_(target_cutoff, max_len).sort(1)[1])
|
||||
)
|
||||
prev_target_tokens = prev_target_tokens[
|
||||
:, : prev_target_tokens.ne(pad).sum(1).max()
|
||||
]
|
||||
|
||||
return prev_target_tokens
|
||||
|
||||
def _random_mask(target_tokens):
|
||||
pad = self.tgt_dict.pad()
|
||||
bos = self.tgt_dict.bos()
|
||||
eos = self.tgt_dict.eos()
|
||||
unk = self.tgt_dict.unk()
|
||||
|
||||
target_masks = (
|
||||
target_tokens.ne(pad) & target_tokens.ne(bos) & target_tokens.ne(eos)
|
||||
)
|
||||
target_score = target_tokens.clone().float().uniform_()
|
||||
target_score.masked_fill_(~target_masks, 2.0)
|
||||
target_length = target_masks.sum(1).float()
|
||||
target_length = target_length * target_length.clone().uniform_()
|
||||
target_length = target_length + 1 # make sure to mask at least one token.
|
||||
|
||||
_, target_rank = target_score.sort(1)
|
||||
target_cutoff = new_arange(target_rank) < target_length[:, None].long()
|
||||
prev_target_tokens = target_tokens.masked_fill(
|
||||
target_cutoff.scatter(1, target_rank, target_cutoff), unk
|
||||
)
|
||||
return prev_target_tokens
|
||||
|
||||
def _full_mask(target_tokens):
|
||||
pad = self.tgt_dict.pad()
|
||||
bos = self.tgt_dict.bos()
|
||||
eos = self.tgt_dict.eos()
|
||||
unk = self.tgt_dict.unk()
|
||||
|
||||
target_mask = (
|
||||
target_tokens.eq(bos) | target_tokens.eq(eos) | target_tokens.eq(pad)
|
||||
)
|
||||
return target_tokens.masked_fill(~target_mask, unk)
|
||||
|
||||
if self.cfg.noise == "random_delete":
|
||||
return _random_delete(target_tokens)
|
||||
elif self.cfg.noise == "random_mask":
|
||||
return _random_mask(target_tokens)
|
||||
elif self.cfg.noise == "full_mask":
|
||||
return _full_mask(target_tokens)
|
||||
elif self.cfg.noise == "no_noise":
|
||||
return target_tokens
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def build_generator(self, models, args, **unused):
|
||||
# add models input to match the API for SequenceGenerator
|
||||
from fairseq.iterative_refinement_generator import IterativeRefinementGenerator
|
||||
|
||||
return IterativeRefinementGenerator(
|
||||
self.target_dictionary,
|
||||
eos_penalty=getattr(args, "iter_decode_eos_penalty", 0.0),
|
||||
max_iter=getattr(args, "iter_decode_max_iter", 10),
|
||||
beam_size=getattr(args, "iter_decode_with_beam", 1),
|
||||
reranking=getattr(args, "iter_decode_with_external_reranker", False),
|
||||
decoding_format=getattr(args, "decoding_format", None),
|
||||
adaptive=not getattr(args, "iter_decode_force_max_iter", False),
|
||||
retain_history=getattr(args, "retain_iter_history", False),
|
||||
)
|
||||
|
||||
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
|
||||
if constraints is not None:
|
||||
# Though see Susanto et al. (ACL 2020): https://www.aclweb.org/anthology/2020.acl-main.325/
|
||||
raise NotImplementedError(
|
||||
"Constrained decoding with the translation_lev task is not supported"
|
||||
)
|
||||
|
||||
return LanguagePairDataset(
|
||||
src_tokens, src_lengths, self.source_dictionary, append_bos=True
|
||||
)
|
||||
|
||||
def train_step(
|
||||
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
|
||||
):
|
||||
model.train()
|
||||
sample["prev_target"] = self.inject_noise(sample["target"])
|
||||
loss, sample_size, logging_output = criterion(model, sample)
|
||||
if ignore_grad:
|
||||
loss *= 0
|
||||
optimizer.backward(loss)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def valid_step(self, sample, model, criterion):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
sample["prev_target"] = self.inject_noise(sample["target"])
|
||||
loss, sample_size, logging_output = criterion(model, sample)
|
||||
return loss, sample_size, logging_output
|
||||
@@ -0,0 +1,441 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
import torch
|
||||
from fairseq.data import (
|
||||
FairseqDataset,
|
||||
LanguagePairDataset,
|
||||
ListDataset,
|
||||
data_utils,
|
||||
iterators,
|
||||
)
|
||||
from fairseq.data.multilingual.multilingual_data_manager import (
|
||||
MultilingualDatasetManager,
|
||||
)
|
||||
from fairseq.data.multilingual.sampling_method import SamplingMethod
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
from fairseq.utils import FileContentsAction
|
||||
|
||||
|
||||
###
|
||||
def get_time_gap(s, e):
|
||||
return (
|
||||
datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)
|
||||
).__str__()
|
||||
|
||||
|
||||
###
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task("translation_multi_simple_epoch")
|
||||
class TranslationMultiSimpleEpochTask(LegacyFairseqTask):
|
||||
"""
|
||||
Translate from one (source) language to another (target) language.
|
||||
|
||||
Args:
|
||||
langs (List[str]): a list of languages that are being supported
|
||||
dicts (Dict[str, fairseq.data.Dictionary]): mapping from supported languages to their dictionaries
|
||||
training (bool): whether the task should be configured for training or not
|
||||
|
||||
.. note::
|
||||
|
||||
The translation task is compatible with :mod:`fairseq-train`,
|
||||
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
|
||||
|
||||
The translation task provides the following additional command-line
|
||||
arguments:
|
||||
|
||||
.. argparse::
|
||||
:ref: fairseq.tasks.translation_parser
|
||||
:prog:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
|
||||
help='inference source language')
|
||||
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
|
||||
help='inference target language')
|
||||
parser.add_argument('--lang-pairs', default=None, metavar='PAIRS',
|
||||
help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr',
|
||||
action=FileContentsAction)
|
||||
parser.add_argument('--keep-inference-langtok', action='store_true',
|
||||
help='keep language tokens in inference output (e.g. for analysis or debugging)')
|
||||
|
||||
SamplingMethod.add_arguments(parser)
|
||||
MultilingualDatasetManager.add_args(parser)
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, args, langs, dicts, training):
|
||||
super().__init__(args)
|
||||
self.langs = langs
|
||||
self.dicts = dicts
|
||||
self.training = training
|
||||
if training:
|
||||
self.lang_pairs = args.lang_pairs
|
||||
else:
|
||||
self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)]
|
||||
# eval_lang_pairs for multilingual translation is usually all of the
|
||||
# lang_pairs. However for other multitask settings or when we want to
|
||||
# optimize for certain languages we want to use a different subset. Thus
|
||||
# the eval_lang_pairs class variable is provided for classes that extend
|
||||
# this class.
|
||||
self.eval_lang_pairs = self.lang_pairs
|
||||
# model_lang_pairs will be used to build encoder-decoder model pairs in
|
||||
# models.build_model(). This allows multitask type of sub-class can
|
||||
# build models other than the input lang_pairs
|
||||
self.model_lang_pairs = self.lang_pairs
|
||||
self.source_langs = [d.split("-")[0] for d in self.lang_pairs]
|
||||
self.target_langs = [d.split("-")[1] for d in self.lang_pairs]
|
||||
self.check_dicts(self.dicts, self.source_langs, self.target_langs)
|
||||
|
||||
self.sampling_method = SamplingMethod.build_sampler(args, self)
|
||||
self.data_manager = MultilingualDatasetManager.setup_data_manager(
|
||||
args, self.lang_pairs, langs, dicts, self.sampling_method
|
||||
)
|
||||
|
||||
def check_dicts(self, dicts, source_langs, target_langs):
|
||||
if self.args.source_dict is not None or self.args.target_dict is not None:
|
||||
# no need to check whether the source side and target side are sharing dictionaries
|
||||
return
|
||||
src_dict = dicts[source_langs[0]]
|
||||
tgt_dict = dicts[target_langs[0]]
|
||||
for src_lang in source_langs:
|
||||
assert (
|
||||
src_dict == dicts[src_lang]
|
||||
), "Diffrent dictionary are specified for different source languages; "
|
||||
"TranslationMultiSimpleEpochTask only supports one shared dictionary across all source languages"
|
||||
for tgt_lang in target_langs:
|
||||
assert (
|
||||
tgt_dict == dicts[tgt_lang]
|
||||
), "Diffrent dictionary are specified for different target languages; "
|
||||
"TranslationMultiSimpleEpochTask only supports one shared dictionary across all target languages"
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
langs, dicts, training = MultilingualDatasetManager.prepare(
|
||||
cls.load_dictionary, args, **kwargs
|
||||
)
|
||||
return cls(args, langs, dicts, training)
|
||||
|
||||
def has_sharded_data(self, split):
|
||||
return self.data_manager.has_sharded_data(split)
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
if split in self.datasets:
|
||||
dataset = self.datasets[split]
|
||||
if self.has_sharded_data(split):
|
||||
if self.args.virtual_epoch_size is not None:
|
||||
if dataset.load_next_shard:
|
||||
shard_epoch = dataset.shard_epoch
|
||||
else:
|
||||
# no need to load next shard so skip loading
|
||||
# also this avoid always loading from beginning of the data
|
||||
return
|
||||
else:
|
||||
shard_epoch = epoch
|
||||
else:
|
||||
# estimate the shard epoch from virtual data size and virtual epoch size
|
||||
shard_epoch = self.data_manager.estimate_global_pass_epoch(epoch)
|
||||
logger.info(f"loading data for {split} epoch={epoch}/{shard_epoch}")
|
||||
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
|
||||
if split in self.datasets:
|
||||
del self.datasets[split]
|
||||
logger.info("old dataset deleted manually")
|
||||
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
|
||||
self.datasets[split] = self.data_manager.load_dataset(
|
||||
split,
|
||||
self.training,
|
||||
epoch=epoch,
|
||||
combine=combine,
|
||||
shard_epoch=shard_epoch,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
|
||||
if constraints is not None:
|
||||
raise NotImplementedError(
|
||||
"Constrained decoding with the multilingual_translation task is not supported"
|
||||
)
|
||||
|
||||
src_data = ListDataset(src_tokens, src_lengths)
|
||||
dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary)
|
||||
src_langtok_spec, tgt_langtok_spec = self.args.langtoks["main"]
|
||||
if self.args.lang_tok_replacing_bos_eos:
|
||||
dataset = self.data_manager.alter_dataset_langtok(
|
||||
dataset,
|
||||
src_eos=self.source_dictionary.eos(),
|
||||
src_lang=self.args.source_lang,
|
||||
tgt_eos=self.target_dictionary.eos(),
|
||||
tgt_lang=self.args.target_lang,
|
||||
src_langtok_spec=src_langtok_spec,
|
||||
tgt_langtok_spec=tgt_langtok_spec,
|
||||
)
|
||||
else:
|
||||
dataset.src = self.data_manager.src_dataset_tranform_func(
|
||||
self.args.source_lang,
|
||||
self.args.target_lang,
|
||||
dataset=dataset.src,
|
||||
spec=src_langtok_spec,
|
||||
)
|
||||
return dataset
|
||||
|
||||
def build_generator(
|
||||
self,
|
||||
models,
|
||||
args,
|
||||
seq_gen_cls=None,
|
||||
extra_gen_cls_kwargs=None,
|
||||
):
|
||||
if not getattr(args, "keep_inference_langtok", False):
|
||||
_, tgt_langtok_spec = self.args.langtoks["main"]
|
||||
if tgt_langtok_spec:
|
||||
tgt_lang_tok = self.data_manager.get_decoder_langtok(
|
||||
self.args.target_lang, tgt_langtok_spec
|
||||
)
|
||||
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
|
||||
extra_gen_cls_kwargs["symbols_to_strip_from_output"] = {tgt_lang_tok}
|
||||
|
||||
return super().build_generator(
|
||||
models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs
|
||||
)
|
||||
|
||||
def build_model(self, args, from_checkpoint=False):
|
||||
return super().build_model(args, from_checkpoint)
|
||||
|
||||
def valid_step(self, sample, model, criterion):
|
||||
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def inference_step(
|
||||
self, generator, models, sample, prefix_tokens=None, constraints=None
|
||||
):
|
||||
with torch.no_grad():
|
||||
_, tgt_langtok_spec = self.args.langtoks["main"]
|
||||
if not self.args.lang_tok_replacing_bos_eos:
|
||||
if prefix_tokens is None and tgt_langtok_spec:
|
||||
tgt_lang_tok = self.data_manager.get_decoder_langtok(
|
||||
self.args.target_lang, tgt_langtok_spec
|
||||
)
|
||||
src_tokens = sample["net_input"]["src_tokens"]
|
||||
bsz = src_tokens.size(0)
|
||||
prefix_tokens = (
|
||||
torch.LongTensor([[tgt_lang_tok]]).expand(bsz, 1).to(src_tokens)
|
||||
)
|
||||
return generator.generate(
|
||||
models,
|
||||
sample,
|
||||
prefix_tokens=prefix_tokens,
|
||||
constraints=constraints,
|
||||
)
|
||||
else:
|
||||
return generator.generate(
|
||||
models,
|
||||
sample,
|
||||
prefix_tokens=prefix_tokens,
|
||||
bos_token=self.data_manager.get_decoder_langtok(
|
||||
self.args.target_lang, tgt_langtok_spec
|
||||
)
|
||||
if tgt_langtok_spec
|
||||
else self.target_dictionary.eos(),
|
||||
)
|
||||
|
||||
def reduce_metrics(self, logging_outputs, criterion):
|
||||
super().reduce_metrics(logging_outputs, criterion)
|
||||
|
||||
def max_positions(self):
|
||||
"""Return the max sentence length allowed by the task."""
|
||||
return (self.args.max_source_positions, self.args.max_target_positions)
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return self.data_manager.get_source_dictionary(self.source_langs[0])
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.data_manager.get_target_dictionary(self.target_langs[0])
|
||||
|
||||
def create_batch_sampler_func(
|
||||
self,
|
||||
max_positions,
|
||||
ignore_invalid_inputs,
|
||||
max_tokens,
|
||||
max_sentences,
|
||||
required_batch_size_multiple=1,
|
||||
seed=1,
|
||||
):
|
||||
def construct_batch_sampler(dataset, epoch):
|
||||
splits = [
|
||||
s for s, _ in self.datasets.items() if self.datasets[s] == dataset
|
||||
]
|
||||
split = splits[0] if len(splits) > 0 else None
|
||||
# NEW implementation
|
||||
if epoch is not None:
|
||||
# initialize the dataset with the correct starting epoch
|
||||
dataset.set_epoch(epoch)
|
||||
|
||||
# get indices ordered by example size
|
||||
start_time = time.time()
|
||||
logger.info(f"start batch sampler: mem usage: {data_utils.get_mem_usage()}")
|
||||
|
||||
with data_utils.numpy_seed(seed):
|
||||
indices = dataset.ordered_indices()
|
||||
logger.info(
|
||||
f"[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}"
|
||||
)
|
||||
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
|
||||
|
||||
# filter examples that are too large
|
||||
if max_positions is not None:
|
||||
my_time = time.time()
|
||||
indices = self.filter_indices_by_size(
|
||||
indices, dataset, max_positions, ignore_invalid_inputs
|
||||
)
|
||||
logger.info(
|
||||
f"[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}"
|
||||
)
|
||||
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
|
||||
|
||||
# create mini-batches with given size constraints
|
||||
my_time = time.time()
|
||||
batch_sampler = dataset.batch_by_size(
|
||||
indices,
|
||||
max_tokens=max_tokens,
|
||||
max_sentences=max_sentences,
|
||||
required_batch_size_multiple=required_batch_size_multiple,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}"
|
||||
)
|
||||
logger.info(
|
||||
f"[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}"
|
||||
)
|
||||
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
|
||||
|
||||
return batch_sampler
|
||||
|
||||
return construct_batch_sampler
|
||||
|
||||
# we need to override get_batch_iterator because we want to reset the epoch iterator each time
|
||||
def get_batch_iterator(
|
||||
self,
|
||||
dataset,
|
||||
max_tokens=None,
|
||||
max_sentences=None,
|
||||
max_positions=None,
|
||||
ignore_invalid_inputs=False,
|
||||
required_batch_size_multiple=1,
|
||||
seed=1,
|
||||
num_shards=1,
|
||||
shard_id=0,
|
||||
num_workers=0,
|
||||
epoch=1,
|
||||
data_buffer_size=0,
|
||||
disable_iterator_cache=False,
|
||||
skip_remainder_batch=False,
|
||||
grouped_shuffling=False,
|
||||
update_epoch_batch_itr=False,
|
||||
):
|
||||
"""
|
||||
Get an iterator that yields batches of data from the given dataset.
|
||||
|
||||
Args:
|
||||
dataset (~fairseq.data.FairseqDataset): dataset to batch
|
||||
max_tokens (int, optional): max number of tokens in each batch
|
||||
(default: None).
|
||||
max_sentences (int, optional): max number of sentences in each
|
||||
batch (default: None).
|
||||
max_positions (optional): max sentence length supported by the
|
||||
model (default: None).
|
||||
ignore_invalid_inputs (bool, optional): don't raise Exception for
|
||||
sentences that are too long (default: False).
|
||||
required_batch_size_multiple (int, optional): require batch size to
|
||||
be a multiple of N (default: 1).
|
||||
seed (int, optional): seed for random number generator for
|
||||
reproducibility (default: 1).
|
||||
num_shards (int, optional): shard the data iterator into N
|
||||
shards (default: 1).
|
||||
shard_id (int, optional): which shard of the data iterator to
|
||||
return (default: 0).
|
||||
num_workers (int, optional): how many subprocesses to use for data
|
||||
loading. 0 means the data will be loaded in the main process
|
||||
(default: 0).
|
||||
epoch (int, optional): the epoch to start the iterator from
|
||||
(default: 0).
|
||||
data_buffer_size (int, optional): number of batches to
|
||||
preload (default: 0).
|
||||
disable_iterator_cache (bool, optional): don't cache the
|
||||
EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`)
|
||||
(default: False).
|
||||
grouped_shuffling (bool, optional): group batches with each groups
|
||||
containing num_shards batches and shuffle groups. Reduces difference
|
||||
between sequence lengths among workers for batches sorted by length.
|
||||
update_epoch_batch_itr (bool optional): if true then donot use the cached
|
||||
batch iterator for the epoch
|
||||
|
||||
Returns:
|
||||
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
|
||||
given dataset split
|
||||
"""
|
||||
# initialize the dataset with the correct starting epoch
|
||||
assert isinstance(dataset, FairseqDataset)
|
||||
if dataset in self.dataset_to_epoch_iter:
|
||||
return self.dataset_to_epoch_iter[dataset]
|
||||
if self.args.sampling_method == "RoundRobin":
|
||||
batch_iter = super().get_batch_iterator(
|
||||
dataset,
|
||||
max_tokens=max_tokens,
|
||||
max_sentences=max_sentences,
|
||||
max_positions=max_positions,
|
||||
ignore_invalid_inputs=ignore_invalid_inputs,
|
||||
required_batch_size_multiple=required_batch_size_multiple,
|
||||
seed=seed,
|
||||
num_shards=num_shards,
|
||||
shard_id=shard_id,
|
||||
num_workers=num_workers,
|
||||
epoch=epoch,
|
||||
data_buffer_size=data_buffer_size,
|
||||
disable_iterator_cache=disable_iterator_cache,
|
||||
skip_remainder_batch=skip_remainder_batch,
|
||||
update_epoch_batch_itr=update_epoch_batch_itr,
|
||||
)
|
||||
self.dataset_to_epoch_iter[dataset] = batch_iter
|
||||
return batch_iter
|
||||
|
||||
construct_batch_sampler = self.create_batch_sampler_func(
|
||||
max_positions,
|
||||
ignore_invalid_inputs,
|
||||
max_tokens,
|
||||
max_sentences,
|
||||
required_batch_size_multiple=required_batch_size_multiple,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
epoch_iter = iterators.EpochBatchIterator(
|
||||
dataset=dataset,
|
||||
collate_fn=dataset.collater,
|
||||
batch_sampler=construct_batch_sampler,
|
||||
seed=seed,
|
||||
num_shards=num_shards,
|
||||
shard_id=shard_id,
|
||||
num_workers=num_workers,
|
||||
epoch=epoch,
|
||||
)
|
||||
return epoch_iter
|
||||
Reference in New Issue
Block a user