Add monkey patched fairseq package to run on python 3.11 (what is needed for our use of RVC at least)

This commit is contained in:
Tony Ribeiro
2023-08-10 02:58:52 +02:00
parent 28024c5649
commit 60a8e5c9c6
465 changed files with 95671 additions and 0 deletions

View File

@@ -0,0 +1,45 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import os
import sys
try:
from .version import __version__ # noqa
except ImportError:
version_txt = os.path.join(os.path.dirname(__file__), "version.txt")
with open(version_txt) as f:
__version__ = f.read().strip()
__all__ = ["pdb"]
# backwards compatibility to support `from fairseq.X import Y`
from fairseq.distributed import utils as distributed_utils
from fairseq.logging import meters, metrics, progress_bar # noqa
sys.modules["fairseq.distributed_utils"] = distributed_utils
sys.modules["fairseq.meters"] = meters
sys.modules["fairseq.metrics"] = metrics
sys.modules["fairseq.progress_bar"] = progress_bar
# initialize hydra
from fairseq.dataclass.initialize import hydra_init
#hydra_init()
import fairseq.criterions # noqa
import fairseq.distributed # noqa
import fairseq.models # noqa
import fairseq.modules # noqa
import fairseq.optim # noqa
import fairseq.optim.lr_scheduler # noqa
import fairseq.pdb # noqa
import fairseq.scoring # noqa
import fairseq.tasks # noqa
import fairseq.token_generation_constraints # noqa
import fairseq.benchmark # noqa
import fairseq.model_parallel # noqa

View File

@@ -0,0 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# import models/tasks to register them
from . import dummy_dataset, dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa

View File

@@ -0,0 +1,172 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import random
import torch
from torch.utils import benchmark
from fairseq.modules.multihead_attention import MultiheadAttention
BATCH = [20, 41, 97]
SEQ = 64
EMB = 48
HEADS = 4
DROP = 0.1
DEVICE = torch.device("cuda")
ATTN_MASK_DTYPE = [torch.uint8, torch.bool, torch.float]
KEY_PADDING_MASK_DTYPE = [torch.uint8, torch.bool]
def _reset_seeds():
torch.manual_seed(0)
random.seed(0)
def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int):
if to_dtype == torch.float:
mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool)
return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf"))
return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype)
def benchmark_multihead_attention(
label="",
attn_dtype=torch.uint8,
key_padding_dtype=torch.uint8,
add_bias_kv=False,
add_zero_attn=False,
static_kv=False,
batch_size=20,
embedding=EMB,
seq_len=SEQ,
num_heads=HEADS,
):
results = []
# device = torch.device("cuda")
xformers_att_config = '{"name": "scaled_dot_product"}'
attn_mask = _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len)
key_padding_mask = _get_mask(
to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len
)
q = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
k = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
v = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
_reset_seeds()
original_mha = MultiheadAttention(
embedding,
num_heads,
dropout=0.0,
xformers_att_config=None,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
xformers_mha = MultiheadAttention(
embedding,
num_heads,
dropout=0.0,
xformers_att_config=xformers_att_config,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
def original_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
original_mha(
query=q,
key=k,
value=v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
def xformers_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
xformers_mha(
query=q,
key=k,
value=v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
def original_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
output, _ = original_mha(
query=q,
key=k,
value=v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
loss = torch.norm(output)
loss.backward()
def xformers_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
output, _ = xformers_mha(
query=q,
key=k,
value=v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
loss = torch.norm(output)
loss.backward()
fns = [
original_bench_fw,
xformers_bench_fw,
original_bench_fw_bw,
xformers_bench_fw_bw,
]
for fn in fns:
results.append(
benchmark.Timer(
stmt="fn(q, k, v, key_padding_mask, attn_mask, static_kv)",
globals={
"q": q,
"k": k,
"v": v,
"key_padding_mask": key_padding_mask,
"attn_mask": attn_mask,
"static_kv": static_kv,
"fn": fn,
},
label="multihead fw + bw",
sub_label=f"{fn.__name__}",
description=label,
).blocked_autorange(min_run_time=1)
)
compare = benchmark.Compare(results)
compare.print()
def run_benchmarks():
for attn_dtype, key_padding_dtype, add_bias_kv, add_zero_attn in itertools.product(
ATTN_MASK_DTYPE, KEY_PADDING_MASK_DTYPE, [True, False], [True, False]
):
label = f"attn_dtype {attn_dtype}, key_padding_dtype {key_padding_dtype}, \
add_bias_kv {add_bias_kv}, add_zero_attn {add_zero_attn}"
benchmark_multihead_attention(
label=label,
attn_dtype=attn_dtype,
key_padding_dtype=key_padding_dtype,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
run_benchmarks()

View File

@@ -0,0 +1,36 @@
import numpy as np
from fairseq.data import FairseqDataset
class DummyDataset(FairseqDataset):
def __init__(self, batch, num_items, item_size):
super().__init__()
self.batch = batch
self.num_items = num_items
self.item_size = item_size
def __getitem__(self, index):
return index
def __len__(self):
return self.num_items
def collater(self, samples):
return self.batch
@property
def sizes(self):
return np.array([self.item_size] * self.num_items)
def num_tokens(self, index):
return self.item_size
def size(self, index):
return self.item_size
def ordered_indices(self):
return np.arange(self.num_items)
@property
def supports_prefetch(self):
return False

View File

@@ -0,0 +1,83 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass, field
from typing import Optional
import torch
from .dummy_dataset import DummyDataset
from fairseq.data import Dictionary
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from omegaconf import II
logger = logging.getLogger(__name__)
@dataclass
class DummyLMConfig(FairseqDataclass):
dict_size: int = 49996
dataset_size: int = 100000
tokens_per_sample: int = field(
default=512, metadata={"help": "max sequence length"}
)
add_bos_token: bool = False
batch_size: Optional[int] = II("dataset.batch_size")
max_tokens: Optional[int] = II("dataset.max_tokens")
max_target_positions: int = II("task.tokens_per_sample")
@register_task("dummy_lm", dataclass=DummyLMConfig)
class DummyLMTask(FairseqTask):
def __init__(self, cfg: DummyLMConfig):
super().__init__(cfg)
# load dictionary
self.dictionary = Dictionary()
for i in range(cfg.dict_size):
self.dictionary.add_symbol("word{}".format(i))
self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8
logger.info("dictionary: {} types".format(len(self.dictionary)))
seq = torch.arange(cfg.tokens_per_sample + 1) + self.dictionary.pad() + 1
self.dummy_src = seq[:-1]
self.dummy_tgt = seq[1:]
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if self.cfg.batch_size is not None:
bsz = self.cfg.batch_size
else:
bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample)
self.datasets[split] = DummyDataset(
{
"id": 1,
"net_input": {
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
"src_lengths": torch.full(
(bsz,), self.cfg.tokens_per_sample, dtype=torch.long
),
},
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
"nsentences": bsz,
"ntokens": bsz * self.cfg.tokens_per_sample,
},
num_items=self.cfg.dataset_size,
item_size=self.cfg.tokens_per_sample,
)
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary

View File

@@ -0,0 +1,94 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass, field
from typing import Optional
import torch
from omegaconf import II
from .dummy_dataset import DummyDataset
from fairseq.data import Dictionary
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
logger = logging.getLogger(__name__)
@dataclass
class DummyMaskedLMConfig(FairseqDataclass):
dict_size: int = 49996
dataset_size: int = 100000
tokens_per_sample: int = field(
default=512,
metadata={
"help": "max number of total tokens over all"
" segments per sample for BERT dataset"
},
)
batch_size: Optional[int] = II("dataset.batch_size")
max_tokens: Optional[int] = II("dataset.max_tokens")
max_target_positions: int = II("task.tokens_per_sample")
@register_task("dummy_masked_lm", dataclass=DummyMaskedLMConfig)
class DummyMaskedLMTask(FairseqTask):
def __init__(self, cfg: DummyMaskedLMConfig):
super().__init__(cfg)
self.dictionary = Dictionary()
for i in range(cfg.dict_size):
self.dictionary.add_symbol("word{}".format(i))
logger.info("dictionary: {} types".format(len(self.dictionary)))
# add mask token
self.mask_idx = self.dictionary.add_symbol("<mask>")
self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8
mask_idx = 0
pad_idx = 1
seq = torch.arange(cfg.tokens_per_sample) + pad_idx + 1
mask = torch.arange(2, cfg.tokens_per_sample, 7) # ~15%
src = seq.clone()
src[mask] = mask_idx
tgt = torch.full_like(seq, pad_idx)
tgt[mask] = seq[mask]
self.dummy_src = src
self.dummy_tgt = tgt
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if self.cfg.batch_size is not None:
bsz = self.cfg.batch_size
else:
bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample)
self.datasets[split] = DummyDataset(
{
"id": 1,
"net_input": {
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
"src_lengths": torch.full(
(bsz,), self.cfg.tokens_per_sample, dtype=torch.long
),
},
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
"nsentences": bsz,
"ntokens": bsz * self.cfg.tokens_per_sample,
},
num_items=self.cfg.dataset_size,
item_size=self.cfg.tokens_per_sample,
)
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary

View File

@@ -0,0 +1,96 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
import torch.nn.functional as F
from fairseq.data import Dictionary
from fairseq.models import (
FairseqDecoder,
FairseqLanguageModel,
register_model,
register_model_architecture,
)
@register_model("dummy_model")
class DummyModel(FairseqLanguageModel):
def __init__(self, args, encoder):
super().__init__(encoder)
self.args = args
@staticmethod
def add_args(parser):
parser.add_argument("--num-layers", type=int, default=24)
parser.add_argument("--embed-dim", type=int, default=1024)
@classmethod
def build_model(cls, args, task):
encoder = DummyEncoder(
num_embed=len(task.target_dictionary),
embed_dim=args.embed_dim,
num_layers=args.num_layers,
)
return cls(args, encoder)
def forward(self, src_tokens, masked_tokens=None, **kwargs):
return self.decoder(src_tokens, masked_tokens=masked_tokens)
class DummyEncoder(FairseqDecoder):
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
super().__init__(Dictionary())
self.embed = nn.Embedding(
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
)
self.layers_a = nn.ModuleList(
[
nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection
nn.Linear(3 * embed_dim, embed_dim), # skip self-attention
nn.Linear(embed_dim, embed_dim), # output projection
nn.Dropout(),
)
for i in range(num_layers)
]
)
self.layers_b = nn.ModuleList(
[
nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, 4 * embed_dim), # FFN
nn.ReLU(),
nn.Linear(4 * embed_dim, embed_dim), # FFN
nn.Dropout(0.1),
)
for i in range(num_layers)
]
)
self.out_proj = nn.Linear(embed_dim, num_embed)
def forward(self, tokens, masked_tokens=None):
x = self.embed(tokens)
for layer_a, layer_b in zip(self.layers_a, self.layers_b):
x = x + layer_a(x)
x = x + layer_b(x)
x = self.out_proj(x)
if masked_tokens is not None:
x = x[masked_tokens]
return (x,)
def max_positions(self):
return 1024
def get_normalized_probs(self, net_output, log_probs, sample=None):
logits = net_output[0].float()
if log_probs:
return F.log_softmax(logits, dim=-1)
else:
return F.softmax(logits, dim=-1)
@register_model_architecture("dummy_model", "dummy_model")
def base_architecture(args):
pass

View File

@@ -0,0 +1,119 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
import torch
from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task("dummy_mt")
class DummyMTTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument("--dict-size", default=49996, type=int)
parser.add_argument("--dataset-size", default=100000, type=int)
parser.add_argument("--src-len", default=30, type=int)
parser.add_argument("--tgt-len", default=30, type=int)
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
self.seed = args.seed
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
self.dummy_src = torch.arange(args.src_len + 1) + dictionary.pad() + 1
self.dummy_tgt = torch.arange(args.tgt_len + 1) + dictionary.pad() + 1
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task."""
dictionary = Dictionary()
for i in range(args.dict_size):
dictionary.add_symbol("word{}".format(i))
logger.info("dictionary: {} types".format(len(dictionary)))
args.max_source_positions = args.src_len + dictionary.pad() + 2
args.max_target_positions = args.tgt_len + dictionary.pad() + 2
return cls(args, dictionary)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
item_size = max(self.args.src_len, self.args.tgt_len)
if self.args.batch_size is not None:
bsz = self.args.batch_size
else:
bsz = max(1, self.args.max_tokens // item_size)
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
self.datasets[split] = DummyDataset(
{
"id": 1,
"net_input": {
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
"src_lengths": torch.full(
(bsz,), self.args.src_len, dtype=torch.long
),
"prev_output_tokens": tgt.clone(),
},
"target": tgt,
"nsentences": bsz,
"ntokens": bsz * self.args.tgt_len,
},
num_items=self.args.dataset_size,
item_size=item_size,
)
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary
class DummyDataset(FairseqDataset):
def __init__(self, batch, num_items, item_size):
super().__init__()
self.batch = batch
self.num_items = num_items
self.item_size = item_size
def __getitem__(self, index):
return index
def __len__(self):
return self.num_items
def collater(self, samples):
return self.batch
@property
def sizes(self):
return np.array([self.item_size] * self.num_items)
def num_tokens(self, index):
return self.item_size
def size(self, index):
return self.item_size
def ordered_indices(self):
return np.arange(self.num_items)
@property
def supports_prefetch(self):
return False

View File

@@ -0,0 +1,381 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import typing as tp
from abc import ABC, abstractmethod
from collections import Counter
from dataclasses import dataclass
from multiprocessing import Pool
import torch
from fairseq.data import Dictionary, indexed_dataset
from fairseq.file_chunker_utils import Chunker, find_offsets
from fairseq.file_io import PathManager
from fairseq.tokenizer import tokenize_line
logger = logging.getLogger("binarizer")
@dataclass
class BinarizeSummary:
"""
Keep track of what's going on in the binarizer
"""
num_seq: int = 0
replaced: tp.Optional[Counter] = None
num_tok: int = 0
@property
def num_replaced(self) -> int:
if self.replaced is None:
return 0
return sum(self.replaced.values())
@property
def replaced_percent(self) -> float:
return 100 * self.num_replaced / self.num_tok
def __str__(self) -> str:
base = f"{self.num_seq} sents, {self.num_tok} tokens"
if self.replaced is None:
return base
return f"{base}, {self.replaced_percent:.3}% replaced"
def merge(self, other: "BinarizeSummary"):
replaced = None
if self.replaced is not None:
replaced = self.replaced
if other.replaced is not None:
if replaced is None:
replaced = other.replaced
else:
replaced += other.replaced
self.replaced = replaced
self.num_seq += other.num_seq
self.num_tok += other.num_tok
class Binarizer(ABC):
"""
a binarizer describes how to take a string and build a tensor out of it
"""
@abstractmethod
def binarize_line(
self,
line: str,
summary: BinarizeSummary,
) -> torch.IntTensor:
...
def _worker_prefix(output_prefix: str, worker_id: int):
return f"{output_prefix}.pt{worker_id}"
class FileBinarizer:
"""
An file binarizer can take a file, tokenize it, and binarize each line to a tensor
"""
@classmethod
def multiprocess_dataset(
cls,
input_file: str,
dataset_impl: str,
binarizer: Binarizer,
output_prefix: str,
vocab_size=None,
num_workers=1,
) -> BinarizeSummary:
final_summary = BinarizeSummary()
offsets = find_offsets(input_file, num_workers)
# find_offsets returns a list of position [pos1, pos2, pos3, pos4] but we would want pairs:
# [(pos1, pos2), (pos2, pos3), (pos3, pos4)] to process the chunks with start/end info
# we zip the list with itself shifted by one to get all the pairs.
(first_chunk, *more_chunks) = zip(offsets, offsets[1:])
pool = None
if num_workers > 1:
pool = Pool(processes=num_workers - 1)
worker_results = [
pool.apply_async(
cls._binarize_chunk_and_finalize,
args=(
binarizer,
input_file,
start_offset,
end_offset,
_worker_prefix(
output_prefix,
worker_id,
),
dataset_impl,
),
kwds={
"vocab_size": vocab_size,
}
if vocab_size is not None
else {},
)
for worker_id, (start_offset, end_offset) in enumerate(
more_chunks, start=1
)
]
pool.close()
pool.join()
for r in worker_results:
summ = r.get()
final_summary.merge(summ)
# do not close the bin file as we need to merge the worker results in
final_ds, summ = cls._binarize_file_chunk(
binarizer,
input_file,
offset_start=first_chunk[0],
offset_end=first_chunk[1],
output_prefix=output_prefix,
dataset_impl=dataset_impl,
vocab_size=vocab_size if vocab_size is not None else None,
)
final_summary.merge(summ)
if num_workers > 1:
for worker_id in range(1, num_workers):
# merge the worker outputs
worker_output_prefix = _worker_prefix(
output_prefix,
worker_id,
)
final_ds.merge_file_(worker_output_prefix)
try:
os.remove(indexed_dataset.data_file_path(worker_output_prefix))
os.remove(indexed_dataset.index_file_path(worker_output_prefix))
except Exception as e:
logger.error(
f"couldn't remove {worker_output_prefix}.*", exc_info=e
)
# now we can close the file
idx_file = indexed_dataset.index_file_path(output_prefix)
final_ds.finalize(idx_file)
return final_summary
@staticmethod
def _binarize_file_chunk(
binarizer: Binarizer,
filename: str,
offset_start: int,
offset_end: int,
output_prefix: str,
dataset_impl: str,
vocab_size=None,
) -> tp.Tuple[tp.Any, BinarizeSummary]: # (dataset builder, BinarizeSummary)
"""
creates a dataset builder and append binarized items to it. This function does not
finalize the builder, this is useful if you want to do other things with your bin file
like appending/merging other files
"""
bin_file = indexed_dataset.data_file_path(output_prefix)
ds = indexed_dataset.make_builder(
bin_file,
impl=dataset_impl,
vocab_size=vocab_size,
)
summary = BinarizeSummary()
with Chunker(
PathManager.get_local_path(filename), offset_start, offset_end
) as line_iterator:
for line in line_iterator:
ds.add_item(binarizer.binarize_line(line, summary))
return ds, summary
@classmethod
def _binarize_chunk_and_finalize(
cls,
binarizer: Binarizer,
filename: str,
offset_start: int,
offset_end: int,
output_prefix: str,
dataset_impl: str,
vocab_size=None,
):
"""
same as above, but also finalizes the builder
"""
ds, summ = cls._binarize_file_chunk(
binarizer,
filename,
offset_start,
offset_end,
output_prefix,
dataset_impl,
vocab_size=vocab_size,
)
idx_file = indexed_dataset.index_file_path(output_prefix)
ds.finalize(idx_file)
return summ
class VocabularyDatasetBinarizer(Binarizer):
"""
Takes a Dictionary/Vocabulary, assign ids to each
token using the dictionary encode_line function.
"""
def __init__(
self,
dict: Dictionary,
tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
append_eos: bool = True,
reverse_order: bool = False,
already_numberized: bool = False,
) -> None:
self.dict = dict
self.tokenize = tokenize
self.append_eos = append_eos
self.reverse_order = reverse_order
self.already_numberized = already_numberized
super().__init__()
def binarize_line(
self,
line: str,
summary: BinarizeSummary,
):
if summary.replaced is None:
summary.replaced = Counter()
def replaced_consumer(word, idx):
if idx == self.dict.unk_index and word != self.dict.unk_word:
summary.replaced.update([word])
if self.already_numberized:
id_strings = line.strip().split()
id_list = [int(id_string) for id_string in id_strings]
if self.reverse_order:
id_list.reverse()
if self.append_eos:
id_list.append(self.dict.eos())
ids = torch.IntTensor(id_list)
else:
ids = self.dict.encode_line(
line=line,
line_tokenizer=self.tokenize,
add_if_not_exist=False,
consumer=replaced_consumer,
append_eos=self.append_eos,
reverse_order=self.reverse_order,
)
summary.num_seq += 1
summary.num_tok += len(ids)
return ids
class AlignmentDatasetBinarizer(Binarizer):
"""
binarize by parsing a set of alignments and packing
them in a tensor (see utils.parse_alignment)
"""
def __init__(
self,
alignment_parser: tp.Callable[[str], torch.IntTensor],
) -> None:
super().__init__()
self.alignment_parser = alignment_parser
def binarize_line(
self,
line: str,
summary: BinarizeSummary,
):
ids = self.alignment_parser(line)
summary.num_seq += 1
summary.num_tok += len(ids)
return ids
class LegacyBinarizer:
@classmethod
def binarize(
cls,
filename: str,
dico: Dictionary,
consumer: tp.Callable[[torch.IntTensor], None],
tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
append_eos: bool = True,
reverse_order: bool = False,
offset: int = 0,
end: int = -1,
already_numberized: bool = False,
) -> tp.Dict[str, int]:
binarizer = VocabularyDatasetBinarizer(
dict=dico,
tokenize=tokenize,
append_eos=append_eos,
reverse_order=reverse_order,
already_numberized=already_numberized,
)
return cls._consume_file(
filename,
binarizer,
consumer,
offset_start=offset,
offset_end=end,
)
@classmethod
def binarize_alignments(
cls,
filename: str,
alignment_parser: tp.Callable[[str], torch.IntTensor],
consumer: tp.Callable[[torch.IntTensor], None],
offset: int = 0,
end: int = -1,
) -> tp.Dict[str, int]:
binarizer = AlignmentDatasetBinarizer(alignment_parser)
return cls._consume_file(
filename,
binarizer,
consumer,
offset_start=offset,
offset_end=end,
)
@staticmethod
def _consume_file(
filename: str,
binarizer: Binarizer,
consumer: tp.Callable[[torch.IntTensor], None],
offset_start: int,
offset_end: int,
) -> tp.Dict[str, int]:
summary = BinarizeSummary()
with Chunker(
PathManager.get_local_path(filename), offset_start, offset_end
) as line_iterator:
for line in line_iterator:
consumer(binarizer.binarize_line(line, summary))
return {
"nseq": summary.num_seq,
"nunk": summary.num_replaced,
"ntok": summary.num_tok,
"replaced": summary.replaced,
}

View File

@@ -0,0 +1,905 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ast
import collections
import contextlib
import inspect
import logging
import os
import re
import time
import traceback
from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
from fairseq.data import data_utils
from fairseq.dataclass.configs import CheckpointConfig
from fairseq.dataclass.utils import (
convert_namespace_to_omegaconf,
overwrite_args_by_name,
)
from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
from fairseq.file_io import PathManager
from fairseq.models import FairseqDecoder, FairseqEncoder
from omegaconf import DictConfig, OmegaConf, open_dict
logger = logging.getLogger(__name__)
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
from fairseq import meters
# only one worker should attempt to create the required dir
if trainer.data_parallel_rank == 0:
os.makedirs(cfg.save_dir, exist_ok=True)
prev_best = getattr(save_checkpoint, "best", val_loss)
if val_loss is not None:
best_function = max if cfg.maximize_best_checkpoint_metric else min
save_checkpoint.best = best_function(val_loss, prev_best)
if cfg.no_save:
return
trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
if not trainer.should_save_checkpoint_on_current_rank:
if trainer.always_call_state_dict_during_save_checkpoint:
trainer.state_dict()
return
write_timer = meters.StopwatchMeter()
write_timer.start()
epoch = epoch_itr.epoch
end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates()
logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
def is_better(a, b):
return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
suffix = trainer.checkpoint_suffix
checkpoint_conds = collections.OrderedDict()
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
)
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
not end_of_epoch
and cfg.save_interval_updates > 0
and updates % cfg.save_interval_updates == 0
)
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
not hasattr(save_checkpoint, "best")
or is_better(val_loss, save_checkpoint.best)
)
if val_loss is not None and cfg.keep_best_checkpoints > 0:
worst_best = getattr(save_checkpoint, "best", None)
chkpts = checkpoint_paths(
cfg.save_dir,
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
cfg.best_checkpoint_metric, suffix
),
)
if len(chkpts) > 0:
p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
# add random digits to resolve ties
with data_utils.numpy_seed(epoch, updates, val_loss):
rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)
checkpoint_conds[
"checkpoint.best_{}_{:.3f}{}{}.pt".format(
cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
)
] = worst_best is None or is_better(val_loss, worst_best)
checkpoint_conds[
"checkpoint_last{}.pt".format(suffix)
] = not cfg.no_last_checkpoints
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
if hasattr(save_checkpoint, "best"):
extra_state.update({"best": save_checkpoint.best})
checkpoints = [
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
]
if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank:
trainer.save_checkpoint(checkpoints[0], extra_state)
for cp in checkpoints[1:]:
if cfg.write_checkpoints_asynchronously:
# TODO[ioPath]: Need to implement a delayed asynchronous
# file copying/moving feature.
logger.warning(
f"ioPath is not copying {checkpoints[0]} to {cp} "
"since async write mode is on."
)
else:
assert PathManager.copy(
checkpoints[0], cp, overwrite=True
), f"Failed to copy {checkpoints[0]} to {cp}"
write_timer.stop()
logger.info(
"Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
checkpoints[0], epoch, updates, val_loss, write_timer.sum
)
)
if not end_of_epoch and cfg.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
if cfg.keep_interval_updates_pattern == -1:
checkpoints = checkpoint_paths(
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
)
else:
checkpoints = checkpoint_paths(
cfg.save_dir,
pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
keep_match=True,
)
checkpoints = [
x[0]
for x in checkpoints
if x[1] % cfg.keep_interval_updates_pattern != 0
]
for old_chk in checkpoints[cfg.keep_interval_updates :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
elif PathManager.exists(old_chk):
PathManager.rm(old_chk)
if cfg.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
)
for old_chk in checkpoints[cfg.keep_last_epochs :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
elif PathManager.exists(old_chk):
PathManager.rm(old_chk)
if cfg.keep_best_checkpoints > 0:
# only keep the best N checkpoints according to validation metric
checkpoints = checkpoint_paths(
cfg.save_dir,
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
cfg.best_checkpoint_metric, suffix
),
)
if not cfg.maximize_best_checkpoint_metric:
checkpoints = checkpoints[::-1]
for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
elif PathManager.exists(old_chk):
PathManager.rm(old_chk)
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
"""
Load a checkpoint and restore the training iterator.
*passthrough_args* will be passed through to
``trainer.get_train_iterator``.
"""
reset_optimizer = cfg.reset_optimizer
reset_lr_scheduler = cfg.reset_lr_scheduler
optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
reset_meters = cfg.reset_meters
reset_dataloader = cfg.reset_dataloader
if cfg.finetune_from_model is not None and (
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
):
raise ValueError(
"--finetune-from-model can not be set together with either --reset-optimizer"
" or reset_lr_scheduler or reset_meters or reset_dataloader"
)
suffix = trainer.checkpoint_suffix
if (
cfg.restore_file == "checkpoint_last.pt"
): # default value of restore_file is 'checkpoint_last.pt'
checkpoint_path = os.path.join(
cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
)
first_launch = not PathManager.exists(checkpoint_path)
if first_launch and getattr(cfg, "continue_once", None) is not None:
checkpoint_path = cfg.continue_once
elif cfg.finetune_from_model is not None and first_launch:
# if there is no last checkpoint to restore, start the finetune from pretrained model
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
if PathManager.exists(cfg.finetune_from_model):
checkpoint_path = cfg.finetune_from_model
reset_optimizer = True
reset_lr_scheduler = True
reset_meters = True
reset_dataloader = True
logger.info(
f"loading pretrained model from {checkpoint_path}: "
"optimizer, lr scheduler, meters, dataloader will be reset"
)
else:
raise ValueError(
f"--finetune-from-model {cfg.finetune_from_model} does not exist"
)
elif suffix is not None:
checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
else:
checkpoint_path = cfg.restore_file
if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
raise ValueError(
"--finetune-from-model and --restore-file (non-default value) "
"can not be specified together: " + str(cfg)
)
extra_state = trainer.load_checkpoint(
checkpoint_path,
reset_optimizer,
reset_lr_scheduler,
optimizer_overrides,
reset_meters=reset_meters,
)
if (
extra_state is not None
and "best" in extra_state
and not reset_optimizer
and not reset_meters
):
save_checkpoint.best = extra_state["best"]
if extra_state is not None and not reset_dataloader:
# restore iterator from checkpoint
itr_state = extra_state["train_iterator"]
epoch_itr = trainer.get_train_iterator(
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
)
epoch_itr.load_state_dict(itr_state)
else:
epoch_itr = trainer.get_train_iterator(
epoch=1, load_dataset=True, **passthrough_args
)
trainer.lr_step(epoch_itr.epoch)
return extra_state, epoch_itr
def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility).
If doing single-GPU training or if the checkpoint is only being loaded by at
most one process on each node (current default behavior is for only rank 0
to read the checkpoint from disk), load_on_all_ranks should be False to
avoid errors from torch.distributed not having been initialized or
torch.distributed.barrier() hanging.
If all processes on each node may be loading the checkpoint
simultaneously, load_on_all_ranks should be set to True to avoid I/O
conflicts.
There's currently no support for > 1 but < all processes loading the
checkpoint on each node.
"""
local_path = PathManager.get_local_path(path)
# The locally cached file returned by get_local_path() may be stale for
# remote files that are periodically updated/overwritten (ex:
# checkpoint_last.pt) - so we remove the local copy, sync across processes
# (if needed), and then download a fresh copy.
if local_path != path and PathManager.path_requires_pathmanager(path):
try:
os.remove(local_path)
except FileNotFoundError:
# With potentially multiple processes removing the same file, the
# file being missing is benign (missing_ok isn't available until
# Python 3.8).
pass
if load_on_all_ranks:
torch.distributed.barrier()
local_path = PathManager.get_local_path(path)
with open(local_path, "rb") as f:
state = torch.load(f, map_location=torch.device("cpu"))
if "args" in state and state["args"] is not None and arg_overrides is not None:
args = state["args"]
for arg_name, arg_val in arg_overrides.items():
setattr(args, arg_name, arg_val)
if "cfg" in state and state["cfg"] is not None:
# hack to be able to set Namespace in dict config. this should be removed when we update to newer
# omegaconf version that supports object flags, or when we migrate all existing models
from omegaconf import __version__ as oc_version
from omegaconf import _utils
if oc_version < "2.2":
old_primitive = _utils.is_primitive_type
_utils.is_primitive_type = lambda _: True
state["cfg"] = OmegaConf.create(state["cfg"])
_utils.is_primitive_type = old_primitive
OmegaConf.set_struct(state["cfg"], True)
else:
state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True})
if arg_overrides is not None:
overwrite_args_by_name(state["cfg"], arg_overrides)
state = _upgrade_state_dict(state)
return state
def load_model_ensemble(
filenames,
arg_overrides: Optional[Dict[str, Any]] = None,
task=None,
strict=True,
suffix="",
num_shards=1,
state=None,
):
"""Loads an ensemble of models.
Args:
filenames (List[str]): checkpoint files to load
arg_overrides (Dict[str,Any], optional): override model args that
were used during model training
task (fairseq.tasks.FairseqTask, optional): task to use for loading
"""
assert not (
strict and num_shards > 1
), "Cannot load state dict with strict=True and checkpoint shards > 1"
ensemble, args, _task = load_model_ensemble_and_task(
filenames,
arg_overrides,
task,
strict,
suffix,
num_shards,
state,
)
return ensemble, args
def get_maybe_sharded_checkpoint_filename(
filename: str, suffix: str, shard_idx: int, num_shards: int
) -> str:
orig_filename = filename
filename = filename.replace(".pt", suffix + ".pt")
fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt"
model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
if PathManager.exists(fsdp_filename):
return fsdp_filename
elif num_shards > 1:
return model_parallel_filename
else:
return filename
def load_model_ensemble_and_task(
filenames,
arg_overrides: Optional[Dict[str, Any]] = None,
task=None,
strict=True,
suffix="",
num_shards=1,
state=None,
):
assert state is None or len(filenames) == 1
from fairseq import tasks
assert not (
strict and num_shards > 1
), "Cannot load state dict with strict=True and checkpoint shards > 1"
ensemble = []
cfg = None
for filename in filenames:
orig_filename = filename
model_shard_state = {"shard_weights": [], "shard_metadata": []}
assert num_shards > 0
st = time.time()
for shard_idx in range(num_shards):
filename = get_maybe_sharded_checkpoint_filename(
orig_filename, suffix, shard_idx, num_shards
)
if not PathManager.exists(filename):
raise IOError("Model file not found: {}".format(filename))
if state is None:
state = load_checkpoint_to_cpu(filename, arg_overrides)
if "args" in state and state["args"] is not None:
cfg = convert_namespace_to_omegaconf(state["args"])
elif "cfg" in state and state["cfg"] is not None:
cfg = state["cfg"]
else:
raise RuntimeError(
f"Neither args nor cfg exist in state keys = {state.keys()}"
)
if task is None:
task = tasks.setup_task(cfg.task)
if "task_state" in state:
task.load_state_dict(state["task_state"])
if "fsdp_metadata" in state and num_shards > 1:
model_shard_state["shard_weights"].append(state["model"])
model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
# check FSDP import before the code goes too far
if not has_FSDP:
raise ImportError(
"Cannot find FullyShardedDataParallel. "
"Please install fairscale with: pip install fairscale"
)
if shard_idx == num_shards - 1:
consolidated_model_state = FSDP.consolidate_shard_weights(
shard_weights=model_shard_state["shard_weights"],
shard_metadata=model_shard_state["shard_metadata"],
)
model = task.build_model(cfg.model)
if (
"optimizer_history" in state
and len(state["optimizer_history"]) > 0
and "num_updates" in state["optimizer_history"][-1]
):
model.set_num_updates(
state["optimizer_history"][-1]["num_updates"]
)
model.load_state_dict(
consolidated_model_state, strict=strict, model_cfg=cfg.model
)
else:
# model parallel checkpoint or unsharded checkpoint
# support old external tasks
argspec = inspect.getfullargspec(task.build_model)
if "from_checkpoint" in argspec.args:
model = task.build_model(cfg.model, from_checkpoint=True)
else:
model = task.build_model(cfg.model)
if (
"optimizer_history" in state
and len(state["optimizer_history"]) > 0
and "num_updates" in state["optimizer_history"][-1]
):
model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
model.load_state_dict(
state["model"], strict=strict, model_cfg=cfg.model
)
# reset state so it gets loaded for the next model in ensemble
state = None
if shard_idx % 10 == 0 and shard_idx > 0:
elapsed = time.time() - st
logger.info(
f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
)
# build model for ensemble
ensemble.append(model)
return ensemble, cfg, task
def load_model_ensemble_and_task_from_hf_hub(
model_id,
cache_dir: Optional[str] = None,
arg_overrides: Optional[Dict[str, Any]] = None,
**kwargs: Any,
):
try:
from huggingface_hub import snapshot_download
except ImportError:
raise ImportError(
"You need to install huggingface_hub to use `load_from_hf_hub`. "
"See https://pypi.org/project/huggingface-hub/ for installation."
)
library_name = "fairseq"
cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix()
cache_dir = snapshot_download(
model_id, cache_dir=cache_dir, library_name=library_name, **kwargs
)
_arg_overrides = arg_overrides or {}
_arg_overrides["data"] = cache_dir
return load_model_ensemble_and_task(
[p.as_posix() for p in Path(cache_dir).glob("*.pt")],
arg_overrides=_arg_overrides,
)
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
"""Retrieves all checkpoints found in `path` directory.
Checkpoints are identified by matching filename to the specified pattern. If
the pattern contains groups, the result will be sorted by the first group in
descending order.
"""
pt_regexp = re.compile(pattern)
files = PathManager.ls(path)
entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = float(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
if keep_match:
return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
else:
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
def torch_persistent_save(obj, filename, async_write: bool = False):
if async_write:
with PathManager.opena(filename, "wb") as f:
_torch_persistent_save(obj, f)
else:
if PathManager.supports_rename(filename):
# do atomic save
with PathManager.open(filename + ".tmp", "wb") as f:
_torch_persistent_save(obj, f)
PathManager.rename(filename + ".tmp", filename)
else:
# fallback to non-atomic save
with PathManager.open(filename, "wb") as f:
_torch_persistent_save(obj, f)
def _torch_persistent_save(obj, f):
if isinstance(f, str):
with PathManager.open(f, "wb") as h:
torch_persistent_save(obj, h)
return
for i in range(3):
try:
return torch.save(obj, f)
except Exception:
if i == 2:
logger.error(traceback.format_exc())
raise
def _upgrade_state_dict(state):
"""Helper for upgrading old model checkpoints."""
# add optimizer_history
if "optimizer_history" not in state:
state["optimizer_history"] = [
{"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
]
state["last_optimizer_state"] = state["optimizer"]
del state["optimizer"]
del state["best_loss"]
# move extra_state into sub-dictionary
if "epoch" in state and "extra_state" not in state:
state["extra_state"] = {
"epoch": state["epoch"],
"batch_offset": state["batch_offset"],
"val_loss": state["val_loss"],
}
del state["epoch"]
del state["batch_offset"]
del state["val_loss"]
# reduce optimizer history's memory usage (only keep the last state)
if "optimizer" in state["optimizer_history"][-1]:
state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
for optim_hist in state["optimizer_history"]:
del optim_hist["optimizer"]
# record the optimizer class name
if "optimizer_name" not in state["optimizer_history"][-1]:
state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
# move best_loss into lr_scheduler_state
if "lr_scheduler_state" not in state["optimizer_history"][-1]:
state["optimizer_history"][-1]["lr_scheduler_state"] = {
"best": state["optimizer_history"][-1]["best_loss"]
}
del state["optimizer_history"][-1]["best_loss"]
# keep track of number of updates
if "num_updates" not in state["optimizer_history"][-1]:
state["optimizer_history"][-1]["num_updates"] = 0
# use stateful training data iterator
if "train_iterator" not in state["extra_state"]:
state["extra_state"]["train_iterator"] = {
"epoch": state["extra_state"].get("epoch", 0),
"iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
}
# backward compatibility, cfg updates
if "args" in state and state["args"] is not None:
# old model checkpoints may not have separate source/target positions
if hasattr(state["args"], "max_positions") and not hasattr(
state["args"], "max_source_positions"
):
state["args"].max_source_positions = state["args"].max_positions
state["args"].max_target_positions = state["args"].max_positions
# default to translation task
if not hasattr(state["args"], "task"):
state["args"].task = "translation"
# --raw-text and --lazy-load are deprecated
if getattr(state["args"], "raw_text", False):
state["args"].dataset_impl = "raw"
elif getattr(state["args"], "lazy_load", False):
state["args"].dataset_impl = "lazy"
# epochs start at 1
if state["extra_state"]["train_iterator"] is not None:
state["extra_state"]["train_iterator"]["epoch"] = max(
state["extra_state"]["train_iterator"].get("epoch", 1), 1
)
# --remove-bpe ==> --postprocess
if hasattr(state["args"], "remove_bpe"):
state["args"].post_process = state["args"].remove_bpe
# --min-lr ==> --stop-min-lr
if hasattr(state["args"], "min_lr"):
state["args"].stop_min_lr = state["args"].min_lr
del state["args"].min_lr
# binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
if hasattr(state["args"], "criterion") and state["args"].criterion in [
"binary_cross_entropy",
"kd_binary_cross_entropy",
]:
state["args"].criterion = "wav2vec"
# remove log_keys if it's None (criteria will supply a default value of [])
if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
delattr(state["args"], "log_keys")
# speech_pretraining => audio pretraining
if (
hasattr(state["args"], "task")
and state["args"].task == "speech_pretraining"
):
state["args"].task = "audio_pretraining"
# audio_cpc => wav2vec
if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc":
state["args"].arch = "wav2vec"
# convert legacy float learning rate to List[float]
if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float):
state["args"].lr = [state["args"].lr]
# convert task data arg to a string instead of List[string]
if (
hasattr(state["args"], "data")
and isinstance(state["args"].data, list)
and len(state["args"].data) > 0
):
state["args"].data = state["args"].data[0]
state["cfg"] = convert_namespace_to_omegaconf(state["args"])
if "cfg" in state and state["cfg"] is not None:
cfg = state["cfg"]
with open_dict(cfg):
# any upgrades for Hydra-based configs
if (
"task" in cfg
and "eval_wer_config" in cfg.task
and isinstance(cfg.task.eval_wer_config.print_alignment, bool)
):
cfg.task.eval_wer_config.print_alignment = "hard"
if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
cfg.generation.print_alignment = (
"hard" if cfg.generation.print_alignment else None
)
if (
"model" in cfg
and "w2v_args" in cfg.model
and cfg.model.w2v_args is not None
and (
hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args
)
and hasattr(cfg.model.w2v_args.task, "eval_wer_config")
and cfg.model.w2v_args.task.eval_wer_config is not None
and isinstance(
cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool
)
):
cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard"
return state
def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
"""Prune the given state_dict if desired for LayerDrop
(https://arxiv.org/abs/1909.11556).
Training with LayerDrop allows models to be robust to pruning at inference
time. This function prunes state_dict to allow smaller models to be loaded
from a larger model and re-maps the existing state_dict for this to occur.
It's called by functions that load models from checkpoints and does not
need to be called directly.
"""
arch = None
if model_cfg is not None:
arch = (
model_cfg._name
if isinstance(model_cfg, DictConfig)
else getattr(model_cfg, "arch", None)
)
if not model_cfg or arch is None or arch == "ptt_transformer":
# args should not be none, but don't crash if it is.
return state_dict
encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
if not encoder_layers_to_keep and not decoder_layers_to_keep:
return state_dict
# apply pruning
logger.info(
"Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
)
def create_pruning_pass(layers_to_keep, layer_name):
keep_layers = sorted(
int(layer_string) for layer_string in layers_to_keep.split(",")
)
mapping_dict = {}
for i in range(len(keep_layers)):
mapping_dict[str(keep_layers[i])] = str(i)
regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
return {"substitution_regex": regex, "mapping_dict": mapping_dict}
pruning_passes = []
if encoder_layers_to_keep:
pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
if decoder_layers_to_keep:
pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
new_state_dict = {}
for layer_name in state_dict.keys():
match = re.search(r"\.layers\.(\d+)\.", layer_name)
# if layer has no number in it, it is a supporting layer, such as an
# embedding
if not match:
new_state_dict[layer_name] = state_dict[layer_name]
continue
# otherwise, layer should be pruned.
original_layer_number = match.group(1)
# figure out which mapping dict to replace from
for pruning_pass in pruning_passes:
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
"substitution_regex"
].search(layer_name):
new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
substitution_match = pruning_pass["substitution_regex"].search(
layer_name
)
new_state_key = (
layer_name[: substitution_match.start(1)]
+ new_layer_number
+ layer_name[substitution_match.end(1) :]
)
new_state_dict[new_state_key] = state_dict[layer_name]
# Since layers are now pruned, *_layers_to_keep are no longer needed.
# This is more of "It would make it work fix" rather than a proper fix.
if isinstance(model_cfg, DictConfig):
context = open_dict(model_cfg)
else:
context = contextlib.ExitStack()
with context:
if hasattr(model_cfg, "encoder_layers_to_keep"):
model_cfg.encoder_layers_to_keep = None
if hasattr(model_cfg, "decoder_layers_to_keep"):
model_cfg.decoder_layers_to_keep = None
return new_state_dict
def load_pretrained_component_from_model(
component: Union[FairseqEncoder, FairseqDecoder],
checkpoint: str,
strict: bool = True,
):
"""
Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
provided `component` object. If state_dict fails to load, there may be a
mismatch in the architecture of the corresponding `component` found in the
`checkpoint` file.
"""
if not PathManager.exists(checkpoint):
raise IOError("Model file not found: {}".format(checkpoint))
state = load_checkpoint_to_cpu(checkpoint)
if isinstance(component, FairseqEncoder):
component_type = "encoder"
elif isinstance(component, FairseqDecoder):
component_type = "decoder"
else:
raise ValueError(
"component to load must be either a FairseqEncoder or "
"FairseqDecoder. Loading other component types are not supported."
)
component_state_dict = OrderedDict()
for key in state["model"].keys():
if key.startswith(component_type):
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight
component_subkey = key[len(component_type) + 1 :]
component_state_dict[component_subkey] = state["model"][key]
component.load_state_dict(component_state_dict, strict=strict)
return component
def verify_checkpoint_directory(save_dir: str) -> None:
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
temp_file_path = os.path.join(save_dir, "dummy")
try:
with open(temp_file_path, "w"):
pass
except OSError as e:
logger.warning(
"Unable to access checkpoint save directory: {}".format(save_dir)
)
raise e
else:
os.remove(temp_file_path)
def save_ema_as_checkpoint(src_path, dst_path):
state = load_ema_from_checkpoint(src_path)
torch_persistent_save(state, dst_path)
def load_ema_from_checkpoint(fpath):
"""Loads exponential moving averaged (EMA) checkpoint from input and
returns a model with ema weights.
Args:
fpath: A string path of checkpoint to load from.
Returns:
A dict of string keys mapping to various values. The 'model' key
from the returned dict should correspond to an OrderedDict mapping
string parameter names to torch Tensors.
"""
params_dict = collections.OrderedDict()
new_state = None
with PathManager.open(fpath, "rb") as f:
new_state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
),
)
# EMA model is stored in a separate "extra state"
model_params = new_state["extra_state"]["ema"]
for key in list(model_params.keys()):
p = model_params[key]
if isinstance(p, torch.HalfTensor):
p = p.float()
if key not in params_dict:
params_dict[key] = p.clone()
# NOTE: clone() is needed in case of p is a shared parameter
else:
raise ValueError("Key {} is repeated in EMA model params.".format(key))
if len(params_dict) == 0:
raise ValueError(
f"Input checkpoint path '{fpath}' does not contain "
"ema model weights, is this model trained with EMA?"
)
new_state["model"] = params_dict
return new_state

View File

@@ -0,0 +1,55 @@
/*
Copyright (c) Microsoft Corporation.
Licensed under the MIT License.
*/
#include <torch/extension.h>
#include <vector>
/*
CPP Binding for CUDA OP
*/
// CUDA forward declarations
torch::Tensor ngram_repeat_block_cuda_forward(
torch::Tensor tokens,
torch::Tensor lprobs,
int bsz,
int step,
int beam_size,
int no_repeat_ngram_size);
#define CHECK_CUDA(x) \
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
// Input check and call to CUDA OP
// Backward method not required
torch::Tensor ngram_repeat_block_forward(
torch::Tensor tokens,
torch::Tensor lprobs,
int bsz,
int step,
int beam_size,
int no_repeat_ngram_size) {
CHECK_INPUT(tokens);
CHECK_INPUT(lprobs);
assert(bsz > 0);
assert(step >= 0);
assert(beam_size > 0);
assert(no_repeat_ngram_size > 0);
return ngram_repeat_block_cuda_forward(
tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"forward",
&ngram_repeat_block_forward,
"No Repeat Ngram Block forward (CUDA)");
}

View File

@@ -0,0 +1,82 @@
/*
Copyright (c) Microsoft Corporation.
Licensed under the MIT License.
*/
/*
Kernel implementation for blocking repeated n-grams.
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
#include <torch/extension.h>
#include <vector>
// Ban repeated ngrams of length = 'no_repeat_ngram_size'
__global__ void banRepeatedTokens(
long* __restrict__ tokens,
float* __restrict__ lprobs,
int max_predict_len,
int vocab_size,
int no_repeat_ngram_size) {
auto row = blockIdx.x;
auto col = threadIdx.x;
auto start = row * (max_predict_len) + col;
// Each thread compares ngram starting from
// thread index with final ngram starting from
// step - no_repeat_ngram_size +2
auto check_start_pos = blockDim.x;
auto lprob_start = row * vocab_size;
bool is_banned = true;
extern __shared__ long tokens_shm[];
tokens_shm[col] = tokens[start];
if (col == blockDim.x - 1) {
for (int i = 1; i < no_repeat_ngram_size; i++) {
if (col + i < max_predict_len) {
tokens_shm[col + i] = tokens[start + i];
}
}
}
__syncthreads();
for (int k = 0; k < no_repeat_ngram_size - 1; k++) {
if (tokens_shm[col + k] != tokens_shm[check_start_pos + k]) {
is_banned = false;
}
}
if (is_banned == true) {
auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1];
lprobs[lprob_start + token_to_be_banned] = -INFINITY;
}
}
// Allocate blocks and threads based on
// batch size and sequence length and launch
// kernel
torch::Tensor ngram_repeat_block_cuda_forward(
const torch::Tensor tokens,
torch::Tensor lprobs,
int bsz,
int step,
int beam_size,
int no_repeat_ngram_size) {
int threads = step - no_repeat_ngram_size + 2;
if (threads <= 0)
return lprobs;
int max_predict_len = tokens.size(1);
int vocab_size = lprobs.size(1);
auto token_ptr = tokens.data_ptr<long>();
auto lprob_ptr = lprobs.data_ptr<float>();
int blocks = bsz * beam_size;
int shared_mem_size = (step + 1) * sizeof(long);
// Launching N blocks where N is number of samples in a batch (beams*bsz)
// Launching T threads where T is number of previous ngrams in a sample
// Allocating shared mem per block for fastser access of input tokens since
// each token will be accessed N times to compare with current Ngram where
// N is Ngram size.
banRepeatedTokens<<<blocks, threads, shared_mem_size>>>(
token_ptr, lprob_ptr, max_predict_len, vocab_size, no_repeat_ngram_size);
return lprobs;
}

View File

@@ -0,0 +1,109 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
/*
C++ code for solving the linear assignment problem.
Based on the Auction Algorithm from
https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf and the
implementation from: https://github.com/bkj/auction-lap Adapted to be more
efficient when each worker is looking for k jobs instead of 1.
*/
#include <torch/extension.h>
#include <iostream>
using namespace torch::indexing;
torch::Tensor balanced_assignment(torch::Tensor job_and_worker_to_score) {
int max_iterations = 100;
torch::Tensor epsilon =
(job_and_worker_to_score.max() - job_and_worker_to_score.min()) / 50;
epsilon.clamp_min_(1e-04);
torch::Tensor worker_and_job_to_score =
job_and_worker_to_score.detach().transpose(0, 1).contiguous();
int num_workers = worker_and_job_to_score.size(0);
int num_jobs = worker_and_job_to_score.size(1);
auto device = worker_and_job_to_score.device();
int jobs_per_worker = num_jobs / num_workers;
torch::Tensor value = worker_and_job_to_score.clone();
int counter = 0;
torch::Tensor max_value = worker_and_job_to_score.max();
torch::Tensor bid_indices;
torch::Tensor cost = worker_and_job_to_score.new_zeros({1, num_jobs});
torch::Tensor bids =
worker_and_job_to_score.new_empty({num_workers, num_jobs});
torch::Tensor bid_increments =
worker_and_job_to_score.new_empty({num_workers, jobs_per_worker});
torch::Tensor top_values =
worker_and_job_to_score.new_empty({num_workers, jobs_per_worker + 1});
torch::Tensor high_bids = worker_and_job_to_score.new_empty({num_jobs});
torch::Tensor top_index = top_values.to(torch::kLong);
torch::Tensor high_bidders = top_index.new_empty({num_jobs});
torch::Tensor have_bids = high_bidders.to(torch::kBool);
torch::Tensor jobs_indices =
torch::arange({num_jobs}, torch::dtype(torch::kLong).device(device));
torch::Tensor true_tensor =
torch::ones({1}, torch::dtype(torch::kBool).device(device));
while (true) {
bids.zero_();
torch::topk_out(top_values, top_index, value, jobs_per_worker + 1, 1);
// Each worker bids the difference in value between that job and the k+1th
// job
torch::sub_out(
bid_increments,
top_values.index({Slice(None, None), Slice(0, jobs_per_worker)}),
top_values.index({Slice(None, None), jobs_per_worker}).unsqueeze(1));
bid_increments.add_(epsilon);
bids.scatter_(
1,
top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}),
bid_increments);
if (counter < max_iterations && counter > 0) {
// Put in a minimal bid to retain items from the last round if no-one else
// bids for them this round
bids.view(-1).index_put_({bid_indices}, epsilon);
}
// Find the highest bidding worker per job
torch::max_out(high_bids, high_bidders, bids, 0);
torch::gt_out(have_bids, high_bids, 0);
if (have_bids.all().item<bool>()) {
// All jobs were bid for
break;
}
// Make popular items more expensive
cost.add_(high_bids);
torch::sub_out(value, worker_and_job_to_score, cost);
bid_indices = ((high_bidders * num_jobs) + jobs_indices).index({have_bids});
if (counter < max_iterations) {
// Make sure that this item will be in the winning worker's top-k next
// time.
value.view(-1).index_put_({bid_indices}, max_value);
} else {
// Suboptimal approximation that converges quickly from current solution
value.view(-1).index_put_(
{bid_indices}, worker_and_job_to_score.view(-1).index({bid_indices}));
}
counter += 1;
}
return top_index.index({Slice(None, None), Slice(0, jobs_per_worker)})
.reshape(-1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("balanced_assignment", &balanced_assignment, "Balanced Assignment");
}

View File

@@ -0,0 +1,157 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <array>
#include <cstdio>
#include <cstring>
#include <map>
// NOLINTNEXTLINE
typedef struct {
size_t reflen;
size_t predlen;
size_t match1;
size_t count1;
size_t match2;
size_t count2;
size_t match3;
size_t count3;
size_t match4;
size_t count4;
} bleu_stat;
// left trim (remove pad)
void bleu_ltrim(size_t* len, int** sent, int pad) {
size_t start = 0;
while (start < *len) {
if (*(*sent + start) != pad) {
break;
}
start++;
}
*sent += start;
*len -= start;
}
// right trim remove (eos)
void bleu_rtrim(size_t* len, int** sent, int pad, int eos) {
size_t end = *len - 1;
while (end > 0) {
if (*(*sent + end) != eos && *(*sent + end) != pad) {
break;
}
end--;
}
*len = end + 1;
}
// left and right trim
void bleu_trim(size_t* len, int** sent, int pad, int eos) {
bleu_ltrim(len, sent, pad);
bleu_rtrim(len, sent, pad, eos);
}
size_t bleu_hash(int len, int* data) {
size_t h = 14695981039346656037ul;
size_t prime = 0x100000001b3;
char* b = (char*)data;
size_t blen = sizeof(int) * len;
while (blen-- > 0) {
h ^= *b++;
h *= prime;
}
return h;
}
void bleu_addngram(
size_t* ntotal,
size_t* nmatch,
size_t n,
size_t reflen,
int* ref,
size_t predlen,
int* pred) {
if (predlen < n) {
return;
}
predlen = predlen - n + 1;
(*ntotal) += predlen;
if (reflen < n) {
return;
}
reflen = reflen - n + 1;
std::map<size_t, size_t> count;
while (predlen > 0) {
size_t w = bleu_hash(n, pred++);
count[w]++;
predlen--;
}
while (reflen > 0) {
size_t w = bleu_hash(n, ref++);
if (count[w] > 0) {
(*nmatch)++;
count[w] -= 1;
}
reflen--;
}
}
extern "C" {
#ifdef _WIN64
__declspec(dllexport)
#endif
void bleu_zero_init(bleu_stat* stat) {
std::memset(stat, 0, sizeof(bleu_stat));
}
#ifdef _WIN64
__declspec(dllexport)
#endif
void bleu_one_init(bleu_stat* stat) {
bleu_zero_init(stat);
stat->count1 = 0;
stat->count2 = 1;
stat->count3 = 1;
stat->count4 = 1;
stat->match1 = 0;
stat->match2 = 1;
stat->match3 = 1;
stat->match4 = 1;
}
#ifdef _WIN64
__declspec(dllexport)
#endif
void bleu_add(
bleu_stat* stat,
size_t reflen,
int* ref,
size_t predlen,
int* pred,
int pad,
int eos) {
bleu_trim(&reflen, &ref, pad, eos);
bleu_trim(&predlen, &pred, pad, eos);
stat->reflen += reflen;
stat->predlen += predlen;
bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred);
bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred);
bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred);
bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred);
}
}

View File

@@ -0,0 +1,33 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <Python.h>
static PyMethodDef method_def[] = {{NULL, NULL, 0, NULL}}; // NOLINT
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"libbleu", /* name of module */
// NOLINTNEXTLINE
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
method_def}; // NOLINT
#if PY_MAJOR_VERSION == 2
PyMODINIT_FUNC init_libbleu()
#else
PyMODINIT_FUNC PyInit_libbleu()
#endif
{
PyObject* m = PyModule_Create(&module_def);
if (!m) {
return NULL;
}
return m;
}

View File

@@ -0,0 +1,231 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <pybind11/detail/common.h>
#include <pybind11/pybind11.h>
#include <torch/torch.h> // @manual=//caffe2:torch_extension
#include <algorithm>
#include <cstdint>
#include <iosfwd>
#include <memory>
#include <new>
#include <string>
#include <utility>
#include <vector>
using namespace ::std;
vector<vector<uint32_t>> edit_distance2_with_dp(
vector<uint32_t>& x,
vector<uint32_t>& y) {
uint32_t lx = x.size();
uint32_t ly = y.size();
vector<vector<uint32_t>> d(lx + 1, vector<uint32_t>(ly + 1));
for (uint32_t i = 0; i < lx + 1; i++) {
d[i][0] = i;
}
for (uint32_t j = 0; j < ly + 1; j++) {
d[0][j] = j;
}
for (uint32_t i = 1; i < lx + 1; i++) {
for (uint32_t j = 1; j < ly + 1; j++) {
d[i][j] =
min(min(d[i - 1][j], d[i][j - 1]) + 1,
d[i - 1][j - 1] + 2 * (x.at(i - 1) == y.at(j - 1) ? 0 : 1));
}
}
return d;
}
vector<vector<uint32_t>> edit_distance2_backtracking(
vector<vector<uint32_t>>& d,
vector<uint32_t>& x,
vector<uint32_t>& y,
uint32_t terminal_symbol) {
vector<uint32_t> seq;
vector<vector<uint32_t>> edit_seqs(x.size() + 2, vector<uint32_t>());
/*
edit_seqs:
0~x.size() cell is the insertion sequences
last cell is the delete sequence
*/
if (x.size() == 0) {
edit_seqs.at(0) = y;
return edit_seqs;
}
uint32_t i = d.size() - 1;
uint32_t j = d.at(0).size() - 1;
while ((i >= 0) && (j >= 0)) {
if ((i == 0) && (j == 0)) {
break;
}
if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
seq.push_back(1); // insert
seq.push_back(y.at(j - 1));
j--;
} else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
seq.push_back(2); // delete
seq.push_back(x.at(i - 1));
i--;
} else {
seq.push_back(3); // keep
seq.push_back(x.at(i - 1));
i--;
j--;
}
}
uint32_t prev_op, op, s, word;
prev_op = 0, s = 0;
for (uint32_t k = 0; k < seq.size() / 2; k++) {
op = seq.at(seq.size() - 2 * k - 2);
word = seq.at(seq.size() - 2 * k - 1);
if (prev_op != 1) {
s++;
}
if (op == 1) // insert
{
edit_seqs.at(s - 1).push_back(word);
} else if (op == 2) // delete
{
edit_seqs.at(x.size() + 1).push_back(1);
} else {
edit_seqs.at(x.size() + 1).push_back(0);
}
prev_op = op;
}
for (uint32_t k = 0; k < edit_seqs.size(); k++) {
if (edit_seqs[k].size() == 0) {
edit_seqs[k].push_back(terminal_symbol);
}
}
return edit_seqs;
}
vector<vector<uint32_t>> edit_distance2_backtracking_with_delete(
vector<vector<uint32_t>>& d,
vector<uint32_t>& x,
vector<uint32_t>& y,
uint32_t terminal_symbol,
uint32_t deletion_symbol) {
vector<uint32_t> seq;
vector<vector<uint32_t>> edit_seqs(x.size() + 1, vector<uint32_t>());
/*
edit_seqs:
0~x.size() cell is the insertion sequences
last cell is the delete sequence
*/
if (x.size() == 0) {
edit_seqs.at(0) = y;
return edit_seqs;
}
uint32_t i = d.size() - 1;
uint32_t j = d.at(0).size() - 1;
while ((i >= 0) && (j >= 0)) {
if ((i == 0) && (j == 0)) {
break;
}
if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
seq.push_back(1); // insert
seq.push_back(y.at(j - 1));
j--;
} else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
seq.push_back(2); // delete
seq.push_back(x.at(i - 1));
i--;
} else {
seq.push_back(3); // keep
seq.push_back(x.at(i - 1));
i--;
j--;
}
}
uint32_t prev_op, op, s, word;
prev_op = 0, s = 0;
for (uint32_t k = 0; k < seq.size() / 2; k++) {
op = seq.at(seq.size() - 2 * k - 2);
word = seq.at(seq.size() - 2 * k - 1);
if (prev_op != 1) {
s++;
}
if (op == 1) // insert
{
edit_seqs.at(s - 1).push_back(word);
} else if (op == 2) // delete
{
edit_seqs.at(s - 1).push_back(deletion_symbol);
}
prev_op = op;
}
for (uint32_t k = 0; k < edit_seqs.size(); k++) {
if (edit_seqs.at(k).size() == 0) {
edit_seqs.at(k).push_back(terminal_symbol);
}
}
return edit_seqs;
}
vector<uint32_t> compute_ed2(
vector<vector<uint32_t>>& xs,
vector<vector<uint32_t>>& ys) {
vector<uint32_t> distances(xs.size());
for (uint32_t i = 0; i < xs.size(); i++) {
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
distances.at(i) = d.at(xs.at(i).size()).at(ys.at(i).size());
}
return distances;
}
vector<vector<vector<uint32_t>>> suggested_ed2_path(
vector<vector<uint32_t>>& xs,
vector<vector<uint32_t>>& ys,
uint32_t terminal_symbol) {
vector<vector<vector<uint32_t>>> seq(xs.size());
for (uint32_t i = 0; i < xs.size(); i++) {
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
seq.at(i) =
edit_distance2_backtracking(d, xs.at(i), ys.at(i), terminal_symbol);
}
return seq;
}
vector<vector<vector<uint32_t>>> suggested_ed2_path_with_delete(
vector<vector<uint32_t>>& xs,
vector<vector<uint32_t>>& ys,
uint32_t terminal_symbol,
uint32_t deletion_symbol) {
vector<vector<vector<uint32_t>>> seq(xs.size());
for (uint32_t i = 0; i < xs.size(); i++) {
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
seq.at(i) = edit_distance2_backtracking_with_delete(
d, xs.at(i), ys.at(i), terminal_symbol, deletion_symbol);
}
return seq;
}
PYBIND11_MODULE(libnat, m) {
m.def("compute_ed2", &compute_ed2, "compute_ed2");
m.def("suggested_ed2_path", &suggested_ed2_path, "suggested_ed2_path");
m.def(
"suggested_ed2_path_with_delete",
&suggested_ed2_path_with_delete,
"suggested_ed2_path_with_delete");
}

View File

@@ -0,0 +1,67 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
/*
This code is partially adpoted from
https://github.com/1ytic/pytorch-edit-distance
*/
#include <torch/types.h>
#include "edit_dist.h"
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#define CHECK_CUDA(x) \
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
torch::Tensor LevenshteinDistance(
torch::Tensor source,
torch::Tensor target,
torch::Tensor source_length,
torch::Tensor target_length) {
CHECK_INPUT(source);
CHECK_INPUT(target);
CHECK_INPUT(source_length);
CHECK_INPUT(target_length);
return LevenshteinDistanceCuda(source, target, source_length, target_length);
}
torch::Tensor GenerateDeletionLabel(
torch::Tensor source,
torch::Tensor operations) {
CHECK_INPUT(source);
CHECK_INPUT(operations);
return GenerateDeletionLabelCuda(source, operations);
}
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabel(
torch::Tensor target,
torch::Tensor operations) {
CHECK_INPUT(target);
CHECK_INPUT(operations);
return GenerateInsertionLabelCuda(target, operations);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance");
m.def(
"generate_deletion_labels",
&GenerateDeletionLabel,
"Generate Deletion Label");
m.def(
"generate_insertion_labels",
&GenerateInsertionLabel,
"Generate Insertion Label");
}

View File

@@ -0,0 +1,344 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "edit_dist.h"
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <device_launch_parameters.h>
#include <utility> // std::pair
template <typename scalar_t>
__global__ void generate_deletion_label_kernel(
const scalar_t* __restrict__ source,
const size_t source_size,
const size_t operation_size,
int* __restrict__ operations,
int* __restrict__ labels) {
const int index = blockIdx.x;
const int offset = index * operation_size;
const int offset_label = index * source_size;
for (int i = 0; i < source_size; i++) {
labels[offset_label + i] = 0;
}
int k = 0;
for (int i = 0; i < operation_size; i++) {
if (operations[offset + i] == 0) {
break;
} else if (operations[offset + i] == 1) {
continue;
} else {
labels[offset_label + k] = 3 - operations[offset + i];
k++;
}
}
}
template <typename scalar_t>
__global__ void generate_insertion_label_kernel(
const scalar_t* __restrict__ target,
const size_t target_size,
const size_t operation_size,
int* __restrict__ operations,
int* __restrict__ labels,
int* __restrict__ masks) {
const int index = blockIdx.x;
const int offset = index * operation_size;
const int offset_label = index * target_size;
int k = 0;
int u = 0;
int m = 0;
for (int i = 0; i < target_size; i++) {
labels[offset_label + i] = 0;
masks[offset_label + i] = 0;
}
for (int i = 0; i < operation_size - 1; i++) {
if (operations[offset + i] == 0) {
break;
} else if (operations[offset + i] == 2) {
continue;
} else if (operations[offset + i] == 1) {
masks[offset_label + m] = 1;
u++;
m++;
} else {
labels[offset_label + k] = u;
masks[offset_label + m] = 0;
k++;
m++;
u = 0;
}
}
}
template <typename scalar_t>
__global__ void levenshtein_distance_kernel(
const scalar_t* __restrict__ source,
const scalar_t* __restrict__ target,
const int* __restrict__ source_length,
const int* __restrict__ target_length,
const size_t source_size,
const size_t target_size,
int* __restrict__ operations,
int* __restrict__ errors_curr) {
const int index = blockIdx.x;
const int offset = index * (source_size + target_size);
const int d = index * (source_size + 1) * (target_size + 1);
const int t = target_size + 1;
auto err_idx = [d, t](int i, int j) { return d + i * t + j; };
auto opt_idx = [offset](int k) { return offset + k; };
const int hyp_len = source_length[index];
const int ref_len = target_length[index];
const scalar_t* hyp_begin = source + index * source_size;
const scalar_t* ref_begin = target + index * target_size;
// dynamic programming
for (int i = 0; i <= hyp_len; i++) {
errors_curr[err_idx(i, 0)] = i;
}
for (int j = 0; j <= ref_len; j++) {
errors_curr[err_idx(0, j)] = j;
}
for (int i = 1; i <= hyp_len; i++) {
for (int j = 1; j <= ref_len; j++) {
errors_curr[err_idx(i, j)] = min(
min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) +
1,
errors_curr[err_idx(i - 1, j - 1)] +
2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1));
}
}
// back-tracing
int i = hyp_len;
int j = ref_len;
int o = hyp_len + ref_len;
for (int k = 0; k < source_size + target_size; k++) {
operations[opt_idx(k)] = 0;
}
while ((i >= 0) && (j >= 0)) {
if ((i == 0) && (j == 0)) {
break;
}
if ((j > 0) &&
(errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) {
o--;
operations[opt_idx(o)] = 1;
j--; // insertion
} else if (
(i > 0) &&
(errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) {
o--;
operations[opt_idx(o)] = 2;
i--; // deletion
} else {
o--;
operations[opt_idx(o)] = 3;
i--;
j--; // do nothing
}
}
// moving to the left
for (int k = 0; k < hyp_len + ref_len; k++) {
if (k + o < hyp_len + ref_len) {
operations[opt_idx(k)] = operations[opt_idx(k + o)];
} else {
operations[opt_idx(k)] = 0; // padding
}
}
}
template <typename scalar_t>
__global__ void faster_levenshtein_distance_kernel(
const scalar_t* __restrict__ source,
const scalar_t* __restrict__ target,
const int* __restrict__ source_length,
const int* __restrict__ target_length,
const size_t source_size,
const size_t target_size,
int* __restrict__ operations) {
extern __shared__ short errors[];
auto errors_curr = errors;
const int index = blockIdx.x;
const int offset = index * (source_size + target_size);
const int t = target_size + 1;
auto err_idx = [t](int i, int j) { return i * t + j; };
auto opt_idx = [offset](int k) { return offset + k; };
const int hyp_len = source_length[index];
const int ref_len = target_length[index];
const scalar_t* hyp_begin = source + index * source_size;
const scalar_t* ref_begin = target + index * target_size;
// dynamic programming
for (int i = 0; i <= hyp_len; i++) {
errors_curr[err_idx(i, 0)] = i;
}
for (int j = 0; j <= ref_len; j++) {
errors_curr[err_idx(0, j)] = j;
}
for (int i = 1; i <= hyp_len; i++) {
for (int j = 1; j <= ref_len; j++) {
errors_curr[err_idx(i, j)] = min(
min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) +
1,
errors_curr[err_idx(i - 1, j - 1)] +
2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1));
}
}
// back-tracing
int i = hyp_len;
int j = ref_len;
int o = hyp_len + ref_len;
for (int k = 0; k < source_size + target_size; k++) {
operations[opt_idx(k)] = 0;
}
while ((i >= 0) && (j >= 0)) {
if ((i == 0) && (j == 0)) {
break;
}
if ((j > 0) &&
(errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) {
o--;
operations[opt_idx(o)] = 1;
j--; // insertion
} else if (
(i > 0) &&
(errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) {
o--;
operations[opt_idx(o)] = 2;
i--; // deletion
} else {
o--;
operations[opt_idx(o)] = 3;
i--;
j--; // do nothing
}
}
// moving to the left
for (int k = 0; k < hyp_len + ref_len; k++) {
if (k + o < hyp_len + ref_len) {
operations[opt_idx(k)] = operations[opt_idx(k + o)];
} else {
operations[opt_idx(k)] = 0; // padding
}
}
}
torch::Tensor GenerateDeletionLabelCuda(
torch::Tensor source,
torch::Tensor operations) {
const auto batch_size = source.size(0);
at::TensorOptions options(source.device());
options = options.dtype(at::ScalarType::Int);
auto labels = torch::empty({batch_size, source.size(1)}, options);
auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] {
generate_deletion_label_kernel<scalar_t>
<<<batch_size, 1, 0, stream>>>(
source.data_ptr<scalar_t>(),
source.size(1),
operations.size(1),
operations.data_ptr<int>(),
labels.data_ptr<int>());
}));
return labels;
}
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
torch::Tensor target,
torch::Tensor operations) {
const auto batch_size = target.size(0);
at::TensorOptions options(target.device());
options = options.dtype(at::ScalarType::Int);
auto labels = torch::empty({batch_size, target.size(1)}, options);
auto masks = torch::empty({batch_size, target.size(1)}, options);
auto stream = at::cuda::getCurrentCUDAStream(target.device().index());
AT_DISPATCH_ALL_TYPES(
target.scalar_type(), "generate_insertion_labels", ([&] {
generate_insertion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
target.data_ptr<scalar_t>(),
target.size(1),
operations.size(1),
operations.data_ptr<int>(),
labels.data_ptr<int>(),
masks.data_ptr<int>());
}));
return std::make_pair(labels, masks);
}
torch::Tensor LevenshteinDistanceCuda(
torch::Tensor source,
torch::Tensor target,
torch::Tensor source_length,
torch::Tensor target_length) {
const auto batch_size = source.size(0);
const auto shared_size =
(source.size(1) + 1) * (target.size(1) + 1) * sizeof(short);
at::TensorOptions options(source.device());
options = options.dtype(at::ScalarType::Int);
auto operations =
torch::empty({batch_size, source.size(1) + target.size(1)}, options);
auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
if (shared_size > 40000) {
auto distances = torch::empty(
{batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options);
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] {
levenshtein_distance_kernel<scalar_t>
<<<batch_size, 1, 0, stream>>>(
source.data_ptr<scalar_t>(),
target.data_ptr<scalar_t>(),
source_length.data_ptr<int>(),
target_length.data_ptr<int>(),
source.size(1),
target.size(1),
operations.data_ptr<int>(),
distances.data_ptr<int>());
}));
} else {
AT_DISPATCH_ALL_TYPES(
source.scalar_type(), "faster_levenshtein_distance", ([&] {
faster_levenshtein_distance_kernel<scalar_t>
<<<batch_size, 1, shared_size, stream>>>(
source.data_ptr<scalar_t>(),
target.data_ptr<scalar_t>(),
source_length.data_ptr<int>(),
target_length.data_ptr<int>(),
source.size(1),
target.size(1),
operations.data_ptr<int>());
}));
}
return operations;
}

View File

@@ -0,0 +1,25 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <torch/extension.h>
torch::Tensor LevenshteinDistanceCuda(
torch::Tensor source,
torch::Tensor target,
torch::Tensor source_length,
torch::Tensor target_length);
torch::Tensor GenerateDeletionLabelCuda(
torch::Tensor source,
torch::Tensor operations);
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
torch::Tensor source,
torch::Tensor operations);

View File

@@ -0,0 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,19 @@
# @package _group_
hydra:
run:
dir: .
defaults:
- _self_
- task: null
- model: null
- criterion: cross_entropy
- optimizer: null
- lr_scheduler: fixed
- bpe: null
- tokenizer: null
- scoring: null
- generation: null
- common_eval: null
- eval_lm: null

View File

@@ -0,0 +1,36 @@
# @package _group_
activation_fn: "relu"
dropout: 0.1
attention_dropout: 0.1
activation_dropout: 0.0
relu_dropout: 0.0
decoder_embed_dim: 512
decoder_output_dim: 512
decoder_input_dim: 512
decoder_ffn_embed_dim: 4096
decoder_layers: 12
decoder_attention_heads: 16
decoder_normalize_before: true
no_decoder_final_norm: true
adaptive_softmax_cutoff: null
adaptive_softmax_dropout: 0
adaptive_softmax_factor: 4
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
character_embeddings: false
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
character_embedding_dim: 4
char_embedder_highway_layers: 2
adaptive_input: false
adaptive_input_factor: 4
adaptive_input_cutoff: null
tie_adaptive_weights: false
tie_adaptive_proj: false
decoder_learned_pos: false
decoder_layerdrop: 0
decoder_layers_to_keep: null
layernorm_embedding: false
no_scale_embedding: false
quant_noise_pq: 0
quant_noise_pq_block_size: 8
quant_noise_scalar: 0

View File

@@ -0,0 +1,36 @@
# @package _group_
activation_fn: "relu"
dropout: 0.3
attention_dropout: 0.1
activation_dropout: 0.1
relu_dropout: 0.1
decoder_embed_dim: 1024
decoder_output_dim: 1024
decoder_input_dim: 1024
decoder_ffn_embed_dim: 4096
decoder_layers: 16
decoder_attention_heads: 8
decoder_normalize_before: true
no_decoder_final_norm: true
adaptive_softmax_cutoff: "20000,60000"
adaptive_softmax_dropout: 0.2
adaptive_softmax_factor: 4
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
character_embeddings: false
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
character_embedding_dim: 4
char_embedder_highway_layers: 2
adaptive_input: true
adaptive_input_factor: 4
adaptive_input_cutoff: "20000,60000"
tie_adaptive_weights: true
tie_adaptive_proj: true
decoder_learned_pos: false
decoder_layerdrop: 0
decoder_layers_to_keep: null
layernorm_embedding: false
no_scale_embedding: false
quant_noise_pq: 0
quant_noise_pq_block_size: 8
quant_noise_scalar: 0

View File

@@ -0,0 +1,36 @@
# @package _group_
activation_fn: "relu"
dropout: 0.1
attention_dropout: 0.0
activation_dropout: 0.0
relu_dropout: 0.0
decoder_embed_dim: 1024
decoder_output_dim: 1024
decoder_input_dim: 1024
decoder_ffn_embed_dim: 4096
decoder_layers: 12
decoder_attention_heads: 16
decoder_normalize_before: true
no_decoder_final_norm: false
adaptive_softmax_cutoff: null
adaptive_softmax_dropout: 0
adaptive_softmax_factor: 4
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
character_embeddings: false
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
character_embedding_dim: 4
char_embedder_highway_layers: 2
adaptive_input: false
adaptive_input_factor: 4
adaptive_input_cutoff: null
tie_adaptive_weights: false
tie_adaptive_proj: false
decoder_learned_pos: false
decoder_layerdrop: 0
decoder_layers_to_keep: null
layernorm_embedding: false
no_scale_embedding: false
quant_noise_pq: 0
quant_noise_pq_block_size: 8
quant_noise_scalar: 0

View File

@@ -0,0 +1,36 @@
# @package _group_
activation_fn: "relu"
dropout: 0.1
attention_dropout: 0.1
activation_dropout: 0.0
relu_dropout: 0.0
decoder_embed_dim: 512
decoder_output_dim: 512
decoder_input_dim: 512
decoder_ffn_embed_dim: 4096
decoder_layers: 12
decoder_attention_heads: 16
decoder_normalize_before: true
no_decoder_final_norm: true
adaptive_softmax_cutoff: null
adaptive_softmax_dropout: 0
adaptive_softmax_factor: 4
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
character_embeddings: false
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
character_embedding_dim: 4
char_embedder_highway_layers: 2
adaptive_input: false
adaptive_input_factor: 4
adaptive_input_cutoff: null
tie_adaptive_weights: false
tie_adaptive_proj: false
decoder_learned_pos: false
decoder_layerdrop: 0
decoder_layers_to_keep: null
layernorm_embedding: false
no_scale_embedding: false
quant_noise_pq: 0
quant_noise_pq_block_size: 8
quant_noise_scalar: 0

View File

@@ -0,0 +1,36 @@
# @package _group_
activation_fn: "gelu"
dropout: 0.1
attention_dropout: 0.1
activation_dropout: 0.0
relu_dropout: 0.0
decoder_embed_dim: 768
decoder_output_dim: 768
decoder_input_dim: 768
decoder_ffn_embed_dim: 3072
decoder_layers: 12
decoder_attention_heads: 12
decoder_normalize_before: true
no_decoder_final_norm: false
adaptive_softmax_cutoff: null
adaptive_softmax_dropout: 0
adaptive_softmax_factor: 4
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
character_embeddings: false
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
character_embedding_dim: 4
char_embedder_highway_layers: 2
adaptive_input: false
adaptive_input_factor: 4
adaptive_input_cutoff: null
tie_adaptive_weights: false
tie_adaptive_proj: false
decoder_learned_pos: false
decoder_layerdrop: 0
decoder_layers_to_keep: null
layernorm_embedding: false
no_scale_embedding: false
quant_noise_pq: 0
quant_noise_pq_block_size: 8
quant_noise_scalar: 0

View File

@@ -0,0 +1,36 @@
# @package _group_
activation_fn: "gelu"
dropout: 0.1
attention_dropout: 0.1
activation_dropout: 0.0
relu_dropout: 0.0
decoder_embed_dim: 1600
decoder_output_dim: 1600
decoder_input_dim: 1600
decoder_ffn_embed_dim: 6400
decoder_layers: 48
decoder_attention_heads: 25
decoder_normalize_before: true
no_decoder_final_norm: false
adaptive_softmax_cutoff: null
adaptive_softmax_dropout: 0
adaptive_softmax_factor: 4
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
character_embeddings: false
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
character_embedding_dim: 4
char_embedder_highway_layers: 2
adaptive_input: false
adaptive_input_factor: 4
adaptive_input_cutoff: null
tie_adaptive_weights: false
tie_adaptive_proj: false
decoder_learned_pos: false
decoder_layerdrop: 0
decoder_layers_to_keep: null
layernorm_embedding: false
no_scale_embedding: false
quant_noise_pq: 0
quant_noise_pq_block_size: 8
quant_noise_scalar: 0

View File

@@ -0,0 +1,36 @@
# @package _group_
activation_fn: "gelu"
dropout: 0.1
attention_dropout: 0.1
activation_dropout: 0.0
relu_dropout: 0.0
decoder_embed_dim: 1280
decoder_output_dim: 1280
decoder_input_dim: 1280
decoder_ffn_embed_dim: 5120
decoder_layers: 36
decoder_attention_heads: 20
decoder_normalize_before: true
no_decoder_final_norm: false
adaptive_softmax_cutoff: null
adaptive_softmax_dropout: 0
adaptive_softmax_factor: 4
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
character_embeddings: false
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
character_embedding_dim: 4
char_embedder_highway_layers: 2
adaptive_input: false
adaptive_input_factor: 4
adaptive_input_cutoff: null
tie_adaptive_weights: false
tie_adaptive_proj: false
decoder_learned_pos: false
decoder_layerdrop: 0
decoder_layers_to_keep: null
layernorm_embedding: false
no_scale_embedding: false
quant_noise_pq: 0
quant_noise_pq_block_size: 8
quant_noise_scalar: 0

View File

@@ -0,0 +1,36 @@
# @package _group_
activation_fn: "gelu"
dropout: 0.1
attention_dropout: 0.1
activation_dropout: 0.0
relu_dropout: 0.0
decoder_embed_dim: 1024
decoder_output_dim: 1024
decoder_input_dim: 1024
decoder_ffn_embed_dim: 4096
decoder_layers: 24
decoder_attention_heads: 16
decoder_normalize_before: true
no_decoder_final_norm: false
adaptive_softmax_cutoff: null
adaptive_softmax_dropout: 0
adaptive_softmax_factor: 4
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
character_embeddings: false
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
character_embedding_dim: 4
char_embedder_highway_layers: 2
adaptive_input: false
adaptive_input_factor: 4
adaptive_input_cutoff: null
tie_adaptive_weights: false
tie_adaptive_proj: false
decoder_learned_pos: false
decoder_layerdrop: 0
decoder_layers_to_keep: null
layernorm_embedding: false
no_scale_embedding: false
quant_noise_pq: 0
quant_noise_pq_block_size: 8
quant_noise_scalar: 0

View File

@@ -0,0 +1,36 @@
# @package _group_
activation_fn: "relu"
dropout: 0.3
attention_dropout: 0.1
activation_dropout: 0.1
relu_dropout: 0.1
decoder_embed_dim: 1024
decoder_output_dim: 1024
decoder_input_dim: 1024
decoder_ffn_embed_dim: 4096
decoder_layers: 16
decoder_attention_heads: 8
decoder_normalize_before: true
no_decoder_final_norm: true
adaptive_softmax_cutoff: "20000,60000"
adaptive_softmax_dropout: 0.2
adaptive_softmax_factor: 4
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
character_embeddings: false
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
character_embedding_dim: 4
char_embedder_highway_layers: 2
adaptive_input: true
adaptive_input_factor: 4
adaptive_input_cutoff: "20000,60000"
tie_adaptive_weights: true
tie_adaptive_proj: true
decoder_learned_pos: false
decoder_layerdrop: 0
decoder_layers_to_keep: null
layernorm_embedding: false
no_scale_embedding: false
quant_noise_pq: 0
quant_noise_pq_block_size: 8
quant_noise_scalar: 0

View File

@@ -0,0 +1,5 @@
# @package _group_
activation: gelu
vq_type: gumbel
vq_depth: 2
combine_groups: true

View File

@@ -0,0 +1,8 @@
# @package _group_
quantize_targets: true
final_dim: 256
encoder_layerdrop: 0.05
dropout_input: 0.1
dropout_features: 0.1
feature_grad_mult: 0.1

View File

@@ -0,0 +1,20 @@
# @package _group_
quantize_targets: true
extractor_mode: layer_norm
layer_norm_first: true
final_dim: 768
latent_temp: [2.0,0.1,0.999995]
encoder_layerdrop: 0.0
dropout_input: 0.0
dropout_features: 0.0
dropout: 0.0
attention_dropout: 0.0
conv_bias: true
encoder_layers: 24
encoder_embed_dim: 1024
encoder_ffn_embed_dim: 4096
encoder_attention_heads: 16
feature_grad_mult: 1.0

View File

@@ -0,0 +1,36 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import importlib
import os
from fairseq import registry
from fairseq.criterions.fairseq_criterion import ( # noqa
FairseqCriterion,
LegacyFairseqCriterion,
)
from omegaconf import DictConfig
(
build_criterion_,
register_criterion,
CRITERION_REGISTRY,
CRITERION_DATACLASS_REGISTRY,
) = registry.setup_registry(
"--criterion", base_class=FairseqCriterion, default="cross_entropy"
)
def build_criterion(cfg: DictConfig, task):
return build_criterion_(cfg, task)
# automatically import any Python files in the criterions/ directory
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
file_name = file[: file.find(".py")]
importlib.import_module("fairseq.criterions." + file_name)

View File

@@ -0,0 +1,123 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.constants import DDP_BACKEND_CHOICES
from omegaconf import II
@dataclass
class AdaptiveLossConfig(FairseqDataclass):
sentence_avg: bool = II("optimization.sentence_avg")
ddp_backend: DDP_BACKEND_CHOICES = II("distributed_training.ddp_backend")
@register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig)
class AdaptiveLoss(FairseqCriterion):
"""This is an implementation of the loss function accompanying the adaptive softmax approximation for
graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
(http://arxiv.org/abs/1609.04309)."""
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg
@classmethod
def build_criterion(cls, cfg: AdaptiveLossConfig, task):
if cfg.ddp_backend in {"c10d", "pytorch_ddp"}:
raise Exception(
"AdaptiveLoss is not compatible with the PyTorch "
"version of DistributedDataParallel. Please use "
"`--ddp-backend=legacy_ddp` instead."
)
return cls(task, cfg.sentence_avg)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert (
hasattr(model.decoder, "adaptive_softmax")
and model.decoder.adaptive_softmax is not None
)
adaptive_softmax = model.decoder.adaptive_softmax
net_output = model(**sample["net_input"])
orig_target = model.get_targets(sample, net_output)
nsentences = orig_target.size(0)
orig_target = orig_target.view(-1)
bsz = orig_target.size(0)
logits, target = adaptive_softmax(net_output[0], orig_target)
assert len(target) == len(logits)
loss = net_output[0].new(1 if reduce else bsz).zero_()
for i in range(len(target)):
if target[i] is not None:
assert target[i].min() >= 0 and target[i].max() <= logits[i].size(1)
loss += F.cross_entropy(
logits[i],
target[i],
ignore_index=self.padding_idx,
reduction="sum" if reduce else "none",
)
orig = utils.strip_pad(orig_target, self.padding_idx)
ntokens = orig.numel()
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
logging_output = {
"loss": loss.data,
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -0,0 +1,100 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq import utils
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
from torch import nn
@register_criterion("composite_loss")
class CompositeLoss(LegacyFairseqCriterion):
"""This is a composite loss that, given a list of model outputs and a list of targets,
computes an average of losses for each output-target pair"""
def __init__(self, args, task):
super().__init__(args, task)
self.underlying_criterion = args.underlying_criterion
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
# fmt: off
parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True,
help='underlying criterion to use for the composite loss')
# fmt: on
@staticmethod
def build_underlying_criterion(args, task):
saved_criterion = args.criterion
args.criterion = args.underlying_criterion
assert saved_criterion != args.underlying_criterion
underlying_criterion = task.build_criterion(args)
args.criterion = saved_criterion
return underlying_criterion
@classmethod
def build_criterion(cls, args, task):
underlying_criterion = CompositeLoss.build_underlying_criterion(args, task)
class FakeModel(nn.Module):
def __init__(self, model, net_out, target):
super().__init__()
self.model = model
self.net_out = net_out
self.target = target
def forward(self, **unused):
return self.net_out
def get_normalized_probs(self, net_output, log_probs, sample=None):
return self.model.get_normalized_probs(
net_output, log_probs, sample=sample
)
def get_targets(self, *unused):
return self.target
@property
def decoder(self):
return self.model.decoder
class _CompositeLoss(LegacyFairseqCriterion):
def __init__(self, args, task, underlying_criterion):
super().__init__(args, task)
self.underlying_criterion = underlying_criterion
def forward(self, model, sample, reduce=True):
net_outputs = model(**sample["net_input"])
targets = sample["target"]
bsz = targets[0].size(0)
loss = net_outputs[0][0].new(1 if reduce else bsz).float().zero_()
sample_size = 0
logging_output = {}
for o, t in zip(net_outputs[0], targets):
m = FakeModel(model, (o, net_outputs[1]), t)
sample["target"] = t
l, ss, logging_output = self.underlying_criterion(m, sample, reduce)
loss += l
sample_size += ss
loss.div_(len(targets))
sample_size /= len(targets)
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
return underlying_criterion.__class__.aggregate_logging_outputs(
logging_outputs
)
@staticmethod
def reduce_metrics(logging_outputs) -> None:
underlying_criterion.__class__.reduce_metrics(logging_outputs)
return _CompositeLoss(args, task, underlying_criterion)

View File

@@ -0,0 +1,90 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from omegaconf import II
@dataclass
class CrossEntropyCriterionConfig(FairseqDataclass):
sentence_avg: bool = II("optimization.sentence_avg")
@register_criterion("cross_entropy", dataclass=CrossEntropyCriterionConfig)
class CrossEntropyCriterion(FairseqCriterion):
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
return loss, sample_size, logging_output
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1)
loss = F.nll_loss(
lprobs,
target,
ignore_index=self.padding_idx,
reduction="sum" if reduce else "none",
)
return loss, loss
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
# we divide by log(2) to convert the loss from base e to base 2
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -0,0 +1,319 @@
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Optional
import torch
import torch.nn.functional as F
from omegaconf import II
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.data.data_utils import post_process
from fairseq.dataclass import FairseqDataclass
from fairseq.logging.meters import safe_round
from fairseq.tasks import FairseqTask
@dataclass
class CtcCriterionConfig(FairseqDataclass):
zero_infinity: bool = field(
default=False,
metadata={"help": "zero inf loss when source length <= target length"},
)
sentence_avg: bool = II("optimization.sentence_avg")
post_process: str = field(
default="letter",
metadata={
"help": "how to post process predictions into words. can be letter, "
"wordpiece, BPE symbols, etc. "
"See fairseq.data.data_utils.post_process() for full list of options"
},
)
wer_kenlm_model: Optional[str] = field(
default=None,
metadata={
"help": "if this is provided, use kenlm to compute wer (along with other wer_* args)"
},
)
wer_lexicon: Optional[str] = field(
default=None,
metadata={"help": "lexicon to use with wer_kenlm_model"},
)
wer_lm_weight: float = field(
default=2.0,
metadata={"help": "lm weight to use with wer_kenlm_model"},
)
wer_word_score: float = field(
default=-1.0,
metadata={"help": "lm word score to use with wer_kenlm_model"},
)
wer_args: Optional[str] = field(
default=None,
metadata={
"help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
},
)
@register_criterion("ctc", dataclass=CtcCriterionConfig)
class CtcCriterion(FairseqCriterion):
def __init__(
self, cfg: CtcCriterionConfig, task: FairseqTask, rdrop_alpha: int = 0.0
):
super().__init__(task)
self.blank_idx = (
task.target_dictionary.index(task.blank_symbol)
if hasattr(task, "blank_symbol")
else 0
)
self.pad_idx = task.target_dictionary.pad()
self.eos_idx = task.target_dictionary.eos()
self.post_process = cfg.post_process
self.rdrop_alpha = rdrop_alpha
if cfg.wer_args is not None:
(
cfg.wer_kenlm_model,
cfg.wer_lexicon,
cfg.wer_lm_weight,
cfg.wer_word_score,
) = eval(cfg.wer_args)
if cfg.wer_kenlm_model is not None and cfg.wer_kenlm_model != "":
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
dec_args = Namespace()
dec_args.nbest = 1
dec_args.criterion = "ctc"
dec_args.kenlm_model = cfg.wer_kenlm_model
dec_args.lexicon = cfg.wer_lexicon
dec_args.beam = 50
dec_args.beam_size_token = min(50, len(task.target_dictionary))
dec_args.beam_threshold = min(50, len(task.target_dictionary))
dec_args.lm_weight = cfg.wer_lm_weight
dec_args.word_score = cfg.wer_word_score
dec_args.unk_weight = -math.inf
dec_args.sil_weight = 0
self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
else:
self.w2l_decoder = None
self.zero_infinity = cfg.zero_infinity
self.sentence_avg = cfg.sentence_avg
def forward(self, model, sample, reduce=True, **kwargs):
net_output = model(**sample["net_input"])
lprobs = model.get_normalized_probs(
net_output, log_probs=True
).contiguous() # (T, B, C) from the encoder
# CTC loss is calculated over duplicated inputs
# sample is already duplicated for R-Drop
if self.rdrop_alpha > 0:
for k, v in sample.items():
if k in ["target", "target_lengths"]:
sample[k] = torch.cat([v, v.clone()], dim=0)
elif k == "net_input":
if sample[k]["src_tokens"].size(1) != sample[k]["src_lengths"].size(
0
):
# for decoder CTC loss
sample[k]["src_lengths"] = torch.cat(
[
sample[k]["src_lengths"],
sample[k]["src_lengths"].clone(),
],
dim=0,
)
if "src_lengths" in sample["net_input"]:
input_lengths = sample["net_input"]["src_lengths"]
else:
if net_output["padding_mask"] is not None:
non_padding_mask = ~net_output["padding_mask"]
input_lengths = non_padding_mask.long().sum(-1)
else:
input_lengths = lprobs.new_full(
(lprobs.size(1),), lprobs.size(0), dtype=torch.long
)
pad_mask = (sample["target"] != self.pad_idx) & (
sample["target"] != self.eos_idx
)
targets_flat = sample["target"].masked_select(pad_mask)
if "target_lengths" in sample:
target_lengths = sample["target_lengths"]
else:
target_lengths = pad_mask.sum(-1)
with torch.backends.cudnn.flags(enabled=False):
loss = F.ctc_loss(
lprobs,
targets_flat,
input_lengths,
target_lengths,
blank=self.blank_idx,
reduction="sum",
zero_infinity=self.zero_infinity,
)
ntokens = (
sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
)
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
logging_output = {
"loss": utils.item(loss.data), # * sample['ntokens'],
"ntokens": ntokens,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
}
if not model.training:
import editdistance
with torch.no_grad():
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
c_err = 0
c_len = 0
w_errs = 0
w_len = 0
wv_errs = 0
for lp, t, inp_l in zip(
lprobs_t,
sample["target_label"]
if "target_label" in sample
else sample["target"],
input_lengths,
):
lp = lp[:inp_l].unsqueeze(0)
decoded = None
if self.w2l_decoder is not None:
decoded = self.w2l_decoder.decode(lp)
if len(decoded) < 1:
decoded = None
else:
decoded = decoded[0]
if len(decoded) < 1:
decoded = None
else:
decoded = decoded[0]
p = (t != self.task.target_dictionary.pad()) & (
t != self.task.target_dictionary.eos()
)
targ = t[p]
targ_units = self.task.target_dictionary.string(targ)
targ_units_arr = targ.tolist()
toks = lp.argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.blank_idx].tolist()
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
c_len += len(targ_units_arr)
targ_words = post_process(targ_units, self.post_process).split()
pred_units = self.task.target_dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
if decoded is not None and "words" in decoded:
pred_words = decoded["words"]
w_errs += editdistance.eval(pred_words, targ_words)
wv_errs += editdistance.eval(pred_words_raw, targ_words)
else:
dist = editdistance.eval(pred_words_raw, targ_words)
w_errs += dist
wv_errs += dist
w_len += len(targ_words)
logging_output["wv_errors"] = wv_errs
logging_output["w_errors"] = w_errs
logging_output["w_total"] = w_len
logging_output["c_errors"] = c_err
logging_output["c_total"] = c_len
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
nsentences = utils.item(
sum(log.get("nsentences", 0) for log in logging_outputs)
)
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
metrics.log_scalar("_c_errors", c_errors)
c_total = sum(log.get("c_total", 0) for log in logging_outputs)
metrics.log_scalar("_c_total", c_total)
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
metrics.log_scalar("_w_errors", w_errors)
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
metrics.log_scalar("_wv_errors", wv_errors)
w_total = sum(log.get("w_total", 0) for log in logging_outputs)
metrics.log_scalar("_w_total", w_total)
if c_total > 0:
metrics.log_derived(
"uer",
lambda meters: safe_round(
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
)
if meters["_c_total"].sum > 0
else float("nan"),
)
if w_total > 0:
metrics.log_derived(
"wer",
lambda meters: safe_round(
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)
metrics.log_derived(
"raw_wer",
lambda meters: safe_round(
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -0,0 +1,120 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import inspect
from typing import Any, Dict, List
from fairseq import metrics, utils
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import gen_parser_from_dataclass
from torch.nn.modules.loss import _Loss
class FairseqCriterion(_Loss):
def __init__(self, task):
super().__init__()
self.task = task
if hasattr(task, "target_dictionary"):
tgt_dict = task.target_dictionary
self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100
@classmethod
def add_args(cls, parser):
"""Add criterion-specific arguments to the parser."""
dc = getattr(cls, "__dataclass", None)
if dc is not None:
gen_parser_from_dataclass(parser, dc())
@classmethod
def build_criterion(cls, cfg: FairseqDataclass, task):
"""Construct a criterion from command-line args."""
# arguments in the __init__.
init_args = {}
for p in inspect.signature(cls).parameters.values():
if (
p.kind == p.POSITIONAL_ONLY
or p.kind == p.VAR_POSITIONAL
or p.kind == p.VAR_KEYWORD
):
# we haven't implemented inference for these argument types,
# but PRs welcome :)
raise NotImplementedError("{} not supported".format(p.kind))
assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY}
if p.name == "task":
init_args["task"] = task
elif p.name == "cfg":
init_args["cfg"] = cfg
elif hasattr(cfg, p.name):
init_args[p.name] = getattr(cfg, p.name)
elif p.default != p.empty:
pass # we'll use the default value
else:
raise NotImplementedError(
"Unable to infer Criterion arguments, please implement "
"{}.build_criterion".format(cls.__name__)
)
return cls(**init_args)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
raise NotImplementedError
@staticmethod
def aggregate_logging_outputs(
logging_outputs: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""Aggregate logging outputs from data parallel training."""
utils.deprecation_warning(
"The aggregate_logging_outputs API is deprecated. "
"Please use the reduce_metrics API instead."
)
raise NotImplementedError
@classmethod
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
"""Aggregate logging outputs from data parallel training."""
utils.deprecation_warning(
"Criterions should implement the reduce_metrics API. "
"Falling back to deprecated aggregate_logging_outputs API."
)
agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
for k, v in agg_logging_outputs.items():
if k in {"nsentences", "ntokens", "sample_size"}:
continue
metrics.log_scalar(k, v)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
class LegacyFairseqCriterion(FairseqCriterion):
def __init__(self, args, task):
super().__init__(task=task)
self.args = args
utils.deprecation_warning(
"Criterions should take explicit arguments instead of an "
"argparse.Namespace object, please update your criterion by "
"extending FairseqCriterion instead of LegacyFairseqCriterion."
)
@classmethod
def build_criterion(cls, args, task):
"""Construct a criterion from command-line args."""
return cls(args, task)

View File

@@ -0,0 +1,136 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from typing import List, Dict, Any
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.data.data_utils import lengths_to_mask
from fairseq.models.fairseq_model import FairseqEncoderModel
@dataclass
class FastSpeech2CriterionConfig(FairseqDataclass):
ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"})
@register_criterion("fastspeech2", dataclass=FastSpeech2CriterionConfig)
class FastSpeech2Loss(FairseqCriterion):
def __init__(self, task, ctc_weight):
super().__init__(task)
self.ctc_weight = ctc_weight
def forward(self, model: FairseqEncoderModel, sample, reduction="mean"):
src_tokens = sample["net_input"]["src_tokens"]
src_lens = sample["net_input"]["src_lengths"]
tgt_lens = sample["target_lengths"]
_feat_out, _feat_out_post, _, log_dur_out, pitch_out, energy_out = model(
src_tokens=src_tokens,
src_lengths=src_lens,
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
incremental_state=None,
target_lengths=tgt_lens,
speaker=sample["speaker"],
durations=sample["durations"],
pitches=sample["pitches"],
energies=sample["energies"],
)
src_mask = lengths_to_mask(sample["net_input"]["src_lengths"])
tgt_mask = lengths_to_mask(sample["target_lengths"])
pitches, energies = sample["pitches"], sample["energies"]
pitch_out, pitches = pitch_out[src_mask], pitches[src_mask]
energy_out, energies = energy_out[src_mask], energies[src_mask]
feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask]
l1_loss = F.l1_loss(feat_out, feat, reduction=reduction)
if _feat_out_post is not None:
l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat, reduction=reduction)
pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction)
energy_loss = F.mse_loss(energy_out, energies, reduction=reduction)
log_dur_out = log_dur_out[src_mask]
dur = sample["durations"].float()
dur = dur.half() if log_dur_out.type().endswith(".HalfTensor") else dur
log_dur = torch.log(dur + 1)[src_mask]
dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction)
ctc_loss = torch.tensor(0.0).type_as(l1_loss)
if self.ctc_weight > 0.0:
lprobs = model.get_normalized_probs((_feat_out,), log_probs=True)
lprobs = lprobs.transpose(0, 1) # T x B x C
src_mask = lengths_to_mask(src_lens)
src_tokens_flat = src_tokens.masked_select(src_mask)
ctc_loss = (
F.ctc_loss(
lprobs,
src_tokens_flat,
tgt_lens,
src_lens,
reduction=reduction,
zero_infinity=True,
)
* self.ctc_weight
)
loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss
sample_size = sample["nsentences"]
logging_output = {
"loss": utils.item(loss.data),
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"l1_loss": utils.item(l1_loss.data),
"dur_loss": utils.item(dur_loss.data),
"pitch_loss": utils.item(pitch_loss.data),
"energy_loss": utils.item(energy_loss.data),
"ctc_loss": utils.item(ctc_loss.data),
}
return loss, sample_size, logging_output
@classmethod
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
ns = [log.get("sample_size", 0) for log in logging_outputs]
ntot = sum(ns)
ws = [n / (ntot + 1e-8) for n in ns]
for key in [
"loss",
"l1_loss",
"dur_loss",
"pitch_loss",
"energy_loss",
"ctc_loss",
]:
vals = [log.get(key, 0) for log in logging_outputs]
val = sum(val * w for val, w in zip(vals, ws))
metrics.log_scalar(key, val, ntot, round=3)
metrics.log_scalar("sample_size", ntot, len(logging_outputs))
# inference metrics
if "targ_frames" not in logging_outputs[0]:
return
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
for key, new_key in [
("mcd_loss", "mcd_loss"),
("pred_frames", "pred_ratio"),
("nins", "ins_rate"),
("ndel", "del_rate"),
]:
val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(new_key, val / n, n, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
return False

View File

@@ -0,0 +1,194 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import re
from dataclasses import dataclass, field
from typing import List, Optional
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
@dataclass
class HubertCriterionConfig(FairseqDataclass):
pred_masked_weight: float = field(
default=1.0,
metadata={"help": "weight for predictive loss for masked frames"},
)
pred_nomask_weight: float = field(
default=0.0,
metadata={"help": "weight for predictive loss for unmasked frames"},
)
loss_weights: Optional[List[float]] = field(
default=None,
metadata={"help": "weights for additional loss terms (not first one)"},
)
log_keys: List[str] = field(
default_factory=lambda: [],
metadata={"help": "output keys to log"},
)
@register_criterion("hubert", dataclass=HubertCriterionConfig)
class HubertCriterion(FairseqCriterion):
def __init__(
self,
task,
pred_masked_weight,
pred_nomask_weight,
loss_weights=None,
log_keys=None,
):
super().__init__(task)
self.pred_masked_weight = pred_masked_weight
self.pred_nomask_weight = pred_nomask_weight
self.loss_weights = loss_weights
self.log_keys = [] if log_keys is None else log_keys
def forward(self, model, sample, reduce=True, log_pred=False):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(target_list=sample["target_list"], **sample["net_input"])
loss = 0.0
sample_size = 0
logging_output = {}
reduction = "sum" if reduce else "none"
loss_m_list = []
logp_m_list = model.get_logits(net_output, True)
targ_m_list = model.get_targets(net_output, True)
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
loss_m_list.append(loss_m)
logging_output[f"loss_m_{i}"] = loss_m.detach().item()
if self.pred_masked_weight > 0:
loss += self.pred_masked_weight * sum(loss_m_list)
sample_size += targ_m_list[0].numel()
loss_u_list = []
logp_u_list = model.get_logits(net_output, False)
targ_u_list = model.get_targets(net_output, False)
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
loss_u_list.append(loss_u)
logging_output[f"loss_u_{i}"] = loss_u.detach().item()
if self.pred_nomask_weight > 0:
loss += self.pred_nomask_weight * sum(loss_u_list)
sample_size += targ_u_list[0].numel()
if self.loss_weights is not None:
assert hasattr(model, "get_extra_losses")
extra_losses, names = model.get_extra_losses(net_output)
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
names = [names]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(
self.loss_weights
), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, n, coef in zip(extra_losses, names, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
logging_output[f"loss_{n}"] = p.item()
logging_output = {
"loss": loss.item() if reduce else loss,
"ntokens": sample_size,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
**logging_output,
}
for lk in self.log_keys:
if lk in net_output:
logging_output[lk] = float((net_output[lk]))
def compute_correct(logits):
if logits.numel() == 0:
return 0, 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == 0
min = logits.argmin(-1) == 0
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
return corr, count
with torch.no_grad():
for i, logp_m in enumerate(logp_m_list):
corr_m, count_m = compute_correct(logp_m)
logging_output[f"correct_m_{i}"] = corr_m
logging_output[f"count_m_{i}"] = count_m
for i, logp_u in enumerate(logp_u_list):
corr_u, count_u = compute_correct(logp_u)
logging_output[f"correct_u_{i}"] = corr_u
logging_output[f"count_u_{i}"] = count_u
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
counts = {}
for lk in logging_outputs[0].keys():
if lk.startswith("count_"):
val = sum(log[lk] for log in logging_outputs)
metrics.log_scalar(lk, val)
counts[lk] = val
for lk in logging_outputs[0].keys():
if lk.startswith("loss_"):
val = sum(log[lk] for log in logging_outputs)
metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
elif lk.startswith("correct_"):
val = sum(log[lk] for log in logging_outputs)
metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
raise NotImplementedError()
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False

View File

@@ -0,0 +1,168 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
import torch
from omegaconf import II
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
@dataclass
class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
label_smoothing: float = field(
default=0.0,
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
)
report_accuracy: bool = field(
default=False,
metadata={"help": "report accuracy metric"},
)
ignore_prefix_size: int = field(
default=0,
metadata={"help": "Ignore first N tokens"},
)
sentence_avg: bool = II("optimization.sentence_avg")
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
nll_loss = -lprobs.gather(dim=-1, index=target)
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if ignore_index is not None:
pad_mask = target.eq(ignore_index)
nll_loss.masked_fill_(pad_mask, 0.0)
smooth_loss.masked_fill_(pad_mask, 0.0)
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
if reduce:
nll_loss = nll_loss.sum()
smooth_loss = smooth_loss.sum()
eps_i = epsilon / (lprobs.size(-1) - 1)
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
return loss, nll_loss
@register_criterion(
"label_smoothed_cross_entropy", dataclass=LabelSmoothedCrossEntropyCriterionConfig
)
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def __init__(
self,
task,
sentence_avg,
label_smoothing,
ignore_prefix_size=0,
report_accuracy=False,
):
super().__init__(task)
self.sentence_avg = sentence_avg
self.eps = label_smoothing
self.ignore_prefix_size = ignore_prefix_size
self.report_accuracy = report_accuracy
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
return loss, sample_size, logging_output
def get_lprobs_and_target(self, model, net_output, sample):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = model.get_targets(sample, net_output)
if self.ignore_prefix_size > 0:
# lprobs: B x T x C
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
target = target[:, self.ignore_prefix_size :].contiguous()
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
loss, nll_loss = label_smoothed_nll_loss(
lprobs,
target,
self.eps,
ignore_index=self.padding_idx,
reduce=reduce,
)
return loss, nll_loss
def compute_accuracy(self, model, net_output, sample):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
mask = target.ne(self.padding_idx)
n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
)
total = torch.sum(mask)
return n_correct, total
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
if total > 0:
metrics.log_scalar("total", total)
n_correct = utils.item(
sum(log.get("n_correct", 0) for log in logging_outputs)
)
metrics.log_scalar("n_correct", n_correct)
metrics.log_derived(
"accuracy",
lambda meters: round(
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
)
if meters["total"].sum > 0
else float("nan"),
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -0,0 +1,220 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
import torch
from fairseq import metrics, utils
from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterion,
LabelSmoothedCrossEntropyCriterionConfig,
)
try:
from simuleval.metrics.latency import (
AverageLagging,
AverageProportion,
DifferentiableAverageLagging,
)
LATENCY_METRICS = {
"average_lagging": AverageLagging,
"average_proportion": AverageProportion,
"differentiable_average_lagging": DifferentiableAverageLagging,
}
except ImportError:
LATENCY_METRICS = None
@dataclass
class LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig(
LabelSmoothedCrossEntropyCriterionConfig
):
latency_avg_weight: float = field(
default=0.0,
metadata={"help": "weight fot average latency loss."},
)
latency_var_weight: float = field(
default=0.0,
metadata={"help": "weight fot variance latency loss."},
)
latency_avg_type: str = field(
default="differentiable_average_lagging",
metadata={"help": "latency type for average loss"},
)
latency_var_type: str = field(
default="variance_delay",
metadata={"help": "latency typ for variance loss"},
)
latency_gather_method: str = field(
default="weighted_average",
metadata={"help": "method to gather latency loss for all heads"},
)
latency_update_after: int = field(
default=0,
metadata={"help": "Add latency loss after certain steps"},
)
@register_criterion(
"latency_augmented_label_smoothed_cross_entropy",
dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig,
)
class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
LabelSmoothedCrossEntropyCriterion
):
def __init__(
self,
task,
sentence_avg,
label_smoothing,
ignore_prefix_size,
report_accuracy,
latency_avg_weight,
latency_var_weight,
latency_avg_type,
latency_var_type,
latency_gather_method,
latency_update_after,
):
super().__init__(
task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy
)
assert LATENCY_METRICS is not None, "Please make sure SimulEval is installed."
self.latency_avg_weight = latency_avg_weight
self.latency_var_weight = latency_var_weight
self.latency_avg_type = latency_avg_type
self.latency_var_type = latency_var_type
self.latency_gather_method = latency_gather_method
self.latency_update_after = latency_update_after
def forward(self, model, sample, reduce=True):
net_output = model(**sample["net_input"])
# 1. Compute cross entropy loss
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
# 2. Compute cross latency loss
latency_loss, expected_latency, expected_delays_var = self.compute_latency_loss(
model, sample, net_output
)
if self.latency_update_after > 0:
num_updates = getattr(model.decoder, "num_updates", None)
assert (
num_updates is not None
), "model.decoder doesn't have attribute 'num_updates'"
if num_updates <= self.latency_update_after:
latency_loss = 0
loss += latency_loss
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
"latency": expected_latency,
"delays_var": expected_delays_var,
"latency_loss": latency_loss,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
return loss, sample_size, logging_output
def compute_latency_loss(self, model, sample, net_output):
assert (
net_output[-1].encoder_padding_mask is None
or not net_output[-1].encoder_padding_mask[:, 0].any()
), "Only right padding on source is supported."
# 1. Obtain the expected alignment
alpha_list = [item["alpha"] for item in net_output[1].attn_list]
num_layers = len(alpha_list)
bsz, num_heads, tgt_len, src_len = alpha_list[0].size()
# bsz * num_layers * num_heads, tgt_len, src_len
alpha_all = torch.cat(alpha_list, dim=1).view(-1, tgt_len, src_len)
# 2 compute expected delays
# bsz * num_heads * num_layers, tgt_len, src_len for MMA
steps = (
torch.arange(1, 1 + src_len)
.unsqueeze(0)
.unsqueeze(1)
.expand_as(alpha_all)
.type_as(alpha_all)
)
expected_delays = torch.sum(steps * alpha_all, dim=-1)
target_padding_mask = (
model.get_targets(sample, net_output)
.eq(self.padding_idx)
.unsqueeze(1)
.expand(bsz, num_layers * num_heads, tgt_len)
.contiguous()
.view(-1, tgt_len)
)
src_lengths = (
sample["net_input"]["src_lengths"]
.unsqueeze(1)
.expand(bsz, num_layers * num_heads)
.contiguous()
.view(-1)
)
expected_latency = LATENCY_METRICS[self.latency_avg_type](
expected_delays, src_lengths, None, target_padding_mask=target_padding_mask
)
# 2.1 average expected latency of heads
# bsz, num_layers * num_heads
expected_latency = expected_latency.view(bsz, -1)
if self.latency_gather_method == "average":
# bsz * tgt_len
expected_latency = expected_delays.mean(dim=1)
elif self.latency_gather_method == "weighted_average":
weights = torch.nn.functional.softmax(expected_latency, dim=1)
expected_latency = torch.sum(expected_latency * weights, dim=1)
elif self.latency_gather_method == "max":
expected_latency = expected_latency.max(dim=1)[0]
else:
raise NotImplementedError
expected_latency = expected_latency.sum()
avg_loss = self.latency_avg_weight * expected_latency
# 2.2 variance of expected delays
expected_delays_var = (
expected_delays.view(bsz, -1, tgt_len).var(dim=1).mean(dim=1)
)
expected_delays_var = expected_delays_var.sum()
var_loss = self.latency_avg_weight * expected_delays_var
# 3. Final loss
latency_loss = avg_loss + var_loss
return latency_loss, expected_latency, expected_delays_var
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
super().reduce_metrics(logging_outputs)
latency = sum(log.get("latency", 0) for log in logging_outputs)
delays_var = sum(log.get("delays_var", 0) for log in logging_outputs)
latency_loss = sum(log.get("latency_loss", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
metrics.log_scalar("latency", latency.float() / nsentences, nsentences, round=3)
metrics.log_scalar("delays_var", delays_var / nsentences, nsentences, round=3)
metrics.log_scalar(
"latency_loss", latency_loss / nsentences, nsentences, round=3
)

View File

@@ -0,0 +1,130 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from fairseq import metrics, utils
from fairseq.criterions import register_criterion
from .label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterion,
LabelSmoothedCrossEntropyCriterionConfig,
)
from dataclasses import dataclass, field
@dataclass
class LabelSmoothedCrossEntropyCriterionWithAlignmentConfig(
LabelSmoothedCrossEntropyCriterionConfig
):
alignment_lambda: float = field(
default=0.05, metadata={"help": "weight for the alignment loss"}
)
@register_criterion(
"label_smoothed_cross_entropy_with_alignment",
dataclass=LabelSmoothedCrossEntropyCriterionWithAlignmentConfig,
)
class LabelSmoothedCrossEntropyCriterionWithAlignment(
LabelSmoothedCrossEntropyCriterion
):
def __init__(self, task, sentence_avg, label_smoothing, alignment_lambda):
super().__init__(task, sentence_avg, label_smoothing)
self.alignment_lambda = alignment_lambda
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
alignment_loss = None
# Compute alignment loss only for training set and non dummy batches.
if "alignments" in sample and sample["alignments"] is not None:
alignment_loss = self.compute_alignment_loss(sample, net_output)
if alignment_loss is not None:
logging_output["alignment_loss"] = utils.item(alignment_loss.data)
loss += self.alignment_lambda * alignment_loss
return loss, sample_size, logging_output
def compute_alignment_loss(self, sample, net_output):
attn_prob = net_output[1]["attn"][0]
bsz, tgt_sz, src_sz = attn_prob.shape
attn = attn_prob.view(bsz * tgt_sz, src_sz)
align = sample["alignments"]
align_weights = sample["align_weights"].float()
if len(align) > 0:
# Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to
# the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing.
loss = -(
(attn[align[:, 1][:, None], align[:, 0][:, None]]).log()
* align_weights[:, None]
).sum()
else:
return None
return loss
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
nll_loss_sum = utils.item(
sum(log.get("nll_loss", 0) for log in logging_outputs)
)
alignment_loss_sum = utils.item(
sum(log.get("alignment_loss", 0) for log in logging_outputs)
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_scalar(
"alignment_loss",
alignment_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -0,0 +1,96 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterion,
LabelSmoothedCrossEntropyCriterionConfig,
)
from fairseq.data.data_utils import lengths_to_mask
@dataclass
class LabelSmoothedCrossEntropyWithCtcCriterionConfig(
LabelSmoothedCrossEntropyCriterionConfig
):
ctc_weight: float = field(default=1.0, metadata={"help": "weight for CTC loss"})
@register_criterion(
"label_smoothed_cross_entropy_with_ctc",
dataclass=LabelSmoothedCrossEntropyWithCtcCriterionConfig,
)
class LabelSmoothedCrossEntropyWithCtcCriterion(LabelSmoothedCrossEntropyCriterion):
def __init__(
self,
task,
sentence_avg,
label_smoothing,
ignore_prefix_size,
report_accuracy,
ctc_weight,
):
super().__init__(
task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy
)
self.ctc_weight = ctc_weight
def forward(self, model, sample, reduce=True):
net_output = model(**sample["net_input"])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
ctc_loss = torch.tensor(0.0).type_as(loss)
if self.ctc_weight > 0.0:
ctc_lprobs, ctc_lens = model.get_ctc_output(net_output, sample)
ctc_tgt, ctc_tgt_lens = model.get_ctc_target(sample)
ctc_tgt_mask = lengths_to_mask(ctc_tgt_lens)
ctc_tgt_flat = ctc_tgt.masked_select(ctc_tgt_mask)
reduction = "sum" if reduce else "none"
ctc_loss = (
F.ctc_loss(
ctc_lprobs,
ctc_tgt_flat,
ctc_lens,
ctc_tgt_lens,
reduction=reduction,
zero_infinity=True,
)
* self.ctc_weight
)
loss += ctc_loss
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data),
"nll_loss": utils.item(nll_loss.data),
"ctc_loss": utils.item(ctc_loss.data),
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
return loss, sample_size, logging_output
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
super().reduce_metrics(logging_outputs)
loss_sum = sum(log.get("ctc_loss", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"ctc_loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)

View File

@@ -0,0 +1,176 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
import torch
from fairseq import metrics, utils
from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterion,
LabelSmoothedCrossEntropyCriterionConfig,
label_smoothed_nll_loss,
)
@dataclass
class RdropLabelSmoothedCrossEntropyCriterionConfig(
LabelSmoothedCrossEntropyCriterionConfig
):
rdrop_alpha: float = field(
default=0.0,
metadata={"help": "alpha for r-drop, 0 means no r-drop"},
)
@register_criterion(
"label_smoothed_cross_entropy_with_rdrop",
dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig,
)
class RdropLabelSmoothedCrossEntropyCriterion(LabelSmoothedCrossEntropyCriterion):
def __init__(
self,
task,
sentence_avg,
label_smoothing,
ignore_prefix_size=0,
report_accuracy=False,
rdrop_alpha=0.0,
):
super().__init__(
task,
sentence_avg,
label_smoothing,
ignore_prefix_size=ignore_prefix_size,
report_accuracy=report_accuracy,
)
self.sentence_avg = sentence_avg
self.eps = label_smoothing
self.ignore_prefix_size = ignore_prefix_size
self.report_accuracy = report_accuracy
self.rdrop_alpha = rdrop_alpha
def forward(self, model, sample, reduce=True, net_output=None):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
if net_output is None:
if self.rdrop_alpha > 0 and sample["net_input"]["src_tokens"].size(
0
) == sample["target"].size(0):
sample = duplicate_input(sample)
net_output = model(**sample["net_input"])
loss, nll_loss, rdrop_kl_loss = self.compute_loss(
model, net_output, sample, reduce=reduce
)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
if self.rdrop_alpha > 0:
logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data)
return loss, sample_size, logging_output
def get_lprobs_and_target(self, model, net_output, sample):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = model.get_targets(sample, net_output)
if self.rdrop_alpha > 0 or target.size(0) != lprobs.size(0):
target = torch.cat([target, target.clone()], dim=0)
if self.ignore_prefix_size > 0:
# lprobs: B x T x C
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
target = target[:, self.ignore_prefix_size :].contiguous()
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
loss, nll_loss = label_smoothed_nll_loss(
lprobs,
target,
self.eps,
ignore_index=self.padding_idx,
reduce=reduce,
)
if self.rdrop_alpha > 0:
pad_mask = target[: target.size(0) // 2].unsqueeze(-1).eq(self.padding_idx)
rdrop_kl_loss = compute_kl_loss(model, net_output, pad_mask)
loss += self.rdrop_alpha * rdrop_kl_loss
else:
rdrop_kl_loss = loss.new_zeros(1)
return loss, nll_loss, rdrop_kl_loss
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
super().reduce_metrics(logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
rdrop_kl_loss = utils.item(
sum(log.get("rdrop_kl_loss", 0) for log in logging_outputs)
/ sample_size
/ math.log(2)
)
if rdrop_kl_loss > 0:
metrics.log_scalar("rdrop_kl_loss", rdrop_kl_loss)
def duplicate_input(sample):
if "net_input" in sample.keys():
sample_input = sample["net_input"]
else:
sample_input = sample
for k, v in sample_input.items():
if isinstance(v, torch.Tensor):
sample_input[k] = torch.cat([v, v.clone()], dim=0)
if "net_input" in sample.keys():
sample["net_input"] = sample_input
else:
sample = sample_input
return sample
def compute_kl_loss(model, net_output, pad_mask=None, reduce=True):
net_prob = model.get_normalized_probs(net_output, log_probs=True)
net_prob_tec = model.get_normalized_probs(net_output, log_probs=False)
net_prob = net_prob.view(-1, net_prob.size(-1))
net_prob_tec = net_prob_tec.view(-1, net_prob_tec.size(-1))
p, q = torch.split(net_prob, net_prob.size(0) // 2, dim=0)
p_tec, q_tec = torch.split(net_prob_tec, net_prob_tec.size(0) // 2, dim=0)
p_loss = torch.nn.functional.kl_div(p, q_tec, reduction="none")
q_loss = torch.nn.functional.kl_div(q, p_tec, reduction="none")
if pad_mask is not None:
p_loss.masked_fill_(pad_mask, 0.0)
q_loss.masked_fill_(pad_mask, 0.0)
if reduce:
p_loss = p_loss.sum()
q_loss = q_loss.sum()
loss = (p_loss + q_loss) / 2
return loss

View File

@@ -0,0 +1,177 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
def compute_cross_entropy_loss(logits, targets, ignore_index=-100):
"""
Function to compute the cross entropy loss. The default value of
ignore_index is the same as the default value for F.cross_entropy in
pytorch.
"""
assert logits.size(0) == targets.size(
-1
), "Logits and Targets tensor shapes don't match up"
loss = F.nll_loss(
F.log_softmax(logits, -1, dtype=torch.float32),
targets,
reduction="sum",
ignore_index=ignore_index,
)
return loss
@register_criterion("legacy_masked_lm_loss")
class LegacyMaskedLmLoss(FairseqCriterion):
"""
Implementation for the loss used in masked language model (MLM) training.
This optionally also computes the next sentence prediction (NSP) loss and
adds it to the overall loss based on the specified args. There are three
cases to consider:
1) Generic MLM training without NSP loss. In this case sentence_targets
and sentence_logits are both None.
2) BERT training without NSP loss. In this case sentence_targets is
not None but sentence_logits is None and we should not be computing
a sentence level loss.
3) BERT training with NSP loss. In this case both sentence_targets and
sentence_logits are not None and we should be computing a sentence
level loss. The weight of the sentence level loss is specified as
an argument.
"""
def __init__(self, task, masked_lm_only, nsp_loss_weight):
super().__init__(task)
self.masked_lm_only = masked_lm_only
self.nsp_loss_weight = nsp_loss_weight
@staticmethod
def add_args(parser):
"""Args for MaskedLM Loss"""
# Default for masked_lm_only is False so as to not break BERT training
parser.add_argument(
"--masked-lm-only",
default=False,
action="store_true",
help="compute MLM loss only",
)
parser.add_argument(
"--nsp-loss-weight",
default=1.0,
type=float,
help="weight for next sentence prediction" " loss (default 1)",
)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
lm_logits, output_metadata = model(**sample["net_input"])
# reshape lm_logits from (N,T,C) to (N*T,C)
lm_logits = lm_logits.view(-1, lm_logits.size(-1))
lm_targets = sample["lm_target"].view(-1)
lm_loss = compute_cross_entropy_loss(lm_logits, lm_targets, self.padding_idx)
# compute the number of tokens for which loss is computed. This is used
# to normalize the loss
ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel()
loss = lm_loss / ntokens
nsentences = sample["nsentences"]
# nsentences = 0
# Compute sentence loss if masked_lm_only is False
sentence_loss = None
if not self.masked_lm_only:
sentence_logits = output_metadata["sentence_logits"]
sentence_targets = sample["sentence_target"].view(-1)
# This needs to be recomputed due to some differences between
# TokenBlock and BlockPair dataset. This can be resolved with a
# refactor of BERTModel which we will do in the future.
# TODO: Remove this after refactor of BERTModel
nsentences = sentence_targets.size(0)
# Check for logits being none which can happen when remove_heads
# is set to true in the BERT model. Ideally we should set
# masked_lm_only to true in this case, but that requires some
# refactor in the BERT model.
if sentence_logits is not None:
sentence_loss = compute_cross_entropy_loss(
sentence_logits, sentence_targets
)
loss += self.nsp_loss_weight * (sentence_loss / nsentences)
# NOTE: as we are summing up per token mlm loss and per sentence nsp loss
# we don't need to use sample_size as denominator for the gradient
# here sample_size is just used for logging
sample_size = 1
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"lm_loss": utils.item(lm_loss.data) if reduce else lm_loss.data,
# sentence loss is not always computed
"sentence_loss": (
(utils.item(sentence_loss.data) if reduce else sentence_loss.data)
if sentence_loss is not None
else 0.0
),
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
lm_loss_sum = sum(log.get("lm_loss", 0) for log in logging_outputs)
sentence_loss_sum = sum(log.get("sentence_loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
agg_loss = sum(log.get("loss", 0) for log in logging_outputs)
metrics.log_scalar(
"loss",
agg_loss / sample_size / math.log(2) if sample_size > 0 else 0.0,
sample_size,
round=3,
)
metrics.log_scalar(
"lm_loss",
lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0,
ntokens,
round=3,
)
metrics.log_scalar(
"sentence_loss",
sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0.0,
nsentences,
round=3,
)
metrics.log_scalar(
"nll_loss",
lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0,
ntokens,
round=3,
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -0,0 +1,98 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
import math
from omegaconf import II
import torch
from fairseq import metrics, modules, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
@dataclass
class MaskedLmConfig(FairseqDataclass):
tpu: bool = II("common.tpu")
@register_criterion("masked_lm", dataclass=MaskedLmConfig)
class MaskedLmLoss(FairseqCriterion):
"""
Implementation for the loss used in masked language model (MLM) training.
"""
def __init__(self, cfg: MaskedLmConfig, task):
super().__init__(task)
self.tpu = cfg.tpu
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
masked_tokens = sample["target"].ne(self.padding_idx)
sample_size = masked_tokens.int().sum()
# Rare: when all tokens are masked, project all tokens.
# We use torch.where to avoid device-to-host transfers,
# except on CPU where torch.where is not well supported
# (see github.com/pytorch/pytorch/issues/26247).
if self.tpu:
masked_tokens = None # always project all tokens on TPU
elif masked_tokens.device == torch.device("cpu"):
if not masked_tokens.any():
masked_tokens = None
else:
masked_tokens = torch.where(
masked_tokens.any(),
masked_tokens,
masked_tokens.new([True]),
)
logits = model(**sample["net_input"], masked_tokens=masked_tokens)[0]
targets = model.get_targets(sample, [logits])
if masked_tokens is not None:
targets = targets[masked_tokens]
loss = modules.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
reduction="sum",
ignore_index=self.padding_idx,
)
logging_output = {
"loss": loss if self.tpu else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -0,0 +1,155 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass, field
from typing import Dict, List
import torch
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
logger = logging.getLogger(__name__)
@dataclass
class ModelCriterionConfig(FairseqDataclass):
loss_weights: Dict[str, float] = field(
default_factory=dict,
metadata={"help": "weights for the loss terms"},
)
log_keys: List[str] = field(
default_factory=list,
metadata={"help": "additional output keys to log"},
)
@register_criterion("model", dataclass=ModelCriterionConfig)
class ModelCriterion(FairseqCriterion):
"""
This criterion relies on the model to supply losses.
The losses should be a dictionary of name -> scalar returned by
the model either by including it in the net_output dict or by
implementing a get_losses(net_output, sample) method. The final loss is
a scaled sum of all losses according to weights in loss_weights.
If no weights are provided, then all losses are scaled by 1.0.
The losses will be automatically logged. Additional keys from
net_output dict can be logged via the log_keys parameter.
"""
def __init__(self, task, loss_weights=None, log_keys=None):
super().__init__(task)
self.loss_weights = loss_weights
self.log_keys = log_keys
def forward(self, model, sample, reduce=True):
net_output = model(**sample["net_input"])
scaled_losses = {}
if hasattr(model, "get_losses"):
losses = model.get_losses(net_output, sample)
elif isinstance(net_output, dict) and "losses" in net_output:
losses = net_output["losses"]
else:
raise Exception("Could not retrieve losses")
for lk, p in losses.items():
try:
coef = 1.0 if len(self.loss_weights) == 0 else self.loss_weights[lk]
except KeyError:
logger.error(
f"weight for loss {lk} is not in loss_weights ({self.loss_weights})"
)
raise
if coef != 0 and p is not None:
scaled_losses[lk] = coef * p.float()
loss = sum(scaled_losses.values())
if "sample_size" in net_output:
sample_size = net_output["sample_size"]
else:
sample_size = loss.numel()
if reduce and loss.numel() > 1:
loss = loss.sum()
logging_output = {
"loss": loss.data,
"ntokens": sample_size,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
"_world_size": 1,
}
for lk in self.log_keys:
if lk in net_output and net_output[lk] is not None:
if not torch.is_tensor(net_output[lk]) or net_output[lk].numel() == 1:
logging_output[lk] = float(net_output[lk])
else:
for i, v in enumerate(net_output[lk]):
logging_output[f"{lk}_{i}"] = float(v)
if len(scaled_losses) > 1:
for lk, l in scaled_losses.items():
if l.numel() > 1:
l = l.sum()
logging_output[f"loss_{lk}"] = l.item()
if "logs" in net_output:
for lgw in net_output["logs"]:
logging_output[lgw] = net_output["logs"][lgw]
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
nsentences = utils.item(
sum(log.get("nsentences", 0) for log in logging_outputs)
)
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3)
metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences)
builtin_keys = {
"loss",
"ntokens",
"nsentences",
"sample_size",
"_world_size",
}
world_size = utils.item(
sum(log.get("_world_size", 0) for log in logging_outputs)
)
for k in logging_outputs[0]:
if k not in builtin_keys:
val = sum(log.get(k, 0) for log in logging_outputs)
if k.startswith("loss_"):
metrics.log_scalar(k, val / sample_size, sample_size, round=3)
else:
metrics.log_scalar(k, val / world_size, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -0,0 +1,180 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from torch import Tensor
from dataclasses import dataclass, field
@dataclass
class LabelSmoothedDualImitationCriterionConfig(FairseqDataclass):
label_smoothing: float = field(
default=0.0,
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
)
@register_criterion("nat_loss", dataclass=LabelSmoothedDualImitationCriterionConfig)
class LabelSmoothedDualImitationCriterion(FairseqCriterion):
def __init__(self, task, label_smoothing):
super().__init__(task)
self.label_smoothing = label_smoothing
def _compute_loss(
self, outputs, targets, masks=None, label_smoothing=0.0, name="loss", factor=1.0
):
"""
outputs: batch x len x d_model
targets: batch x len
masks: batch x len
policy_logprob: if there is some policy
depends on the likelihood score as rewards.
"""
def mean_ds(x: Tensor, dim=None) -> Tensor:
return (
x.float().mean().type_as(x)
if dim is None
else x.float().mean(dim).type_as(x)
)
if masks is not None:
outputs, targets = outputs[masks], targets[masks]
if masks is not None and not masks.any():
nll_loss = torch.tensor(0)
loss = nll_loss
else:
logits = F.log_softmax(outputs, dim=-1)
if targets.dim() == 1:
losses = F.nll_loss(logits, targets.to(logits.device), reduction="none")
else: # soft-labels
losses = F.kl_div(logits, targets.to(logits.device), reduction="none")
losses = losses.sum(-1)
nll_loss = mean_ds(losses)
if label_smoothing > 0:
loss = (
nll_loss * (1 - label_smoothing) - mean_ds(logits) * label_smoothing
)
else:
loss = nll_loss
loss = loss * factor
return {"name": name, "loss": loss, "nll_loss": nll_loss, "factor": factor}
def _custom_loss(self, loss, name="loss", factor=1.0):
return {"name": name, "loss": loss, "factor": factor}
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
nsentences, ntokens = sample["nsentences"], sample["ntokens"]
# B x T
src_tokens, src_lengths = (
sample["net_input"]["src_tokens"],
sample["net_input"]["src_lengths"],
)
tgt_tokens, prev_output_tokens = sample["target"], sample["prev_target"]
outputs = model(src_tokens, src_lengths, prev_output_tokens, tgt_tokens)
losses, nll_loss = [], []
for obj in outputs:
if outputs[obj].get("loss", None) is None:
_losses = self._compute_loss(
outputs[obj].get("out"),
outputs[obj].get("tgt"),
outputs[obj].get("mask", None),
outputs[obj].get("ls", 0.0),
name=obj + "-loss",
factor=outputs[obj].get("factor", 1.0),
)
else:
_losses = self._custom_loss(
outputs[obj].get("loss"),
name=obj + "-loss",
factor=outputs[obj].get("factor", 1.0),
)
losses += [_losses]
if outputs[obj].get("nll_loss", False):
nll_loss += [_losses.get("nll_loss", 0.0)]
loss = sum(l["loss"] for l in losses)
nll_loss = sum(l for l in nll_loss) if len(nll_loss) > 0 else loss.new_tensor(0)
# NOTE:
# we don't need to use sample_size as denominator for the gradient
# here sample_size is just used for logging
sample_size = 1
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
for l in losses:
logging_output[l["name"]] = (
utils.item(l["loss"].data / l["factor"])
if reduce
else l[["loss"]].data / l["factor"]
)
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
loss = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
nll_loss = utils.item(sum(log.get("nll_loss", 0) for log in logging_outputs))
metrics.log_scalar(
"loss", loss / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar(
"nll_loss", nll_loss / sample_size / math.log(2), sample_size, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
for key in logging_outputs[0]:
if key[-5:] == "-loss":
val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(
key[:-5],
val / sample_size / math.log(2) if sample_size > 0 else 0.0,
sample_size,
round=3,
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -0,0 +1,141 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from fairseq import metrics
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
@dataclass
class SentencePredictionConfig(FairseqDataclass):
classification_head_name: str = field(
default="sentence_classification_head",
metadata={"help": "name of the classification head to use"},
)
regression_target: bool = field(
default=False,
)
@register_criterion("sentence_prediction", dataclass=SentencePredictionConfig)
class SentencePredictionCriterion(FairseqCriterion):
def __init__(self, cfg: SentencePredictionConfig, task):
super().__init__(task)
self.classification_head_name = cfg.classification_head_name
self.regression_target = cfg.regression_target
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert (
hasattr(model, "classification_heads")
and self.classification_head_name in model.classification_heads
), "model must provide sentence classification head for --criterion=sentence_prediction"
logits, _ = model(
**sample["net_input"],
features_only=True,
classification_head_name=self.classification_head_name,
)
targets = model.get_targets(sample, [logits]).view(-1)
sample_size = targets.numel()
if not self.regression_target:
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
task_loss = F.nll_loss(lprobs, targets, reduction="sum")
else:
logits = logits.view(-1).float()
targets = targets.float()
task_loss = F.mse_loss(logits, targets, reduction="sum")
logging_output = {}
loss = task_loss
# mha & ffn regularization update
if (
hasattr(model.args, "mha_reg_scale_factor")
and model.args.mha_reg_scale_factor != 0.0
):
mha_reg_loss = model._get_adaptive_head_loss()
loss += mha_reg_loss
logging_output.update({"mha_reg_loss": mha_reg_loss})
if (
hasattr(model.args, "ffn_reg_scale_factor")
and model.args.ffn_reg_scale_factor != 0.0
):
ffn_reg_loss = model._get_adaptive_ffn_loss()
loss += ffn_reg_loss
logging_output.update({"ffn_reg_loss": ffn_reg_loss})
logging_output.update(
{
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample_size,
"sample_size": sample_size,
}
)
if not self.regression_target:
preds = logits.argmax(dim=1)
logging_output["ncorrect"] = (preds == targets).sum()
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
mha_reg_loss_sum = sum(log.get("mha_reg_loss", 0) for log in logging_outputs)
ffn_reg_loss_sum = sum(log.get("ffn_reg_loss", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if mha_reg_loss_sum:
metrics.log_scalar(
"mha_reg_loss",
mha_reg_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if ffn_reg_loss_sum:
metrics.log_scalar(
"ffn_reg_loss",
ffn_reg_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
metrics.log_scalar(
"accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -0,0 +1,63 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from fairseq.criterions import register_criterion
from fairseq.criterions.sentence_prediction import (
SentencePredictionCriterion,
SentencePredictionConfig,
)
@register_criterion("sentence_prediction_adapters", dataclass=SentencePredictionConfig)
class SentencePredictionCriterionAdapters(SentencePredictionCriterion):
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert (
hasattr(model, "classification_heads")
and self.classification_head_name in model.classification_heads
), "model must provide sentence classification head for --criterion=sentence_prediction"
if not hasattr(sample, "lang_id"):
# If no language ID is given, we fall back to English
lang_id = ["en_XX"] * sample["nsentences"]
else:
lang_id = sample["lang_id"]
logits, _ = model(
**sample["net_input"],
features_only=True,
classification_head_name=self.classification_head_name,
lang_id=lang_id,
)
targets = model.get_targets(sample, [logits]).view(-1)
sample_size = targets.numel()
if not self.regression_target:
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
loss = F.nll_loss(lprobs, targets, reduction="sum")
else:
logits = logits.view(-1).float()
targets = targets.float()
loss = F.mse_loss(logits, targets, reduction="sum")
logging_output = {
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample_size,
"sample_size": sample_size,
}
if not self.regression_target:
preds = logits.argmax(dim=1)
logging_output["ncorrect"] = (preds == targets).sum()
return loss, sample_size, logging_output

View File

@@ -0,0 +1,120 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
@register_criterion("sentence_ranking")
class SentenceRankingCriterion(FairseqCriterion):
def __init__(self, task, ranking_head_name, save_predictions, num_classes):
super().__init__(task)
self.ranking_head_name = ranking_head_name
if save_predictions is not None:
self.prediction_h = open(save_predictions, "w")
else:
self.prediction_h = None
self.num_classes = num_classes
def __del__(self):
if self.prediction_h is not None:
self.prediction_h.close()
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--save-predictions', metavar='FILE',
help='file to save predictions to')
parser.add_argument('--ranking-head-name',
default='sentence_classification_head',
help='name of the ranking head to use')
# fmt: on
def forward(self, model, sample, reduce=True):
"""Compute ranking loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert (
hasattr(model, "classification_heads")
and self.ranking_head_name in model.classification_heads
), "model must provide sentence ranking head for --criterion=sentence_ranking"
scores = []
for idx in range(self.num_classes):
score, _ = model(
**sample["net_input{idx}".format(idx=idx + 1)],
classification_head_name=self.ranking_head_name,
)
scores.append(score)
logits = torch.cat(scores, dim=1)
sample_size = logits.size(0)
if "target" in sample:
targets = model.get_targets(sample, [logits]).view(-1)
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
loss = F.nll_loss(lprobs, targets, reduction="sum")
else:
targets = None
loss = torch.tensor(0.0, requires_grad=True)
if self.prediction_h is not None:
preds = logits.argmax(dim=1)
for i, (id, pred) in enumerate(zip(sample["id"].tolist(), preds.tolist())):
if targets is not None:
label = targets[i].item()
print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)
else:
print("{}\t{}".format(id, pred), file=self.prediction_h)
logging_output = {
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample_size,
"sample_size": sample_size,
}
if targets is not None:
logging_output["ncorrect"] = (logits.argmax(dim=1) == targets).sum()
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
metrics.log_scalar(
"accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -0,0 +1,516 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from collections import OrderedDict
import torch
from fairseq import metrics, utils
from fairseq.criterions import register_criterion
from fairseq.criterions.ctc import CtcCriterion
from fairseq.criterions.label_smoothed_cross_entropy_with_rdrop import (
RdropLabelSmoothedCrossEntropyCriterion,
RdropLabelSmoothedCrossEntropyCriterionConfig,
duplicate_input,
)
from fairseq.criterions.tacotron2_loss import (
Tacotron2Criterion,
Tacotron2CriterionConfig,
)
logger = logging.getLogger(__name__)
class MultitaskCriterion:
def __init__(self, multitask_tasks, rdrop_alpha=0.0):
self.rdrop_alpha = rdrop_alpha
self.rdrop_alpha_mtl = rdrop_alpha
self.multitask_criterion = OrderedDict()
self.multitask_loss_weight = OrderedDict()
for task_name, task_obj in multitask_tasks.items():
if task_obj.args.get_loss_weight(0) == 0:
logger.info(f"Skip {task_name} loss criterion")
continue
rdrop_alpha_task = task_obj.args.rdrop_alpha
if rdrop_alpha_task is None:
rdrop_alpha_task = rdrop_alpha
self.rdrop_alpha_mtl = rdrop_alpha_task
logger.info(f"rdrop_alpha is set to {rdrop_alpha_task} for {task_name}")
if task_obj.args.decoder_type == "ctc":
self.multitask_criterion[task_name] = CtcCriterion(
task_obj.args.criterion_cfg,
task_obj,
rdrop_alpha=rdrop_alpha_task,
)
else:
self.multitask_criterion[
task_name
] = RdropLabelSmoothedCrossEntropyCriterion(
task_obj,
task_obj.args.criterion_cfg.sentence_avg,
label_smoothing=task_obj.args.criterion_cfg.label_smoothing,
rdrop_alpha=rdrop_alpha_task,
)
def set_multitask_loss_weight(self, task_name, weight=0.0):
self.multitask_loss_weight[task_name] = weight
def get_multitask_loss(self, model, sample, model_out):
logging_output = {}
loss = 0.0
for task_name, task_criterion in self.multitask_criterion.items():
layer_id = task_criterion.task.args.input_layer
if isinstance(task_criterion, CtcCriterion):
if task_criterion.task.args.input_from == "encoder":
if len(model_out["encoder_padding_mask"]) > 0:
non_padding_mask = ~model_out["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1)
else:
out = model_out["encoder_states"][layer_id]
input_lengths = out.new_full(
(out.shape[1],), out.shape[0]
).long()
task_sample = {
"net_input": {
"src_tokens": model_out["encoder_states"][
layer_id
], # check batch idx
"src_lengths": input_lengths,
},
"id": sample["id"],
}
else:
task_sample = {
"net_input": {
"src_tokens": model_out["inner_states"][layer_id],
"src_lengths": sample["target_lengths"],
},
"id": sample["id"],
}
else:
task_sample = {
"net_input": {
"src_tokens": sample["multitask"][task_name]["net_input"][
"prev_output_tokens"
],
"encoder_out": {
"encoder_out": [model_out["encoder_states"][layer_id]],
"encoder_padding_mask": model_out["encoder_padding_mask"],
},
}
}
for key in ["target", "target_lengths", "ntokens"]:
task_sample[key] = sample["multitask"][task_name][key]
if task_name == getattr(model, "mt_task_name", None):
decoder_out = model_out["mt_decoder_out"]
else:
decoder_out = None
task_loss, task_sample_size, task_logging_output = task_criterion(
model.multitask_decoders[task_name], task_sample, net_output=decoder_out
)
loss = loss + self.multitask_loss_weight[task_name] * task_loss
task_logging_output["loss_weight"] = self.multitask_loss_weight[task_name]
logging_output[task_name] = task_logging_output
return loss, logging_output
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
for task_name in logging_outputs[0]["multitask"].keys():
# different criterion may return different logging
# currently only reduce on loss, the most common one
# ideally the way that losses are reduced should also depend on the task type
loss_sum = sum(
log["multitask"][task_name].get("loss", 0) for log in logging_outputs
)
sample_size = sum(
log["multitask"][task_name].get("sample_size", 0)
for log in logging_outputs
)
metrics.log_scalar(
f"multitask_{task_name}_loss",
loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
loss_weight = logging_outputs[0]["multitask"][task_name].get(
"loss_weight", 0
)
metrics.log_scalar(
f"multitask_{task_name}_loss_weight",
loss_weight,
weight=0,
priority=250,
)
@register_criterion(
"speech_to_unit", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig
)
class SpeechToUnitMultitaskTaskCriterion(
RdropLabelSmoothedCrossEntropyCriterion, MultitaskCriterion
):
def __init__(
self,
task,
sentence_avg,
label_smoothing,
ignore_prefix_size=0,
report_accuracy=False,
rdrop_alpha=0.0,
):
super().__init__(
task,
sentence_avg,
label_smoothing,
ignore_prefix_size,
report_accuracy,
rdrop_alpha,
)
MultitaskCriterion.__init__(self, task.multitask_tasks, rdrop_alpha)
def forward(self, model, sample, reduce=True):
net_input_concat = {
"src_tokens": sample["net_input"]["src_tokens"],
"src_lengths": sample["net_input"]["src_lengths"],
"prev_output_tokens": sample["net_input"]["prev_output_tokens"],
"tgt_speaker": sample["net_input"].get("tgt_speaker", None),
"return_all_hiddens": True,
}
if self.rdrop_alpha > 0 or self.rdrop_alpha_mtl > 0:
net_input_concat = duplicate_input(net_input_concat)
net_output, extra = model(**net_input_concat)
loss, nll_loss, rdrop_kl_loss = self.compute_loss(
model, [net_output], sample, reduce=reduce
)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, [net_output], sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
if self.rdrop_alpha > 0:
logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data)
if len(self.multitask_criterion) == 0:
return loss, sample_size, logging_output
# multitask
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
loss += multitask_loss
logging_output["multitask"] = multitask_log
return loss, sample_size, logging_output
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
super().reduce_metrics(logging_outputs)
# inference metrics
if "targ_frames" in logging_outputs[0]:
n = sum(log.get("norm_frames", 0) for log in logging_outputs)
for key, new_key in [
("mcd_loss", "mcd_loss"),
("pred_frames", "pred_ratio"),
("nins", "ins_rate"),
("ndel", "del_rate"),
]:
val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(new_key, val / n, n, round=3)
if "multitask" not in logging_outputs[0]:
return
MultitaskCriterion.reduce_metrics(logging_outputs)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
@register_criterion(
"speech_to_unit_2pass", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig
)
class SpeechToUnit2passMultitaskTaskCriterion(SpeechToUnitMultitaskTaskCriterion):
def __init__(
self,
task,
sentence_avg,
label_smoothing,
ignore_prefix_size=0,
report_accuracy=False,
rdrop_alpha=0.0,
):
super().__init__(
task,
sentence_avg,
label_smoothing,
ignore_prefix_size,
report_accuracy,
rdrop_alpha,
)
def forward(self, model, sample, reduce=True):
net_input_concat = {
"src_tokens": sample["net_input"]["src_tokens"],
"src_lengths": sample["net_input"]["src_lengths"],
"prev_output_tokens": sample["net_input"]["prev_output_tokens"],
"prev_output_tokens_mt": sample["multitask"][model.mt_task_name][
"net_input"
]["prev_output_tokens"],
"tgt_speaker": sample["net_input"].get("tgt_speaker", None),
"return_all_hiddens": True,
}
if getattr(model, "asr_task_name", None) is not None:
net_input_concat["prev_output_tokens_asr"] = sample["multitask"][
model.asr_task_name
]["net_input"]["prev_output_tokens"]
if self.rdrop_alpha > 0 or self.rdrop_alpha_mtl > 0:
net_input_concat = duplicate_input(net_input_concat)
net_output, extra = model(**net_input_concat)
loss, nll_loss, rdrop_kl_loss = self.compute_loss(
model, [net_output], sample, reduce=reduce
)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, [net_output], sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
if self.rdrop_alpha > 0:
logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data)
if len(self.multitask_criterion) == 0:
return loss, sample_size, logging_output
# multitask
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
loss += multitask_loss
logging_output["multitask"] = multitask_log
return loss, sample_size, logging_output
@register_criterion("speech_to_spectrogram", dataclass=Tacotron2CriterionConfig)
class SpeechToSpectrogramMultitaskTaskCriterion(Tacotron2Criterion, MultitaskCriterion):
def __init__(
self,
task,
sentence_avg,
use_guided_attention_loss,
guided_attention_loss_sigma,
bce_pos_weight,
ctc_weight,
):
super().__init__(
task,
sentence_avg,
use_guided_attention_loss,
guided_attention_loss_sigma,
bce_pos_weight,
ctc_weight,
)
MultitaskCriterion.__init__(self, task.multitask_tasks)
def forward(self, model, sample, reduction="mean"):
bsz, max_len, _ = sample["target"].size()
feat_tgt = sample["target"]
feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
eos_tgt = torch.arange(max_len).to(sample["target"].device)
eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
eos_tgt = (eos_tgt == (feat_len - 1)).float()
feat_out, eos_out, extra = model(
src_tokens=sample["net_input"]["src_tokens"],
src_lengths=sample["net_input"]["src_lengths"],
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
tgt_speaker=sample["net_input"]["tgt_speaker"],
target_lengths=sample["target_lengths"],
return_all_hiddens=True,
)
l1_loss, mse_loss, eos_loss = self.compute_loss(
extra["feature_out"],
feat_out,
eos_out,
feat_tgt,
eos_tgt,
sample["target_lengths"],
reduction,
)
attn_loss = torch.tensor(0.0).type_as(l1_loss)
if self.guided_attn is not None:
attn_loss = self.guided_attn(
extra["attn"],
sample["net_input"]["src_lengths"],
sample["target_lengths"],
reduction,
)
loss = (
l1_loss + mse_loss + eos_loss + attn_loss
) # do not include ctc loss as there's no text target
sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
logging_output = {
"loss": utils.item(loss.data),
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"l1_loss": utils.item(l1_loss.data),
"mse_loss": utils.item(mse_loss.data),
"eos_loss": utils.item(eos_loss.data),
"attn_loss": utils.item(attn_loss.data),
}
if len(self.multitask_criterion) == 0:
return loss, sample_size, logging_output
# multitask
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
loss += multitask_loss
logging_output["multitask"] = multitask_log
return loss, sample_size, logging_output
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
super().reduce_metrics(logging_outputs)
# inference metrics
if "targ_frames" in logging_outputs[0]:
n = sum(log.get("norm_frames", 0) for log in logging_outputs)
for key, new_key in [
("mcd_loss", "mcd_loss"),
("pred_frames", "pred_ratio"),
("nins", "ins_rate"),
("ndel", "del_rate"),
]:
val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(new_key, val / n, n, round=3)
if "multitask" not in logging_outputs[0]:
return
MultitaskCriterion.reduce_metrics(logging_outputs)
@register_criterion("speech_to_spectrogram_2pass", dataclass=Tacotron2CriterionConfig)
class SpeechToSpectrogram2passMultitaskTaskCriterion(
SpeechToSpectrogramMultitaskTaskCriterion
):
def __init__(
self,
task,
sentence_avg,
use_guided_attention_loss,
guided_attention_loss_sigma,
bce_pos_weight,
ctc_weight,
):
super().__init__(
task,
sentence_avg,
use_guided_attention_loss,
guided_attention_loss_sigma,
bce_pos_weight,
ctc_weight,
)
def forward(self, model, sample, reduction="mean"):
bsz, max_len, _ = sample["target"].size()
feat_tgt = sample["target"]
feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
eos_tgt = torch.arange(max_len).to(sample["target"].device)
eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
eos_tgt = (eos_tgt == (feat_len - 1)).float()
feat_out, eos_out, extra = model(
src_tokens=sample["net_input"]["src_tokens"],
src_lengths=sample["net_input"]["src_lengths"],
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
prev_output_tokens_mt=sample["multitask"][model.mt_task_name]["net_input"][
"prev_output_tokens"
],
tgt_speaker=sample["net_input"]["tgt_speaker"],
target_lengths=sample["target_lengths"],
return_all_hiddens=True,
)
l1_loss, mse_loss, eos_loss = self.compute_loss(
extra["feature_out"],
feat_out,
eos_out,
feat_tgt,
eos_tgt,
sample["target_lengths"],
reduction,
)
attn_loss = torch.tensor(0.0).type_as(l1_loss)
if self.guided_attn is not None:
attn_loss = self.guided_attn(
extra["attn"],
sample["net_input"]["src_lengths"],
sample["target_lengths"],
reduction,
)
loss = (
l1_loss + mse_loss + eos_loss + attn_loss
) # do not include ctc loss as there's no text target
sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
logging_output = {
"loss": utils.item(loss.data),
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"l1_loss": utils.item(l1_loss.data),
"mse_loss": utils.item(mse_loss.data),
"eos_loss": utils.item(eos_loss.data),
"attn_loss": utils.item(attn_loss.data),
}
if len(self.multitask_criterion) == 0:
return loss, sample_size, logging_output
# multitask
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
loss += multitask_loss
logging_output["multitask"] = multitask_log
return loss, sample_size, logging_output

View File

@@ -0,0 +1,126 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from dataclasses import dataclass, field
import torch.nn.functional as F
from fairseq import metrics
from fairseq.tasks import FairseqTask
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from omegaconf import II
@dataclass
class SpeechUnitLmCriterionConfig(FairseqDataclass):
sentence_avg: bool = II("optimization.sentence_avg")
loss_weights: str = field(
default="1.;0.0;0.0",
metadata={
"help": "Weights of the losses that correspond to token, duration, and F0 streams"
},
)
discrete_duration: bool = II("task.discrete_duration")
discrete_f0: bool = II("task.discrete_f0")
def mae_loss(pred, targ, mask, reduce=True):
if pred.ndim == 3:
pred = pred.squeeze(2)
else:
assert pred.ndim == 2
loss = (pred.float() - targ.float()).abs() * (~mask).float()
loss = loss.sum() if reduce else loss.view(-1)
return loss
def nll_loss(pred, targ, mask, reduce=True):
lprob = F.log_softmax(pred, dim=-1)
loss = F.nll_loss(lprob.view(-1, lprob.size(-1)), targ.view(-1), reduction="none")
loss = loss * (~mask).float().view(-1)
loss = loss.sum() if reduce else loss.view(-1)
return loss
@register_criterion("speech_unit_lm_criterion", dataclass=SpeechUnitLmCriterionConfig)
class SpeechUnitLmCriterion(FairseqCriterion):
def __init__(self, cfg: SpeechUnitLmCriterionConfig, task: FairseqTask):
super().__init__(task)
self.sentence_avg = cfg.sentence_avg
self.weights = torch.tensor([float(w) for w in cfg.loss_weights.split(";")])
assert self.weights.size(0) == 3
assert (self.weights >= 0.0).all()
self.dur_loss_fn = nll_loss if cfg.discrete_duration else mae_loss
self.f0_loss_fn = nll_loss if cfg.discrete_f0 else mae_loss
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
token_loss = nll_loss(
net_output["token"], sample["target"], sample["mask"], reduce
)
dur_loss = self.dur_loss_fn(
net_output["duration"],
sample["dur_target"],
sample["dur_mask"],
reduce,
)
f0_loss = self.f0_loss_fn(
net_output["f0"],
sample["f0_target"],
sample["f0_mask"],
reduce,
)
loss = self.weights.to(token_loss.device) * torch.stack(
[token_loss, dur_loss, f0_loss], dim=-1
)
loss = loss.sum() if reduce else loss.sum(-1)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": loss.detach().sum().item(),
"token_loss": token_loss.detach().sum().item(),
"dur_loss": dur_loss.detach().sum().item(),
"f0_loss": f0_loss.detach().sum().item(),
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
token_loss_sum = sum(log.get("token_loss", 0) for log in logging_outputs)
dur_loss_sum = sum(log.get("dur_loss", 0) for log in logging_outputs)
f0_loss_sum = sum(log.get("f0_loss", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3)
metrics.log_scalar(
"token_loss", token_loss_sum / sample_size, sample_size, round=3
)
metrics.log_scalar("dur_loss", dur_loss_sum / sample_size, sample_size, round=3)
metrics.log_scalar("f0_loss", f0_loss_sum / sample_size, sample_size, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
return True

View File

@@ -0,0 +1,226 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Any, Dict, List
import torch
import torch.nn.functional as F
from omegaconf import II
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.data.data_utils import lengths_to_mask
from fairseq.dataclass import FairseqDataclass
logger = logging.getLogger(__name__)
@dataclass
class Tacotron2CriterionConfig(FairseqDataclass):
bce_pos_weight: float = field(
default=1.0,
metadata={"help": "weight of positive examples for BCE loss"},
)
use_guided_attention_loss: bool = field(
default=False,
metadata={"help": "use guided attention loss"},
)
guided_attention_loss_sigma: float = field(
default=0.4,
metadata={"help": "weight of positive examples for BCE loss"},
)
ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"})
sentence_avg: bool = II("optimization.sentence_avg")
class GuidedAttentionLoss(torch.nn.Module):
"""
Efficiently Trainable Text-to-Speech System Based on Deep Convolutional
Networks with Guided Attention (https://arxiv.org/abs/1710.08969)
"""
def __init__(self, sigma):
super().__init__()
self.sigma = sigma
@staticmethod
@lru_cache(maxsize=8)
def _get_weight(s_len, t_len, sigma):
grid_x, grid_y = torch.meshgrid(torch.arange(t_len), torch.arange(s_len))
grid_x = grid_x.to(s_len.device)
grid_y = grid_y.to(s_len.device)
w = (grid_y.float() / s_len - grid_x.float() / t_len) ** 2
return 1.0 - torch.exp(-w / (2 * (sigma**2)))
def _get_weights(self, src_lens, tgt_lens):
bsz, max_s_len, max_t_len = len(src_lens), max(src_lens), max(tgt_lens)
weights = torch.zeros((bsz, max_t_len, max_s_len))
for i, (s_len, t_len) in enumerate(zip(src_lens, tgt_lens)):
weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len, self.sigma)
return weights
@staticmethod
def _get_masks(src_lens, tgt_lens):
in_masks = lengths_to_mask(src_lens)
out_masks = lengths_to_mask(tgt_lens)
return out_masks.unsqueeze(2) & in_masks.unsqueeze(1)
def forward(self, attn, src_lens, tgt_lens, reduction="mean"):
weights = self._get_weights(src_lens, tgt_lens).to(attn.device)
masks = self._get_masks(src_lens, tgt_lens).to(attn.device)
loss = (weights * attn.transpose(1, 2)).masked_select(masks)
loss = torch.sum(loss) if reduction == "sum" else torch.mean(loss)
return loss
@register_criterion("tacotron2", dataclass=Tacotron2CriterionConfig)
class Tacotron2Criterion(FairseqCriterion):
def __init__(
self,
task,
sentence_avg,
use_guided_attention_loss,
guided_attention_loss_sigma,
bce_pos_weight,
ctc_weight,
):
super().__init__(task)
self.sentence_avg = sentence_avg
self.bce_pos_weight = bce_pos_weight
self.guided_attn = None
if use_guided_attention_loss:
self.guided_attn = GuidedAttentionLoss(guided_attention_loss_sigma)
self.ctc_weight = ctc_weight
def forward(self, model, sample, reduction="mean"):
bsz, max_len, _ = sample["target"].size()
feat_tgt = sample["target"]
feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
eos_tgt = torch.arange(max_len).to(sample["target"].device)
eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
eos_tgt = (eos_tgt == (feat_len - 1)).float()
src_tokens = sample["net_input"]["src_tokens"]
src_lens = sample["net_input"]["src_lengths"]
tgt_lens = sample["target_lengths"]
feat_out, eos_out, extra = model(
src_tokens=src_tokens,
src_lengths=src_lens,
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
incremental_state=None,
target_lengths=tgt_lens,
speaker=sample["speaker"],
)
l1_loss, mse_loss, eos_loss = self.compute_loss(
extra["feature_out"],
feat_out,
eos_out,
feat_tgt,
eos_tgt,
tgt_lens,
reduction,
)
attn_loss = torch.tensor(0.0).type_as(l1_loss)
if self.guided_attn is not None:
attn_loss = self.guided_attn(extra["attn"], src_lens, tgt_lens, reduction)
ctc_loss = torch.tensor(0.0).type_as(l1_loss)
if self.ctc_weight > 0.0:
net_output = (feat_out, eos_out, extra)
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.transpose(0, 1) # T x B x C
src_mask = lengths_to_mask(src_lens)
src_tokens_flat = src_tokens.masked_select(src_mask)
ctc_loss = (
F.ctc_loss(
lprobs,
src_tokens_flat,
tgt_lens,
src_lens,
reduction=reduction,
zero_infinity=True,
)
* self.ctc_weight
)
loss = l1_loss + mse_loss + eos_loss + attn_loss + ctc_loss
sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
logging_output = {
"loss": utils.item(loss.data),
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"l1_loss": utils.item(l1_loss.data),
"mse_loss": utils.item(mse_loss.data),
"eos_loss": utils.item(eos_loss.data),
"attn_loss": utils.item(attn_loss.data),
"ctc_loss": utils.item(ctc_loss.data),
}
return loss, sample_size, logging_output
def compute_loss(
self,
feat_out,
feat_out_post,
eos_out,
feat_tgt,
eos_tgt,
tgt_lens,
reduction="mean",
):
mask = lengths_to_mask(tgt_lens)
_eos_out = eos_out[mask].squeeze()
_eos_tgt = eos_tgt[mask]
_feat_tgt = feat_tgt[mask]
_feat_out = feat_out[mask]
_feat_out_post = feat_out_post[mask]
l1_loss = F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) + F.l1_loss(
_feat_out_post, _feat_tgt, reduction=reduction
)
mse_loss = F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) + F.mse_loss(
_feat_out_post, _feat_tgt, reduction=reduction
)
eos_loss = F.binary_cross_entropy_with_logits(
_eos_out,
_eos_tgt,
pos_weight=torch.tensor(self.bce_pos_weight),
reduction=reduction,
)
return l1_loss, mse_loss, eos_loss
@classmethod
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
ns = [log.get("sample_size", 0) for log in logging_outputs]
ntot = sum(ns)
ws = [n / (ntot + 1e-8) for n in ns]
for key in ["loss", "l1_loss", "mse_loss", "eos_loss", "attn_loss", "ctc_loss"]:
vals = [log.get(key, 0) for log in logging_outputs]
val = sum(val * w for val, w in zip(vals, ws))
metrics.log_scalar(key, val, ntot, round=3)
metrics.log_scalar("sample_size", ntot, len(logging_outputs))
# inference metrics
if "targ_frames" not in logging_outputs[0]:
return
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
for key, new_key in [
("mcd_loss", "mcd_loss"),
("pred_frames", "pred_ratio"),
("nins", "ins_rate"),
("ndel", "del_rate"),
]:
val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(new_key, val / n, n, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
return False

View File

@@ -0,0 +1,230 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
from typing import List, Optional
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.logging.meters import safe_round
from fairseq.utils import is_xla_tensor
@dataclass
class Wav2VecCriterionConfig(FairseqDataclass):
infonce: bool = field(
default=False,
metadata={
"help": "if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)"
},
)
loss_weights: Optional[List[float]] = field(
default=None,
metadata={"help": "weights for additional loss terms (not first one)"},
)
log_keys: List[str] = field(
default_factory=lambda: [],
metadata={"help": "output keys to log"},
)
@register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig)
class Wav2vecCriterion(FairseqCriterion):
def __init__(self, task, infonce=False, loss_weights=None, log_keys=None):
super().__init__(task)
self.infonce = infonce
self.loss_weights = loss_weights
self.log_keys = [] if log_keys is None else log_keys
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
logits = model.get_logits(net_output).float()
target = model.get_targets(sample, net_output)
self.xla = is_xla_tensor(logits)
# XXX: handle weights on xla.
weights = None
if hasattr(model, "get_target_weights") and not self.infonce:
weights = model.get_target_weights(target, net_output)
if torch.is_tensor(weights):
weights = weights.float()
losses = []
reduction = "none" if ((not reduce) or self.xla) else "sum"
if self.infonce:
loss = F.cross_entropy(logits, target, reduction=reduction)
else:
loss = F.binary_cross_entropy_with_logits(
logits, target.float(), weights, reduction=reduction
)
if self.xla:
# tpu-comment: since dynamic shapes lead to recompilations on xla,
# we don't shrink tensors using mask_indices.
# Instead, we use mask indices to adjust loss.
mi = (
sample["net_input"]["mask_indices"]
.transpose(0, 1) # logits are transposed in `model.get_logits`
.reshape(logits.size(0))
)
loss = (loss * mi).sum() if reduce else (loss * mi)
if "sample_size" in sample:
sample_size = sample["sample_size"]
elif "mask_indices" in sample["net_input"]:
sample_size = sample["net_input"]["mask_indices"].sum()
else:
sample_size = target.numel() if self.infonce else target.long().sum().item()
losses.append(loss.detach().clone())
if self.loss_weights is not None:
assert hasattr(model, "get_extra_losses")
extra_losses = model.get_extra_losses(net_output)
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(
self.loss_weights
), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, coef in zip(extra_losses, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
losses.append(p)
logging_output = {
"loss": loss.item() if (reduce and not self.xla) else loss.detach(),
"ntokens": sample_size,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
}
for lk in self.log_keys:
# Only store "logits" and "target" for computing MAP and MAUC
# during validation
if lk == "logits":
if not self.training:
logging_output["logits"] = logits.cpu().numpy()
elif lk == "target":
if not self.training:
# If the targets have been mixed with the predictions of
# teacher models, find the original targets
if hasattr(model, "get_original_targets"):
original_target = model.get_original_targets(sample, net_output)
else:
original_target = target
logging_output["target"] = original_target.cpu().numpy()
elif lk in net_output:
value = net_output[lk]
if not is_xla_tensor(value):
value = float(value)
logging_output[lk] = value
if len(losses) > 1:
for i, l in enumerate(losses):
logging_output[f"loss_{i}"] = l.item() if not self.xla else l.detach()
if self.infonce:
with torch.no_grad():
if logits.numel() == 0:
corr = 0
count = 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == 0
min = logits.argmin(-1) == 0
if is_xla_tensor(logits):
max, min = max * mi, min * mi
both = max & min
corr = max.long().sum() - both.long().sum()
count = mi.sum()
else:
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = float(max.numel())
logging_output["correct"] = corr
logging_output["count"] = count
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
nsentences = utils.item(
sum(log.get("nsentences", 0) for log in logging_outputs)
)
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / (sample_size or 1) / math.log(2), sample_size, round=3
)
metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences)
correct = sum(log.get("correct", 0) for log in logging_outputs)
metrics.log_scalar("_correct", correct)
total = sum(log.get("count", 0) for log in logging_outputs)
metrics.log_scalar("_total", total)
if total > 0:
metrics.log_derived(
"accuracy",
lambda meters: safe_round(
meters["_correct"].sum / meters["_total"].sum, 5
)
if meters["_total"].sum > 0
else float("nan"),
)
builtin_keys = {
"loss",
"ntokens",
"nsentences",
"sample_size",
"correct",
"count",
}
for k in logging_outputs[0]:
if k not in builtin_keys:
val = sum(log.get(k, 0) for log in logging_outputs)
if k.startswith("loss"):
metrics.log_scalar(
k, val / (sample_size or 1) / math.log(2), sample_size, round=3
)
else:
metrics.log_scalar(k, val / len(logging_outputs), round=3)
# FIXME: revert when gather based xla reduction is implemented
# @staticmethod
# def logging_outputs_can_be_summed() -> bool:
def logging_outputs_can_be_summed(self) -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
# XXX: Gather based reduction not implemented for xla yet.
# So we fall to sum based reduction for xla.
return self.xla

View File

@@ -0,0 +1,130 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset, FairseqIterableDataset
from .base_wrapper_dataset import BaseWrapperDataset
from .add_target_dataset import AddTargetDataset
from .append_token_dataset import AppendTokenDataset
from .audio.raw_audio_dataset import BinarizedAudioDataset, FileAudioDataset
from .audio.hubert_dataset import HubertDataset
from .backtranslation_dataset import BacktranslationDataset
from .bucket_pad_length_dataset import BucketPadLengthDataset
from .colorize_dataset import ColorizeDataset
from .concat_dataset import ConcatDataset
from .concat_sentences_dataset import ConcatSentencesDataset
from .denoising_dataset import DenoisingDataset
from .id_dataset import IdDataset
from .indexed_dataset import (
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
MMapIndexedDataset,
)
from .language_pair_dataset import LanguagePairDataset
from .list_dataset import ListDataset
from .lm_context_window_dataset import LMContextWindowDataset
from .lru_cache_dataset import LRUCacheDataset
from .mask_tokens_dataset import MaskTokensDataset
from .monolingual_dataset import MonolingualDataset
from .multi_corpus_sampled_dataset import MultiCorpusSampledDataset
from .nested_dictionary_dataset import NestedDictionaryDataset
from .noising import NoisingDataset
from .numel_dataset import NumelDataset
from .num_samples_dataset import NumSamplesDataset
from .offset_tokens_dataset import OffsetTokensDataset
from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset
from .prepend_dataset import PrependDataset
from .prepend_token_dataset import PrependTokenDataset
from .raw_label_dataset import RawLabelDataset
from .replace_dataset import ReplaceDataset
from .resampling_dataset import ResamplingDataset
from .roll_dataset import RollDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .sort_dataset import SortDataset
from .strip_token_dataset import StripTokenDataset
from .subsample_dataset import SubsampleDataset
from .token_block_dataset import TokenBlockDataset
from .transform_eos_dataset import TransformEosDataset
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
from .shorten_dataset import TruncateDataset, RandomCropDataset
from .multilingual.sampled_multi_dataset import SampledMultiDataset
from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset
from .fasta_dataset import FastaDataset, EncodedFastaDataset
from .transform_eos_concat_langpair_dataset import TransformEosConcatLangPairDataset
from .iterators import (
CountingIterator,
EpochBatchIterator,
GroupedIterator,
ShardedIterator,
)
__all__ = [
"AddTargetDataset",
"AppendTokenDataset",
"BacktranslationDataset",
"BaseWrapperDataset",
"BinarizedAudioDataset",
"BucketPadLengthDataset",
"ColorizeDataset",
"ConcatDataset",
"ConcatSentencesDataset",
"CountingIterator",
"DenoisingDataset",
"Dictionary",
"EncodedFastaDataset",
"EpochBatchIterator",
"FairseqDataset",
"FairseqIterableDataset",
"FastaDataset",
"FileAudioDataset",
"GroupedIterator",
"HubertDataset",
"IdDataset",
"IndexedCachedDataset",
"IndexedDataset",
"IndexedRawTextDataset",
"LanguagePairDataset",
"LeftPadDataset",
"ListDataset",
"LMContextWindowDataset",
"LRUCacheDataset",
"MaskTokensDataset",
"MMapIndexedDataset",
"MonolingualDataset",
"MultiCorpusSampledDataset",
"NestedDictionaryDataset",
"NoisingDataset",
"NumelDataset",
"NumSamplesDataset",
"OffsetTokensDataset",
"PadDataset",
"PrependDataset",
"PrependTokenDataset",
"RandomCropDataset",
"RawLabelDataset",
"ResamplingDataset",
"ReplaceDataset",
"RightPadDataset",
"RollDataset",
"RoundRobinZipDatasets",
"SampledMultiDataset",
"SampledMultiEpochDataset",
"ShardedIterator",
"SortDataset",
"StripTokenDataset",
"SubsampleDataset",
"TokenBlockDataset",
"TransformEosDataset",
"TransformEosLangPairDataset",
"TransformEosConcatLangPairDataset",
"TruncateDataset",
"TruncatedDictionary",
]

View File

@@ -0,0 +1,83 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import BaseWrapperDataset, data_utils
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
class AddTargetDataset(BaseWrapperDataset):
def __init__(
self,
dataset,
labels,
pad,
eos,
batch_targets,
process_label=None,
label_len_fn=None,
add_to_input=False,
text_compression_level=TextCompressionLevel.none,
):
super().__init__(dataset)
self.labels = labels
self.batch_targets = batch_targets
self.pad = pad
self.eos = eos
self.process_label = process_label
self.label_len_fn = label_len_fn
self.add_to_input = add_to_input
self.text_compressor = TextCompressor(level=text_compression_level)
def get_label(self, index, process_fn=None):
lbl = self.labels[index]
lbl = self.text_compressor.decompress(lbl)
return lbl if process_fn is None else process_fn(lbl)
def __getitem__(self, index):
item = self.dataset[index]
item["label"] = self.get_label(index, process_fn=self.process_label)
return item
def size(self, index):
sz = self.dataset.size(index)
own_sz = self.label_len_fn(self.get_label(index))
return sz, own_sz
def collater(self, samples):
collated = self.dataset.collater(samples)
if len(collated) == 0:
return collated
indices = set(collated["id"].tolist())
target = [s["label"] for s in samples if s["id"] in indices]
if self.add_to_input:
eos = torch.LongTensor([self.eos])
prev_output_tokens = [torch.cat([eos, t], axis=-1) for t in target]
target = [torch.cat([t, eos], axis=-1) for t in target]
collated["net_input"]["prev_output_tokens"] = prev_output_tokens
if self.batch_targets:
collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
collated["ntokens"] = collated["target_lengths"].sum().item()
if getattr(collated["net_input"], "prev_output_tokens", None):
collated["net_input"]["prev_output_tokens"] = data_utils.collate_tokens(
collated["net_input"]["prev_output_tokens"],
pad_idx=self.pad,
left_pad=False,
)
else:
collated["ntokens"] = sum([len(t) for t in target])
collated["target"] = target
return collated
def filter_indices_by_size(self, indices, max_sizes):
indices, ignored = data_utils._filter_by_size_dynamic(
indices, self.size, max_sizes
)
return indices, ignored

View File

@@ -0,0 +1,41 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from . import BaseWrapperDataset
class AppendTokenDataset(BaseWrapperDataset):
def __init__(self, dataset, token=None):
super().__init__(dataset)
self.token = token
if token is not None:
self._sizes = np.array(dataset.sizes) + 1
else:
self._sizes = dataset.sizes
def __getitem__(self, idx):
item = self.dataset[idx]
if self.token is not None:
item = torch.cat([item, item.new([self.token])])
return item
@property
def sizes(self):
return self._sizes
def num_tokens(self, index):
n = self.dataset.num_tokens(index)
if self.token is not None:
n += 1
return n
def size(self, index):
n = self.dataset.size(index)
if self.token is not None:
n += 1
return n

View File

@@ -0,0 +1,93 @@
from abc import ABC, abstractmethod
from typing import Dict, Optional
import importlib
import os
import numpy as np
class AudioTransform(ABC):
@classmethod
@abstractmethod
def from_config_dict(cls, config: Optional[Dict] = None):
pass
class CompositeAudioTransform(AudioTransform):
def _from_config_dict(
cls,
transform_type,
get_audio_transform,
composite_cls,
config=None,
return_empty=False,
):
_config = {} if config is None else config
_transforms = _config.get(f"{transform_type}_transforms")
if _transforms is None:
if return_empty:
_transforms = []
else:
return None
transforms = [
get_audio_transform(_t).from_config_dict(_config.get(_t))
for _t in _transforms
]
return composite_cls(transforms)
def __init__(self, transforms):
self.transforms = [t for t in transforms if t is not None]
def __call__(self, x):
for t in self.transforms:
x = t(x)
return x
def __repr__(self):
format_string = (
[self.__class__.__name__ + "("]
+ [f" {t.__repr__()}" for t in self.transforms]
+ [")"]
)
return "\n".join(format_string)
def register_audio_transform(name, cls_type, registry, class_names):
def register_audio_transform_cls(cls):
if name in registry:
raise ValueError(f"Cannot register duplicate transform ({name})")
if not issubclass(cls, cls_type):
raise ValueError(
f"Transform ({name}: {cls.__name__}) must extend "
f"{cls_type.__name__}"
)
if cls.__name__ in class_names:
raise ValueError(
f"Cannot register audio transform with duplicate "
f"class name ({cls.__name__})"
)
registry[name] = cls
class_names.add(cls.__name__)
return cls
return register_audio_transform_cls
def import_transforms(transforms_dir, transform_type):
for file in os.listdir(transforms_dir):
path = os.path.join(transforms_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
name = file[: file.find(".py")] if file.endswith(".py") else file
importlib.import_module(
f"fairseq.data.audio.{transform_type}_transforms." + name
)
# Utility fn for uniform numbers in transforms
def rand_uniform(a, b):
return np.random.uniform() * (b - a) + a

View File

@@ -0,0 +1,389 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import mmap
from pathlib import Path
import io
from typing import BinaryIO, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"}
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
def convert_waveform(
waveform: Union[np.ndarray, torch.Tensor],
sample_rate: int,
normalize_volume: bool = False,
to_mono: bool = False,
to_sample_rate: Optional[int] = None,
) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
"""convert a waveform:
- to a target sample rate
- from multi-channel to mono channel
- volume normalization
Args:
waveform (numpy.ndarray or torch.Tensor): 2D original waveform
(channels x length)
sample_rate (int): original sample rate
normalize_volume (bool): perform volume normalization
to_mono (bool): convert to mono channel if having multiple channels
to_sample_rate (Optional[int]): target sample rate
Returns:
waveform (numpy.ndarray): converted 2D waveform (channels x length)
sample_rate (float): target sample rate
"""
try:
import torchaudio.sox_effects as ta_sox
except ImportError:
raise ImportError("Please install torchaudio: pip install torchaudio")
effects = []
if normalize_volume:
effects.append(["gain", "-n"])
if to_sample_rate is not None and to_sample_rate != sample_rate:
effects.append(["rate", f"{to_sample_rate}"])
if to_mono and waveform.shape[0] > 1:
effects.append(["channels", "1"])
if len(effects) > 0:
is_np_input = isinstance(waveform, np.ndarray)
_waveform = torch.from_numpy(waveform) if is_np_input else waveform
converted, converted_sample_rate = ta_sox.apply_effects_tensor(
_waveform, sample_rate, effects
)
if is_np_input:
converted = converted.numpy()
return converted, converted_sample_rate
return waveform, sample_rate
def get_waveform(
path_or_fp: Union[str, BinaryIO],
normalization: bool = True,
mono: bool = True,
frames: int = -1,
start: int = 0,
always_2d: bool = True,
output_sample_rate: Optional[int] = None,
normalize_volume: bool = False,
waveform_transforms: Optional[CompositeAudioWaveformTransform] = None,
) -> Tuple[np.ndarray, int]:
"""Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio.
Args:
path_or_fp (str or BinaryIO): the path or file-like object
normalization (bool): normalize values to [-1, 1] (Default: True)
mono (bool): convert multi-channel audio to mono-channel one
frames (int): the number of frames to read. (-1 for reading all)
start (int): Where to start reading. A negative value counts from the end.
always_2d (bool): always return 2D array even for mono-channel audios
output_sample_rate (Optional[int]): output sample rate
normalize_volume (bool): normalize volume
Returns:
waveform (numpy.ndarray): 1D or 2D waveform (channels x length)
sample_rate (float): sample rate
"""
if isinstance(path_or_fp, str):
ext = Path(path_or_fp).suffix
if ext not in SF_AUDIO_FILE_EXTENSIONS:
raise ValueError(f"Unsupported audio format: {ext}")
try:
import soundfile as sf
except ImportError:
raise ImportError("Please install soundfile: pip install soundfile")
waveform, sample_rate = sf.read(
path_or_fp, dtype="float32", always_2d=True, frames=frames, start=start
)
waveform = waveform.T # T x C -> C x T
waveform, sample_rate = convert_waveform(
waveform,
sample_rate,
normalize_volume=normalize_volume,
to_mono=mono,
to_sample_rate=output_sample_rate,
)
if not normalization:
waveform *= 2**15 # denormalized to 16-bit signed integers
if waveform_transforms is not None:
waveform, sample_rate = waveform_transforms(waveform, sample_rate)
if not always_2d:
waveform = waveform.squeeze(axis=0)
return waveform, sample_rate
def get_features_from_npy_or_audio(path, waveform_transforms=None):
ext = Path(path).suffix
if ext not in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
raise ValueError(f'Unsupported file format for "{path}"')
return (
np.load(path)
if ext == ".npy"
else get_fbank(path, waveform_transforms=waveform_transforms)
)
def get_features_or_waveform_from_stored_zip(
path,
byte_offset,
byte_size,
need_waveform=False,
use_sample_rate=None,
waveform_transforms=None,
):
assert path.endswith(".zip")
data = read_from_stored_zip(path, byte_offset, byte_size)
f = io.BytesIO(data)
if is_npy_data(data):
features_or_waveform = np.load(f)
elif is_sf_audio_data(data):
features_or_waveform = (
get_waveform(
f,
always_2d=False,
output_sample_rate=use_sample_rate,
waveform_transforms=waveform_transforms,
)[0]
if need_waveform
else get_fbank(f, waveform_transforms=waveform_transforms)
)
else:
raise ValueError(f'Unknown file format for "{path}"')
return features_or_waveform
def get_features_or_waveform(
path: str, need_waveform=False, use_sample_rate=None, waveform_transforms=None
):
"""Get speech features from .npy file or waveform from .wav/.flac file.
The file may be inside an uncompressed ZIP file and is accessed via byte
offset and length.
Args:
path (str): File path in the format of "<.npy/.wav/.flac path>" or
"<zip path>:<byte offset>:<byte length>".
need_waveform (bool): return waveform instead of features.
use_sample_rate (int): change sample rate for the input wave file
Returns:
features_or_waveform (numpy.ndarray): speech features or waveform.
"""
_path, slice_ptr = parse_path(path)
if len(slice_ptr) == 0:
if need_waveform:
return get_waveform(
_path,
always_2d=False,
output_sample_rate=use_sample_rate,
waveform_transforms=waveform_transforms,
)[0]
return get_features_from_npy_or_audio(
_path, waveform_transforms=waveform_transforms
)
elif len(slice_ptr) == 2:
features_or_waveform = get_features_or_waveform_from_stored_zip(
_path,
slice_ptr[0],
slice_ptr[1],
need_waveform=need_waveform,
use_sample_rate=use_sample_rate,
waveform_transforms=waveform_transforms,
)
else:
raise ValueError(f"Invalid path: {path}")
return features_or_waveform
def _get_kaldi_fbank(
waveform: np.ndarray, sample_rate: int, n_bins=80
) -> Optional[np.ndarray]:
"""Get mel-filter bank features via PyKaldi."""
try:
from kaldi.feat.fbank import Fbank, FbankOptions
from kaldi.feat.mel import MelBanksOptions
from kaldi.feat.window import FrameExtractionOptions
from kaldi.matrix import Vector
mel_opts = MelBanksOptions()
mel_opts.num_bins = n_bins
frame_opts = FrameExtractionOptions()
frame_opts.samp_freq = sample_rate
opts = FbankOptions()
opts.mel_opts = mel_opts
opts.frame_opts = frame_opts
fbank = Fbank(opts=opts)
features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy()
return features
except ImportError:
return None
def _get_torchaudio_fbank(
waveform: np.ndarray, sample_rate, n_bins=80
) -> Optional[np.ndarray]:
"""Get mel-filter bank features via TorchAudio."""
try:
import torchaudio.compliance.kaldi as ta_kaldi
waveform = torch.from_numpy(waveform)
features = ta_kaldi.fbank(
waveform, num_mel_bins=n_bins, sample_frequency=sample_rate
)
return features.numpy()
except ImportError:
return None
def get_fbank(
path_or_fp: Union[str, BinaryIO], n_bins=80, waveform_transforms=None
) -> np.ndarray:
"""Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
(faster CPP implementation) to TorchAudio (Python implementation). Note that
Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
waveform should not be normalized."""
waveform, sample_rate = get_waveform(
path_or_fp, normalization=False, waveform_transforms=waveform_transforms
)
features = _get_kaldi_fbank(waveform, sample_rate, n_bins)
if features is None:
features = _get_torchaudio_fbank(waveform, sample_rate, n_bins)
if features is None:
raise ImportError(
"Please install pyKaldi or torchaudio to enable "
"online filterbank feature extraction"
)
return features
def is_npy_data(data: bytes) -> bool:
return data[0] == 147 and data[1] == 78
def is_sf_audio_data(data: bytes) -> bool:
is_wav = data[0] == 82 and data[1] == 73 and data[2] == 70
is_flac = data[0] == 102 and data[1] == 76 and data[2] == 97
is_ogg = data[0] == 79 and data[1] == 103 and data[2] == 103
return is_wav or is_flac or is_ogg
def mmap_read(path: str, offset: int, length: int) -> bytes:
with open(path, "rb") as f:
with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_o:
data = mmap_o[offset : offset + length]
return data
def read_from_stored_zip(zip_path: str, offset: int, length: int) -> bytes:
return mmap_read(zip_path, offset, length)
def parse_path(path: str) -> Tuple[str, List[int]]:
"""Parse data path which is either a path to
1. a .npy/.wav/.flac/.ogg file
2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
Args:
path (str): the data path to parse
Returns:
file_path (str): the file path
slice_ptr (list of int): empty in case 1;
byte offset and length for the slice in case 2
"""
if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
_path, slice_ptr = path, []
else:
_path, *slice_ptr = path.split(":")
if not Path(_path).is_file():
raise FileNotFoundError(f"File not found: {_path}")
assert len(slice_ptr) in {0, 2}, f"Invalid path: {path}"
slice_ptr = [int(i) for i in slice_ptr]
return _path, slice_ptr
def get_window(window_fn: callable, n_fft: int, win_length: int) -> torch.Tensor:
padding = n_fft - win_length
assert padding >= 0
return F.pad(window_fn(win_length), (padding // 2, padding - padding // 2))
def get_fourier_basis(n_fft: int) -> torch.Tensor:
basis = np.fft.fft(np.eye(n_fft))
basis = np.vstack(
[np.real(basis[: n_fft // 2 + 1, :]), np.imag(basis[: n_fft // 2 + 1, :])]
)
return torch.from_numpy(basis).float()
def get_mel_filters(
sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float
) -> torch.Tensor:
try:
import librosa
except ImportError:
raise ImportError("Please install librosa: pip install librosa")
basis = librosa.filters.mel(sample_rate, n_fft, n_mels, f_min, f_max)
return torch.from_numpy(basis).float()
class TTSSpectrogram(torch.nn.Module):
def __init__(
self,
n_fft: int,
win_length: int,
hop_length: int,
window_fn: callable = torch.hann_window,
return_phase: bool = False,
) -> None:
super(TTSSpectrogram, self).__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.return_phase = return_phase
basis = get_fourier_basis(n_fft).unsqueeze(1)
basis *= get_window(window_fn, n_fft, win_length)
self.register_buffer("basis", basis)
def forward(
self, waveform: torch.Tensor
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
padding = (self.n_fft // 2, self.n_fft // 2)
x = F.pad(waveform.unsqueeze(1), padding, mode="reflect")
x = F.conv1d(x, self.basis, stride=self.hop_length)
real_part = x[:, : self.n_fft // 2 + 1, :]
imag_part = x[:, self.n_fft // 2 + 1 :, :]
magnitude = torch.sqrt(real_part**2 + imag_part**2)
if self.return_phase:
phase = torch.atan2(imag_part, real_part)
return magnitude, phase
return magnitude
class TTSMelScale(torch.nn.Module):
def __init__(
self, n_mels: int, sample_rate: int, f_min: float, f_max: float, n_stft: int
) -> None:
super(TTSMelScale, self).__init__()
basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max)
self.register_buffer("basis", basis)
def forward(self, specgram: torch.Tensor) -> torch.Tensor:
return torch.matmul(self.basis, specgram)

View File

@@ -0,0 +1,387 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from argparse import Namespace
from copy import deepcopy
from pathlib import Path
from typing import Dict, Optional
from fairseq.data import Dictionary
logger = logging.getLogger(__name__)
def get_config_from_yaml(yaml_path: Path):
try:
import yaml
except ImportError:
print("Please install PyYAML: pip install PyYAML")
config = {}
if yaml_path.is_file():
try:
with open(yaml_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
except Exception as e:
raise Exception(f"Failed to load config from {yaml_path.as_posix()}: {e}")
else:
raise FileNotFoundError(f"{yaml_path.as_posix()} not found")
return config
class S2TDataConfig(object):
"""Wrapper class for data config YAML"""
def __init__(self, yaml_path: Path):
self.config = get_config_from_yaml(yaml_path)
self.root = yaml_path.parent
def _auto_convert_to_abs_path(self, x):
if isinstance(x, str):
if not Path(x).exists() and (self.root / x).exists():
return (self.root / x).as_posix()
elif isinstance(x, dict):
return {k: self._auto_convert_to_abs_path(v) for k, v in x.items()}
return x
@property
def vocab_filename(self):
"""fairseq vocabulary file under data root"""
return self.config.get("vocab_filename", "dict.txt")
@property
def speaker_set_filename(self):
"""speaker set file under data root"""
return self.config.get("speaker_set_filename", None)
@property
def shuffle(self) -> bool:
"""Shuffle dataset samples before batching"""
return self.config.get("shuffle", False)
@property
def pre_tokenizer(self) -> Dict:
"""Pre-tokenizer to apply before subword tokenization. Returning
a dictionary with `tokenizer` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
tokenizer = self.config.get("pre_tokenizer", {"tokenizer": None})
return self._auto_convert_to_abs_path(tokenizer)
@property
def bpe_tokenizer(self) -> Dict:
"""Subword tokenizer to apply after pre-tokenization. Returning
a dictionary with `bpe` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
tokenizer = self.config.get("bpe_tokenizer", {"bpe": None})
return self._auto_convert_to_abs_path(tokenizer)
@property
def prepend_tgt_lang_tag(self) -> bool:
"""Prepend target lang ID token as the target BOS (e.g. for to-many
multilingual setting). During inference, this requires `--prefix-size 1`
to force BOS to be lang ID token."""
return self.config.get("prepend_tgt_lang_tag", False)
@property
def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
@property
def input_feat_per_channel(self):
"""The dimension of input features (per audio channel)"""
return self.config.get("input_feat_per_channel", 80)
@property
def input_channels(self):
"""The number of channels in the input audio"""
return self.config.get("input_channels", 1)
@property
def sample_rate(self):
return self.config.get("sample_rate", 16_000)
@property
def sampling_alpha(self):
"""Hyper-parameter alpha = 1/T for temperature-based resampling.
(alpha = 1 for no resampling)"""
return self.config.get("sampling_alpha", 1.0)
@property
def use_audio_input(self):
"""Needed by the dataset loader to see if the model requires
raw audio as inputs."""
return self.config.get("use_audio_input", False)
def standardize_audio(self) -> bool:
return self.use_audio_input and self.config.get("standardize_audio", False)
@property
def use_sample_rate(self):
"""Needed by the dataset loader to see if the model requires
raw audio with specific sample rate as inputs."""
return self.config.get("use_sample_rate", 16000)
@property
def audio_root(self):
"""Audio paths in the manifest TSV can be relative and this provides
the root path. Set this to empty string when using absolute paths."""
return self.config.get("audio_root", "")
def get_transforms(self, transform_type, split, is_train):
"""Split-specific feature transforms. Allowing train set
wildcard `_train`, evaluation set wildcard `_eval` and general
wildcard `*` for matching."""
from copy import deepcopy
cfg = deepcopy(self.config)
_cur = cfg.get(f"{transform_type}transforms", {})
cur = _cur.get(split)
cur = _cur.get("_train") if cur is None and is_train else cur
cur = _cur.get("_eval") if cur is None and not is_train else cur
cur = _cur.get("*") if cur is None else cur
return cur
def get_feature_transforms(self, split, is_train):
cfg = deepcopy(self.config)
# TODO: deprecate transforms
cur = self.get_transforms("", split, is_train)
if cur is not None:
logger.warning(
"Auto converting transforms into feature_transforms, "
"but transforms will be deprecated in the future. Please "
"update this in the config."
)
ft_transforms = self.get_transforms("feature_", split, is_train)
if ft_transforms:
cur.extend(ft_transforms)
else:
cur = self.get_transforms("feature_", split, is_train)
cfg["feature_transforms"] = cur
return cfg
def get_waveform_transforms(self, split, is_train):
cfg = deepcopy(self.config)
cfg["waveform_transforms"] = self.get_transforms("waveform_", split, is_train)
return cfg
def get_dataset_transforms(self, split, is_train):
cfg = deepcopy(self.config)
cfg["dataset_transforms"] = self.get_transforms("dataset_", split, is_train)
return cfg
@property
def global_cmvn_stats_npz(self) -> Optional[str]:
path = self.config.get("global_cmvn", {}).get("stats_npz_path", None)
return self._auto_convert_to_abs_path(path)
@property
def vocoder(self) -> Dict[str, str]:
vocoder = self.config.get("vocoder", {"type": "griffin_lim"})
return self._auto_convert_to_abs_path(vocoder)
@property
def hub(self) -> Dict[str, str]:
return self.config.get("hub", {})
class S2SDataConfig(S2TDataConfig):
"""Wrapper class for data config YAML"""
@property
def vocab_filename(self):
"""fairseq vocabulary file under data root"""
return self.config.get("vocab_filename", None)
@property
def pre_tokenizer(self) -> Dict:
return None
@property
def bpe_tokenizer(self) -> Dict:
return None
@property
def input_transformed_channels(self):
"""The number of channels in the audio after feature transforms"""
# TODO: move this into individual transforms
# TODO: deprecate transforms
_cur = self.config.get("transforms", {})
ft_transforms = self.config.get("feature_transforms", {})
if _cur and ft_transforms:
_cur.update(ft_transforms)
else:
_cur = self.config.get("feature_transforms", {})
cur = _cur.get("_train", [])
_channels = self.input_channels
if "delta_deltas" in cur:
_channels *= 3
return _channels
@property
def output_sample_rate(self):
"""The audio sample rate of output target speech"""
return self.config.get("output_sample_rate", 22050)
@property
def target_speaker_embed(self):
"""Target speaker embedding file (one line per target audio sample)"""
return self.config.get("target_speaker_embed", None)
@property
def prepend_tgt_lang_tag_as_bos(self) -> bool:
"""Prepend target lang ID token as the target BOS."""
return self.config.get("prepend_tgt_lang_tag_as_bos", False)
class MultitaskConfig(object):
"""Wrapper class for data config YAML"""
def __init__(self, yaml_path: Path):
config = get_config_from_yaml(yaml_path)
self.config = {}
for k, v in config.items():
self.config[k] = SingleTaskConfig(k, v)
def get_all_tasks(self):
return self.config
def get_single_task(self, name):
assert name in self.config, f"multitask '{name}' does not exist!"
return self.config[name]
@property
def first_pass_decoder_task_index(self):
"""Return the task index of the first-pass text decoder.
If there are multiple 'is_first_pass_decoder: True' in the config file,
the last task is used for the first-pass decoder.
If there is no 'is_first_pass_decoder: True' in the config file,
the last task whose task_name includes 'target' and decoder_type is not ctc.
"""
idx = -1
for i, (k, v) in enumerate(self.config.items()):
if v.is_first_pass_decoder:
idx = i
if idx < 0:
for i, (k, v) in enumerate(self.config.items()):
if k.startswith("target") and v.decoder_type == "transformer":
idx = i
return idx
class SingleTaskConfig(object):
def __init__(self, name, config):
self.task_name = name
self.config = config
dict_path = config.get("dict", "")
self.tgt_dict = Dictionary.load(dict_path) if Path(dict_path).exists() else None
@property
def data(self):
return self.config.get("data", "")
@property
def decoder_type(self):
return self.config.get("decoder_type", "transformer")
@property
def decoder_args(self):
"""Decoder arch related args"""
args = self.config.get("decoder_args", {})
return Namespace(**args)
@property
def criterion_cfg(self):
"""cfg for the multitask criterion"""
if self.decoder_type == "ctc":
from fairseq.criterions.ctc import CtcCriterionConfig
cfg = CtcCriterionConfig
cfg.zero_infinity = self.config.get("zero_infinity", True)
else:
from fairseq.criterions.label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterionConfig,
)
cfg = LabelSmoothedCrossEntropyCriterionConfig
cfg.label_smoothing = self.config.get("label_smoothing", 0.2)
return cfg
@property
def input_from(self):
"""Condition on encoder/decoder of the main model"""
return "decoder" if "decoder_layer" in self.config else "encoder"
@property
def input_layer(self):
if self.input_from == "decoder":
return self.config["decoder_layer"] - 1
else:
# default using the output from the last encoder layer (-1)
return self.config.get("encoder_layer", 0) - 1
@property
def loss_weight_schedule(self):
return (
"decay"
if "loss_weight_max" in self.config
and "loss_weight_decay_steps" in self.config
else "fixed"
)
def get_loss_weight(self, num_updates):
if self.loss_weight_schedule == "fixed":
weight = self.config.get("loss_weight", 1.0)
else: # "decay"
assert (
self.config.get("loss_weight_decay_steps", 0) > 0
), "loss_weight_decay_steps must be greater than 0 for a decay schedule"
loss_weight_min = self.config.get("loss_weight_min", 0.0001)
loss_weight_decay_stepsize = (
self.config["loss_weight_max"] - loss_weight_min
) / self.config["loss_weight_decay_steps"]
weight = max(
self.config["loss_weight_max"]
- loss_weight_decay_stepsize * num_updates,
loss_weight_min,
)
return weight
@property
def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
@property
def eos_token(self):
"""EOS token during generation"""
return self.config.get("eos_token", "<eos>")
@property
def rdrop_alpha(self):
return self.config.get("rdrop_alpha", 0.0)
@property
def is_first_pass_decoder(self):
flag = self.config.get("is_first_pass_decoder", False)
if flag:
if self.decoder_type == "ctc":
raise ValueError(
"First-pass decoder in the multi-decoder model must not be CTC."
)
if "target" not in self.task_name:
raise Warning(
'The name of the first-pass decoder does not include "target".'
)
return flag
@property
def get_lang_tag_mapping(self):
return self.config.get("lang_tag_mapping", {})

View File

@@ -0,0 +1,53 @@
import os
from fairseq.data.audio import (
AudioTransform,
CompositeAudioTransform,
import_transforms,
register_audio_transform,
)
class AudioDatasetTransform(AudioTransform):
pass
AUDIO_DATASET_TRANSFORM_REGISTRY = {}
AUDIO_DATASET_TRANSFORM_CLASS_NAMES = set()
def get_audio_dataset_transform(name):
return AUDIO_DATASET_TRANSFORM_REGISTRY[name]
def register_audio_dataset_transform(name):
return register_audio_transform(
name,
AudioDatasetTransform,
AUDIO_DATASET_TRANSFORM_REGISTRY,
AUDIO_DATASET_TRANSFORM_CLASS_NAMES,
)
import_transforms(os.path.dirname(__file__), "dataset")
class CompositeAudioDatasetTransform(CompositeAudioTransform):
@classmethod
def from_config_dict(cls, config=None):
return super()._from_config_dict(
cls,
"dataset",
get_audio_dataset_transform,
CompositeAudioDatasetTransform,
config,
return_empty=True,
)
def get_transform(self, cls):
for t in self.transforms:
if isinstance(t, cls):
return t
return None
def has_transform(self, cls):
return self.get_transform(cls) is not None

View File

@@ -0,0 +1,61 @@
from typing import List
import numpy as np
from fairseq.data.audio.dataset_transforms import (
AudioDatasetTransform,
register_audio_dataset_transform,
)
_DEFAULTS = {"rate": 0.25, "max_tokens": 3000, "attempts": 5}
@register_audio_dataset_transform("concataugment")
class ConcatAugment(AudioDatasetTransform):
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return ConcatAugment(
_config.get("rate", _DEFAULTS["rate"]),
_config.get("max_tokens", _DEFAULTS["max_tokens"]),
_config.get("attempts", _DEFAULTS["attempts"]),
)
def __init__(
self,
rate=_DEFAULTS["rate"],
max_tokens=_DEFAULTS["max_tokens"],
attempts=_DEFAULTS["attempts"],
):
self.rate, self.max_tokens, self.attempts = rate, max_tokens, attempts
def __repr__(self):
return (
self.__class__.__name__
+ "("
+ ", ".join(
[
f"rate={self.rate}",
f"max_tokens={self.max_tokens}",
f"attempts={self.attempts}",
]
)
+ ")"
)
def find_indices(self, index: int, n_frames: List[int], n_samples: int):
# skip conditions: application rate, max_tokens limit exceeded
if np.random.random() > self.rate:
return [index]
if self.max_tokens and n_frames[index] > self.max_tokens:
return [index]
# pick second sample to concatenate
for _ in range(self.attempts):
index2 = np.random.randint(0, n_samples)
if index2 != index and (
not self.max_tokens
or n_frames[index] + n_frames[index2] < self.max_tokens
):
return [index, index2]
return [index]

View File

@@ -0,0 +1,105 @@
import numpy as np
import torch
from fairseq.data.audio import rand_uniform
from fairseq.data.audio.dataset_transforms import (
AudioDatasetTransform,
register_audio_dataset_transform,
)
from fairseq.data.audio.waveform_transforms.noiseaugment import (
NoiseAugmentTransform,
)
_DEFAULTS = {
"rate": 0.25,
"mixing_noise_rate": 0.1,
"noise_path": "",
"noise_snr_min": -5,
"noise_snr_max": 5,
"utterance_snr_min": -5,
"utterance_snr_max": 5,
}
@register_audio_dataset_transform("noisyoverlapaugment")
class NoisyOverlapAugment(AudioDatasetTransform):
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return NoisyOverlapAugment(
_config.get("rate", _DEFAULTS["rate"]),
_config.get("mixing_noise_rate", _DEFAULTS["mixing_noise_rate"]),
_config.get("noise_path", _DEFAULTS["noise_path"]),
_config.get("noise_snr_min", _DEFAULTS["noise_snr_min"]),
_config.get("noise_snr_max", _DEFAULTS["noise_snr_max"]),
_config.get("utterance_snr_min", _DEFAULTS["utterance_snr_min"]),
_config.get("utterance_snr_max", _DEFAULTS["utterance_snr_max"]),
)
def __init__(
self,
rate=_DEFAULTS["rate"],
mixing_noise_rate=_DEFAULTS["mixing_noise_rate"],
noise_path=_DEFAULTS["noise_path"],
noise_snr_min=_DEFAULTS["noise_snr_min"],
noise_snr_max=_DEFAULTS["noise_snr_max"],
utterance_snr_min=_DEFAULTS["utterance_snr_min"],
utterance_snr_max=_DEFAULTS["utterance_snr_max"],
):
self.rate = rate
self.mixing_noise_rate = mixing_noise_rate
self.noise_shaper = NoiseAugmentTransform(noise_path)
self.noise_snr_min = noise_snr_min
self.noise_snr_max = noise_snr_max
self.utterance_snr_min = utterance_snr_min
self.utterance_snr_max = utterance_snr_max
def __repr__(self):
return (
self.__class__.__name__
+ "("
+ ", ".join(
[
f"rate={self.rate}",
f"mixing_noise_rate={self.mixing_noise_rate}",
f"noise_snr_min={self.noise_snr_min}",
f"noise_snr_max={self.noise_snr_max}",
f"utterance_snr_min={self.utterance_snr_min}",
f"utterance_snr_max={self.utterance_snr_max}",
]
)
+ ")"
)
def __call__(self, sources):
for i, source in enumerate(sources):
if np.random.random() > self.rate:
continue
pri = source.numpy()
if np.random.random() > self.mixing_noise_rate:
sec = sources[np.random.randint(0, len(sources))].numpy()
snr = rand_uniform(self.utterance_snr_min, self.utterance_snr_max)
else:
sec = self.noise_shaper.pick_sample(source.shape)
snr = rand_uniform(self.noise_snr_min, self.noise_snr_max)
L1 = pri.shape[-1]
L2 = sec.shape[-1]
l = np.random.randint(0, min(round(L1 / 2), L2)) # mix len
s_source = np.random.randint(0, L1 - l)
s_sec = np.random.randint(0, L2 - l)
get_power = lambda x: np.mean(x**2)
if get_power(sec) == 0:
continue
scl = np.sqrt(get_power(pri) / (np.power(10, snr / 10) * get_power(sec)))
pri[s_source : s_source + l] = np.add(
pri[s_source : s_source + l], np.multiply(scl, sec[s_sec : s_sec + l])
)
sources[i] = torch.from_numpy(pri).float()
return sources

View File

@@ -0,0 +1,43 @@
import os
from fairseq.data.audio import (
AudioTransform,
CompositeAudioTransform,
import_transforms,
register_audio_transform,
)
class AudioFeatureTransform(AudioTransform):
pass
AUDIO_FEATURE_TRANSFORM_REGISTRY = {}
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set()
def get_audio_feature_transform(name):
return AUDIO_FEATURE_TRANSFORM_REGISTRY[name]
def register_audio_feature_transform(name):
return register_audio_transform(
name,
AudioFeatureTransform,
AUDIO_FEATURE_TRANSFORM_REGISTRY,
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES,
)
import_transforms(os.path.dirname(__file__), "feature")
class CompositeAudioFeatureTransform(CompositeAudioTransform):
@classmethod
def from_config_dict(cls, config=None):
return super()._from_config_dict(
cls,
"feature",
get_audio_feature_transform,
CompositeAudioFeatureTransform,
config,
)

View File

@@ -0,0 +1,37 @@
import numpy as np
import torch
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform,
register_audio_feature_transform,
)
@register_audio_feature_transform("delta_deltas")
class DeltaDeltas(AudioFeatureTransform):
"""Expand delta-deltas features from spectrum."""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return DeltaDeltas(_config.get("win_length", 5))
def __init__(self, win_length=5):
self.win_length = win_length
def __repr__(self):
return self.__class__.__name__
def __call__(self, spectrogram):
from torchaudio.functional import compute_deltas
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
# spectrogram is T x F, while compute_deltas takes (…, F, T)
spectrogram = torch.from_numpy(spectrogram).transpose(0, 1)
delta = compute_deltas(spectrogram)
delta_delta = compute_deltas(delta)
out_feat = np.concatenate(
[spectrogram, delta.numpy(), delta_delta.numpy()], axis=0
)
out_feat = np.transpose(out_feat)
return out_feat

View File

@@ -0,0 +1,29 @@
import numpy as np
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform,
register_audio_feature_transform,
)
@register_audio_feature_transform("global_cmvn")
class GlobalCMVN(AudioFeatureTransform):
"""Global CMVN (cepstral mean and variance normalization). The global mean
and variance need to be pre-computed and stored in NumPy format (.npz)."""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return GlobalCMVN(_config.get("stats_npz_path"))
def __init__(self, stats_npz_path):
self.stats_npz_path = stats_npz_path
stats = np.load(stats_npz_path)
self.mean, self.std = stats["mean"], stats["std"]
def __repr__(self):
return self.__class__.__name__ + f'(stats_npz_path="{self.stats_npz_path}")'
def __call__(self, x):
x = np.subtract(x, self.mean)
x = np.divide(x, self.std)
return x

View File

@@ -0,0 +1,131 @@
import math
import numbers
from typing import Optional
import numpy as np
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform,
register_audio_feature_transform,
)
@register_audio_feature_transform("specaugment")
class SpecAugmentTransform(AudioFeatureTransform):
"""SpecAugment (https://arxiv.org/abs/1904.08779)"""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return SpecAugmentTransform(
_config.get("time_warp_W", 0),
_config.get("freq_mask_N", 0),
_config.get("freq_mask_F", 0),
_config.get("time_mask_N", 0),
_config.get("time_mask_T", 0),
_config.get("time_mask_p", 0.0),
_config.get("mask_value", None),
)
def __init__(
self,
time_warp_w: int = 0,
freq_mask_n: int = 0,
freq_mask_f: int = 0,
time_mask_n: int = 0,
time_mask_t: int = 0,
time_mask_p: float = 0.0,
mask_value: Optional[float] = 0.0,
):
# Sanity checks
assert mask_value is None or isinstance(
mask_value, numbers.Number
), f"mask_value (type: {type(mask_value)}) must be None or a number"
if freq_mask_n > 0:
assert freq_mask_f > 0, (
f"freq_mask_F ({freq_mask_f}) "
f"must be larger than 0 when doing freq masking."
)
if time_mask_n > 0:
assert time_mask_t > 0, (
f"time_mask_T ({time_mask_t}) must be larger than 0 when "
f"doing time masking."
)
self.time_warp_w = time_warp_w
self.freq_mask_n = freq_mask_n
self.freq_mask_f = freq_mask_f
self.time_mask_n = time_mask_n
self.time_mask_t = time_mask_t
self.time_mask_p = time_mask_p
self.mask_value = mask_value
def __repr__(self):
return (
self.__class__.__name__
+ "("
+ ", ".join(
[
f"time_warp_w={self.time_warp_w}",
f"freq_mask_n={self.freq_mask_n}",
f"freq_mask_f={self.freq_mask_f}",
f"time_mask_n={self.time_mask_n}",
f"time_mask_t={self.time_mask_t}",
f"time_mask_p={self.time_mask_p}",
]
)
+ ")"
)
def __call__(self, spectrogram):
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
distorted = spectrogram.copy() # make a copy of input spectrogram.
num_frames = spectrogram.shape[0] # or 'tau' in the paper.
num_freqs = spectrogram.shape[1] # or 'miu' in the paper.
mask_value = self.mask_value
if mask_value is None: # if no value was specified, use local mean.
mask_value = spectrogram.mean()
if num_frames == 0:
return spectrogram
if num_freqs < self.freq_mask_f:
return spectrogram
if self.time_warp_w > 0:
if 2 * self.time_warp_w < num_frames:
import cv2
w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w)
w = np.random.randint(-self.time_warp_w + 1, self.time_warp_w)
upper, lower = distorted[:w0, :], distorted[w0:, :]
upper = cv2.resize(
upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR
)
lower = cv2.resize(
lower,
dsize=(num_freqs, num_frames - w0 - w),
interpolation=cv2.INTER_LINEAR,
)
distorted = np.concatenate((upper, lower), axis=0)
for _i in range(self.freq_mask_n):
f = np.random.randint(0, self.freq_mask_f)
f0 = np.random.randint(0, num_freqs - f)
if f != 0:
distorted[:, f0 : f0 + f] = mask_value
max_time_mask_t = min(
self.time_mask_t, math.floor(num_frames * self.time_mask_p)
)
if max_time_mask_t < 1:
return distorted
for _i in range(self.time_mask_n):
t = np.random.randint(0, max_time_mask_t)
t0 = np.random.randint(0, num_frames - t)
if t != 0:
distorted[t0 : t0 + t, :] = mask_value
return distorted

View File

@@ -0,0 +1,41 @@
import numpy as np
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform,
register_audio_feature_transform,
)
@register_audio_feature_transform("utterance_cmvn")
class UtteranceCMVN(AudioFeatureTransform):
"""Utterance-level CMVN (cepstral mean and variance normalization)"""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return UtteranceCMVN(
_config.get("norm_means", True),
_config.get("norm_vars", True),
)
def __init__(self, norm_means=True, norm_vars=True):
self.norm_means, self.norm_vars = norm_means, norm_vars
def __repr__(self):
return (
self.__class__.__name__
+ f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})"
)
def __call__(self, x):
mean = x.mean(axis=0)
square_sums = (x**2).sum(axis=0)
if self.norm_means:
x = np.subtract(x, mean)
if self.norm_vars:
var = square_sums / x.shape[0] - mean**2
std = np.sqrt(np.maximum(var, 1e-10))
x = np.divide(x, std)
return x

View File

@@ -0,0 +1,205 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.abs
import csv
import logging
import os.path as op
from typing import List, Optional
import numpy as np
import torch
from fairseq.data import Dictionary
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
from fairseq.data.audio.text_to_speech_dataset import (
TextToSpeechDataset,
TextToSpeechDatasetCreator,
)
logger = logging.getLogger(__name__)
class FrmTextToSpeechDataset(TextToSpeechDataset):
def __init__(
self,
split: str,
is_train_split: bool,
data_cfg: S2TDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
n_frames_per_step=1,
speaker_to_id=None,
do_chunk=False,
chunk_bound=-1,
chunk_init=50,
chunk_incr=5,
add_eos=True,
dedup=True,
ref_fpu=-1,
):
# It assumes texts are encoded at a fixed frame-rate
super().__init__(
split=split,
is_train_split=is_train_split,
data_cfg=data_cfg,
audio_paths=audio_paths,
n_frames=n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
)
self.do_chunk = do_chunk
self.chunk_bound = chunk_bound
self.chunk_init = chunk_init
self.chunk_incr = chunk_incr
self.add_eos = add_eos
self.dedup = dedup
self.ref_fpu = ref_fpu
self.chunk_size = -1
if do_chunk:
assert self.chunk_incr >= 0
assert self.pre_tokenizer is None
def __getitem__(self, index):
index, source, target, speaker_id, _, _, _ = super().__getitem__(index)
if target[-1].item() == self.tgt_dict.eos_index:
target = target[:-1]
fpu = source.size(0) / target.size(0) # frame-per-unit
fps = self.n_frames_per_step
assert (
self.ref_fpu == -1 or abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1
), f"{fpu*fps} != {self.ref_fpu}"
# only chunk training split
if self.is_train_split and self.do_chunk and self.chunk_size > 0:
lang = target[: int(self.data_cfg.prepend_tgt_lang_tag)]
text = target[int(self.data_cfg.prepend_tgt_lang_tag) :]
size = len(text)
chunk_size = min(self.chunk_size, size)
chunk_start = np.random.randint(size - chunk_size + 1)
text = text[chunk_start : chunk_start + chunk_size]
target = torch.cat((lang, text), 0)
f_size = int(np.floor(chunk_size * fpu))
f_start = int(np.floor(chunk_start * fpu))
assert f_size > 0
source = source[f_start : f_start + f_size, :]
if self.dedup:
target = torch.unique_consecutive(target)
if self.add_eos:
eos_idx = self.tgt_dict.eos_index
target = torch.cat((target, torch.LongTensor([eos_idx])), 0)
return index, source, target, speaker_id
def set_epoch(self, epoch):
if self.is_train_split and self.do_chunk:
old = self.chunk_size
self.chunk_size = self.chunk_init + epoch * self.chunk_incr
if self.chunk_bound > 0:
self.chunk_size = min(self.chunk_size, self.chunk_bound)
logger.info(
(
f"{self.split}: setting chunk size "
f"from {old} to {self.chunk_size}"
)
)
class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
# inherit for key names
@classmethod
def from_tsv(
cls,
root: str,
data_cfg: S2TDataConfig,
split: str,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
is_train_split: bool,
n_frames_per_step: int,
speaker_to_id,
do_chunk: bool = False,
chunk_bound: int = -1,
chunk_init: int = 50,
chunk_incr: int = 5,
add_eos: bool = True,
dedup: bool = True,
ref_fpu: float = -1,
) -> FrmTextToSpeechDataset:
tsv_path = op.join(root, f"{split}.tsv")
if not op.isfile(tsv_path):
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
with open(tsv_path) as f:
reader = csv.DictReader(
f,
delimiter="\t",
quotechar=None,
doublequote=False,
lineterminator="\n",
quoting=csv.QUOTE_NONE,
)
s = [dict(e) for e in reader]
assert len(s) > 0
ids = [ss[cls.KEY_ID] for ss in s]
audio_paths = [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s]
n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s]
tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s]
src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s]
speakers = [ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s]
src_langs = [ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s]
tgt_langs = [ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s]
return FrmTextToSpeechDataset(
split=split,
is_train_split=is_train_split,
data_cfg=data_cfg,
audio_paths=audio_paths,
n_frames=n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
do_chunk=do_chunk,
chunk_bound=chunk_bound,
chunk_init=chunk_init,
chunk_incr=chunk_incr,
add_eos=add_eos,
dedup=dedup,
ref_fpu=ref_fpu,
)

View File

@@ -0,0 +1,356 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import logging
import os
import sys
from typing import Any, List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.data import data_utils
from fairseq.data.fairseq_dataset import FairseqDataset
from fairseq.data.audio.audio_utils import (
parse_path,
read_from_stored_zip,
)
import io
logger = logging.getLogger(__name__)
def load_audio(manifest_path, max_keep, min_keep):
n_long, n_short = 0, 0
names, inds, sizes = [], [], []
with open(manifest_path) as f:
root = f.readline().strip()
for ind, line in enumerate(f):
items = line.strip().split("\t")
assert len(items) == 2, line
sz = int(items[1])
if min_keep is not None and sz < min_keep:
n_short += 1
elif max_keep is not None and sz > max_keep:
n_long += 1
else:
names.append(items[0])
inds.append(ind)
sizes.append(sz)
tot = ind + 1
logger.info(
(
f"max_keep={max_keep}, min_keep={min_keep}, "
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
)
)
return root, names, inds, tot, sizes
def load_label(label_path, inds, tot):
with open(label_path) as f:
labels = [line.rstrip() for line in f]
assert (
len(labels) == tot
), f"number of labels does not match ({len(labels)} != {tot})"
labels = [labels[i] for i in inds]
return labels
def load_label_offset(label_path, inds, tot):
with open(label_path) as f:
code_lengths = [len(line.encode("utf-8")) for line in f]
assert (
len(code_lengths) == tot
), f"number of labels does not match ({len(code_lengths)} != {tot})"
offsets = list(itertools.accumulate([0] + code_lengths))
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
return offsets
def verify_label_lengths(
audio_sizes,
audio_rate,
label_path,
label_rate,
inds,
tot,
tol=0.1, # tolerance in seconds
):
if label_rate < 0:
logger.info(f"{label_path} is sequence label. skipped")
return
with open(label_path) as f:
lengths = [len(line.rstrip().split()) for line in f]
assert len(lengths) == tot
lengths = [lengths[i] for i in inds]
num_invalid = 0
for i, ind in enumerate(inds):
dur_from_audio = audio_sizes[i] / audio_rate
dur_from_label = lengths[i] / label_rate
if abs(dur_from_audio - dur_from_label) > tol:
logger.warning(
(
f"audio and label duration differ too much "
f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
f"in line {ind+1} of {label_path}. Check if `label_rate` "
f"is correctly set (currently {label_rate}). "
f"num. of samples = {audio_sizes[i]}; "
f"label length = {lengths[i]}"
)
)
num_invalid += 1
if num_invalid > 0:
logger.warning(
f"total {num_invalid} (audio, label) pairs with mismatched lengths"
)
class HubertDataset(FairseqDataset):
def __init__(
self,
manifest_path: str,
sample_rate: float,
label_paths: List[str],
label_rates: Union[List[float], float], # -1 for sequence labels
pad_list: List[str],
eos_list: List[str],
label_processors: Optional[List[Any]] = None,
max_keep_sample_size: Optional[int] = None,
min_keep_sample_size: Optional[int] = None,
max_sample_size: Optional[int] = None,
shuffle: bool = True,
pad_audio: bool = False,
normalize: bool = False,
store_labels: bool = True,
random_crop: bool = False,
single_target: bool = False,
):
self.audio_root, self.audio_names, inds, tot, self.sizes = load_audio(
manifest_path, max_keep_sample_size, min_keep_sample_size
)
self.sample_rate = sample_rate
self.shuffle = shuffle
self.random_crop = random_crop
self.num_labels = len(label_paths)
self.pad_list = pad_list
self.eos_list = eos_list
self.label_processors = label_processors
self.single_target = single_target
self.label_rates = (
[label_rates for _ in range(len(label_paths))]
if isinstance(label_rates, float)
else label_rates
)
self.store_labels = store_labels
if store_labels:
self.label_list = [load_label(p, inds, tot) for p in label_paths]
else:
self.label_paths = label_paths
self.label_offsets_list = [
load_label_offset(p, inds, tot) for p in label_paths
]
assert label_processors is None or len(label_processors) == self.num_labels
for label_path, label_rate in zip(label_paths, self.label_rates):
verify_label_lengths(
self.sizes, sample_rate, label_path, label_rate, inds, tot
)
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
self.pad_audio = pad_audio
self.normalize = normalize
logger.info(
f"pad_audio={pad_audio}, random_crop={random_crop}, "
f"normalize={normalize}, max_sample_size={self.max_sample_size}"
)
def get_audio(self, index):
import soundfile as sf
wav_path = os.path.join(self.audio_root, self.audio_names[index])
_path, slice_ptr = parse_path(wav_path)
if len(slice_ptr) == 0:
wav, cur_sample_rate = sf.read(_path)
else:
assert _path.endswith(".zip")
data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
f = io.BytesIO(data)
wav, cur_sample_rate = sf.read(f)
wav = torch.from_numpy(wav).float()
wav = self.postprocess(wav, cur_sample_rate)
return wav
def get_label(self, index, label_idx):
if self.store_labels:
label = self.label_list[label_idx][index]
else:
with open(self.label_paths[label_idx]) as f:
offset_s, offset_e = self.label_offsets_list[label_idx][index]
f.seek(offset_s)
label = f.read(offset_e - offset_s)
if self.label_processors is not None:
label = self.label_processors[label_idx](label)
return label
def get_labels(self, index):
return [self.get_label(index, i) for i in range(self.num_labels)]
def __getitem__(self, index):
wav = self.get_audio(index)
labels = self.get_labels(index)
return {"id": index, "source": wav, "label_list": labels}
def __len__(self):
return len(self.sizes)
def crop_to_max_size(self, wav, target_size):
size = len(wav)
diff = size - target_size
if diff <= 0:
return wav, 0
start, end = 0, target_size
if self.random_crop:
start = np.random.randint(0, diff + 1)
end = size - diff + start
return wav[start:end], start
def collater(self, samples):
# target = max(sizes) -> random_crop not used
# target = max_sample_size -> random_crop used for long
samples = [s for s in samples if s["source"] is not None]
if len(samples) == 0:
return {}
audios = [s["source"] for s in samples]
audio_sizes = [len(s) for s in audios]
if self.pad_audio:
audio_size = min(max(audio_sizes), self.max_sample_size)
else:
audio_size = min(min(audio_sizes), self.max_sample_size)
collated_audios, padding_mask, audio_starts = self.collater_audio(
audios, audio_size
)
targets_by_label = [
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
]
targets_list, lengths_list, ntokens_list = self.collater_label(
targets_by_label, audio_size, audio_starts
)
net_input = {"source": collated_audios, "padding_mask": padding_mask}
batch = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": net_input,
}
if self.single_target:
batch["target_lengths"] = lengths_list[0]
batch["ntokens"] = ntokens_list[0]
batch["target"] = targets_list[0]
else:
batch["target_lengths_list"] = lengths_list
batch["ntokens_list"] = ntokens_list
batch["target_list"] = targets_list
return batch
def collater_audio(self, audios, audio_size):
collated_audios = audios[0].new_zeros(len(audios), audio_size)
padding_mask = (
torch.BoolTensor(collated_audios.shape).fill_(False)
# if self.pad_audio else None
)
audio_starts = [0 for _ in audios]
for i, audio in enumerate(audios):
diff = len(audio) - audio_size
if diff == 0:
collated_audios[i] = audio
elif diff < 0:
assert self.pad_audio
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
padding_mask[i, diff:] = True
else:
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
audio, audio_size
)
return collated_audios, padding_mask, audio_starts
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
assert label_rate > 0
s2f = label_rate / self.sample_rate
frm_starts = [int(round(s * s2f)) for s in audio_starts]
frm_size = int(round(audio_size * s2f))
if not self.pad_audio:
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
frm_size = min(frm_size, *rem_size)
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
logger.debug(f"audio_starts={audio_starts}")
logger.debug(f"frame_starts={frm_starts}")
logger.debug(f"frame_size={frm_size}")
lengths = torch.LongTensor([len(t) for t in targets])
ntokens = lengths.sum().item()
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
return targets, lengths, ntokens
def collater_seq_label(self, targets, pad):
lengths = torch.LongTensor([len(t) for t in targets])
ntokens = lengths.sum().item()
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
return targets, lengths, ntokens
def collater_label(self, targets_by_label, audio_size, audio_starts):
targets_list, lengths_list, ntokens_list = [], [], []
itr = zip(targets_by_label, self.label_rates, self.pad_list)
for targets, label_rate, pad in itr:
if label_rate == -1.0:
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
else:
targets, lengths, ntokens = self.collater_frm_label(
targets, audio_size, audio_starts, label_rate, pad
)
targets_list.append(targets)
lengths_list.append(lengths)
ntokens_list.append(ntokens)
return targets_list, lengths_list, ntokens_list
def num_tokens(self, index):
return self.size(index)
def size(self, index):
if self.pad_audio:
return self.sizes[index]
return min(self.sizes[index], self.max_sample_size)
def ordered_indices(self):
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)[::-1]
def postprocess(self, wav, cur_sample_rate):
if wav.dim() == 2:
wav = wav.mean(-1)
assert wav.dim() == 1, wav.dim()
if cur_sample_rate != self.sample_rate:
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
if self.normalize:
with torch.no_grad():
wav = F.layer_norm(wav, wav.shape)
return wav

View File

@@ -0,0 +1,284 @@
# Copyright (c) 2021-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
import math
from typing import List, Optional, NamedTuple
import numpy as np
from fairseq.data.resampling_dataset import ResamplingDataset
import torch
from fairseq.data import (
ConcatDataset,
LanguagePairDataset,
FileAudioDataset,
data_utils,
)
from fairseq.data import FairseqDataset
logger = logging.getLogger(__name__)
class ModalityDatasetItem(NamedTuple):
datasetname: str
dataset: any
max_positions: List[int]
max_tokens: Optional[int] = None
max_sentences: Optional[int] = None
def resampling_dataset_present(ds):
if isinstance(ds, ResamplingDataset):
return True
if isinstance(ds, ConcatDataset):
return any(resampling_dataset_present(d) for d in ds.datasets)
if hasattr(ds, "dataset"):
return resampling_dataset_present(ds.dataset)
return False
# MultiModalityDataset: it concate multiple datasets with different modalities.
# Compared with ConcatDataset it can 1) sample data given the ratios for different datasets
# 2) it adds mode to indicate what type of the data samples come from.
# It will be used with GroupedEpochBatchIterator together to generate mini-batch with samples
# from the same type of dataset
# If only one dataset is used, it will perform like the original dataset with mode added
class MultiModalityDataset(ConcatDataset):
def __init__(self, datasets: List[ModalityDatasetItem]):
id_to_mode = []
dsets = []
max_tokens = []
max_sentences = []
max_positions = []
for dset in datasets:
id_to_mode.append(dset.datasetname)
dsets.append(dset.dataset)
max_tokens.append(dset.max_tokens)
max_positions.append(dset.max_positions)
max_sentences.append(dset.max_sentences)
weights = [1.0 for s in dsets]
super().__init__(dsets, weights)
self.max_tokens = max_tokens
self.max_positions = max_positions
self.max_sentences = max_sentences
self.id_to_mode = id_to_mode
self.raw_sub_batch_samplers = []
self._cur_epoch = 0
def set_epoch(self, epoch):
super().set_epoch(epoch)
self._cur_epoch = epoch
def __getitem__(self, idx):
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
sample = self.datasets[dataset_idx][sample_idx]
return (dataset_idx, sample)
def collater(self, samples):
if len(samples) == 0:
return {}
dataset_idx = samples[0][0]
# make sure all samples in samples are from same dataset
assert sum([0 if dataset_idx == s[0] else 1 for s in samples]) == 0
samples = self.datasets[dataset_idx].collater([x[1] for x in samples])
# add mode
samples["net_input"]["mode"] = self.id_to_mode[dataset_idx]
return samples
def size(self, index: int):
if len(self.datasets) == 1:
return self.datasets[0].size(index)
return super().size(index)
@property
def sizes(self):
if len(self.datasets) == 1:
return self.datasets[0].sizes
return super().sizes
def ordered_indices(self):
"""
Returns indices sorted by length. So less padding is needed.
"""
if len(self.datasets) == 1:
return self.datasets[0].ordered_indices()
indices_group = []
for d_idx, ds in enumerate(self.datasets):
sample_num = self.cumulative_sizes[d_idx]
if d_idx > 0:
sample_num = sample_num - self.cumulative_sizes[d_idx - 1]
assert sample_num == len(ds)
indices_group.append(ds.ordered_indices())
return indices_group
def get_raw_batch_samplers(self, required_batch_size_multiple, seed):
with data_utils.numpy_seed(seed):
indices = self.ordered_indices()
for i, ds in enumerate(self.datasets):
# If we have ResamplingDataset, the same id can correpond to a different
# sample in the next epoch, so we need to rebuild this at every epoch
if i < len(self.raw_sub_batch_samplers) and not resampling_dataset_present(
ds
):
logger.info(f"dataset {i} is valid and it is not re-sampled")
continue
indices[i] = ds.filter_indices_by_size(
indices[i],
self.max_positions[i],
)[0]
sub_batch_sampler = ds.batch_by_size(
indices[i],
max_tokens=self.max_tokens[i],
max_sentences=self.max_sentences[i],
required_batch_size_multiple=required_batch_size_multiple,
)
if i < len(self.raw_sub_batch_samplers):
self.raw_sub_batch_samplers[i] = sub_batch_sampler
else:
self.raw_sub_batch_samplers.append(sub_batch_sampler)
def get_batch_samplers(self, mult_ratios, required_batch_size_multiple, seed):
self.get_raw_batch_samplers(required_batch_size_multiple, seed)
batch_samplers = []
for i, _ in enumerate(self.datasets):
if i > 0:
sub_batch_sampler = [
[y + self.cumulative_sizes[i - 1] for y in x]
for x in self.raw_sub_batch_samplers[i]
]
else:
sub_batch_sampler = list(self.raw_sub_batch_samplers[i])
smp_r = mult_ratios[i]
if smp_r != 1:
is_increase = "increased" if smp_r > 1 else "decreased"
logger.info(
"number of batch for the dataset {} is {} from {} to {}".format(
self.id_to_mode[i],
is_increase,
len(sub_batch_sampler),
int(len(sub_batch_sampler) * smp_r),
)
)
mul_samplers = []
for _ in range(math.floor(smp_r)):
mul_samplers = mul_samplers + sub_batch_sampler
if math.floor(smp_r) != smp_r:
with data_utils.numpy_seed(seed + self._cur_epoch):
np.random.shuffle(sub_batch_sampler)
smp_num = int(
(smp_r - math.floor(smp_r)) * len(sub_batch_sampler)
)
mul_samplers = mul_samplers + sub_batch_sampler[:smp_num]
sub_batch_sampler = mul_samplers
else:
logger.info(
"dataset {} batch number is {} ".format(
self.id_to_mode[i], len(sub_batch_sampler)
)
)
batch_samplers.append(sub_batch_sampler)
return batch_samplers
class LangPairMaskDataset(FairseqDataset):
def __init__(
self,
dataset: LanguagePairDataset,
src_eos: int,
src_bos: Optional[int] = None,
noise_id: Optional[int] = -1,
mask_ratio: Optional[float] = 0,
mask_type: Optional[str] = "random",
):
self.dataset = dataset
self.src_eos = src_eos
self.src_bos = src_bos
self.noise_id = noise_id
self.mask_ratio = mask_ratio
self.mask_type = mask_type
assert mask_type in ("random", "tail")
@property
def src_sizes(self):
return self.dataset.src_sizes
@property
def tgt_sizes(self):
return self.dataset.tgt_sizes
@property
def sizes(self):
# dataset.sizes can be a dynamically computed sizes:
return self.dataset.sizes
def get_batch_shapes(self):
if hasattr(self.dataset, "get_batch_shapes"):
return self.dataset.get_batch_shapes()
return self.dataset.buckets
def num_tokens_vec(self, indices):
return self.dataset.num_tokens_vec(indices)
def __len__(self):
return len(self.dataset)
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
return self.dataset.size(index)
def ordered_indices(self):
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)
def mask_src_tokens(self, sample):
src_item = sample["source"]
mask = None
if self.mask_type == "random":
mask = torch.rand(len(src_item)).le(self.mask_ratio)
else:
mask = torch.ones(len(src_item))
mask[: int(len(src_item) * (1 - self.mask_ratio))] = 0
mask = mask.eq(1)
if src_item[0] == self.src_bos:
mask[0] = False
if src_item[-1] == self.src_eos:
mask[-1] = False
mask_src_item = src_item.masked_fill(mask, self.noise_id)
smp = {"id": sample["id"], "source": mask_src_item, "target": sample["target"]}
return smp
def __getitem__(self, index):
sample = self.dataset[index]
if self.mask_ratio > 0:
sample = self.mask_src_tokens(sample)
return sample
def collater(self, samples, pad_to_length=None):
return self.dataset.collater(samples, pad_to_length)
class FileAudioDatasetWrapper(FileAudioDataset):
def collater(self, samples):
samples = super().collater(samples)
if len(samples) == 0:
return {}
samples["net_input"]["src_tokens"] = samples["net_input"]["source"]
samples["net_input"]["prev_output_tokens"] = None
del samples["net_input"]["source"]
samples["net_input"]["src_lengths"] = None
samples["net_input"]["alignment"] = None
return samples

View File

@@ -0,0 +1,393 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import sys
import io
import numpy as np
import torch
import torch.nn.functional as F
from .. import FairseqDataset
from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes
from fairseq.data.audio.audio_utils import (
parse_path,
read_from_stored_zip,
is_sf_audio_data,
)
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
logger = logging.getLogger(__name__)
class RawAudioDataset(FairseqDataset):
def __init__(
self,
sample_rate,
max_sample_size=None,
min_sample_size=0,
shuffle=True,
pad=False,
normalize=False,
compute_mask_indices=False,
**mask_compute_kwargs,
):
super().__init__()
self.sample_rate = sample_rate
self.sizes = []
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
self.min_sample_size = min_sample_size
self.pad = pad
self.shuffle = shuffle
self.normalize = normalize
self.compute_mask_indices = compute_mask_indices
if self.compute_mask_indices:
self.mask_compute_kwargs = mask_compute_kwargs
self._features_size_map = {}
self._C = mask_compute_kwargs["encoder_embed_dim"]
self._conv_feature_layers = eval(mask_compute_kwargs["conv_feature_layers"])
def __getitem__(self, index):
raise NotImplementedError()
def __len__(self):
return len(self.sizes)
def postprocess(self, feats, curr_sample_rate):
if feats.dim() == 2:
feats = feats.mean(-1)
if curr_sample_rate != self.sample_rate:
raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}")
assert feats.dim() == 1, feats.dim()
if self.normalize:
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
return feats
def crop_to_max_size(self, wav, target_size):
size = len(wav)
diff = size - target_size
if diff <= 0:
return wav
start = np.random.randint(0, diff + 1)
end = size - diff + start
return wav[start:end]
def _compute_mask_indices(self, dims, padding_mask):
B, T, C = dims
mask_indices, mask_channel_indices = None, None
if self.mask_compute_kwargs["mask_prob"] > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_compute_kwargs["mask_prob"],
self.mask_compute_kwargs["mask_length"],
self.mask_compute_kwargs["mask_selection"],
self.mask_compute_kwargs["mask_other"],
min_masks=2,
no_overlap=self.mask_compute_kwargs["no_mask_overlap"],
min_space=self.mask_compute_kwargs["mask_min_space"],
)
mask_indices = torch.from_numpy(mask_indices)
if self.mask_compute_kwargs["mask_channel_prob"] > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_compute_kwargs["mask_channel_prob"],
self.mask_compute_kwargs["mask_channel_length"],
self.mask_compute_kwargs["mask_channel_selection"],
self.mask_compute_kwargs["mask_channel_other"],
no_overlap=self.mask_compute_kwargs["no_mask_channel_overlap"],
min_space=self.mask_compute_kwargs["mask_channel_min_space"],
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices).unsqueeze(1).expand(-1, T, -1)
)
return mask_indices, mask_channel_indices
@staticmethod
def _bucket_tensor(tensor, num_pad, value):
return F.pad(tensor, (0, num_pad), value=value)
def collater(self, samples):
samples = [s for s in samples if s["source"] is not None]
if len(samples) == 0:
return {}
sources = [s["source"] for s in samples]
sizes = [len(s) for s in sources]
if self.pad:
target_size = min(max(sizes), self.max_sample_size)
else:
target_size = min(min(sizes), self.max_sample_size)
collated_sources = sources[0].new_zeros(len(sources), target_size)
padding_mask = (
torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None
)
for i, (source, size) in enumerate(zip(sources, sizes)):
diff = size - target_size
if diff == 0:
collated_sources[i] = source
elif diff < 0:
assert self.pad
collated_sources[i] = torch.cat(
[source, source.new_full((-diff,), 0.0)]
)
padding_mask[i, diff:] = True
else:
collated_sources[i] = self.crop_to_max_size(source, target_size)
input = {"source": collated_sources}
out = {"id": torch.LongTensor([s["id"] for s in samples])}
if self.pad:
input["padding_mask"] = padding_mask
if hasattr(self, "num_buckets") and self.num_buckets > 0:
assert self.pad, "Cannot bucket without padding first."
bucket = max(self._bucketed_sizes[s["id"]] for s in samples)
num_pad = bucket - collated_sources.size(-1)
if num_pad:
input["source"] = self._bucket_tensor(collated_sources, num_pad, 0)
input["padding_mask"] = self._bucket_tensor(padding_mask, num_pad, True)
if self.compute_mask_indices:
B = input["source"].size(0)
T = self._get_mask_indices_dims(input["source"].size(-1))
padding_mask_reshaped = input["padding_mask"].clone()
extra = padding_mask_reshaped.size(1) % T
if extra > 0:
padding_mask_reshaped = padding_mask_reshaped[:, :-extra]
padding_mask_reshaped = padding_mask_reshaped.view(
padding_mask_reshaped.size(0), T, -1
)
padding_mask_reshaped = padding_mask_reshaped.all(-1)
input["padding_count"] = padding_mask_reshaped.sum(-1).max().item()
mask_indices, mask_channel_indices = self._compute_mask_indices(
(B, T, self._C),
padding_mask_reshaped,
)
input["mask_indices"] = mask_indices
input["mask_channel_indices"] = mask_channel_indices
out["sample_size"] = mask_indices.sum().item()
out["net_input"] = input
return out
def _get_mask_indices_dims(self, size, padding=0, dilation=1):
if size not in self._features_size_map:
L_in = size
for (_, kernel_size, stride) in self._conv_feature_layers:
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
L_out = 1 + L_out // stride
L_in = L_out
self._features_size_map[size] = L_out
return self._features_size_map[size]
def num_tokens(self, index):
return self.size(index)
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
if self.pad:
return self.sizes[index]
return min(self.sizes[index], self.max_sample_size)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
order.append(
np.minimum(
np.array(self.sizes),
self.max_sample_size,
)
)
return np.lexsort(order)[::-1]
else:
return np.arange(len(self))
def set_bucket_info(self, num_buckets):
self.num_buckets = num_buckets
if self.num_buckets > 0:
self._collated_sizes = np.minimum(
np.array(self.sizes),
self.max_sample_size,
)
self.buckets = get_buckets(
self._collated_sizes,
self.num_buckets,
)
self._bucketed_sizes = get_bucketed_sizes(
self._collated_sizes, self.buckets
)
logger.info(
f"{len(self.buckets)} bucket(s) for the audio dataset: "
f"{self.buckets}"
)
class FileAudioDataset(RawAudioDataset):
def __init__(
self,
manifest_path,
sample_rate,
max_sample_size=None,
min_sample_size=0,
shuffle=True,
pad=False,
normalize=False,
num_buckets=0,
compute_mask_indices=False,
text_compression_level=TextCompressionLevel.none,
**mask_compute_kwargs,
):
super().__init__(
sample_rate=sample_rate,
max_sample_size=max_sample_size,
min_sample_size=min_sample_size,
shuffle=shuffle,
pad=pad,
normalize=normalize,
compute_mask_indices=compute_mask_indices,
**mask_compute_kwargs,
)
self.text_compressor = TextCompressor(level=text_compression_level)
skipped = 0
self.fnames = []
sizes = []
self.skipped_indices = set()
with open(manifest_path, "r") as f:
self.root_dir = f.readline().strip()
for i, line in enumerate(f):
items = line.strip().split("\t")
assert len(items) == 2, line
sz = int(items[1])
if min_sample_size is not None and sz < min_sample_size:
skipped += 1
self.skipped_indices.add(i)
continue
self.fnames.append(self.text_compressor.compress(items[0]))
sizes.append(sz)
logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples")
self.sizes = np.array(sizes, dtype=np.int64)
try:
import pyarrow
self.fnames = pyarrow.array(self.fnames)
except:
logger.debug(
"Could not create a pyarrow array. Please install pyarrow for better performance"
)
pass
self.set_bucket_info(num_buckets)
def __getitem__(self, index):
import soundfile as sf
fn = self.fnames[index]
fn = fn if isinstance(self.fnames, list) else fn.as_py()
fn = self.text_compressor.decompress(fn)
path_or_fp = os.path.join(self.root_dir, fn)
_path, slice_ptr = parse_path(path_or_fp)
if len(slice_ptr) == 2:
byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
assert is_sf_audio_data(byte_data)
path_or_fp = io.BytesIO(byte_data)
wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32")
feats = torch.from_numpy(wav).float()
feats = self.postprocess(feats, curr_sample_rate)
return {"id": index, "source": feats}
class BinarizedAudioDataset(RawAudioDataset):
def __init__(
self,
data_dir,
split,
sample_rate,
max_sample_size=None,
min_sample_size=0,
shuffle=True,
pad=False,
normalize=False,
num_buckets=0,
compute_mask_indices=False,
**mask_compute_kwargs,
):
super().__init__(
sample_rate=sample_rate,
max_sample_size=max_sample_size,
min_sample_size=min_sample_size,
shuffle=shuffle,
pad=pad,
normalize=normalize,
compute_mask_indices=compute_mask_indices,
**mask_compute_kwargs,
)
from fairseq.data import data_utils, Dictionary
self.fnames_dict = Dictionary.load(os.path.join(data_dir, "dict.txt"))
root_path = os.path.join(data_dir, f"{split}.root")
if os.path.exists(root_path):
with open(root_path, "r") as f:
self.root_dir = next(f).strip()
else:
self.root_dir = None
fnames_path = os.path.join(data_dir, split)
self.fnames = data_utils.load_indexed_dataset(fnames_path, self.fnames_dict)
lengths_path = os.path.join(data_dir, f"{split}.lengths")
with open(lengths_path, "r") as f:
for line in f:
sz = int(line.rstrip())
assert (
sz >= min_sample_size
), f"Min sample size is not supported for binarized dataset, but found a sample with size {sz}"
self.sizes.append(sz)
self.sizes = np.array(self.sizes, dtype=np.int64)
self.set_bucket_info(num_buckets)
logger.info(f"loaded {len(self.fnames)} samples")
def __getitem__(self, index):
import soundfile as sf
fname = self.fnames_dict.string(self.fnames[index], separator="")
if self.root_dir:
fname = os.path.join(self.root_dir, fname)
wav, curr_sample_rate = sf.read(fname)
feats = torch.from_numpy(wav).float()
feats = self.postprocess(feats, curr_sample_rate)
return {"id": index, "source": feats}

View File

@@ -0,0 +1,379 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
from fairseq.data import ConcatDataset, Dictionary
from fairseq.data import data_utils as fairseq_data_utils
from fairseq.data.audio.audio_utils import get_features_or_waveform
from fairseq.data.audio.data_cfg import S2SDataConfig
from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDataset,
SpeechToTextDatasetCreator,
TextTargetMultitaskData,
_collate_frames,
)
logger = logging.getLogger(__name__)
@dataclass
class SpeechToSpeechDatasetItem(object):
index: int
source: torch.Tensor
target: Optional[torch.Tensor] = None
target_speaker: Optional[torch.Tensor] = None
tgt_lang_tag: Optional[int] = None
class SpeechToSpeechDataset(SpeechToTextDataset):
def __init__(
self,
split: str,
is_train_split: bool,
data_cfg: S2SDataConfig,
src_audio_paths: List[str],
src_n_frames: List[int],
tgt_audio_paths: List[str],
tgt_n_frames: List[int],
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
target_is_code: bool = False,
tgt_dict: Dictionary = None,
n_frames_per_step: int = 1,
):
tgt_texts = tgt_audio_paths if target_is_code else None
super().__init__(
split=split,
is_train_split=is_train_split,
cfg=data_cfg,
audio_paths=src_audio_paths,
n_frames=src_n_frames,
ids=ids,
tgt_dict=tgt_dict,
tgt_texts=tgt_texts,
src_langs=src_langs,
tgt_langs=tgt_langs,
n_frames_per_step=n_frames_per_step,
)
self.tgt_audio_paths = tgt_audio_paths
self.tgt_lens = [t // self.n_frames_per_step for t in tgt_n_frames]
assert not target_is_code or tgt_dict is not None
self.target_is_code = target_is_code
assert len(tgt_audio_paths) == self.n_samples
assert len(tgt_n_frames) == self.n_samples
self.tgt_speakers = None
if self.cfg.target_speaker_embed:
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(
self.cfg.target_speaker_embed, split
)
spk_emb_dict = {s["id"]: s["speaker_embed"] for s in samples}
self.tgt_speakers = [spk_emb_dict[id] for id in self.ids]
assert len(self.tgt_speakers) == self.n_samples
logger.info(self.__repr__())
def pack_units(self, input: torch.Tensor) -> torch.Tensor:
if self.n_frames_per_step <= 1:
return input
offset = 4
vocab_size = (
len(self.tgt_dict) - offset
) # remove offset from <bos>, <pad>, <eos>, <unk>, which is specific to fairseq dictionary
assert input.dim() == 1
stacked_input = (
input[:-1].view(-1, self.n_frames_per_step) - offset
) # remove <eos>
scale = [
pow(vocab_size, self.n_frames_per_step - 1 - i)
for i in range(self.n_frames_per_step)
]
scale = torch.LongTensor(scale).squeeze(0)
res = input.new((len(input) - 1) // self.n_frames_per_step + 1).fill_(input[-1])
res[:-1] = (stacked_input * scale).sum(dim=1) + offset
return res
def __getitem__(self, index: int) -> SpeechToSpeechDatasetItem:
source = self._get_source_audio(index)
tgt_lang_tag = None
if self.cfg.prepend_tgt_lang_tag_as_bos:
# prepend_tgt_lang_tag_as_bos: put tgt_lang_tag as bos of target
tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
if not self.target_is_code:
target = get_features_or_waveform(self.tgt_audio_paths[index])
target = torch.from_numpy(target).float()
target = self.pack_frames(target)
else:
target = self.tgt_dict.encode_line(
self.tgt_audio_paths[index],
add_if_not_exist=False,
append_eos=True,
).long()
if self.n_frames_per_step > 1:
n_tgt_frame = target.size(0) - 1 # exclude <eos>
keep_n_tgt_frame = n_tgt_frame - n_tgt_frame % self.n_frames_per_step
target = torch.cat(
(
target[:keep_n_tgt_frame],
target.new_full((1,), self.tgt_dict.eos()),
),
dim=0,
)
if self.tgt_speakers:
tgt_spk = get_features_or_waveform(self.tgt_speakers[index])
tgt_spk = torch.from_numpy(tgt_spk).float()
else:
tgt_spk = torch.FloatTensor([])
return SpeechToSpeechDatasetItem(
index=index,
source=source,
target=target,
target_speaker=tgt_spk,
tgt_lang_tag=tgt_lang_tag,
)
def _collate_target(self, samples: List[SpeechToSpeechDatasetItem]) -> torch.Tensor:
if self.target_is_code:
target = fairseq_data_utils.collate_tokens(
[x.target for x in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
)
# convert stacked units to a single id
pack_targets = [self.pack_units(x.target) for x in samples]
prev_output_tokens = fairseq_data_utils.collate_tokens(
pack_targets,
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=True,
)
target_lengths = torch.tensor(
[x.size(0) for x in pack_targets], dtype=torch.long
)
else:
target = _collate_frames([x.target for x in samples], is_audio_input=False)
bsz, _, d = target.size()
prev_output_tokens = torch.cat(
(target.new_full((bsz, 1, d), 0.0), target[:, :-1, :]), dim=1
)
target_lengths = torch.tensor(
[x.target.size(0) for x in samples], dtype=torch.long
)
return target, prev_output_tokens, target_lengths
def collater(
self, samples: List[SpeechToSpeechDatasetItem], return_order: bool = False
) -> Dict:
if len(samples) == 0:
return {}
indices = torch.tensor([x.index for x in samples], dtype=torch.long)
frames = _collate_frames([x.source for x in samples], self.cfg.use_audio_input)
# sort samples by descending number of frames
n_frames = torch.tensor([x.source.size(0) for x in samples], dtype=torch.long)
n_frames, order = n_frames.sort(descending=True)
indices = indices.index_select(0, order)
frames = frames.index_select(0, order)
target, prev_output_tokens, target_lengths = self._collate_target(samples)
target = target.index_select(0, order)
target_lengths = target_lengths.index_select(0, order)
prev_output_tokens = prev_output_tokens.index_select(0, order)
ntokens = sum(x.target.size(0) for x in samples)
tgt_speakers = None
if self.cfg.target_speaker_embed:
tgt_speakers = _collate_frames(
[x.target_speaker for x in samples], is_audio_input=True
).index_select(0, order)
net_input = {
"src_tokens": frames,
"src_lengths": n_frames,
"prev_output_tokens": prev_output_tokens,
"tgt_speaker": tgt_speakers, # TODO: unify "speaker" and "tgt_speaker"
}
if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
for i in range(len(samples)):
net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
out = {
"id": indices,
"net_input": net_input,
"speaker": tgt_speakers, # to support Tacotron2 loss for speech-to-spectrogram model
"target": target,
"target_lengths": target_lengths,
"ntokens": ntokens,
"nsentences": len(samples),
}
if return_order:
out["order"] = order
return out
class SpeechToSpeechMultitaskDataset(SpeechToSpeechDataset):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.multitask_data = {}
def add_multitask_dataset(self, task_name, task_data):
self.multitask_data[task_name] = task_data
def __getitem__(
self, index: int
) -> Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]:
s2s_data = super().__getitem__(index)
multitask_target = {}
sample_id = self.ids[index]
tgt_lang = self.tgt_langs[index]
for task_name, task_dataset in self.multitask_data.items():
multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
return s2s_data, multitask_target
def collater(
self, samples: List[Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]]
) -> Dict:
if len(samples) == 0:
return {}
out = super().collater([s for s, _ in samples], return_order=True)
order = out["order"]
del out["order"]
for task_name, task_dataset in self.multitask_data.items():
if "multitask" not in out:
out["multitask"] = {}
d = [s[task_name] for _, s in samples]
task_target = task_dataset.collater(d)
out["multitask"][task_name] = {
"target": task_target["target"].index_select(0, order),
"target_lengths": task_target["target_lengths"].index_select(0, order),
"ntokens": task_target["ntokens"],
}
out["multitask"][task_name]["net_input"] = {
"prev_output_tokens": task_target["prev_output_tokens"].index_select(
0, order
),
}
return out
class SpeechToSpeechDatasetCreator(object):
# mandatory columns
KEY_ID, KEY_SRC_AUDIO, KEY_SRC_N_FRAMES = "id", "src_audio", "src_n_frames"
KEY_TGT_AUDIO, KEY_TGT_N_FRAMES = "tgt_audio", "tgt_n_frames"
# optional columns
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
# default values
DEFAULT_LANG = ""
@classmethod
def _from_list(
cls,
split_name: str,
is_train_split,
samples: List[Dict],
data_cfg: S2SDataConfig,
target_is_code: bool = False,
tgt_dict: Dictionary = None,
n_frames_per_step: int = 1,
multitask: Optional[Dict] = None,
) -> SpeechToSpeechDataset:
audio_root = Path(data_cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples]
src_audio_paths = [
(audio_root / s[cls.KEY_SRC_AUDIO]).as_posix() for s in samples
]
tgt_audio_paths = [
s[cls.KEY_TGT_AUDIO]
if target_is_code
else (audio_root / s[cls.KEY_TGT_AUDIO]).as_posix()
for s in samples
]
src_n_frames = [int(s[cls.KEY_SRC_N_FRAMES]) for s in samples]
tgt_n_frames = [int(s[cls.KEY_TGT_N_FRAMES]) for s in samples]
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
has_multitask = multitask is not None and len(multitask.keys()) > 0
dataset_cls = (
SpeechToSpeechMultitaskDataset if has_multitask else SpeechToSpeechDataset
)
ds = dataset_cls(
split=split_name,
is_train_split=is_train_split,
data_cfg=data_cfg,
src_audio_paths=src_audio_paths,
src_n_frames=src_n_frames,
tgt_audio_paths=tgt_audio_paths,
tgt_n_frames=tgt_n_frames,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
target_is_code=target_is_code,
tgt_dict=tgt_dict,
n_frames_per_step=n_frames_per_step,
)
if has_multitask:
for task_name, task_obj in multitask.items():
task_data = TextTargetMultitaskData(
task_obj.args, split_name, task_obj.target_dictionary
)
ds.add_multitask_dataset(task_name, task_data)
return ds
@classmethod
def from_tsv(
cls,
root: str,
data_cfg: S2SDataConfig,
splits: str,
is_train_split: bool,
epoch: int,
seed: int,
target_is_code: bool = False,
tgt_dict: Dictionary = None,
n_frames_per_step: int = 1,
multitask: Optional[Dict] = None,
) -> SpeechToSpeechDataset:
datasets = []
for split in splits.split(","):
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(root, split)
ds = cls._from_list(
split_name=split,
is_train_split=is_train_split,
samples=samples,
data_cfg=data_cfg,
target_is_code=target_is_code,
tgt_dict=tgt_dict,
n_frames_per_step=n_frames_per_step,
multitask=multitask,
)
datasets.append(ds)
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]

View File

@@ -0,0 +1,733 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import csv
import logging
import re
from argparse import Namespace
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset
from fairseq.data import data_utils as fairseq_data_utils
from fairseq.data import encoders
from fairseq.data.audio.audio_utils import get_features_or_waveform
from fairseq.data.audio.data_cfg import S2TDataConfig
from fairseq.data.audio.dataset_transforms import CompositeAudioDatasetTransform
from fairseq.data.audio.dataset_transforms.concataugment import ConcatAugment
from fairseq.data.audio.dataset_transforms.noisyoverlapaugment import (
NoisyOverlapAugment,
)
from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform
from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
logger = logging.getLogger(__name__)
def _collate_frames(
frames: List[torch.Tensor], is_audio_input: bool = False
) -> torch.Tensor:
"""
Convert a list of 2D frames into a padded 3D tensor
Args:
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
length of i-th frame and f_dim is static dimension of features
Returns:
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
"""
max_len = max(frame.size(0) for frame in frames)
if is_audio_input:
out = frames[0].new_zeros((len(frames), max_len))
else:
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
for i, v in enumerate(frames):
out[i, : v.size(0)] = v
return out
def _is_int_or_np_int(n):
return isinstance(n, int) or (
isinstance(n, np.generic) and isinstance(n.item(), int)
)
@dataclass
class SpeechToTextDatasetItem(object):
index: int
source: torch.Tensor
target: Optional[torch.Tensor] = None
speaker_id: Optional[int] = None
class SpeechToTextDataset(FairseqDataset):
LANG_TAG_TEMPLATE = "<lang:{}>"
def __init__(
self,
split: str,
is_train_split: bool,
cfg: S2TDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
n_frames_per_step=1,
speaker_to_id=None,
append_eos=True,
):
self.split, self.is_train_split = split, is_train_split
self.cfg = cfg
self.audio_paths, self.n_frames = audio_paths, n_frames
self.n_samples = len(audio_paths)
assert len(n_frames) == self.n_samples > 0
assert src_texts is None or len(src_texts) == self.n_samples
assert tgt_texts is None or len(tgt_texts) == self.n_samples
assert speakers is None or len(speakers) == self.n_samples
assert src_langs is None or len(src_langs) == self.n_samples
assert tgt_langs is None or len(tgt_langs) == self.n_samples
assert ids is None or len(ids) == self.n_samples
assert (tgt_dict is None and tgt_texts is None) or (
tgt_dict is not None and tgt_texts is not None
)
self.src_texts, self.tgt_texts = src_texts, tgt_texts
self.src_langs, self.tgt_langs = src_langs, tgt_langs
self.speakers = speakers
self.tgt_dict = tgt_dict
self.check_tgt_lang_tag()
self.ids = ids
self.shuffle = cfg.shuffle if is_train_split else False
self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
self.cfg.get_feature_transforms(split, is_train_split)
)
self.waveform_transforms = CompositeAudioWaveformTransform.from_config_dict(
self.cfg.get_waveform_transforms(split, is_train_split)
)
# TODO: add these to data_cfg.py
self.dataset_transforms = CompositeAudioDatasetTransform.from_config_dict(
self.cfg.get_dataset_transforms(split, is_train_split)
)
# check proper usage of transforms
if self.feature_transforms and self.cfg.use_audio_input:
logger.warning(
"Feature transforms will not be applied. To use feature transforms, "
"set use_audio_input as False in config."
)
self.pre_tokenizer = pre_tokenizer
self.bpe_tokenizer = bpe_tokenizer
self.n_frames_per_step = n_frames_per_step
self.speaker_to_id = speaker_to_id
self.tgt_lens = self.get_tgt_lens_and_check_oov()
self.append_eos = append_eos
logger.info(self.__repr__())
def get_tgt_lens_and_check_oov(self):
if self.tgt_texts is None:
return [0 for _ in range(self.n_samples)]
tgt_lens = []
n_tokens, n_oov_tokens = 0, 0
for i in range(self.n_samples):
tokenized = self.get_tokenized_tgt_text(i).split(" ")
oov_tokens = [
t
for t in tokenized
if self.tgt_dict.index(t) == self.tgt_dict.unk_index
]
n_tokens += len(tokenized)
n_oov_tokens += len(oov_tokens)
tgt_lens.append(len(tokenized))
logger.info(f"'{self.split}' has {n_oov_tokens / n_tokens * 100:.2f}% OOV")
return tgt_lens
def __repr__(self):
return (
self.__class__.__name__
+ f'(split="{self.split}", n_samples={self.n_samples:_}, '
f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, "
f"n_frames_per_step={self.n_frames_per_step}, "
f"shuffle={self.shuffle}, "
f"feature_transforms={self.feature_transforms}, "
f"waveform_transforms={self.waveform_transforms}, "
f"dataset_transforms={self.dataset_transforms})"
)
@classmethod
def is_lang_tag(cls, token):
pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
return re.match(pattern, token)
def check_tgt_lang_tag(self):
if self.cfg.prepend_tgt_lang_tag:
assert self.tgt_langs is not None and self.tgt_dict is not None
tgt_lang_tags = [
self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
]
assert all(t in self.tgt_dict for t in tgt_lang_tags)
@classmethod
def tokenize(cls, tokenizer, text: str):
return text if tokenizer is None else tokenizer.encode(text)
def get_tokenized_tgt_text(self, index: Union[int, List[int]]):
if _is_int_or_np_int(index):
text = self.tgt_texts[index]
else:
text = " ".join([self.tgt_texts[i] for i in index])
text = self.tokenize(self.pre_tokenizer, text)
text = self.tokenize(self.bpe_tokenizer, text)
return text
def pack_frames(self, feature: torch.Tensor):
if self.n_frames_per_step == 1:
return feature
n_packed_frames = feature.shape[0] // self.n_frames_per_step
feature = feature[: self.n_frames_per_step * n_packed_frames]
return feature.reshape(n_packed_frames, -1)
@classmethod
def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary):
lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang))
assert lang_tag_idx != dictionary.unk()
return lang_tag_idx
def _get_source_audio(self, index: Union[int, List[int]]) -> torch.Tensor:
"""
Gives source audio for given index with any relevant transforms
applied. For ConcatAug, source audios for given indices are
concatenated in given order.
Args:
index (int or List[int]): index—or in the case of ConcatAug,
indices—to pull the source audio for
Returns:
source audios concatenated for given indices with
relevant transforms appplied
"""
if _is_int_or_np_int(index):
source = get_features_or_waveform(
self.audio_paths[index],
need_waveform=self.cfg.use_audio_input,
use_sample_rate=self.cfg.use_sample_rate,
waveform_transforms=self.waveform_transforms,
)
else:
source = np.concatenate(
[
get_features_or_waveform(
self.audio_paths[i],
need_waveform=self.cfg.use_audio_input,
use_sample_rate=self.cfg.use_sample_rate,
waveform_transforms=self.waveform_transforms,
)
for i in index
]
)
if self.cfg.use_audio_input:
source = torch.from_numpy(source).float()
if self.cfg.standardize_audio:
with torch.no_grad():
source = F.layer_norm(source, source.shape)
else:
if self.feature_transforms is not None:
source = self.feature_transforms(source)
source = torch.from_numpy(source).float()
return source
def __getitem__(self, index: int) -> SpeechToTextDatasetItem:
has_concat = self.dataset_transforms.has_transform(ConcatAugment)
if has_concat:
concat = self.dataset_transforms.get_transform(ConcatAugment)
indices = concat.find_indices(index, self.n_frames, self.n_samples)
source = self._get_source_audio(indices if has_concat else index)
source = self.pack_frames(source)
target = None
if self.tgt_texts is not None:
tokenized = self.get_tokenized_tgt_text(indices if has_concat else index)
target = self.tgt_dict.encode_line(
tokenized, add_if_not_exist=False, append_eos=self.append_eos
).long()
if self.cfg.prepend_tgt_lang_tag:
lang_tag_idx = self.get_lang_tag_idx(
self.tgt_langs[index], self.tgt_dict
)
target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
if self.cfg.prepend_bos_and_append_tgt_lang_tag:
bos = torch.LongTensor([self.tgt_dict.bos()])
lang_tag_idx = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
assert lang_tag_idx != self.tgt_dict.unk()
lang_tag_idx = torch.LongTensor([lang_tag_idx])
target = torch.cat((bos, target, lang_tag_idx), 0)
speaker_id = None
if self.speaker_to_id is not None:
speaker_id = self.speaker_to_id[self.speakers[index]]
return SpeechToTextDatasetItem(
index=index, source=source, target=target, speaker_id=speaker_id
)
def __len__(self):
return self.n_samples
def collater(
self, samples: List[SpeechToTextDatasetItem], return_order: bool = False
) -> Dict:
if len(samples) == 0:
return {}
indices = torch.tensor([x.index for x in samples], dtype=torch.long)
sources = [x.source for x in samples]
has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment)
if has_NOAug and self.cfg.use_audio_input:
NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment)
sources = NOAug(sources)
frames = _collate_frames(sources, self.cfg.use_audio_input)
# sort samples by descending number of frames
n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long)
n_frames, order = n_frames.sort(descending=True)
indices = indices.index_select(0, order)
frames = frames.index_select(0, order)
target, target_lengths = None, None
prev_output_tokens = None
ntokens = None
if self.tgt_texts is not None:
target = fairseq_data_utils.collate_tokens(
[x.target for x in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
)
target = target.index_select(0, order)
target_lengths = torch.tensor(
[x.target.size(0) for x in samples], dtype=torch.long
).index_select(0, order)
prev_output_tokens = fairseq_data_utils.collate_tokens(
[x.target for x in samples],
self.tgt_dict.pad(),
eos_idx=None,
left_pad=False,
move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, order)
ntokens = sum(x.target.size(0) for x in samples)
speaker = None
if self.speaker_to_id is not None:
speaker = (
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
.index_select(0, order)
.view(-1, 1)
)
net_input = {
"src_tokens": frames,
"src_lengths": n_frames,
"prev_output_tokens": prev_output_tokens,
}
out = {
"id": indices,
"net_input": net_input,
"speaker": speaker,
"target": target,
"target_lengths": target_lengths,
"ntokens": ntokens,
"nsentences": len(samples),
}
if return_order:
out["order"] = order
return out
def num_tokens(self, index):
return self.n_frames[index]
def size(self, index):
return self.n_frames[index], self.tgt_lens[index]
@property
def sizes(self):
return np.array(self.n_frames)
@property
def can_reuse_epoch_itr_across_epochs(self):
return True
def ordered_indices(self):
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
# first by descending order of # of frames then by original/random order
order.append([-n for n in self.n_frames])
return np.lexsort(order)
def prefetch(self, indices):
raise False
class TextTargetMultitaskData(object):
# mandatory columns
KEY_ID, KEY_TEXT = "id", "tgt_text"
LANG_TAG_TEMPLATE = "<lang:{}>"
def __init__(self, args, split, tgt_dict):
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split)
self.data = {s[self.KEY_ID]: s[self.KEY_TEXT] for s in samples}
self.dict = tgt_dict
self.append_eos = args.decoder_type != "ctc"
self.pre_tokenizer = self.build_tokenizer(args)
self.bpe_tokenizer = self.build_bpe(args)
self.prepend_bos_and_append_tgt_lang_tag = (
args.prepend_bos_and_append_tgt_lang_tag
)
self.eos_token = args.eos_token
self.lang_tag_mapping = args.get_lang_tag_mapping
@classmethod
def is_lang_tag(cls, token):
pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
return re.match(pattern, token)
@classmethod
def tokenize(cls, tokenizer, text: str):
return text if tokenizer is None else tokenizer.encode(text)
def get_tokenized_tgt_text(self, index: int):
text = self.tokenize(self.pre_tokenizer, self.data[index])
text = self.tokenize(self.bpe_tokenizer, text)
return text
def get_lang_tag_idx(self, lang: str, dictionary: Dictionary):
lang_tag = self.LANG_TAG_TEMPLATE.format(lang)
lang_tag = self.lang_tag_mapping.get(lang_tag, lang_tag)
lang_tag_idx = dictionary.index(lang_tag)
assert lang_tag_idx != dictionary.unk(), (lang, lang_tag)
return lang_tag_idx
def build_tokenizer(self, args):
pre_tokenizer = args.config.get("pre_tokenizer")
if pre_tokenizer is not None:
logger.info(f"pre-tokenizer: {pre_tokenizer}")
return encoders.build_tokenizer(Namespace(**pre_tokenizer))
else:
return None
def build_bpe(self, args):
bpe_tokenizer = args.config.get("bpe_tokenizer")
if bpe_tokenizer is not None:
logger.info(f"tokenizer: {bpe_tokenizer}")
return encoders.build_bpe(Namespace(**bpe_tokenizer))
else:
return None
def get(self, sample_id, tgt_lang=None):
if sample_id in self.data:
tokenized = self.get_tokenized_tgt_text(sample_id)
target = self.dict.encode_line(
tokenized,
add_if_not_exist=False,
append_eos=self.append_eos,
)
if self.prepend_bos_and_append_tgt_lang_tag:
bos = torch.LongTensor([self.dict.bos()])
lang_tag_idx = self.get_lang_tag_idx(tgt_lang, self.dict)
assert lang_tag_idx != self.dict.unk()
lang_tag_idx = torch.LongTensor([lang_tag_idx])
target = torch.cat((bos, target, lang_tag_idx), 0)
return target
else:
logger.warning(f"no target for {sample_id}")
return torch.IntTensor([])
def collater(self, samples: List[torch.Tensor]) -> torch.Tensor:
out = fairseq_data_utils.collate_tokens(
samples,
self.dict.pad(),
eos_idx=None,
left_pad=False,
move_eos_to_beginning=False,
).long()
prev_out = fairseq_data_utils.collate_tokens(
samples,
self.dict.pad(),
eos_idx=None,
left_pad=False,
move_eos_to_beginning=True,
).long()
target_lengths = torch.tensor([t.size(0) for t in samples], dtype=torch.long)
ntokens = sum(t.size(0) for t in samples)
output = {
"prev_output_tokens": prev_out,
"target": out,
"target_lengths": target_lengths,
"ntokens": ntokens,
}
return output
class SpeechToTextMultitaskDataset(SpeechToTextDataset):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.multitask_data = {}
def add_multitask_dataset(self, task_name, task_data):
self.multitask_data[task_name] = task_data
def __getitem__(
self, index: int
) -> Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]:
s2t_data = super().__getitem__(index)
multitask_target = {}
sample_id = self.ids[index]
tgt_lang = self.tgt_langs[index]
for task_name, task_dataset in self.multitask_data.items():
multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
return s2t_data, multitask_target
def collater(
self, samples: List[Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]]
) -> Dict:
if len(samples) == 0:
return {}
out = super().collater([s for s, _ in samples], return_order=True)
order = out["order"]
del out["order"]
for task_name, task_dataset in self.multitask_data.items():
if "multitask" not in out:
out["multitask"] = {}
d = [s[task_name] for _, s in samples]
task_target = task_dataset.collater(d)
out["multitask"][task_name] = {
"target": task_target["target"].index_select(0, order),
"target_lengths": task_target["target_lengths"].index_select(0, order),
"ntokens": task_target["ntokens"],
}
out["multitask"][task_name]["net_input"] = {
"prev_output_tokens": task_target["prev_output_tokens"].index_select(
0, order
),
}
return out
class SpeechToTextDatasetCreator(object):
# mandatory columns
KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames"
KEY_TGT_TEXT = "tgt_text"
# optional columns
KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text"
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
# default values
DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = ""
@classmethod
def _from_list(
cls,
split_name: str,
is_train_split,
samples: List[Dict],
cfg: S2TDataConfig,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
multitask: Optional[Dict] = None,
) -> SpeechToTextDataset:
audio_root = Path(cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples]
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
has_multitask = multitask is not None and len(multitask.keys()) > 0
dataset_cls = (
SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset
)
ds = dataset_cls(
split=split_name,
is_train_split=is_train_split,
cfg=cfg,
audio_paths=audio_paths,
n_frames=n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
)
if has_multitask:
for task_name, task_obj in multitask.items():
task_data = TextTargetMultitaskData(
task_obj.args, split_name, task_obj.target_dictionary
)
ds.add_multitask_dataset(task_name, task_data)
return ds
@classmethod
def get_size_ratios(
cls, datasets: List[SpeechToTextDataset], alpha: float = 1.0
) -> List[float]:
"""Size ratios for temperature-based sampling
(https://arxiv.org/abs/1907.05019)"""
id_to_lp, lp_to_sz = {}, defaultdict(int)
for ds in datasets:
lang_pairs = {f"{s}->{t}" for s, t in zip(ds.src_langs, ds.tgt_langs)}
assert len(lang_pairs) == 1
lang_pair = list(lang_pairs)[0]
id_to_lp[ds.split] = lang_pair
lp_to_sz[lang_pair] += sum(ds.n_frames)
sz_sum = sum(v for v in lp_to_sz.values())
lp_to_prob = {k: v / sz_sum for k, v in lp_to_sz.items()}
lp_to_tgt_prob = {k: v**alpha for k, v in lp_to_prob.items()}
prob_sum = sum(v for v in lp_to_tgt_prob.values())
lp_to_tgt_prob = {k: v / prob_sum for k, v in lp_to_tgt_prob.items()}
lp_to_sz_ratio = {
k: (lp_to_tgt_prob[k] * sz_sum) / v for k, v in lp_to_sz.items()
}
size_ratio = [lp_to_sz_ratio[id_to_lp[ds.split]] for ds in datasets]
p_formatted = {
k: f"{lp_to_prob[k]:.3f}->{lp_to_tgt_prob[k]:.3f}" for k in lp_to_sz
}
logger.info(f"sampling probability balancing: {p_formatted}")
sr_formatted = {ds.split: f"{r:.3f}" for ds, r in zip(datasets, size_ratio)}
logger.info(f"balanced sampling size ratio: {sr_formatted}")
return size_ratio
@classmethod
def _load_samples_from_tsv(cls, root: str, split: str):
tsv_path = Path(root) / f"{split}.tsv"
if not tsv_path.is_file():
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
with open(tsv_path) as f:
reader = csv.DictReader(
f,
delimiter="\t",
quotechar=None,
doublequote=False,
lineterminator="\n",
quoting=csv.QUOTE_NONE,
)
samples = [dict(e) for e in reader]
if len(samples) == 0:
raise ValueError(f"Empty manifest: {tsv_path}")
return samples
@classmethod
def _from_tsv(
cls,
root: str,
cfg: S2TDataConfig,
split: str,
tgt_dict,
is_train_split: bool,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
multitask: Optional[Dict] = None,
) -> SpeechToTextDataset:
samples = cls._load_samples_from_tsv(root, split)
return cls._from_list(
split,
is_train_split,
samples,
cfg,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
multitask,
)
@classmethod
def from_tsv(
cls,
root: str,
cfg: S2TDataConfig,
splits: str,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
is_train_split: bool,
epoch: int,
seed: int,
n_frames_per_step: int = 1,
speaker_to_id=None,
multitask: Optional[Dict] = None,
) -> SpeechToTextDataset:
datasets = [
cls._from_tsv(
root=root,
cfg=cfg,
split=split,
tgt_dict=tgt_dict,
is_train_split=is_train_split,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
multitask=multitask,
)
for split in splits.split(",")
]
if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
# temperature-based sampling
size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
datasets = [
ResamplingDataset(
d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
)
for r, d in zip(size_ratios, datasets)
]
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]

View File

@@ -0,0 +1,359 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from pathlib import Path
from typing import Dict, List, NamedTuple, Optional
import torch
from fairseq.data import ConcatDataset, Dictionary, ResamplingDataset
from fairseq.data import data_utils as fairseq_data_utils
from fairseq.data.audio.speech_to_text_dataset import (
S2TDataConfig,
SpeechToTextDataset,
SpeechToTextDatasetCreator,
)
logger = logging.getLogger(__name__)
class S2TJointDataConfig(S2TDataConfig):
"""Wrapper class for data config YAML"""
@property
def src_vocab_filename(self):
"""fairseq vocabulary file under data root"""
return self.config.get("src_vocab_filename", "src_dict.txt")
@property
def src_pre_tokenizer(self) -> Dict:
"""Pre-tokenizer to apply before subword tokenization. Returning
a dictionary with `tokenizer` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
return self.config.get("src_pre_tokenizer", {"tokenizer": None})
@property
def src_bpe_tokenizer(self) -> Dict:
"""Subword tokenizer to apply on source text after pre-tokenization.
Returning a dictionary with `bpe` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
return self.config.get("src_bpe_tokenizer", {"bpe": None})
@property
def prepend_tgt_lang_tag_no_change(self) -> bool:
"""Prepend target lang ID token as the prev_output_tokens BOS (e.g. for
to-many multilingual setting). No change needed during inference.
This option is deprecated and replaced by prepend_tgt_lang_tag_as_bos.
"""
value = self.config.get("prepend_tgt_lang_tag_no_change", None)
if value is None:
return self.config.get("prepend_tgt_lang_tag_as_bos", False)
return value
@property
def sampling_text_alpha(self):
"""Hyper-parameter alpha = 1/T for temperature-based resampling. (text
input only) (alpha = 1 for no resampling)"""
return self.config.get("sampling_text_alpha", 1.0)
class SpeechToTextJointDatasetItem(NamedTuple):
index: int
source: torch.Tensor
target: Optional[torch.Tensor] = None
src_txt_tokens: Optional[torch.Tensor] = None
tgt_lang_tag: Optional[int] = None
src_lang_tag: Optional[int] = None
tgt_alignment: Optional[torch.Tensor] = None
# use_src_lang_id:
# 0: don't use src_lang_id
# 1: attach src_lang_id to the src_txt_tokens as eos
class SpeechToTextJointDataset(SpeechToTextDataset):
def __init__(
self,
split: str,
is_train_split: bool,
cfg: S2TJointDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
src_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
src_pre_tokenizer=None,
src_bpe_tokenizer=None,
append_eos: Optional[bool] = True,
alignment: Optional[List[str]] = None,
use_src_lang_id: Optional[int] = 0,
):
super().__init__(
split,
is_train_split,
cfg,
audio_paths,
n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
append_eos=append_eos,
)
self.src_dict = src_dict
self.src_pre_tokenizer = src_pre_tokenizer
self.src_bpe_tokenizer = src_bpe_tokenizer
self.alignment = None
self.use_src_lang_id = use_src_lang_id
if alignment is not None:
self.alignment = [
[float(s) for s in sample.split()] for sample in alignment
]
def get_tokenized_src_text(self, index: int):
text = self.tokenize(self.src_pre_tokenizer, self.src_texts[index])
text = self.tokenize(self.src_bpe_tokenizer, text)
return text
def __getitem__(self, index: int) -> SpeechToTextJointDatasetItem:
s2t_dataset_item = super().__getitem__(index)
src_tokens = None
src_lang_tag = None
if self.src_texts is not None and self.src_dict is not None:
src_tokens = self.get_tokenized_src_text(index)
src_tokens = self.src_dict.encode_line(
src_tokens, add_if_not_exist=False, append_eos=True
).long()
if self.use_src_lang_id > 0:
src_lang_tag = self.get_lang_tag_idx(
self.src_langs[index], self.src_dict
)
tgt_lang_tag = None
if self.cfg.prepend_tgt_lang_tag_no_change:
# prepend_tgt_lang_tag_no_change: modify prev_output_tokens instead
tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
ali = None
if self.alignment is not None:
ali = torch.Tensor(self.alignment[index]).float()
return SpeechToTextJointDatasetItem(
index=index,
source=s2t_dataset_item.source,
target=s2t_dataset_item.target,
src_txt_tokens=src_tokens,
tgt_lang_tag=tgt_lang_tag,
src_lang_tag=src_lang_tag,
tgt_alignment=ali,
)
def __len__(self):
return self.n_samples
def collater(self, samples: List[SpeechToTextJointDatasetItem]) -> Dict:
s2t_out = super().collater(samples, return_order=True)
if s2t_out == {}:
return s2t_out
net_input, order = s2t_out["net_input"], s2t_out["order"]
if self.src_texts is not None and self.src_dict is not None:
src_txt_tokens = fairseq_data_utils.collate_tokens(
[x.src_txt_tokens for x in samples],
self.src_dict.pad(),
self.src_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
)
src_txt_lengths = torch.tensor(
[x.src_txt_tokens.size()[0] for x in samples], dtype=torch.long
)
if self.use_src_lang_id > 0:
src_lang_idxs = torch.tensor(
[s.src_lang_tag for s in samples], dtype=src_txt_tokens.dtype
)
if self.use_src_lang_id == 1: # replace eos with lang_id
eos_idx = src_txt_lengths - 1
src_txt_tokens.scatter_(
1, eos_idx.view(-1, 1), src_lang_idxs.view(-1, 1)
)
else:
raise NotImplementedError("Implementation is required")
src_txt_tokens = src_txt_tokens.index_select(0, order)
src_txt_lengths = src_txt_lengths.index_select(0, order)
net_input["src_txt_tokens"] = src_txt_tokens
net_input["src_txt_lengths"] = src_txt_lengths
net_input["alignment"] = None
if self.alignment is not None:
max_len = max([s.tgt_alignment.size(0) for s in samples])
alignment = torch.ones(len(samples), max_len).float()
for i, s in enumerate(samples):
cur_len = s.tgt_alignment.size(0)
alignment[i][:cur_len].copy_(s.tgt_alignment)
net_input["alignment"] = alignment.index_select(0, order)
if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
for i in range(len(samples)):
net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
out = {
"id": s2t_out["id"],
"net_input": net_input,
"target": s2t_out["target"],
"target_lengths": s2t_out["target_lengths"],
"ntokens": s2t_out["ntokens"],
"nsentences": len(samples),
}
return out
class SpeechToTextJointDatasetCreator(SpeechToTextDatasetCreator):
KEY_ALIGN = "align"
@classmethod
def _from_list(
cls,
split_name: str,
is_train_split,
samples: List[Dict],
cfg: S2TJointDataConfig,
tgt_dict,
src_dict,
pre_tokenizer,
bpe_tokenizer,
src_pre_tokenizer,
src_bpe_tokenizer,
append_eos,
use_src_lang_id,
) -> SpeechToTextJointDataset:
audio_root = Path(cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples]
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_alignment = None
if cls.KEY_ALIGN in samples[0].keys():
tgt_alignment = [s[cls.KEY_ALIGN] for s in samples]
return SpeechToTextJointDataset(
split_name,
is_train_split,
cfg,
audio_paths,
n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
src_dict=src_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
src_pre_tokenizer=src_pre_tokenizer,
src_bpe_tokenizer=src_bpe_tokenizer,
append_eos=append_eos,
alignment=tgt_alignment,
use_src_lang_id=use_src_lang_id,
)
@classmethod
def _from_tsv(
cls,
root: str,
cfg: S2TJointDataConfig,
split: str,
tgt_dict,
src_dict,
is_train_split: bool,
pre_tokenizer,
bpe_tokenizer,
src_pre_tokenizer,
src_bpe_tokenizer,
append_eos: bool,
use_src_lang_id: int,
) -> SpeechToTextJointDataset:
samples = cls._load_samples_from_tsv(root, split)
return cls._from_list(
split,
is_train_split,
samples,
cfg,
tgt_dict,
src_dict,
pre_tokenizer,
bpe_tokenizer,
src_pre_tokenizer,
src_bpe_tokenizer,
append_eos,
use_src_lang_id,
)
@classmethod
def from_tsv(
cls,
root: str,
cfg: S2TJointDataConfig,
splits: str,
tgt_dict,
src_dict,
pre_tokenizer,
bpe_tokenizer,
src_pre_tokenizer,
src_bpe_tokenizer,
is_train_split: bool,
epoch: int,
seed: int,
append_eos: Optional[bool] = True,
use_src_lang_id: Optional[int] = 0,
) -> SpeechToTextJointDataset:
datasets = [
cls._from_tsv(
root,
cfg,
split,
tgt_dict,
src_dict,
is_train_split,
pre_tokenizer,
bpe_tokenizer,
src_pre_tokenizer,
src_bpe_tokenizer,
append_eos=append_eos,
use_src_lang_id=use_src_lang_id,
)
for split in splits.split(",")
]
if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
# temperature-based sampling
size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
datasets = [
ResamplingDataset(
d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
)
for r, d in zip(size_ratios, datasets)
]
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]

View File

@@ -0,0 +1,250 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.abs
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional
import numpy as np
import torch
from fairseq.data import Dictionary
from fairseq.data import data_utils as fairseq_data_utils
from fairseq.data.audio.audio_utils import get_features_or_waveform
from fairseq.data.audio.speech_to_text_dataset import (
S2TDataConfig,
SpeechToTextDataset,
SpeechToTextDatasetCreator,
_collate_frames,
)
@dataclass
class TextToSpeechDatasetItem(object):
index: int
source: torch.Tensor
target: Optional[torch.Tensor] = None
speaker_id: Optional[int] = None
duration: Optional[torch.Tensor] = None
pitch: Optional[torch.Tensor] = None
energy: Optional[torch.Tensor] = None
class TextToSpeechDataset(SpeechToTextDataset):
def __init__(
self,
split: str,
is_train_split: bool,
cfg: S2TDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
n_frames_per_step=1,
speaker_to_id=None,
durations: Optional[List[List[int]]] = None,
pitches: Optional[List[str]] = None,
energies: Optional[List[str]] = None,
):
super(TextToSpeechDataset, self).__init__(
split,
is_train_split,
cfg,
audio_paths,
n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
)
self.durations = durations
self.pitches = pitches
self.energies = energies
def __getitem__(self, index: int) -> TextToSpeechDatasetItem:
s2t_item = super().__getitem__(index)
duration, pitch, energy = None, None, None
if self.durations is not None:
duration = torch.tensor(
self.durations[index] + [0], dtype=torch.long # pad 0 for EOS
)
if self.pitches is not None:
pitch = get_features_or_waveform(self.pitches[index])
pitch = torch.from_numpy(
np.concatenate((pitch, [0])) # pad 0 for EOS
).float()
if self.energies is not None:
energy = get_features_or_waveform(self.energies[index])
energy = torch.from_numpy(
np.concatenate((energy, [0])) # pad 0 for EOS
).float()
return TextToSpeechDatasetItem(
index=index,
source=s2t_item.source,
target=s2t_item.target,
speaker_id=s2t_item.speaker_id,
duration=duration,
pitch=pitch,
energy=energy,
)
def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]:
if len(samples) == 0:
return {}
src_lengths, order = torch.tensor(
[s.target.shape[0] for s in samples], dtype=torch.long
).sort(descending=True)
id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select(
0, order
)
feat = _collate_frames(
[s.source for s in samples], self.cfg.use_audio_input
).index_select(0, order)
target_lengths = torch.tensor(
[s.source.shape[0] for s in samples], dtype=torch.long
).index_select(0, order)
src_tokens = fairseq_data_utils.collate_tokens(
[s.target for s in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
).index_select(0, order)
speaker = None
if self.speaker_to_id is not None:
speaker = (
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
.index_select(0, order)
.view(-1, 1)
)
bsz, _, d = feat.size()
prev_output_tokens = torch.cat(
(feat.new_zeros((bsz, 1, d)), feat[:, :-1, :]), dim=1
)
durations, pitches, energies = None, None, None
if self.durations is not None:
durations = fairseq_data_utils.collate_tokens(
[s.duration for s in samples], 0
).index_select(0, order)
assert src_tokens.shape[1] == durations.shape[1]
if self.pitches is not None:
pitches = _collate_frames([s.pitch for s in samples], True)
pitches = pitches.index_select(0, order)
assert src_tokens.shape[1] == pitches.shape[1]
if self.energies is not None:
energies = _collate_frames([s.energy for s in samples], True)
energies = energies.index_select(0, order)
assert src_tokens.shape[1] == energies.shape[1]
src_texts = [self.tgt_dict.string(samples[i].target) for i in order]
return {
"id": id_,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"prev_output_tokens": prev_output_tokens,
},
"speaker": speaker,
"target": feat,
"durations": durations,
"pitches": pitches,
"energies": energies,
"target_lengths": target_lengths,
"ntokens": sum(target_lengths).item(),
"nsentences": len(samples),
"src_texts": src_texts,
}
class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator):
KEY_DURATION = "duration"
KEY_PITCH = "pitch"
KEY_ENERGY = "energy"
@classmethod
def _from_list(
cls,
split_name: str,
is_train_split,
samples: List[Dict],
cfg: S2TDataConfig,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
multitask=None,
) -> TextToSpeechDataset:
audio_root = Path(cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples]
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
durations = [s.get(cls.KEY_DURATION, None) for s in samples]
durations = [
None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations
]
durations = None if any(dd is None for dd in durations) else durations
pitches = [s.get(cls.KEY_PITCH, None) for s in samples]
pitches = [
None if pp is None else (audio_root / pp).as_posix() for pp in pitches
]
pitches = None if any(pp is None for pp in pitches) else pitches
energies = [s.get(cls.KEY_ENERGY, None) for s in samples]
energies = [
None if ee is None else (audio_root / ee).as_posix() for ee in energies
]
energies = None if any(ee is None for ee in energies) else energies
return TextToSpeechDataset(
split_name,
is_train_split,
cfg,
audio_paths,
n_frames,
src_texts,
tgt_texts,
speakers,
src_langs,
tgt_langs,
ids,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
durations,
pitches,
energies,
)

View File

@@ -0,0 +1,48 @@
import os
from fairseq.data.audio import (
AudioTransform,
CompositeAudioTransform,
import_transforms,
register_audio_transform,
)
class AudioWaveformTransform(AudioTransform):
pass
AUDIO_WAVEFORM_TRANSFORM_REGISTRY = {}
AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES = set()
def get_audio_waveform_transform(name):
return AUDIO_WAVEFORM_TRANSFORM_REGISTRY[name]
def register_audio_waveform_transform(name):
return register_audio_transform(
name,
AudioWaveformTransform,
AUDIO_WAVEFORM_TRANSFORM_REGISTRY,
AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES,
)
import_transforms(os.path.dirname(__file__), "waveform")
class CompositeAudioWaveformTransform(CompositeAudioTransform):
@classmethod
def from_config_dict(cls, config=None):
return super()._from_config_dict(
cls,
"waveform",
get_audio_waveform_transform,
CompositeAudioWaveformTransform,
config,
)
def __call__(self, x, sample_rate):
for t in self.transforms:
x, sample_rate = t(x, sample_rate)
return x, sample_rate

View File

@@ -0,0 +1,201 @@
from pathlib import Path
import numpy as np
from math import ceil
from fairseq.data.audio import rand_uniform
from fairseq.data.audio.waveform_transforms import (
AudioWaveformTransform,
register_audio_waveform_transform,
)
SNR_MIN = 5.0
SNR_MAX = 15.0
RATE = 0.25
NOISE_RATE = 1.0
NOISE_LEN_MEAN = 0.2
NOISE_LEN_STD = 0.05
class NoiseAugmentTransform(AudioWaveformTransform):
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return cls(
_config.get("samples_path", None),
_config.get("snr_min", SNR_MIN),
_config.get("snr_max", SNR_MAX),
_config.get("rate", RATE),
)
def __init__(
self,
samples_path: str,
snr_min: float = SNR_MIN,
snr_max: float = SNR_MAX,
rate: float = RATE,
):
# Sanity checks
assert (
samples_path
), "need to provide path to audio samples for noise augmentation"
assert snr_max >= snr_min, f"empty signal-to-noise range ({snr_min}, {snr_max})"
assert rate >= 0 and rate <= 1, "rate should be a float between 0 to 1"
self.paths = list(Path(samples_path).glob("**/*.wav")) # load music
self.n_samples = len(self.paths)
assert self.n_samples > 0, f"no audio files found in {samples_path}"
self.snr_min = snr_min
self.snr_max = snr_max
self.rate = rate
def __repr__(self):
return (
self.__class__.__name__
+ "("
+ ", ".join(
[
f"n_samples={self.n_samples}",
f"snr={self.snr_min}-{self.snr_max}dB",
f"rate={self.rate}",
]
)
+ ")"
)
def pick_sample(self, goal_shape, always_2d=False, use_sample_rate=None):
from fairseq.data.audio.audio_utils import get_waveform
path = self.paths[np.random.randint(0, self.n_samples)]
sample = get_waveform(
path, always_2d=always_2d, output_sample_rate=use_sample_rate
)[0]
# Check dimensions match, else silently skip adding noise to sample
# NOTE: SHOULD THIS QUIT WITH AN ERROR?
is_2d = len(goal_shape) == 2
if len(goal_shape) != sample.ndim or (
is_2d and goal_shape[0] != sample.shape[0]
):
return np.zeros(goal_shape)
# Cut/repeat sample to size
len_dim = len(goal_shape) - 1
n_repeat = ceil(goal_shape[len_dim] / sample.shape[len_dim])
repeated = np.tile(sample, [1, n_repeat] if is_2d else n_repeat)
start = np.random.randint(0, repeated.shape[len_dim] - goal_shape[len_dim] + 1)
return (
repeated[:, start : start + goal_shape[len_dim]]
if is_2d
else repeated[start : start + goal_shape[len_dim]]
)
def _mix(self, source, noise, snr):
get_power = lambda x: np.mean(x**2)
if get_power(noise):
scl = np.sqrt(
get_power(source) / (np.power(10, snr / 10) * get_power(noise))
)
else:
scl = 0
return 1 * source + scl * noise
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
return self.pick_sample(goal_shape, always_2d, use_sample_rate)
def __call__(self, source, sample_rate):
if np.random.random() > self.rate:
return source, sample_rate
noise = self._get_noise(
source.shape, always_2d=True, use_sample_rate=sample_rate
)
return (
self._mix(source, noise, rand_uniform(self.snr_min, self.snr_max)),
sample_rate,
)
@register_audio_waveform_transform("musicaugment")
class MusicAugmentTransform(NoiseAugmentTransform):
pass
@register_audio_waveform_transform("backgroundnoiseaugment")
class BackgroundNoiseAugmentTransform(NoiseAugmentTransform):
pass
@register_audio_waveform_transform("babbleaugment")
class BabbleAugmentTransform(NoiseAugmentTransform):
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
for i in range(np.random.randint(3, 8)):
speech = self.pick_sample(goal_shape, always_2d, use_sample_rate)
if i == 0:
agg_noise = speech
else: # SNR scaled by i (how many noise signals already in agg_noise)
agg_noise = self._mix(agg_noise, speech, i)
return agg_noise
@register_audio_waveform_transform("sporadicnoiseaugment")
class SporadicNoiseAugmentTransform(NoiseAugmentTransform):
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return cls(
_config.get("samples_path", None),
_config.get("snr_min", SNR_MIN),
_config.get("snr_max", SNR_MAX),
_config.get("rate", RATE),
_config.get("noise_rate", NOISE_RATE),
_config.get("noise_len_mean", NOISE_LEN_MEAN),
_config.get("noise_len_std", NOISE_LEN_STD),
)
def __init__(
self,
samples_path: str,
snr_min: float = SNR_MIN,
snr_max: float = SNR_MAX,
rate: float = RATE,
noise_rate: float = NOISE_RATE, # noises per second
noise_len_mean: float = NOISE_LEN_MEAN, # length of noises in seconds
noise_len_std: float = NOISE_LEN_STD,
):
super().__init__(samples_path, snr_min, snr_max, rate)
self.noise_rate = noise_rate
self.noise_len_mean = noise_len_mean
self.noise_len_std = noise_len_std
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
agg_noise = np.zeros(goal_shape)
len_dim = len(goal_shape) - 1
is_2d = len(goal_shape) == 2
n_noises = round(self.noise_rate * goal_shape[len_dim] / use_sample_rate)
start_pointers = [
round(rand_uniform(0, goal_shape[len_dim])) for _ in range(n_noises)
]
for start_pointer in start_pointers:
noise_shape = list(goal_shape)
len_seconds = np.random.normal(self.noise_len_mean, self.noise_len_std)
noise_shape[len_dim] = round(max(0, len_seconds) * use_sample_rate)
end_pointer = start_pointer + noise_shape[len_dim]
if end_pointer >= goal_shape[len_dim]:
continue
noise = self.pick_sample(noise_shape, always_2d, use_sample_rate)
if is_2d:
agg_noise[:, start_pointer:end_pointer] = (
agg_noise[:, start_pointer:end_pointer] + noise
)
else:
agg_noise[start_pointer:end_pointer] = (
agg_noise[start_pointer:end_pointer] + noise
)
return agg_noise

View File

@@ -0,0 +1,165 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq import utils
from . import FairseqDataset
def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
"""Backtranslate a list of samples.
Given an input (*samples*) of the form:
[{'id': 1, 'source': 'hallo welt'}]
this will return:
[{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}]
Args:
samples (List[dict]): samples to backtranslate. Individual samples are
expected to have a 'source' key, which will become the 'target'
after backtranslation.
collate_fn (callable): function to collate samples into a mini-batch
generate_fn (callable): function to generate backtranslations
cuda (bool): use GPU for generation (default: ``True``)
Returns:
List[dict]: an updated list of samples with a backtranslated source
"""
collated_samples = collate_fn(samples)
s = utils.move_to_cuda(collated_samples) if cuda else collated_samples
generated_sources = generate_fn(s)
id_to_src = {sample["id"]: sample["source"] for sample in samples}
# Go through each tgt sentence in batch and its corresponding best
# generated hypothesis and create a backtranslation data pair
# {id: id, source: generated backtranslation, target: original tgt}
return [
{
"id": id.item(),
"target": id_to_src[id.item()],
"source": hypos[0]["tokens"].cpu(),
}
for id, hypos in zip(collated_samples["id"], generated_sources)
]
class BacktranslationDataset(FairseqDataset):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation function (*backtranslation_fn*),
and returns the corresponding `{generated src, input tgt}` batch.
Args:
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
backtranslated. Only the source side of this dataset will be used.
After backtranslation, the source sentences in this dataset will be
returned as the targets.
src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated
sentences.
tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of
sentences to be backtranslated.
backtranslation_fn (callable, optional): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
Pass in None when it is not available at initialization time, and
use set_backtranslation_fn function to set it when available.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch
(default: ``tgt_dataset.collater``).
cuda: use GPU for generation
"""
def __init__(
self,
tgt_dataset,
src_dict,
tgt_dict=None,
backtranslation_fn=None,
output_collater=None,
cuda=True,
**kwargs
):
self.tgt_dataset = tgt_dataset
self.backtranslation_fn = backtranslation_fn
self.output_collater = (
output_collater if output_collater is not None else tgt_dataset.collater
)
self.cuda = cuda if torch.cuda.is_available() else False
self.src_dict = src_dict
self.tgt_dict = tgt_dict
def __getitem__(self, index):
"""
Returns a single sample from *tgt_dataset*. Note that backtranslation is
not applied in this step; use :func:`collater` instead to backtranslate
a batch of samples.
"""
return self.tgt_dataset[index]
def __len__(self):
return len(self.tgt_dataset)
def set_backtranslation_fn(self, backtranslation_fn):
self.backtranslation_fn = backtranslation_fn
def collater(self, samples):
"""Merge and backtranslate a list of samples to form a mini-batch.
Using the samples from *tgt_dataset*, load a collated target sample to
feed to the backtranslation model. Then take the backtranslation with
the best score as the source and the original input as the target.
Note: we expect *tgt_dataset* to provide a function `collater()` that
will collate samples into the format expected by *backtranslation_fn*.
After backtranslation, we will feed the new list of samples (i.e., the
`(backtranslated source, original source)` pairs) to *output_collater*
and return the result.
Args:
samples (List[dict]): samples to backtranslate and collate
Returns:
dict: a mini-batch with keys coming from *output_collater*
"""
if samples[0].get("is_dummy", False):
return samples
samples = backtranslate_samples(
samples=samples,
collate_fn=self.tgt_dataset.collater,
generate_fn=(lambda net_input: self.backtranslation_fn(net_input)),
cuda=self.cuda,
)
return self.output_collater(samples)
def num_tokens(self, index):
"""Just use the tgt dataset num_tokens"""
return self.tgt_dataset.num_tokens(index)
def ordered_indices(self):
"""Just use the tgt dataset ordered_indices"""
return self.tgt_dataset.ordered_indices()
def size(self, index):
"""Return an example's size as a float or tuple. This value is used
when filtering a dataset with ``--max-positions``.
Note: we use *tgt_dataset* to approximate the length of the source
sentence, since we do not know the actual length until after
backtranslation.
"""
tgt_size = self.tgt_dataset.size(index)[0]
return (tgt_size, tgt_size)
@property
def supports_prefetch(self):
return getattr(self.tgt_dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.tgt_dataset.prefetch(indices)

View File

@@ -0,0 +1,78 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from torch.utils.data.dataloader import default_collate
from . import FairseqDataset
class BaseWrapperDataset(FairseqDataset):
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples):
if hasattr(self.dataset, "collater"):
return self.dataset.collater(samples)
else:
return default_collate(samples)
@property
def sizes(self):
return self.dataset.sizes
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
return self.dataset.size(index)
def ordered_indices(self):
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def attr(self, attr: str, index: int):
return self.dataset.attr(attr, index)
def prefetch(self, indices):
self.dataset.prefetch(indices)
def get_batch_shapes(self):
return self.dataset.get_batch_shapes()
def batch_by_size(
self,
indices,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
):
return self.dataset.batch_by_size(
indices,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
def filter_indices_by_size(self, indices, max_sizes):
return self.dataset.filter_indices_by_size(indices, max_sizes)
@property
def can_reuse_epoch_itr_across_epochs(self):
return self.dataset.can_reuse_epoch_itr_across_epochs
def set_epoch(self, epoch):
super().set_epoch(epoch)
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)

View File

@@ -0,0 +1,78 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch.nn.functional as F
from fairseq.data import BaseWrapperDataset
from fairseq.data.data_utils import get_buckets, get_bucketed_sizes
class BucketPadLengthDataset(BaseWrapperDataset):
"""
Bucket and pad item lengths to the nearest bucket size. This can be used to
reduce the number of unique batch shapes, which is important on TPUs since
each new batch shape requires a recompilation.
Args:
dataset (FairseqDatset): dataset to bucket
sizes (List[int]): all item sizes
num_buckets (int): number of buckets to create
pad_idx (int): padding symbol
left_pad (bool): if True, pad on the left; otherwise right pad
"""
def __init__(
self,
dataset,
sizes,
num_buckets,
pad_idx,
left_pad,
tensor_key=None,
):
super().__init__(dataset)
self.pad_idx = pad_idx
self.left_pad = left_pad
assert num_buckets > 0
self.buckets = get_buckets(sizes, num_buckets)
self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
self._tensor_key = tensor_key
def _set_tensor(self, item, val):
if self._tensor_key is None:
return val
item[self._tensor_key] = val
return item
def _get_tensor(self, item):
if self._tensor_key is None:
return item
return item[self._tensor_key]
def _pad(self, tensor, bucket_size, dim=-1):
num_pad = bucket_size - tensor.size(dim)
return F.pad(
tensor,
(num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad),
value=self.pad_idx,
)
def __getitem__(self, index):
item = self.dataset[index]
bucket_size = self._bucketed_sizes[index]
tensor = self._get_tensor(item)
padded = self._pad(tensor, bucket_size)
return self._set_tensor(item, padded)
@property
def sizes(self):
return self._bucketed_sizes
def num_tokens(self, index):
return self._bucketed_sizes[index]
def size(self, index):
return self._bucketed_sizes[index]

View File

@@ -0,0 +1,576 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import logging
import os
import random
from pathlib import Path
import numpy as np
import torch
import torch.utils.data
from . import data_utils
from fairseq.data.fairseq_dataset import FairseqDataset
F0_FRAME_SPACE = 0.005 # sec
logger = logging.getLogger(__name__)
class ExpressiveCodeDataConfig(object):
def __init__(self, json_path):
with open(json_path, "r") as f:
self.config = json.load(f)
self._manifests = self.config["manifests"]
@property
def manifests(self):
return self._manifests
@property
def n_units(self):
return self.config["n_units"]
@property
def sampling_rate(self):
return self.config["sampling_rate"]
@property
def code_hop_size(self):
return self.config["code_hop_size"]
@property
def f0_stats(self):
"""pre-computed f0 statistics path"""
return self.config.get("f0_stats", None)
@property
def f0_vq_type(self):
"""naive or precomp"""
return self.config["f0_vq_type"]
@property
def f0_vq_name(self):
return self.config["f0_vq_name"]
def get_f0_vq_naive_quantizer(self, log, norm_mean, norm_std):
key = "log" if log else "linear"
if norm_mean and norm_std:
key += "_mean_std_norm"
elif norm_mean:
key += "_mean_norm"
else:
key += "_none_norm"
return self.config["f0_vq_naive_quantizer"][key]
@property
def f0_vq_n_units(self):
return self.config["f0_vq_n_units"]
@property
def multispkr(self):
"""how to parse speaker label from audio path"""
return self.config.get("multispkr", None)
def get_f0(audio, rate=16000):
try:
import amfm_decompy.basic_tools as basic
import amfm_decompy.pYAAPT as pYAAPT
from librosa.util import normalize
except ImportError:
raise "Please install amfm_decompy (`pip install AMFM-decompy`) and librosa (`pip install librosa`)."
assert audio.ndim == 1
frame_length = 20.0 # ms
to_pad = int(frame_length / 1000 * rate) // 2
audio = normalize(audio) * 0.95
audio = np.pad(audio, (to_pad, to_pad), "constant", constant_values=0)
audio = basic.SignalObj(audio, rate)
pitch = pYAAPT.yaapt(
audio,
frame_length=frame_length,
frame_space=F0_FRAME_SPACE * 1000,
nccf_thresh1=0.25,
tda_frame_length=25.0,
)
f0 = pitch.samp_values
return f0
def interpolate_f0(f0):
try:
from scipy.interpolate import interp1d
except ImportError:
raise "Please install scipy (`pip install scipy`)"
orig_t = np.arange(f0.shape[0])
f0_interp = f0[:]
ii = f0_interp != 0
if ii.sum() > 1:
f0_interp = interp1d(
orig_t[ii], f0_interp[ii], bounds_error=False, kind="linear", fill_value=0
)(orig_t)
f0_interp = torch.Tensor(f0_interp).type_as(f0).to(f0.device)
return f0_interp
def naive_quantize(x, edges):
bin_idx = (x.view(-1, 1) > edges.view(1, -1)).long().sum(dim=1)
return bin_idx
def load_wav(full_path):
try:
import soundfile as sf
except ImportError:
raise "Please install soundfile (`pip install SoundFile`)"
data, sampling_rate = sf.read(full_path)
return data, sampling_rate
def parse_code(code_str, dictionary, append_eos):
code, duration = torch.unique_consecutive(
torch.ShortTensor(list(map(int, code_str.split()))), return_counts=True
)
code = " ".join(map(str, code.tolist()))
code = dictionary.encode_line(code, append_eos).short()
if append_eos:
duration = torch.cat((duration, duration.new_zeros((1,))), dim=0) # eos
duration = duration.short()
return code, duration
def parse_manifest(manifest, dictionary):
audio_files = []
codes = []
durations = []
speakers = []
with open(manifest) as info:
for line in info.readlines():
sample = eval(line.strip())
if "cpc_km100" in sample:
k = "cpc_km100"
elif "hubert_km100" in sample:
k = "hubert_km100"
elif "phone" in sample:
k = "phone"
else:
assert False, "unknown format"
code = sample[k]
code, duration = parse_code(code, dictionary, append_eos=True)
codes.append(code)
durations.append(duration)
audio_files.append(sample["audio"])
speakers.append(sample.get("speaker", None))
return audio_files, codes, durations, speakers
def parse_speaker(path, method):
if type(path) == str:
path = Path(path)
if method == "parent_name":
return path.parent.name
elif method == "parent_parent_name":
return path.parent.parent.name
elif method == "_":
return path.name.split("_")[0]
elif method == "single":
return "A"
elif callable(method):
return method(path)
else:
raise NotImplementedError()
def get_f0_by_filename(filename, tgt_sampling_rate):
audio, sampling_rate = load_wav(filename)
if sampling_rate != tgt_sampling_rate:
raise ValueError(
"{} SR doesn't match target {} SR".format(sampling_rate, tgt_sampling_rate)
)
# compute un-interpolated f0, and use Ann's interp in __getitem__ if set
f0 = get_f0(audio, rate=tgt_sampling_rate)
f0 = torch.from_numpy(f0.astype(np.float32))
return f0
def align_f0_to_durations(f0, durations, f0_code_ratio, tol=1):
code_len = durations.sum()
targ_len = int(f0_code_ratio * code_len)
diff = f0.size(0) - targ_len
assert abs(diff) <= tol, (
f"Cannot subsample F0: |{f0.size(0)} - {f0_code_ratio}*{code_len}|"
f" > {tol} (dur=\n{durations})"
)
if diff > 0:
f0 = f0[:targ_len]
elif diff < 0:
f0 = torch.cat((f0, f0.new_full((-diff,), f0[-1])), 0)
f0_offset = 0.0
seg_f0s = []
for dur in durations:
f0_dur = dur.item() * f0_code_ratio
seg_f0 = f0[int(f0_offset) : int(f0_offset + f0_dur)]
seg_f0 = seg_f0[seg_f0 != 0]
if len(seg_f0) == 0:
seg_f0 = torch.tensor(0).type(seg_f0.type())
else:
seg_f0 = seg_f0.mean()
seg_f0s.append(seg_f0)
f0_offset += f0_dur
assert int(f0_offset) == f0.size(0), f"{f0_offset} {f0.size()} {durations.sum()}"
return torch.tensor(seg_f0s)
class Paddings(object):
def __init__(self, code_val, dur_val=0, f0_val=-2.0):
self.code = code_val
self.dur = dur_val
self.f0 = f0_val
class Shifts(object):
def __init__(self, shifts_str, pads):
self._shifts = list(map(int, shifts_str.split(",")))
assert len(self._shifts) == 2, self._shifts
assert all(s >= 0 for s in self._shifts)
self.extra_length = max(s for s in self._shifts)
self.pads = pads
@property
def dur(self):
return self._shifts[0]
@property
def f0(self):
return self._shifts[1]
@staticmethod
def shift_one(seq, left_pad_num, right_pad_num, pad):
assert seq.ndim == 1
bos = seq.new_full((left_pad_num,), pad)
eos = seq.new_full((right_pad_num,), pad)
seq = torch.cat([bos, seq, eos])
mask = torch.ones_like(seq).bool()
mask[left_pad_num : len(seq) - right_pad_num] = 0
return seq, mask
def __call__(self, code, dur, f0):
if self.extra_length == 0:
code_mask = torch.zeros_like(code).bool()
dur_mask = torch.zeros_like(dur).bool()
f0_mask = torch.zeros_like(f0).bool()
return code, code_mask, dur, dur_mask, f0, f0_mask
code, code_mask = self.shift_one(code, 0, self.extra_length, self.pads.code)
dur, dur_mask = self.shift_one(
dur, self.dur, self.extra_length - self.dur, self.pads.dur
)
f0, f0_mask = self.shift_one(
f0, self.f0, self.extra_length - self.f0, self.pads.f0
)
return code, code_mask, dur, dur_mask, f0, f0_mask
class CodeDataset(FairseqDataset):
def __init__(
self,
manifest,
dictionary,
dur_dictionary,
f0_dictionary,
config,
discrete_dur,
discrete_f0,
log_f0,
normalize_f0_mean,
normalize_f0_std,
interpolate_f0,
return_filename=False,
strip_filename=True,
shifts="0,0",
return_continuous_f0=False,
):
random.seed(1234)
self.dictionary = dictionary
self.dur_dictionary = dur_dictionary
self.f0_dictionary = f0_dictionary
self.config = config
# duration config
self.discrete_dur = discrete_dur
# pitch config
self.discrete_f0 = discrete_f0
self.log_f0 = log_f0
self.normalize_f0_mean = normalize_f0_mean
self.normalize_f0_std = normalize_f0_std
self.interpolate_f0 = interpolate_f0
self.return_filename = return_filename
self.strip_filename = strip_filename
self.f0_code_ratio = config.code_hop_size / (
config.sampling_rate * F0_FRAME_SPACE
)
# use lazy loading to avoid sharing file handlers across workers
self.manifest = manifest
self._codes = None
self._durs = None
self._f0s = None
with open(f"{manifest}.leng.txt", "r") as f:
lengs = [int(line.rstrip()) for line in f]
edges = np.cumsum([0] + lengs)
self.starts, self.ends = edges[:-1], edges[1:]
with open(f"{manifest}.path.txt", "r") as f:
self.file_names = [line.rstrip() for line in f]
logger.info(f"num entries: {len(self.starts)}")
if os.path.exists(f"{manifest}.f0_stat.pt"):
self.f0_stats = torch.load(f"{manifest}.f0_stat.pt")
elif config.f0_stats:
self.f0_stats = torch.load(config.f0_stats)
self.multispkr = config.multispkr
if config.multispkr:
with open(f"{manifest}.speaker.txt", "r") as f:
self.spkrs = [line.rstrip() for line in f]
self.id_to_spkr = sorted(self.spkrs)
self.spkr_to_id = {k: v for v, k in enumerate(self.id_to_spkr)}
self.pads = Paddings(
dictionary.pad(),
0, # use 0 for duration padding
f0_dictionary.pad() if discrete_f0 else -5.0,
)
self.shifts = Shifts(shifts, pads=self.pads)
self.return_continuous_f0 = return_continuous_f0
def get_data_handlers(self):
logging.info(f"loading data for {self.manifest}")
self._codes = np.load(f"{self.manifest}.code.npy", mmap_mode="r")
self._durs = np.load(f"{self.manifest}.dur.npy", mmap_mode="r")
if self.discrete_f0:
if self.config.f0_vq_type == "precomp":
self._f0s = np.load(
f"{self.manifest}.{self.config.f0_vq_name}.npy", mmap_mode="r"
)
elif self.config.f0_vq_type == "naive":
self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
quantizers_path = self.config.get_f0_vq_naive_quantizer(
self.log_f0, self.normalize_f0_mean, self.normalize_f0_std
)
quantizers = torch.load(quantizers_path)
n_units = self.config.f0_vq_n_units
self._f0_quantizer = torch.from_numpy(quantizers[n_units])
else:
raise ValueError(f"f0_vq_type {self.config.f0_vq_type} not supported")
else:
self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
def preprocess_f0(self, f0, stats):
"""
1. interpolate
2. log transform (keep unvoiced frame 0)
"""
# TODO: change this to be dependent on config for naive quantizer
f0 = f0.clone()
if self.interpolate_f0:
f0 = interpolate_f0(f0)
mask = f0 != 0 # only process voiced frames
if self.log_f0:
f0[mask] = f0[mask].log()
if self.normalize_f0_mean:
mean = stats["logf0_mean"] if self.log_f0 else stats["f0_mean"]
f0[mask] = f0[mask] - mean
if self.normalize_f0_std:
std = stats["logf0_std"] if self.log_f0 else stats["f0_std"]
f0[mask] = f0[mask] / std
return f0
def _get_raw_item(self, index):
start, end = self.starts[index], self.ends[index]
if self._codes is None:
self.get_data_handlers()
code = torch.from_numpy(np.array(self._codes[start:end])).long()
dur = torch.from_numpy(np.array(self._durs[start:end]))
f0 = torch.from_numpy(np.array(self._f0s[start:end]))
return code, dur, f0
def __getitem__(self, index):
code, dur, f0 = self._get_raw_item(index)
code = torch.cat([code.new([self.dictionary.bos()]), code])
# use 0 for eos and bos
dur = torch.cat([dur.new([0]), dur])
if self.discrete_dur:
dur = self.dur_dictionary.encode_line(
" ".join(map(str, dur.tolist())), append_eos=False
).long()
else:
dur = dur.float()
# TODO: find a more elegant approach
raw_f0 = None
if self.discrete_f0:
if self.config.f0_vq_type == "precomp":
f0 = self.f0_dictionary.encode_line(
" ".join(map(str, f0.tolist())), append_eos=False
).long()
else:
f0 = f0.float()
f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
if self.return_continuous_f0:
raw_f0 = f0
raw_f0 = torch.cat([raw_f0.new([self.f0_dictionary.bos()]), raw_f0])
f0 = naive_quantize(f0, self._f0_quantizer)
f0 = torch.cat([f0.new([self.f0_dictionary.bos()]), f0])
else:
f0 = f0.float()
if self.multispkr:
f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
else:
f0 = self.preprocess_f0(f0, self.f0_stats)
f0 = torch.cat([f0.new([0]), f0])
if raw_f0 is not None:
*_, raw_f0, raw_f0_mask = self.shifts(code, dur, raw_f0)
else:
raw_f0_mask = None
code, code_mask, dur, dur_mask, f0, f0_mask = self.shifts(code, dur, f0)
if raw_f0_mask is not None:
assert (raw_f0_mask == f0_mask).all()
# is a padded frame if either input or output is padded
feats = {
"source": code[:-1],
"target": code[1:],
"mask": code_mask[1:].logical_or(code_mask[:-1]),
"dur_source": dur[:-1],
"dur_target": dur[1:],
"dur_mask": dur_mask[1:].logical_or(dur_mask[:-1]),
"f0_source": f0[:-1],
"f0_target": f0[1:],
"f0_mask": f0_mask[1:].logical_or(f0_mask[:-1]),
}
if raw_f0 is not None:
feats["raw_f0"] = raw_f0[1:]
if self.return_filename:
fname = self.file_names[index]
feats["filename"] = (
fname if not self.strip_filename else Path(fname).with_suffix("").name
)
return feats
def __len__(self):
return len(self.starts)
def size(self, index):
return self.ends[index] - self.starts[index] + self.shifts.extra_length
def num_tokens(self, index):
return self.size(index)
def collater(self, samples):
pad_idx, eos_idx = self.dictionary.pad(), self.dictionary.eos()
if len(samples) == 0:
return {}
src_tokens = data_utils.collate_tokens(
[s["source"] for s in samples], pad_idx, eos_idx, left_pad=False
)
tgt_tokens = data_utils.collate_tokens(
[s["target"] for s in samples],
pad_idx=pad_idx,
eos_idx=pad_idx, # appending padding, eos is there already
left_pad=False,
)
src_durs, tgt_durs = [
data_utils.collate_tokens(
[s[k] for s in samples],
pad_idx=self.pads.dur,
eos_idx=self.pads.dur,
left_pad=False,
)
for k in ["dur_source", "dur_target"]
]
src_f0s, tgt_f0s = [
data_utils.collate_tokens(
[s[k] for s in samples],
pad_idx=self.pads.f0,
eos_idx=self.pads.f0,
left_pad=False,
)
for k in ["f0_source", "f0_target"]
]
mask, dur_mask, f0_mask = [
data_utils.collate_tokens(
[s[k] for s in samples],
pad_idx=1,
eos_idx=1,
left_pad=False,
)
for k in ["mask", "dur_mask", "f0_mask"]
]
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
n_tokens = sum(len(s["source"]) for s in samples)
result = {
"nsentences": len(samples),
"ntokens": n_tokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"dur_src": src_durs,
"f0_src": src_f0s,
},
"target": tgt_tokens,
"dur_target": tgt_durs,
"f0_target": tgt_f0s,
"mask": mask,
"dur_mask": dur_mask,
"f0_mask": f0_mask,
}
if "filename" in samples[0]:
result["filename"] = [s["filename"] for s in samples]
# TODO: remove this hack into the inference dataset
if "prefix" in samples[0]:
result["prefix"] = [s["prefix"] for s in samples]
if "raw_f0" in samples[0]:
raw_f0s = data_utils.collate_tokens(
[s["raw_f0"] for s in samples],
pad_idx=self.pads.f0,
eos_idx=self.pads.f0,
left_pad=False,
)
result["raw_f0"] = raw_f0s
return result

View File

@@ -0,0 +1,25 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import BaseWrapperDataset
class ColorizeDataset(BaseWrapperDataset):
"""Adds 'colors' property to net input that is obtained from the provided color getter for use by models"""
def __init__(self, dataset, color_getter):
super().__init__(dataset)
self.color_getter = color_getter
def collater(self, samples):
base_collate = super().collater(samples)
if len(base_collate) > 0:
base_collate["net_input"]["colors"] = torch.tensor(
list(self.color_getter(self.dataset, s["id"]) for s in samples),
dtype=torch.long,
)
return base_collate

View File

@@ -0,0 +1,124 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import bisect
import numpy as np
from torch.utils.data.dataloader import default_collate
from . import FairseqDataset
class ConcatDataset(FairseqDataset):
@staticmethod
def cumsum(sequence, sample_ratios):
r, s = [], 0
for e, ratio in zip(sequence, sample_ratios):
curr_len = int(ratio * len(e))
r.append(curr_len + s)
s += curr_len
return r
def __init__(self, datasets, sample_ratios=1):
super(ConcatDataset, self).__init__()
assert len(datasets) > 0, "datasets should not be an empty iterable"
self.datasets = list(datasets)
if isinstance(sample_ratios, int):
sample_ratios = [sample_ratios] * len(self.datasets)
self.sample_ratios = sample_ratios
self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
self.real_sizes = [len(d) for d in self.datasets]
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
return self.datasets[dataset_idx][sample_idx]
def _get_dataset_and_sample_index(self, idx: int):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
sample_idx = sample_idx % self.real_sizes[dataset_idx]
return dataset_idx, sample_idx
def collater(self, samples, **extra_args):
# For now only supports datasets with same underlying collater implementations
if hasattr(self.datasets[0], "collater"):
return self.datasets[0].collater(samples, **extra_args)
else:
return default_collate(samples, **extra_args)
def size(self, idx: int):
"""
Return an example's size as a float or tuple.
"""
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
return self.datasets[dataset_idx].size(sample_idx)
def num_tokens(self, index: int):
return np.max(self.size(index))
def attr(self, attr: str, index: int):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
return getattr(self.datasets[dataset_idx], attr, None)
@property
def sizes(self):
_dataset_sizes = []
for ds, sr in zip(self.datasets, self.sample_ratios):
if isinstance(ds.sizes, np.ndarray):
_dataset_sizes.append(np.tile(ds.sizes, sr))
else:
# Only support underlying dataset with single size array.
assert isinstance(ds.sizes, list)
_dataset_sizes.append(np.tile(ds.sizes[0], sr))
return np.concatenate(_dataset_sizes)
@property
def supports_prefetch(self):
return all(d.supports_prefetch for d in self.datasets)
def ordered_indices(self):
"""
Returns indices sorted by length. So less padding is needed.
"""
if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1:
# special handling for concatenating lang_pair_datasets
indices = np.arange(len(self))
sizes = self.sizes
tgt_sizes = (
sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
)
src_sizes = (
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
)
# sort by target length, then source length
if tgt_sizes is not None:
indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
return indices[np.argsort(src_sizes[indices], kind="mergesort")]
else:
return np.argsort(self.sizes)
def prefetch(self, indices):
frm = 0
for to, ds in zip(self.cumulative_sizes, self.datasets):
real_size = len(ds)
if getattr(ds, "supports_prefetch", False):
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
frm = to
@property
def can_reuse_epoch_itr_across_epochs(self):
return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.datasets:
if hasattr(ds, "set_epoch"):
ds.set_epoch(epoch)

View File

@@ -0,0 +1,54 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import FairseqDataset
class ConcatSentencesDataset(FairseqDataset):
def __init__(self, *datasets):
super().__init__()
self.datasets = datasets
assert all(
len(ds) == len(datasets[0]) for ds in datasets
), "datasets must have the same length"
def __getitem__(self, index):
return torch.cat([ds[index] for ds in self.datasets])
def __len__(self):
return len(self.datasets[0])
def collater(self, samples):
return self.datasets[0].collater(samples)
@property
def sizes(self):
return sum(ds.sizes for ds in self.datasets)
def num_tokens(self, index):
return sum(ds.num_tokens(index) for ds in self.datasets)
def size(self, index):
return sum(ds.size(index) for ds in self.datasets)
def ordered_indices(self):
return self.datasets[0].ordered_indices()
@property
def supports_prefetch(self):
return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets)
def prefetch(self, indices):
for ds in self.datasets:
if getattr(ds, "supports_prefetch", False):
ds.prefetch(indices)
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.datasets:
if hasattr(ds, "set_epoch"):
ds.set_epoch(epoch)

View File

@@ -0,0 +1,604 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
try:
from collections.abc import Iterable
except ImportError:
from collections import Iterable
import contextlib
import itertools
import logging
import re
import warnings
from typing import Optional, Tuple
import numpy as np
import torch
from fairseq.file_io import PathManager
from fairseq import utils
import os
logger = logging.getLogger(__name__)
def infer_language_pair(path):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src, dst = None, None
for filename in PathManager.ls(path):
parts = filename.split(".")
if len(parts) >= 3 and len(parts[1].split("-")) == 2:
return parts[1].split("-")
return src, dst
def collate_tokens(
values,
pad_idx,
eos_idx=None,
left_pad=False,
move_eos_to_beginning=False,
pad_to_length=None,
pad_to_multiple=1,
pad_to_bsz=None,
):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
size = size if pad_to_length is None else max(size, pad_to_length)
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz)
res = values[0].new(batch_size, size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
if eos_idx is None:
# if no eos_idx is specified, then use the last token in src
dst[0] = src[-1]
else:
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)
for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
return res
def load_indexed_dataset(
path, dictionary=None, dataset_impl=None, combine=False, default="cached"
):
"""A helper function for loading indexed datasets.
Args:
path (str): path to indexed dataset (e.g., 'data-bin/train')
dictionary (~fairseq.data.Dictionary): data dictionary
dataset_impl (str, optional): which dataset implementation to use. If
not provided, it will be inferred automatically. For legacy indexed
data we use the 'cached' implementation by default.
combine (bool, optional): automatically load and combine multiple
datasets. For example, if *path* is 'data-bin/train', then we will
combine 'data-bin/train', 'data-bin/train1', ... and return a
single ConcatDataset instance.
"""
import fairseq.data.indexed_dataset as indexed_dataset
from fairseq.data.concat_dataset import ConcatDataset
datasets = []
for k in itertools.count():
path_k = path + (str(k) if k > 0 else "")
try:
path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
except Exception as e:
if "StorageException: [404] Path not found" in str(e):
logger.warning(f"path_k: {e} not found")
else:
raise e
dataset_impl_k = dataset_impl
if dataset_impl_k is None:
dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
dataset = indexed_dataset.make_dataset(
path_k,
impl=dataset_impl_k or default,
fix_lua_indexing=True,
dictionary=dictionary,
)
if dataset is None:
break
logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k))
datasets.append(dataset)
if not combine:
break
if len(datasets) == 0:
return None
elif len(datasets) == 1:
return datasets[0]
else:
return ConcatDataset(datasets)
@contextlib.contextmanager
def numpy_seed(seed, *addl_seeds):
"""Context manager which seeds the NumPy PRNG with the specified seed and
restores the state afterward"""
if seed is None:
yield
return
if len(addl_seeds) > 0:
seed = int(hash((seed, *addl_seeds)) % 1e6)
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)
def collect_filtered(function, iterable, filtered):
"""
Similar to :func:`filter` but collects filtered elements in ``filtered``.
Args:
function (callable): function that returns ``False`` for elements that
should be filtered
iterable (iterable): iterable to filter
filtered (list): list to store filtered elements
"""
for el in iterable:
if function(el):
yield el
else:
filtered.append(el)
def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
def compare_leq(a, b):
return a <= b if not isinstance(a, tuple) else max(a) <= b
def check_size(idx):
if isinstance(max_positions, float) or isinstance(max_positions, int):
return size_fn(idx) <= max_positions
elif isinstance(max_positions, dict):
idx_size = size_fn(idx)
assert isinstance(idx_size, dict)
intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
return all(
all(
a is None or b is None or a <= b
for a, b in zip(idx_size[key], max_positions[key])
)
for key in intersect_keys
)
else:
# For MultiCorpusSampledDataset, will generalize it later
if not isinstance(size_fn(idx), Iterable):
return all(size_fn(idx) <= b for b in max_positions)
return all(
a is None or b is None or a <= b
for a, b in zip(size_fn(idx), max_positions)
)
ignored = []
itr = collect_filtered(check_size, indices, ignored)
indices = np.fromiter(itr, dtype=np.int64, count=-1)
return indices, ignored
def filter_by_size(indices, dataset, max_positions, raise_exception=False):
"""
[deprecated] Filter indices based on their size.
Use `FairseqDataset::filter_indices_by_size` instead.
Args:
indices (List[int]): ordered list of dataset indices
dataset (FairseqDataset): fairseq dataset instance
max_positions (tuple): filter elements larger than this size.
Comparisons are done component-wise.
raise_exception (bool, optional): if ``True``, raise an exception if
any elements are filtered (default: False).
"""
warnings.warn(
"data_utils.filter_by_size is deprecated. "
"Use `FairseqDataset::filter_indices_by_size` instead.",
stacklevel=2,
)
if isinstance(max_positions, float) or isinstance(max_positions, int):
if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
ignored = indices[dataset.sizes[indices] > max_positions].tolist()
indices = indices[dataset.sizes[indices] <= max_positions]
elif (
hasattr(dataset, "sizes")
and isinstance(dataset.sizes, list)
and len(dataset.sizes) == 1
):
ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
indices = indices[dataset.sizes[0][indices] <= max_positions]
else:
indices, ignored = _filter_by_size_dynamic(
indices, dataset.size, max_positions
)
else:
indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
if len(ignored) > 0 and raise_exception:
raise Exception(
(
"Size of sample #{} is invalid (={}) since max_positions={}, "
"skip this example with --skip-invalid-size-inputs-valid-test"
).format(ignored[0], dataset.size(ignored[0]), max_positions)
)
if len(ignored) > 0:
logger.warning(
(
"{} samples have invalid sizes and will be skipped, "
"max_positions={}, first few sample ids={}"
).format(len(ignored), max_positions, ignored[:10])
)
return indices
def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
"""Filter a list of sample indices. Remove those that are longer
than specified in max_sizes.
Args:
indices (np.array): original array of sample indices
max_sizes (int or list[int] or tuple[int]): max sample size,
can be defined separately for src and tgt (then list or tuple)
Returns:
np.array: filtered sample array
list: list of removed indices
"""
if max_sizes is None:
return indices, []
if type(max_sizes) in (int, float):
max_src_size, max_tgt_size = max_sizes, max_sizes
else:
max_src_size, max_tgt_size = max_sizes
if tgt_sizes is None:
ignored = indices[src_sizes[indices] > max_src_size]
else:
ignored = indices[
(src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
]
if len(ignored) > 0:
if tgt_sizes is None:
indices = indices[src_sizes[indices] <= max_src_size]
else:
indices = indices[
(src_sizes[indices] <= max_src_size)
& (tgt_sizes[indices] <= max_tgt_size)
]
return indices, ignored.tolist()
def batch_by_size(
indices,
num_tokens_fn,
num_tokens_vec=None,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
fixed_shapes=None,
):
"""
Yield mini-batches of indices bucketed by size. Batches may contain
sequences of different lengths.
Args:
indices (List[int]): ordered list of dataset indices
num_tokens_fn (callable): function that returns the number of tokens at
a given index
num_tokens_vec (List[int], optional): precomputed vector of the number
of tokens for each index in indices (to enable faster batch generation)
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch (default: None).
required_batch_size_multiple (int, optional): require batch size to
be less than N or a multiple of N (default: 1).
fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
only be created with the given shapes. *max_sentences* and
*required_batch_size_multiple* will be ignored (default: None).
"""
try:
from fairseq.data.data_utils_fast import (
batch_by_size_fn,
batch_by_size_vec,
batch_fixed_shapes_fast,
)
except ImportError:
raise ImportError(
"Please build Cython components with: "
"`python setup.py build_ext --inplace`"
)
except ValueError:
raise ValueError(
"Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`."
)
# added int() to avoid TypeError: an integer is required
max_tokens = int(max_tokens) if max_tokens is not None else -1
max_sentences = max_sentences if max_sentences is not None else -1
bsz_mult = required_batch_size_multiple
if not isinstance(indices, np.ndarray):
indices = np.fromiter(indices, dtype=np.int64, count=-1)
if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray):
num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1)
if fixed_shapes is None:
if num_tokens_vec is None:
return batch_by_size_fn(
indices,
num_tokens_fn,
max_tokens,
max_sentences,
bsz_mult,
)
else:
return batch_by_size_vec(
indices,
num_tokens_vec,
max_tokens,
max_sentences,
bsz_mult,
)
else:
fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
sort_order = np.lexsort(
[
fixed_shapes[:, 1].argsort(), # length
fixed_shapes[:, 0].argsort(), # bsz
]
)
fixed_shapes_sorted = fixed_shapes[sort_order]
return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
def post_process(sentence: str, symbol: str):
if symbol == "sentencepiece":
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
elif symbol == "wordpiece":
sentence = sentence.replace(" ", "").replace("_", " ").strip()
elif symbol == "letter":
sentence = sentence.replace(" ", "").replace("|", " ").strip()
elif symbol == "silence":
import re
sentence = sentence.replace("<SIL>", "")
sentence = re.sub(" +", " ", sentence).strip()
elif symbol == "_EOW":
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
elif symbol in {"subword_nmt", "@@ ", "@@"}:
if symbol == "subword_nmt":
symbol = "@@ "
sentence = (sentence + " ").replace(symbol, "").rstrip()
elif symbol == "none":
pass
elif symbol is not None:
raise NotImplementedError(f"Unknown post_process option: {symbol}")
return sentence
def compute_mask_indices(
shape: Tuple[int, int],
padding_mask: Optional[torch.Tensor],
mask_prob: float,
mask_length: int,
mask_type: str = "static",
mask_other: float = 0.0,
min_masks: int = 0,
no_overlap: bool = False,
min_space: int = 0,
require_same_masks: bool = True,
mask_dropout: float = 0.0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape
Args:
shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_type: how to compute mask lengths
static = fixed size
uniform = sample from uniform distribution [mask_other, mask_length*2]
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
poisson = sample from possion distribution with lambda = mask length
min_masks: minimum number of masked spans
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
mask_dropout: randomly dropout this percentage of masks in each example
"""
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
all_num_mask = int(
# add a random number for probabilistic rounding
mask_prob * all_sz / float(mask_length)
+ np.random.rand()
)
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
for i in range(bsz):
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ np.random.rand()
)
num_mask = max(min_masks, num_mask)
else:
sz = all_sz
num_mask = all_num_mask
if mask_type == "static":
lengths = np.full(num_mask, mask_length)
elif mask_type == "uniform":
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
elif mask_type == "normal":
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
lengths = [max(1, int(round(x))) for x in lengths]
elif mask_type == "poisson":
lengths = np.random.poisson(mask_length, size=num_mask)
lengths = [int(round(x)) for x in lengths]
else:
raise Exception("unknown mask selection " + mask_type)
if sum(lengths) == 0:
lengths[0] = min(mask_length, sz - 1)
if no_overlap:
mask_idc = []
def arrange(s, e, length, keep_length):
span_start = np.random.randint(s, e - length)
mask_idc.extend(span_start + i for i in range(length))
new_parts = []
if span_start - s - min_space >= keep_length:
new_parts.append((s, span_start - min_space + 1))
if e - span_start - length - min_space > keep_length:
new_parts.append((span_start + length + min_space, e))
return new_parts
parts = [(0, sz)]
min_length = min(lengths)
for length in sorted(lengths, reverse=True):
lens = np.fromiter(
(e - s if e - s >= length + min_space else 0 for s, e in parts),
np.int,
)
l_sum = np.sum(lens)
if l_sum == 0:
break
probs = lens / np.sum(lens)
c = np.random.choice(len(parts), p=probs)
s, e = parts.pop(c)
parts.extend(arrange(s, e, length, min_length))
mask_idc = np.asarray(mask_idc)
else:
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
mask_idc = np.asarray(
[
mask_idc[j] + offset
for j in range(len(mask_idc))
for offset in range(lengths[j])
]
)
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len and require_same_masks:
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
if mask_dropout > 0:
num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int)
mask_idc = np.random.choice(
mask_idc, len(mask_idc) - num_holes, replace=False
)
mask[i, mask_idc] = True
return mask
def get_mem_usage():
try:
import psutil
mb = 1024 * 1024
return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
except ImportError:
return "N/A"
# lens: torch.LongTensor
# returns: torch.BoolTensor
def lengths_to_padding_mask(lens):
bsz, max_lens = lens.size(0), torch.max(lens).item()
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
return mask
# lens: torch.LongTensor
# returns: torch.BoolTensor
def lengths_to_mask(lens):
return ~lengths_to_padding_mask(lens)
def get_buckets(sizes, num_buckets):
buckets = np.unique(
np.percentile(
sizes,
np.linspace(0, 100, num_buckets + 1),
interpolation="lower",
)[1:]
)
return buckets
def get_bucketed_sizes(orig_sizes, buckets):
sizes = np.copy(orig_sizes)
assert np.min(sizes) >= 0
start_val = -1
for end_val in buckets:
mask = (sizes > start_val) & (sizes <= end_val)
sizes[mask] = end_val
start_val = end_val
return sizes
def _find_extra_valid_paths(dataset_path: str) -> set:
paths = utils.split_paths(dataset_path)
all_valid_paths = set()
for sub_dir in paths:
contents = PathManager.ls(sub_dir)
valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
all_valid_paths |= {os.path.basename(p) for p in valid_paths}
# Remove .bin, .idx etc
roots = {os.path.splitext(p)[0] for p in all_valid_paths}
return roots
def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
"""Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored."""
if (
train_cfg.dataset.ignore_unused_valid_subsets
or train_cfg.dataset.combine_valid_subsets
or train_cfg.dataset.disable_validation
or not hasattr(train_cfg.task, "data")
):
return
other_paths = _find_extra_valid_paths(train_cfg.task.data)
specified_subsets = train_cfg.dataset.valid_subset.split(",")
ignored_paths = [p for p in other_paths if p not in specified_subsets]
if ignored_paths:
advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
raise ValueError(msg)

View File

@@ -0,0 +1,178 @@
# cython: language_level=3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
cimport cython
cimport numpy as np
from libc.stdint cimport int32_t, int64_t
from libcpp cimport bool as bool_t
ctypedef int64_t DTYPE_t
@cython.cdivision(True)
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef list batch_by_size_vec(
np.ndarray[int64_t, ndim=1] indices,
np.ndarray[int64_t, ndim=1] num_tokens_vec,
int64_t max_tokens,
int64_t max_sentences,
int32_t bsz_mult,
):
if indices.shape[0] == 0:
return []
assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, (
f"Sentences lengths should not exceed max_tokens={max_tokens}"
)
cdef int32_t indices_len = indices.shape[0]
cdef np.ndarray[int32_t, ndim=1] batches_ends = \
np.zeros(indices_len, dtype=np.int32)
cdef int32_t[:] batches_ends_view = batches_ends
cdef int64_t[:] num_tokens_view = num_tokens_vec
cdef int32_t pos = 0
cdef int32_t new_batch_end = 0
cdef int64_t new_batch_max_tokens = 0
cdef int32_t new_batch_sentences = 0
cdef int64_t new_batch_num_tokens = 0
cdef bool_t overflow = False
cdef bool_t size_matches_with_bsz_mult = False
cdef int32_t batches_count = 0
cdef int32_t batch_start = 0
cdef int64_t tail_max_tokens = 0
cdef int64_t batch_max_tokens = 0
for pos in range(indices_len):
# At every pos we keep stats about the last complete batch [batch_start:batch_end),
# and tail [batch_end:pos].
# 1) Every time when (batch + tail) forms a valid batch
# (according to max_tokens, max_sentences and bsz_mult) we append tail to batch.
# 2) When (batch+tail) violates max_tokens or max_sentences constraints
# we finalize running batch, and tail becomes a new batch.
# 3) There is a corner case when tail also violates constraints.
# In that situation [batch_end:pos-1] (tail without the current pos)
# gets added to the finalized batches, while [pos:pos] becomes a new tail.
#
# Important: For the sake of performance try to avoid using function calls within this loop.
tail_max_tokens = tail_max_tokens \
if tail_max_tokens > num_tokens_view[pos] \
else num_tokens_view[pos]
new_batch_end = pos + 1
new_batch_max_tokens = batch_max_tokens \
if batch_max_tokens > tail_max_tokens \
else tail_max_tokens
new_batch_sentences = new_batch_end - batch_start
new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens
overflow = (new_batch_sentences > max_sentences > 0 or
new_batch_num_tokens > max_tokens > 0)
size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or
new_batch_sentences % bsz_mult == 0)
if overflow:
tail_num_tokens = tail_max_tokens * \
(new_batch_end - batches_ends_view[batches_count])
tail_overflow = tail_num_tokens > max_tokens > 0
# In case of a tail overflow finalize two batches
if tail_overflow:
batches_count += 1
batches_ends_view[batches_count] = pos
tail_max_tokens = num_tokens_view[pos]
batch_start = batches_ends_view[batches_count]
batches_count += 1
new_batch_max_tokens = tail_max_tokens
if overflow or size_matches_with_bsz_mult:
batches_ends_view[batches_count] = new_batch_end
batch_max_tokens = new_batch_max_tokens
tail_max_tokens = 0
if batches_ends_view[batches_count] != indices_len:
batches_count += 1
# Memory and time-efficient split
return np.split(indices, batches_ends[:batches_count])
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef list batch_by_size_fn(
np.ndarray[DTYPE_t, ndim=1] indices,
num_tokens_fn,
int64_t max_tokens,
int64_t max_sentences,
int32_t bsz_mult,
):
cdef int32_t indices_len = indices.shape[0]
cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len,
dtype=np.int64)
cdef DTYPE_t[:] indices_view = indices
cdef DTYPE_t[:] num_tokens_vec_view = num_tokens_vec
cdef int64_t pos
for pos in range(indices_len):
num_tokens_vec[pos] = num_tokens_fn(indices_view[pos])
return batch_by_size_vec(indices, num_tokens_vec, max_tokens,
max_sentences, bsz_mult,)
cdef _find_valid_shape(
DTYPE_t[:, :] shapes_view,
int64_t num_sentences,
int64_t num_tokens,
):
"""Return index of first valid shape of -1 if none is found."""
for i in range(shapes_view.shape[0]):
if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]:
return i
return -1
@cython.cdivision(True)
cpdef list batch_fixed_shapes_fast(
np.ndarray[DTYPE_t, ndim=1] indices,
num_tokens_fn,
np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted,
):
cdef int64_t sample_len = 0
cdef list sample_lens = []
cdef list batch = []
cdef list batches = []
cdef int64_t mod_len
cdef int64_t i
cdef int64_t idx
cdef int64_t num_tokens
cdef DTYPE_t[:] indices_view = indices
cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted
for i in range(len(indices_view)):
idx = indices_view[i]
num_tokens = num_tokens_fn(idx)
sample_lens.append(num_tokens)
sample_len = max(sample_len, num_tokens)
shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len)
if shape_idx == -1:
batches.append(batch)
batch = []
sample_lens = []
sample_len = 0
shapes_view = fixed_shapes_sorted
elif shape_idx > 0:
# small optimization for the next call to _find_valid_shape
shapes_view = shapes_view[shape_idx:]
batch.append(idx)
if len(batch) > 0:
batches.append(batch)
return batches

View File

@@ -0,0 +1,443 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import numpy as np
import torch
from . import FairseqDataset, data_utils
def collate(
samples,
pad_idx,
eos_idx,
vocab,
left_pad_source=False,
left_pad_target=False,
input_feeding=True,
pad_to_length=None,
):
assert input_feeding
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
left_pad=left_pad,
move_eos_to_beginning=move_eos_to_beginning,
pad_to_length=pad_to_length,
)
id = torch.LongTensor([s["id"] for s in samples])
src_tokens = merge(
"source",
left_pad=left_pad_source,
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
)
# sort by descending source length
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge(
"target",
left_pad=left_pad_target,
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
target = target.index_select(0, sort_order)
ntokens = sum(len(s["target"]) for s in samples)
if input_feeding:
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
"target",
left_pad=left_pad_target,
move_eos_to_beginning=True,
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
else:
ntokens = sum(len(s["source"]) for s in samples)
batch = {
"id": id,
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
},
"target": target,
"nsentences": samples[0]["source"].size(0),
"sort_order": sort_order,
}
if prev_output_tokens is not None:
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
return batch
class DenoisingDataset(FairseqDataset):
"""
A wrapper around TokenBlockDataset for BART dataset.
Args:
dataset (TokenBlockDataset): dataset to wrap
sizes (List[int]): sentence lengths
vocab (~fairseq.data.Dictionary): vocabulary
mask_idx (int): dictionary index used for masked token
mask_whole_words: only mask whole words. This should be a byte mask
over vocab indices, indicating whether it is the beginning of a
word. We will extend any mask to encompass the whole word.
shuffle (bool, optional): shuffle the elements before batching.
Default: ``True``
seed: Seed for random number generator for reproducibility.
"""
def __init__(
self,
dataset,
sizes,
vocab,
mask_idx,
mask_whole_words,
shuffle,
seed,
mask,
mask_random,
insert,
rotate,
permute_sentences,
bpe,
replace_length,
mask_length,
poisson_lambda,
eos=None,
item_transform_func=None,
):
self.dataset = dataset
self.sizes = sizes
self.vocab = vocab
self.shuffle = shuffle
self.seed = seed
self.mask_idx = mask_idx
self.mask_whole_word = mask_whole_words
self.mask_ratio = mask
self.random_ratio = mask_random
self.insert_ratio = insert
self.rotate_ratio = rotate
self.permute_sentence_ratio = permute_sentences
self.eos = eos if eos is not None else vocab.eos()
self.item_transform_func = item_transform_func
if bpe != "gpt2":
self.full_stop_index = self.vocab.eos()
else:
assert bpe == "gpt2"
self.full_stop_index = self.vocab.index("13")
self.replace_length = replace_length
if self.replace_length not in [-1, 0, 1]:
raise ValueError(f"invalid arg: replace_length={self.replace_length}")
if mask_length not in ["subword", "word", "span-poisson"]:
raise ValueError(f"invalid arg: mask-length={mask_length}")
if mask_length == "subword" and replace_length not in [0, 1]:
raise ValueError(f"if using subwords, use replace-length=1 or 0")
self.mask_span_distribution = None
if mask_length == "span-poisson":
_lambda = poisson_lambda
lambda_to_the_k = 1
e_to_the_minus_lambda = math.exp(-_lambda)
k_factorial = 1
ps = []
for k in range(0, 128):
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
lambda_to_the_k *= _lambda
k_factorial *= k + 1
if ps[-1] < 0.0000001:
break
ps = torch.FloatTensor(ps)
self.mask_span_distribution = torch.distributions.Categorical(ps)
self.epoch = 0
@property
def can_reuse_epoch_itr_across_epochs(self):
return True # only the noise changes, not item sizes
def set_epoch(self, epoch, **unused):
self.epoch = epoch
def __getitem__(self, index):
with data_utils.numpy_seed(self.seed, self.epoch, index):
tokens = self.dataset[index]
assert tokens[-1] == self.eos
source, target = tokens, tokens.clone()
if self.permute_sentence_ratio > 0.0:
source = self.permute_sentences(source, self.permute_sentence_ratio)
if self.mask_ratio > 0:
source = self.add_whole_word_mask(source, self.mask_ratio)
if self.insert_ratio > 0:
source = self.add_insertion_noise(source, self.insert_ratio)
if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio:
source = self.add_rolling_noise(source)
# there can additional changes to make:
if self.item_transform_func is not None:
source, target = self.item_transform_func(source, target)
assert (source >= 0).all()
assert (source[1:-1] >= 1).all()
assert (source <= len(self.vocab)).all()
assert source[0] == self.vocab.bos()
assert source[-1] == self.eos
return {
"id": index,
"source": source,
"target": target,
}
def __len__(self):
return len(self.dataset)
def permute_sentences(self, source, p=1.0):
full_stops = source == self.full_stop_index
# Pretend it ends with a full stop so last span is a sentence
full_stops[-2] = 1
# Tokens that are full stops, where the previous token is not
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2
result = source.clone()
num_sentences = sentence_ends.size(0)
num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0)
substitutions = torch.randperm(num_sentences)[:num_to_permute]
ordering = torch.arange(0, num_sentences)
ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]
# Ignore <bos> at start
index = 1
for i in ordering:
sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]]
result[index : index + sentence.size(0)] = sentence
index += sentence.size(0)
return result
def word_starts(self, source):
if self.mask_whole_word is not None:
is_word_start = self.mask_whole_word.gather(0, source)
else:
is_word_start = torch.ones(source.size())
is_word_start[0] = 0
is_word_start[-1] = 0
return is_word_start
def add_whole_word_mask(self, source, p):
is_word_start = self.word_starts(source)
num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
num_inserts = 0
if num_to_mask == 0:
return source
if self.mask_span_distribution is not None:
lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
# Make sure we have enough to mask
cum_length = torch.cumsum(lengths, 0)
while cum_length[-1] < num_to_mask:
lengths = torch.cat(
[
lengths,
self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
],
dim=0,
)
cum_length = torch.cumsum(lengths, 0)
# Trim to masking budget
i = 0
while cum_length[i] < num_to_mask:
i += 1
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
num_to_mask = i + 1
lengths = lengths[:num_to_mask]
# Handle 0-length mask (inserts) separately
lengths = lengths[lengths > 0]
num_inserts = num_to_mask - lengths.size(0)
num_to_mask -= num_inserts
if num_to_mask == 0:
return self.add_insertion_noise(source, num_inserts / source.size(0))
assert (lengths > 0).all()
else:
lengths = torch.ones((num_to_mask,)).long()
assert is_word_start[-1] == 0
word_starts = is_word_start.nonzero(as_tuple=False)
indices = word_starts[
torch.randperm(word_starts.size(0))[:num_to_mask]
].squeeze(1)
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
source_length = source.size(0)
assert source_length - 1 not in indices
to_keep = torch.ones(source_length, dtype=torch.bool)
is_word_start[
-1
] = 255 # acts as a long length, so spans don't go over the end of doc
if self.replace_length == 0:
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
source[indices[mask_random]] = torch.randint(
1, len(self.vocab), size=(mask_random.sum(),)
)
if self.mask_span_distribution is not None:
assert len(lengths.size()) == 1
assert lengths.size() == indices.size()
lengths -= 1
while indices.size(0) > 0:
assert lengths.size() == indices.size()
lengths -= is_word_start[indices + 1].long()
uncompleted = lengths >= 0
indices = indices[uncompleted] + 1
mask_random = mask_random[uncompleted]
lengths = lengths[uncompleted]
if self.replace_length != -1:
# delete token
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
source[indices[mask_random]] = torch.randint(
1, len(self.vocab), size=(mask_random.sum(),)
)
else:
# A bit faster when all lengths are 1
while indices.size(0) > 0:
uncompleted = is_word_start[indices + 1] == 0
indices = indices[uncompleted] + 1
mask_random = mask_random[uncompleted]
if self.replace_length != -1:
# delete token
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
source[indices[mask_random]] = torch.randint(
1, len(self.vocab), size=(mask_random.sum(),)
)
assert source_length - 1 not in indices
source = source[to_keep]
if num_inserts > 0:
source = self.add_insertion_noise(source, num_inserts / source.size(0))
return source
def add_permuted_noise(self, tokens, p):
num_words = len(tokens)
num_to_permute = math.ceil(((num_words * 2) * p) / 2.0)
substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1
tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]]
return tokens
def add_rolling_noise(self, tokens):
offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1)
tokens = torch.cat(
(tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]),
dim=0,
)
return tokens
def add_insertion_noise(self, tokens, p):
if p == 0.0:
return tokens
num_tokens = len(tokens)
n = int(math.ceil(num_tokens * p))
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
noise_mask[noise_indices] = 1
result = torch.LongTensor(n + len(tokens)).fill_(-1)
num_random = int(math.ceil(n * self.random_ratio))
result[noise_indices[num_random:]] = self.mask_idx
result[noise_indices[:num_random]] = torch.randint(
low=1, high=len(self.vocab), size=(num_random,)
)
result[~noise_mask] = tokens
assert (result >= 0).all()
return result
def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch of data
"""
return collate(
samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length
)
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return self.sizes[index]
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return self.sizes[index]
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
indices = np.random.permutation(len(self))
else:
indices = np.arange(len(self))
return indices[np.argsort(self.sizes[indices], kind="mergesort")]
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
@property
def supports_prefetch(self):
return (
hasattr(self.src, "supports_prefetch")
and self.src.supports_prefetch
and hasattr(self.tgt, "supports_prefetch")
and self.tgt.supports_prefetch
)

View File

@@ -0,0 +1,401 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from collections import Counter
from multiprocessing import Pool
import torch
from fairseq import utils
from fairseq.data import data_utils
from fairseq.file_chunker_utils import Chunker, find_offsets
from fairseq.file_io import PathManager
from fairseq.tokenizer import tokenize_line
class Dictionary:
"""A mapping from symbols to consecutive integers"""
def __init__(
self,
*, # begin keyword-only arguments
bos="<s>",
pad="<pad>",
eos="</s>",
unk="<unk>",
extra_special_symbols=None,
):
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
self.bos_index = self.add_symbol(bos)
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
if extra_special_symbols:
for s in extra_special_symbols:
self.add_symbol(s)
self.nspecial = len(self.symbols)
def __eq__(self, other):
return self.indices == other.indices
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
return self.unk_word
def get_count(self, idx):
return self.count[idx]
def __len__(self):
"""Returns the number of symbols in the dictionary"""
return len(self.symbols)
def __contains__(self, sym):
return sym in self.indices
def index(self, sym):
"""Returns the index of the specified symbol"""
assert isinstance(sym, str)
if sym in self.indices:
return self.indices[sym]
return self.unk_index
def string(
self,
tensor,
bpe_symbol=None,
escape_unk=False,
extra_symbols_to_ignore=None,
unk_string=None,
include_eos=False,
separator=" ",
):
"""Helper for converting a tensor of token indices to a string.
Can optionally remove BPE symbols or escape <unk> words.
"""
if torch.is_tensor(tensor) and tensor.dim() == 2:
return "\n".join(
self.string(
t,
bpe_symbol,
escape_unk,
extra_symbols_to_ignore,
include_eos=include_eos,
)
for t in tensor
)
extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
if not include_eos:
extra_symbols_to_ignore.add(self.eos())
def token_string(i):
if i == self.unk():
if unk_string is not None:
return unk_string
else:
return self.unk_string(escape_unk)
else:
return self[i]
if hasattr(self, "bos_index"):
extra_symbols_to_ignore.add(self.bos())
sent = separator.join(
token_string(i)
for i in tensor
if utils.item(i) not in extra_symbols_to_ignore
)
return data_utils.post_process(sent, bpe_symbol)
def unk_string(self, escape=False):
"""Return unknown string, optionally escaped as: <<unk>>"""
if escape:
return "<{}>".format(self.unk_word)
else:
return self.unk_word
def add_symbol(self, word, n=1, overwrite=False):
"""Adds a word to the dictionary"""
if word in self.indices and not overwrite:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
def update(self, new_dict):
"""Updates counts from new dictionary."""
for word in new_dict.symbols:
idx2 = new_dict.indices[word]
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + new_dict.count[idx2]
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(new_dict.count[idx2])
def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
"""Sort symbols by frequency in descending order, ignoring special ones.
Args:
- threshold defines the minimum word count
- nwords defines the total number of words in the final dictionary,
including special symbols
- padding_factor can be used to pad the dictionary size to be a
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
"""
if nwords <= 0:
nwords = len(self)
new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial)))
new_symbols = self.symbols[: self.nspecial]
new_count = self.count[: self.nspecial]
c = Counter(
dict(
sorted(zip(self.symbols[self.nspecial :], self.count[self.nspecial :]))
)
)
for symbol, count in c.most_common(nwords - self.nspecial):
if count >= threshold:
new_indices[symbol] = len(new_symbols)
new_symbols.append(symbol)
new_count.append(count)
else:
break
assert len(new_symbols) == len(new_indices)
self.count = list(new_count)
self.symbols = list(new_symbols)
self.indices = new_indices
self.pad_to_multiple_(padding_factor)
def pad_to_multiple_(self, padding_factor):
"""Pad Dictionary size to be a multiple of *padding_factor*."""
if padding_factor > 1:
i = 0
while len(self) % padding_factor != 0:
symbol = "madeupword{:04d}".format(i)
self.add_symbol(symbol, n=0)
i += 1
def bos(self):
"""Helper to get index of beginning-of-sentence symbol"""
return self.bos_index
def pad(self):
"""Helper to get index of pad symbol"""
return self.pad_index
def eos(self):
"""Helper to get index of end-of-sentence symbol"""
return self.eos_index
def unk(self):
"""Helper to get index of unk symbol"""
return self.unk_index
@classmethod
def load(cls, f):
"""Loads the dictionary from a text file with the format:
```
<symbol0> <count0>
<symbol1> <count1>
...
```
"""
d = cls()
d.add_from_file(f)
return d
def add_from_file(self, f):
"""
Loads a pre-existing dictionary from a text file and adds its symbols
to this instance.
"""
if isinstance(f, str):
try:
with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd:
self.add_from_file(fd)
except FileNotFoundError as fnfe:
raise fnfe
except UnicodeError:
raise Exception(
"Incorrect encoding detected in {}, please "
"rebuild the dataset".format(f)
)
return
lines = f.readlines()
indices_start_line = self._load_meta(lines)
for line in lines[indices_start_line:]:
try:
line, field = line.rstrip().rsplit(" ", 1)
if field == "#fairseq:overwrite":
overwrite = True
line, field = line.rsplit(" ", 1)
else:
overwrite = False
count = int(field)
word = line
if word in self and not overwrite:
raise RuntimeError(
"Duplicate word found when loading Dictionary: '{}'. "
"Duplicate words can overwrite earlier ones by adding the "
"#fairseq:overwrite flag at the end of the corresponding row "
"in the dictionary file. If using the Camembert model, please "
"download an updated copy of the model file.".format(word)
)
self.add_symbol(word, n=count, overwrite=overwrite)
except ValueError:
raise ValueError(
f"Incorrect dictionary format, expected '<token> <cnt> [flags]': \"{line}\""
)
def _save(self, f, kv_iterator):
if isinstance(f, str):
PathManager.mkdirs(os.path.dirname(f))
with PathManager.open(f, "w", encoding="utf-8") as fd:
return self.save(fd)
for k, v in kv_iterator:
print("{} {}".format(k, v), file=f)
def _get_meta(self):
return [], []
def _load_meta(self, lines):
return 0
def save(self, f):
"""Stores dictionary into a text file"""
ex_keys, ex_vals = self._get_meta()
self._save(
f,
zip(
ex_keys + self.symbols[self.nspecial :],
ex_vals + self.count[self.nspecial :],
),
)
def dummy_sentence(self, length):
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
t[-1] = self.eos()
return t
def encode_line(
self,
line,
line_tokenizer=tokenize_line,
add_if_not_exist=True,
consumer=None,
append_eos=True,
reverse_order=False,
) -> torch.IntTensor:
words = line_tokenizer(line)
if reverse_order:
words = list(reversed(words))
nwords = len(words)
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
for i, word in enumerate(words):
if add_if_not_exist:
idx = self.add_symbol(word)
else:
idx = self.index(word)
if consumer is not None:
consumer(word, idx)
ids[i] = idx
if append_eos:
ids[nwords] = self.eos_index
return ids
@staticmethod
def _add_file_to_dictionary_single_worker(
filename,
tokenize,
eos_word,
start_offset,
end_offset,
):
counter = Counter()
with Chunker(filename, start_offset, end_offset) as line_iterator:
for line in line_iterator:
for word in tokenize(line):
counter.update([word])
counter.update([eos_word])
return counter
@staticmethod
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
def merge_result(counter):
for w, c in sorted(counter.items()):
dict.add_symbol(w, c)
local_file = PathManager.get_local_path(filename)
offsets = find_offsets(local_file, num_workers)
if num_workers > 1:
chunks = zip(offsets, offsets[1:])
pool = Pool(processes=num_workers)
results = []
for (start_offset, end_offset) in chunks:
results.append(
pool.apply_async(
Dictionary._add_file_to_dictionary_single_worker,
(
local_file,
tokenize,
dict.eos_word,
start_offset,
end_offset,
),
)
)
pool.close()
pool.join()
for r in results:
merge_result(r.get())
else:
merge_result(
Dictionary._add_file_to_dictionary_single_worker(
local_file, tokenize, dict.eos_word, offsets[0], offsets[1]
)
)
class TruncatedDictionary(object):
def __init__(self, wrapped_dict, length):
self.__class__ = type(
wrapped_dict.__class__.__name__,
(self.__class__, wrapped_dict.__class__),
{},
)
self.__dict__ = wrapped_dict.__dict__
self.wrapped_dict = wrapped_dict
self.length = min(len(self.wrapped_dict), length)
def __len__(self):
return self.length
def __getitem__(self, i):
if i < self.length:
return self.wrapped_dict[i]
return self.wrapped_dict.unk()

View File

@@ -0,0 +1,29 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
from fairseq import registry
build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry(
"--tokenizer",
default=None,
)
build_bpe, register_bpe, BPE_REGISTRY, _ = registry.setup_registry(
"--bpe",
default=None,
)
# automatically import any Python files in the encoders/ directory
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
module = file[: file.find(".py")]
importlib.import_module("fairseq.data.encoders." + module)

View File

@@ -0,0 +1,48 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.data.encoders.byte_utils import (
SPACE,
SPACE_ESCAPE,
byte_encode,
smart_byte_decode,
)
from fairseq.dataclass import FairseqDataclass
@dataclass
class ByteBpeConfig(FairseqDataclass):
sentencepiece_model_path: str = field(
default="???", metadata={"help": "path to sentencepiece model"}
)
@register_bpe("byte_bpe", dataclass=ByteBpeConfig)
class ByteBPE(object):
def __init__(self, cfg):
vocab = file_utils.cached_path(cfg.sentencepiece_model_path)
try:
import sentencepiece as spm
self.sp = spm.SentencePieceProcessor()
self.sp.Load(vocab)
except ImportError:
raise ImportError(
"Please install sentencepiece with: pip install sentencepiece"
)
def encode(self, x: str) -> str:
byte_encoded = byte_encode(x)
return SPACE.join(self.sp.EncodeAsPieces(byte_encoded))
@staticmethod
def decode(x: str) -> str:
unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
return smart_byte_decode(unescaped)

View File

@@ -0,0 +1,51 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import re
WHITESPACE_NORMALIZER = re.compile(r"\s+")
SPACE = chr(32)
SPACE_ESCAPE = chr(9601)
# excluding non-breaking space (160) here
PRINTABLE_LATIN = set(
list(range(32, 126 + 1)) + list(range(161, 172 + 1)) + list(range(174, 255 + 1))
)
BYTE_TO_BCHAR = {
b: chr(b) if b in PRINTABLE_LATIN else chr(256 + b) for b in range(256)
}
BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
def byte_encode(x: str) -> str:
normalized = WHITESPACE_NORMALIZER.sub(SPACE, x)
return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")])
def byte_decode(x: str) -> str:
try:
return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8")
except ValueError:
return ""
def smart_byte_decode(x: str) -> str:
output = byte_decode(x)
if output == "":
# DP the best recovery (max valid chars) if it's broken
n_bytes = len(x)
f = [0 for _ in range(n_bytes + 1)]
pt = [0 for _ in range(n_bytes + 1)]
for i in range(1, n_bytes + 1):
f[i], pt[i] = f[i - 1], i - 1
for j in range(1, min(4, i) + 1):
if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0:
f[i], pt[i] = f[i - j] + 1, i - j
cur_pt = n_bytes
while cur_pt > 0:
if f[cur_pt] == f[pt[cur_pt]] + 1:
output = byte_decode(x[pt[cur_pt] : cur_pt]) + output
cur_pt = pt[cur_pt]
return output

View File

@@ -0,0 +1,34 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.data.encoders import register_bpe
from fairseq.data.encoders.byte_utils import (
SPACE,
SPACE_ESCAPE,
byte_encode,
smart_byte_decode,
)
@register_bpe("bytes")
class Bytes(object):
def __init__(self, *unused):
pass
@staticmethod
def add_args(parser):
pass
@staticmethod
def encode(x: str) -> str:
encoded = byte_encode(x)
escaped = encoded.replace(SPACE, SPACE_ESCAPE)
return SPACE.join(list(escaped))
@staticmethod
def decode(x: str) -> str:
unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
return smart_byte_decode(unescaped)

View File

@@ -0,0 +1,30 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.data.encoders import register_bpe
SPACE = chr(32)
SPACE_ESCAPE = chr(9601)
@register_bpe("characters")
class Characters(object):
def __init__(self, *unused):
pass
@staticmethod
def add_args(parser):
pass
@staticmethod
def encode(x: str) -> str:
escaped = x.replace(SPACE, SPACE_ESCAPE)
return SPACE.join(list(escaped))
@staticmethod
def decode(x: str) -> str:
return x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)

View File

@@ -0,0 +1,36 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@dataclass
class fastBPEConfig(FairseqDataclass):
bpe_codes: str = field(default="???", metadata={"help": "path to fastBPE BPE"})
@register_bpe("fastbpe", dataclass=fastBPEConfig)
class fastBPE(object):
def __init__(self, cfg):
if cfg.bpe_codes is None:
raise ValueError("--bpe-codes is required for --bpe=fastbpe")
codes = file_utils.cached_path(cfg.bpe_codes)
try:
import fastBPE
self.bpe = fastBPE.fastBPE(codes)
self.bpe_symbol = "@@ "
except ImportError:
raise ImportError("Please install fastBPE with: pip install fastBPE")
def encode(self, x: str) -> str:
return self.bpe.apply([x])[0]
def decode(self, x: str) -> str:
return (x + " ").replace(self.bpe_symbol, "").rstrip()

View File

@@ -0,0 +1,45 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
from .gpt2_bpe_utils import get_encoder
DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
@dataclass
class GPT2BPEConfig(FairseqDataclass):
gpt2_encoder_json: str = field(
default=DEFAULT_ENCODER_JSON, metadata={"help": "path to encoder.json"}
)
gpt2_vocab_bpe: str = field(
default=DEFAULT_VOCAB_BPE, metadata={"help": "path to vocab.bpe"}
)
@register_bpe("gpt2", dataclass=GPT2BPEConfig)
class GPT2BPE(object):
def __init__(self, cfg):
encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json)
vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe)
self.bpe = get_encoder(encoder_json, vocab_bpe)
def encode(self, x: str) -> str:
return " ".join(map(str, self.bpe.encode(x)))
def decode(self, x: str) -> str:
return self.bpe.decode(
[int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]
)
def is_beginning_of_word(self, x: str) -> bool:
return self.decode(x).startswith(" ")

View File

@@ -0,0 +1,140 @@
"""
Byte pair encoding utilities from GPT-2.
Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py
Original license: MIT
"""
import json
from functools import lru_cache
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class Encoder:
def __init__(self, encoder, bpe_merges, errors="replace"):
self.encoder = encoder
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
try:
import regex as re
self.re = re
except ImportError:
raise ImportError("Please install regex with: pip install regex")
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = self.re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
for token in self.re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
)
return bpe_tokens
def decode(self, tokens):
text = "".join([self.decoder.get(token, token) for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode(
"utf-8", errors=self.errors
)
return text
def get_encoder(encoder_json_path, vocab_bpe_path):
with open(encoder_json_path, "r") as f:
encoder = json.load(f)
with open(vocab_bpe_path, "r", encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
return Encoder(
encoder=encoder,
bpe_merges=bpe_merges,
)

Some files were not shown because too many files have changed in this diff Show More