Merge pull request #116 from Tony-sama/neo

Feature: RVC module, applying voice conversion to audio sent by ST
This commit is contained in:
Cohee
2023-08-10 19:46:15 +03:00
committed by GitHub
404 changed files with 90425 additions and 2 deletions

View File

0
data/tmp/.placeholder Normal file
View File

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) Facebook, Inc. and its affiliates.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,45 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import os
import sys
try:
from .version import __version__ # noqa
except ImportError:
version_txt = os.path.join(os.path.dirname(__file__), "version.txt")
with open(version_txt) as f:
__version__ = f.read().strip()
__all__ = ["pdb"]
# backwards compatibility to support `from fairseq.X import Y`
from fairseq.distributed import utils as distributed_utils
from fairseq.logging import meters, metrics, progress_bar # noqa
sys.modules["fairseq.distributed_utils"] = distributed_utils
sys.modules["fairseq.meters"] = meters
sys.modules["fairseq.metrics"] = metrics
sys.modules["fairseq.progress_bar"] = progress_bar
# initialize hydra
#from fairseq.dataclass.initialize import hydra_init
#hydra_init()
#import fairseq.criterions # noqa
#import fairseq.distributed # noqa
#import fairseq.models # noqa
#import fairseq.modules # noqa
#import fairseq.optim # noqa
#import fairseq.optim.lr_scheduler # noqa
#import fairseq.pdb # noqa
#import fairseq.scoring # noqa
#import fairseq.tasks # noqa
#import fairseq.token_generation_constraints # noqa
#import fairseq.benchmark # noqa
#import fairseq.model_parallel # noqa

View File

@@ -0,0 +1,381 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import typing as tp
from abc import ABC, abstractmethod
from collections import Counter
from dataclasses import dataclass
from multiprocessing import Pool
import torch
from fairseq.data import Dictionary, indexed_dataset
from fairseq.file_chunker_utils import Chunker, find_offsets
from fairseq.file_io import PathManager
from fairseq.tokenizer import tokenize_line
logger = logging.getLogger("binarizer")
@dataclass
class BinarizeSummary:
"""
Keep track of what's going on in the binarizer
"""
num_seq: int = 0
replaced: tp.Optional[Counter] = None
num_tok: int = 0
@property
def num_replaced(self) -> int:
if self.replaced is None:
return 0
return sum(self.replaced.values())
@property
def replaced_percent(self) -> float:
return 100 * self.num_replaced / self.num_tok
def __str__(self) -> str:
base = f"{self.num_seq} sents, {self.num_tok} tokens"
if self.replaced is None:
return base
return f"{base}, {self.replaced_percent:.3}% replaced"
def merge(self, other: "BinarizeSummary"):
replaced = None
if self.replaced is not None:
replaced = self.replaced
if other.replaced is not None:
if replaced is None:
replaced = other.replaced
else:
replaced += other.replaced
self.replaced = replaced
self.num_seq += other.num_seq
self.num_tok += other.num_tok
class Binarizer(ABC):
"""
a binarizer describes how to take a string and build a tensor out of it
"""
@abstractmethod
def binarize_line(
self,
line: str,
summary: BinarizeSummary,
) -> torch.IntTensor:
...
def _worker_prefix(output_prefix: str, worker_id: int):
return f"{output_prefix}.pt{worker_id}"
class FileBinarizer:
"""
An file binarizer can take a file, tokenize it, and binarize each line to a tensor
"""
@classmethod
def multiprocess_dataset(
cls,
input_file: str,
dataset_impl: str,
binarizer: Binarizer,
output_prefix: str,
vocab_size=None,
num_workers=1,
) -> BinarizeSummary:
final_summary = BinarizeSummary()
offsets = find_offsets(input_file, num_workers)
# find_offsets returns a list of position [pos1, pos2, pos3, pos4] but we would want pairs:
# [(pos1, pos2), (pos2, pos3), (pos3, pos4)] to process the chunks with start/end info
# we zip the list with itself shifted by one to get all the pairs.
(first_chunk, *more_chunks) = zip(offsets, offsets[1:])
pool = None
if num_workers > 1:
pool = Pool(processes=num_workers - 1)
worker_results = [
pool.apply_async(
cls._binarize_chunk_and_finalize,
args=(
binarizer,
input_file,
start_offset,
end_offset,
_worker_prefix(
output_prefix,
worker_id,
),
dataset_impl,
),
kwds={
"vocab_size": vocab_size,
}
if vocab_size is not None
else {},
)
for worker_id, (start_offset, end_offset) in enumerate(
more_chunks, start=1
)
]
pool.close()
pool.join()
for r in worker_results:
summ = r.get()
final_summary.merge(summ)
# do not close the bin file as we need to merge the worker results in
final_ds, summ = cls._binarize_file_chunk(
binarizer,
input_file,
offset_start=first_chunk[0],
offset_end=first_chunk[1],
output_prefix=output_prefix,
dataset_impl=dataset_impl,
vocab_size=vocab_size if vocab_size is not None else None,
)
final_summary.merge(summ)
if num_workers > 1:
for worker_id in range(1, num_workers):
# merge the worker outputs
worker_output_prefix = _worker_prefix(
output_prefix,
worker_id,
)
final_ds.merge_file_(worker_output_prefix)
try:
os.remove(indexed_dataset.data_file_path(worker_output_prefix))
os.remove(indexed_dataset.index_file_path(worker_output_prefix))
except Exception as e:
logger.error(
f"couldn't remove {worker_output_prefix}.*", exc_info=e
)
# now we can close the file
idx_file = indexed_dataset.index_file_path(output_prefix)
final_ds.finalize(idx_file)
return final_summary
@staticmethod
def _binarize_file_chunk(
binarizer: Binarizer,
filename: str,
offset_start: int,
offset_end: int,
output_prefix: str,
dataset_impl: str,
vocab_size=None,
) -> tp.Tuple[tp.Any, BinarizeSummary]: # (dataset builder, BinarizeSummary)
"""
creates a dataset builder and append binarized items to it. This function does not
finalize the builder, this is useful if you want to do other things with your bin file
like appending/merging other files
"""
bin_file = indexed_dataset.data_file_path(output_prefix)
ds = indexed_dataset.make_builder(
bin_file,
impl=dataset_impl,
vocab_size=vocab_size,
)
summary = BinarizeSummary()
with Chunker(
PathManager.get_local_path(filename), offset_start, offset_end
) as line_iterator:
for line in line_iterator:
ds.add_item(binarizer.binarize_line(line, summary))
return ds, summary
@classmethod
def _binarize_chunk_and_finalize(
cls,
binarizer: Binarizer,
filename: str,
offset_start: int,
offset_end: int,
output_prefix: str,
dataset_impl: str,
vocab_size=None,
):
"""
same as above, but also finalizes the builder
"""
ds, summ = cls._binarize_file_chunk(
binarizer,
filename,
offset_start,
offset_end,
output_prefix,
dataset_impl,
vocab_size=vocab_size,
)
idx_file = indexed_dataset.index_file_path(output_prefix)
ds.finalize(idx_file)
return summ
class VocabularyDatasetBinarizer(Binarizer):
"""
Takes a Dictionary/Vocabulary, assign ids to each
token using the dictionary encode_line function.
"""
def __init__(
self,
dict: Dictionary,
tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
append_eos: bool = True,
reverse_order: bool = False,
already_numberized: bool = False,
) -> None:
self.dict = dict
self.tokenize = tokenize
self.append_eos = append_eos
self.reverse_order = reverse_order
self.already_numberized = already_numberized
super().__init__()
def binarize_line(
self,
line: str,
summary: BinarizeSummary,
):
if summary.replaced is None:
summary.replaced = Counter()
def replaced_consumer(word, idx):
if idx == self.dict.unk_index and word != self.dict.unk_word:
summary.replaced.update([word])
if self.already_numberized:
id_strings = line.strip().split()
id_list = [int(id_string) for id_string in id_strings]
if self.reverse_order:
id_list.reverse()
if self.append_eos:
id_list.append(self.dict.eos())
ids = torch.IntTensor(id_list)
else:
ids = self.dict.encode_line(
line=line,
line_tokenizer=self.tokenize,
add_if_not_exist=False,
consumer=replaced_consumer,
append_eos=self.append_eos,
reverse_order=self.reverse_order,
)
summary.num_seq += 1
summary.num_tok += len(ids)
return ids
class AlignmentDatasetBinarizer(Binarizer):
"""
binarize by parsing a set of alignments and packing
them in a tensor (see utils.parse_alignment)
"""
def __init__(
self,
alignment_parser: tp.Callable[[str], torch.IntTensor],
) -> None:
super().__init__()
self.alignment_parser = alignment_parser
def binarize_line(
self,
line: str,
summary: BinarizeSummary,
):
ids = self.alignment_parser(line)
summary.num_seq += 1
summary.num_tok += len(ids)
return ids
class LegacyBinarizer:
@classmethod
def binarize(
cls,
filename: str,
dico: Dictionary,
consumer: tp.Callable[[torch.IntTensor], None],
tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
append_eos: bool = True,
reverse_order: bool = False,
offset: int = 0,
end: int = -1,
already_numberized: bool = False,
) -> tp.Dict[str, int]:
binarizer = VocabularyDatasetBinarizer(
dict=dico,
tokenize=tokenize,
append_eos=append_eos,
reverse_order=reverse_order,
already_numberized=already_numberized,
)
return cls._consume_file(
filename,
binarizer,
consumer,
offset_start=offset,
offset_end=end,
)
@classmethod
def binarize_alignments(
cls,
filename: str,
alignment_parser: tp.Callable[[str], torch.IntTensor],
consumer: tp.Callable[[torch.IntTensor], None],
offset: int = 0,
end: int = -1,
) -> tp.Dict[str, int]:
binarizer = AlignmentDatasetBinarizer(alignment_parser)
return cls._consume_file(
filename,
binarizer,
consumer,
offset_start=offset,
offset_end=end,
)
@staticmethod
def _consume_file(
filename: str,
binarizer: Binarizer,
consumer: tp.Callable[[torch.IntTensor], None],
offset_start: int,
offset_end: int,
) -> tp.Dict[str, int]:
summary = BinarizeSummary()
with Chunker(
PathManager.get_local_path(filename), offset_start, offset_end
) as line_iterator:
for line in line_iterator:
consumer(binarizer.binarize_line(line, summary))
return {
"nseq": summary.num_seq,
"nunk": summary.num_replaced,
"ntok": summary.num_tok,
"replaced": summary.replaced,
}

View File

@@ -0,0 +1,905 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ast
import collections
import contextlib
import inspect
import logging
import os
import re
import time
import traceback
from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
from fairseq.data import data_utils
from fairseq.dataclass.configs import CheckpointConfig
from fairseq.dataclass.utils import (
convert_namespace_to_omegaconf,
overwrite_args_by_name,
)
from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
from fairseq.file_io import PathManager
from fairseq.models import FairseqDecoder, FairseqEncoder
from omegaconf import DictConfig, OmegaConf, open_dict
logger = logging.getLogger(__name__)
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
from fairseq import meters
# only one worker should attempt to create the required dir
if trainer.data_parallel_rank == 0:
os.makedirs(cfg.save_dir, exist_ok=True)
prev_best = getattr(save_checkpoint, "best", val_loss)
if val_loss is not None:
best_function = max if cfg.maximize_best_checkpoint_metric else min
save_checkpoint.best = best_function(val_loss, prev_best)
if cfg.no_save:
return
trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
if not trainer.should_save_checkpoint_on_current_rank:
if trainer.always_call_state_dict_during_save_checkpoint:
trainer.state_dict()
return
write_timer = meters.StopwatchMeter()
write_timer.start()
epoch = epoch_itr.epoch
end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates()
logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
def is_better(a, b):
return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
suffix = trainer.checkpoint_suffix
checkpoint_conds = collections.OrderedDict()
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
)
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
not end_of_epoch
and cfg.save_interval_updates > 0
and updates % cfg.save_interval_updates == 0
)
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
not hasattr(save_checkpoint, "best")
or is_better(val_loss, save_checkpoint.best)
)
if val_loss is not None and cfg.keep_best_checkpoints > 0:
worst_best = getattr(save_checkpoint, "best", None)
chkpts = checkpoint_paths(
cfg.save_dir,
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
cfg.best_checkpoint_metric, suffix
),
)
if len(chkpts) > 0:
p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
# add random digits to resolve ties
with data_utils.numpy_seed(epoch, updates, val_loss):
rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)
checkpoint_conds[
"checkpoint.best_{}_{:.3f}{}{}.pt".format(
cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
)
] = worst_best is None or is_better(val_loss, worst_best)
checkpoint_conds[
"checkpoint_last{}.pt".format(suffix)
] = not cfg.no_last_checkpoints
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
if hasattr(save_checkpoint, "best"):
extra_state.update({"best": save_checkpoint.best})
checkpoints = [
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
]
if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank:
trainer.save_checkpoint(checkpoints[0], extra_state)
for cp in checkpoints[1:]:
if cfg.write_checkpoints_asynchronously:
# TODO[ioPath]: Need to implement a delayed asynchronous
# file copying/moving feature.
logger.warning(
f"ioPath is not copying {checkpoints[0]} to {cp} "
"since async write mode is on."
)
else:
assert PathManager.copy(
checkpoints[0], cp, overwrite=True
), f"Failed to copy {checkpoints[0]} to {cp}"
write_timer.stop()
logger.info(
"Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
checkpoints[0], epoch, updates, val_loss, write_timer.sum
)
)
if not end_of_epoch and cfg.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
if cfg.keep_interval_updates_pattern == -1:
checkpoints = checkpoint_paths(
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
)
else:
checkpoints = checkpoint_paths(
cfg.save_dir,
pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
keep_match=True,
)
checkpoints = [
x[0]
for x in checkpoints
if x[1] % cfg.keep_interval_updates_pattern != 0
]
for old_chk in checkpoints[cfg.keep_interval_updates :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
elif PathManager.exists(old_chk):
PathManager.rm(old_chk)
if cfg.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
)
for old_chk in checkpoints[cfg.keep_last_epochs :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
elif PathManager.exists(old_chk):
PathManager.rm(old_chk)
if cfg.keep_best_checkpoints > 0:
# only keep the best N checkpoints according to validation metric
checkpoints = checkpoint_paths(
cfg.save_dir,
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
cfg.best_checkpoint_metric, suffix
),
)
if not cfg.maximize_best_checkpoint_metric:
checkpoints = checkpoints[::-1]
for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
elif PathManager.exists(old_chk):
PathManager.rm(old_chk)
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
"""
Load a checkpoint and restore the training iterator.
*passthrough_args* will be passed through to
``trainer.get_train_iterator``.
"""
reset_optimizer = cfg.reset_optimizer
reset_lr_scheduler = cfg.reset_lr_scheduler
optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
reset_meters = cfg.reset_meters
reset_dataloader = cfg.reset_dataloader
if cfg.finetune_from_model is not None and (
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
):
raise ValueError(
"--finetune-from-model can not be set together with either --reset-optimizer"
" or reset_lr_scheduler or reset_meters or reset_dataloader"
)
suffix = trainer.checkpoint_suffix
if (
cfg.restore_file == "checkpoint_last.pt"
): # default value of restore_file is 'checkpoint_last.pt'
checkpoint_path = os.path.join(
cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
)
first_launch = not PathManager.exists(checkpoint_path)
if first_launch and getattr(cfg, "continue_once", None) is not None:
checkpoint_path = cfg.continue_once
elif cfg.finetune_from_model is not None and first_launch:
# if there is no last checkpoint to restore, start the finetune from pretrained model
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
if PathManager.exists(cfg.finetune_from_model):
checkpoint_path = cfg.finetune_from_model
reset_optimizer = True
reset_lr_scheduler = True
reset_meters = True
reset_dataloader = True
logger.info(
f"loading pretrained model from {checkpoint_path}: "
"optimizer, lr scheduler, meters, dataloader will be reset"
)
else:
raise ValueError(
f"--finetune-from-model {cfg.finetune_from_model} does not exist"
)
elif suffix is not None:
checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
else:
checkpoint_path = cfg.restore_file
if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
raise ValueError(
"--finetune-from-model and --restore-file (non-default value) "
"can not be specified together: " + str(cfg)
)
extra_state = trainer.load_checkpoint(
checkpoint_path,
reset_optimizer,
reset_lr_scheduler,
optimizer_overrides,
reset_meters=reset_meters,
)
if (
extra_state is not None
and "best" in extra_state
and not reset_optimizer
and not reset_meters
):
save_checkpoint.best = extra_state["best"]
if extra_state is not None and not reset_dataloader:
# restore iterator from checkpoint
itr_state = extra_state["train_iterator"]
epoch_itr = trainer.get_train_iterator(
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
)
epoch_itr.load_state_dict(itr_state)
else:
epoch_itr = trainer.get_train_iterator(
epoch=1, load_dataset=True, **passthrough_args
)
trainer.lr_step(epoch_itr.epoch)
return extra_state, epoch_itr
def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility).
If doing single-GPU training or if the checkpoint is only being loaded by at
most one process on each node (current default behavior is for only rank 0
to read the checkpoint from disk), load_on_all_ranks should be False to
avoid errors from torch.distributed not having been initialized or
torch.distributed.barrier() hanging.
If all processes on each node may be loading the checkpoint
simultaneously, load_on_all_ranks should be set to True to avoid I/O
conflicts.
There's currently no support for > 1 but < all processes loading the
checkpoint on each node.
"""
local_path = PathManager.get_local_path(path)
# The locally cached file returned by get_local_path() may be stale for
# remote files that are periodically updated/overwritten (ex:
# checkpoint_last.pt) - so we remove the local copy, sync across processes
# (if needed), and then download a fresh copy.
if local_path != path and PathManager.path_requires_pathmanager(path):
try:
os.remove(local_path)
except FileNotFoundError:
# With potentially multiple processes removing the same file, the
# file being missing is benign (missing_ok isn't available until
# Python 3.8).
pass
if load_on_all_ranks:
torch.distributed.barrier()
local_path = PathManager.get_local_path(path)
with open(local_path, "rb") as f:
state = torch.load(f, map_location=torch.device("cpu"))
if "args" in state and state["args"] is not None and arg_overrides is not None:
args = state["args"]
for arg_name, arg_val in arg_overrides.items():
setattr(args, arg_name, arg_val)
if "cfg" in state and state["cfg"] is not None:
# hack to be able to set Namespace in dict config. this should be removed when we update to newer
# omegaconf version that supports object flags, or when we migrate all existing models
from omegaconf import __version__ as oc_version
from omegaconf import _utils
if oc_version < "2.2":
old_primitive = _utils.is_primitive_type
_utils.is_primitive_type = lambda _: True
state["cfg"] = OmegaConf.create(state["cfg"])
_utils.is_primitive_type = old_primitive
OmegaConf.set_struct(state["cfg"], True)
else:
state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True})
if arg_overrides is not None:
overwrite_args_by_name(state["cfg"], arg_overrides)
state = _upgrade_state_dict(state)
return state
def load_model_ensemble(
filenames,
arg_overrides: Optional[Dict[str, Any]] = None,
task=None,
strict=True,
suffix="",
num_shards=1,
state=None,
):
"""Loads an ensemble of models.
Args:
filenames (List[str]): checkpoint files to load
arg_overrides (Dict[str,Any], optional): override model args that
were used during model training
task (fairseq.tasks.FairseqTask, optional): task to use for loading
"""
assert not (
strict and num_shards > 1
), "Cannot load state dict with strict=True and checkpoint shards > 1"
ensemble, args, _task = load_model_ensemble_and_task(
filenames,
arg_overrides,
task,
strict,
suffix,
num_shards,
state,
)
return ensemble, args
def get_maybe_sharded_checkpoint_filename(
filename: str, suffix: str, shard_idx: int, num_shards: int
) -> str:
orig_filename = filename
filename = filename.replace(".pt", suffix + ".pt")
fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt"
model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
if PathManager.exists(fsdp_filename):
return fsdp_filename
elif num_shards > 1:
return model_parallel_filename
else:
return filename
def load_model_ensemble_and_task(
filenames,
arg_overrides: Optional[Dict[str, Any]] = None,
task=None,
strict=True,
suffix="",
num_shards=1,
state=None,
):
assert state is None or len(filenames) == 1
from fairseq import tasks
assert not (
strict and num_shards > 1
), "Cannot load state dict with strict=True and checkpoint shards > 1"
ensemble = []
cfg = None
for filename in filenames:
orig_filename = filename
model_shard_state = {"shard_weights": [], "shard_metadata": []}
assert num_shards > 0
st = time.time()
for shard_idx in range(num_shards):
filename = get_maybe_sharded_checkpoint_filename(
orig_filename, suffix, shard_idx, num_shards
)
if not PathManager.exists(filename):
raise IOError("Model file not found: {}".format(filename))
if state is None:
state = load_checkpoint_to_cpu(filename, arg_overrides)
if "args" in state and state["args"] is not None:
cfg = convert_namespace_to_omegaconf(state["args"])
elif "cfg" in state and state["cfg"] is not None:
cfg = state["cfg"]
else:
raise RuntimeError(
f"Neither args nor cfg exist in state keys = {state.keys()}"
)
if task is None:
task = tasks.setup_task(cfg.task)
if "task_state" in state:
task.load_state_dict(state["task_state"])
if "fsdp_metadata" in state and num_shards > 1:
model_shard_state["shard_weights"].append(state["model"])
model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
# check FSDP import before the code goes too far
if not has_FSDP:
raise ImportError(
"Cannot find FullyShardedDataParallel. "
"Please install fairscale with: pip install fairscale"
)
if shard_idx == num_shards - 1:
consolidated_model_state = FSDP.consolidate_shard_weights(
shard_weights=model_shard_state["shard_weights"],
shard_metadata=model_shard_state["shard_metadata"],
)
model = task.build_model(cfg.model)
if (
"optimizer_history" in state
and len(state["optimizer_history"]) > 0
and "num_updates" in state["optimizer_history"][-1]
):
model.set_num_updates(
state["optimizer_history"][-1]["num_updates"]
)
model.load_state_dict(
consolidated_model_state, strict=strict, model_cfg=cfg.model
)
else:
# model parallel checkpoint or unsharded checkpoint
# support old external tasks
argspec = inspect.getfullargspec(task.build_model)
if "from_checkpoint" in argspec.args:
model = task.build_model(cfg.model, from_checkpoint=True)
else:
model = task.build_model(cfg.model)
if (
"optimizer_history" in state
and len(state["optimizer_history"]) > 0
and "num_updates" in state["optimizer_history"][-1]
):
model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
model.load_state_dict(
state["model"], strict=strict, model_cfg=cfg.model
)
# reset state so it gets loaded for the next model in ensemble
state = None
if shard_idx % 10 == 0 and shard_idx > 0:
elapsed = time.time() - st
logger.info(
f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
)
# build model for ensemble
ensemble.append(model)
return ensemble, cfg, task
def load_model_ensemble_and_task_from_hf_hub(
model_id,
cache_dir: Optional[str] = None,
arg_overrides: Optional[Dict[str, Any]] = None,
**kwargs: Any,
):
try:
from huggingface_hub import snapshot_download
except ImportError:
raise ImportError(
"You need to install huggingface_hub to use `load_from_hf_hub`. "
"See https://pypi.org/project/huggingface-hub/ for installation."
)
library_name = "fairseq"
cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix()
cache_dir = snapshot_download(
model_id, cache_dir=cache_dir, library_name=library_name, **kwargs
)
_arg_overrides = arg_overrides or {}
_arg_overrides["data"] = cache_dir
return load_model_ensemble_and_task(
[p.as_posix() for p in Path(cache_dir).glob("*.pt")],
arg_overrides=_arg_overrides,
)
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
"""Retrieves all checkpoints found in `path` directory.
Checkpoints are identified by matching filename to the specified pattern. If
the pattern contains groups, the result will be sorted by the first group in
descending order.
"""
pt_regexp = re.compile(pattern)
files = PathManager.ls(path)
entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = float(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
if keep_match:
return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
else:
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
def torch_persistent_save(obj, filename, async_write: bool = False):
if async_write:
with PathManager.opena(filename, "wb") as f:
_torch_persistent_save(obj, f)
else:
if PathManager.supports_rename(filename):
# do atomic save
with PathManager.open(filename + ".tmp", "wb") as f:
_torch_persistent_save(obj, f)
PathManager.rename(filename + ".tmp", filename)
else:
# fallback to non-atomic save
with PathManager.open(filename, "wb") as f:
_torch_persistent_save(obj, f)
def _torch_persistent_save(obj, f):
if isinstance(f, str):
with PathManager.open(f, "wb") as h:
torch_persistent_save(obj, h)
return
for i in range(3):
try:
return torch.save(obj, f)
except Exception:
if i == 2:
logger.error(traceback.format_exc())
raise
def _upgrade_state_dict(state):
"""Helper for upgrading old model checkpoints."""
# add optimizer_history
if "optimizer_history" not in state:
state["optimizer_history"] = [
{"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
]
state["last_optimizer_state"] = state["optimizer"]
del state["optimizer"]
del state["best_loss"]
# move extra_state into sub-dictionary
if "epoch" in state and "extra_state" not in state:
state["extra_state"] = {
"epoch": state["epoch"],
"batch_offset": state["batch_offset"],
"val_loss": state["val_loss"],
}
del state["epoch"]
del state["batch_offset"]
del state["val_loss"]
# reduce optimizer history's memory usage (only keep the last state)
if "optimizer" in state["optimizer_history"][-1]:
state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
for optim_hist in state["optimizer_history"]:
del optim_hist["optimizer"]
# record the optimizer class name
if "optimizer_name" not in state["optimizer_history"][-1]:
state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
# move best_loss into lr_scheduler_state
if "lr_scheduler_state" not in state["optimizer_history"][-1]:
state["optimizer_history"][-1]["lr_scheduler_state"] = {
"best": state["optimizer_history"][-1]["best_loss"]
}
del state["optimizer_history"][-1]["best_loss"]
# keep track of number of updates
if "num_updates" not in state["optimizer_history"][-1]:
state["optimizer_history"][-1]["num_updates"] = 0
# use stateful training data iterator
if "train_iterator" not in state["extra_state"]:
state["extra_state"]["train_iterator"] = {
"epoch": state["extra_state"].get("epoch", 0),
"iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
}
# backward compatibility, cfg updates
if "args" in state and state["args"] is not None:
# old model checkpoints may not have separate source/target positions
if hasattr(state["args"], "max_positions") and not hasattr(
state["args"], "max_source_positions"
):
state["args"].max_source_positions = state["args"].max_positions
state["args"].max_target_positions = state["args"].max_positions
# default to translation task
if not hasattr(state["args"], "task"):
state["args"].task = "translation"
# --raw-text and --lazy-load are deprecated
if getattr(state["args"], "raw_text", False):
state["args"].dataset_impl = "raw"
elif getattr(state["args"], "lazy_load", False):
state["args"].dataset_impl = "lazy"
# epochs start at 1
if state["extra_state"]["train_iterator"] is not None:
state["extra_state"]["train_iterator"]["epoch"] = max(
state["extra_state"]["train_iterator"].get("epoch", 1), 1
)
# --remove-bpe ==> --postprocess
if hasattr(state["args"], "remove_bpe"):
state["args"].post_process = state["args"].remove_bpe
# --min-lr ==> --stop-min-lr
if hasattr(state["args"], "min_lr"):
state["args"].stop_min_lr = state["args"].min_lr
del state["args"].min_lr
# binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
if hasattr(state["args"], "criterion") and state["args"].criterion in [
"binary_cross_entropy",
"kd_binary_cross_entropy",
]:
state["args"].criterion = "wav2vec"
# remove log_keys if it's None (criteria will supply a default value of [])
if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
delattr(state["args"], "log_keys")
# speech_pretraining => audio pretraining
if (
hasattr(state["args"], "task")
and state["args"].task == "speech_pretraining"
):
state["args"].task = "audio_pretraining"
# audio_cpc => wav2vec
if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc":
state["args"].arch = "wav2vec"
# convert legacy float learning rate to List[float]
if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float):
state["args"].lr = [state["args"].lr]
# convert task data arg to a string instead of List[string]
if (
hasattr(state["args"], "data")
and isinstance(state["args"].data, list)
and len(state["args"].data) > 0
):
state["args"].data = state["args"].data[0]
state["cfg"] = convert_namespace_to_omegaconf(state["args"])
if "cfg" in state and state["cfg"] is not None:
cfg = state["cfg"]
with open_dict(cfg):
# any upgrades for Hydra-based configs
if (
"task" in cfg
and "eval_wer_config" in cfg.task
and isinstance(cfg.task.eval_wer_config.print_alignment, bool)
):
cfg.task.eval_wer_config.print_alignment = "hard"
if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
cfg.generation.print_alignment = (
"hard" if cfg.generation.print_alignment else None
)
if (
"model" in cfg
and "w2v_args" in cfg.model
and cfg.model.w2v_args is not None
and (
hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args
)
and hasattr(cfg.model.w2v_args.task, "eval_wer_config")
and cfg.model.w2v_args.task.eval_wer_config is not None
and isinstance(
cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool
)
):
cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard"
return state
def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
"""Prune the given state_dict if desired for LayerDrop
(https://arxiv.org/abs/1909.11556).
Training with LayerDrop allows models to be robust to pruning at inference
time. This function prunes state_dict to allow smaller models to be loaded
from a larger model and re-maps the existing state_dict for this to occur.
It's called by functions that load models from checkpoints and does not
need to be called directly.
"""
arch = None
if model_cfg is not None:
arch = (
model_cfg._name
if isinstance(model_cfg, DictConfig)
else getattr(model_cfg, "arch", None)
)
if not model_cfg or arch is None or arch == "ptt_transformer":
# args should not be none, but don't crash if it is.
return state_dict
encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
if not encoder_layers_to_keep and not decoder_layers_to_keep:
return state_dict
# apply pruning
logger.info(
"Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
)
def create_pruning_pass(layers_to_keep, layer_name):
keep_layers = sorted(
int(layer_string) for layer_string in layers_to_keep.split(",")
)
mapping_dict = {}
for i in range(len(keep_layers)):
mapping_dict[str(keep_layers[i])] = str(i)
regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
return {"substitution_regex": regex, "mapping_dict": mapping_dict}
pruning_passes = []
if encoder_layers_to_keep:
pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
if decoder_layers_to_keep:
pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
new_state_dict = {}
for layer_name in state_dict.keys():
match = re.search(r"\.layers\.(\d+)\.", layer_name)
# if layer has no number in it, it is a supporting layer, such as an
# embedding
if not match:
new_state_dict[layer_name] = state_dict[layer_name]
continue
# otherwise, layer should be pruned.
original_layer_number = match.group(1)
# figure out which mapping dict to replace from
for pruning_pass in pruning_passes:
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
"substitution_regex"
].search(layer_name):
new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
substitution_match = pruning_pass["substitution_regex"].search(
layer_name
)
new_state_key = (
layer_name[: substitution_match.start(1)]
+ new_layer_number
+ layer_name[substitution_match.end(1) :]
)
new_state_dict[new_state_key] = state_dict[layer_name]
# Since layers are now pruned, *_layers_to_keep are no longer needed.
# This is more of "It would make it work fix" rather than a proper fix.
if isinstance(model_cfg, DictConfig):
context = open_dict(model_cfg)
else:
context = contextlib.ExitStack()
with context:
if hasattr(model_cfg, "encoder_layers_to_keep"):
model_cfg.encoder_layers_to_keep = None
if hasattr(model_cfg, "decoder_layers_to_keep"):
model_cfg.decoder_layers_to_keep = None
return new_state_dict
def load_pretrained_component_from_model(
component: Union[FairseqEncoder, FairseqDecoder],
checkpoint: str,
strict: bool = True,
):
"""
Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
provided `component` object. If state_dict fails to load, there may be a
mismatch in the architecture of the corresponding `component` found in the
`checkpoint` file.
"""
if not PathManager.exists(checkpoint):
raise IOError("Model file not found: {}".format(checkpoint))
state = load_checkpoint_to_cpu(checkpoint)
if isinstance(component, FairseqEncoder):
component_type = "encoder"
elif isinstance(component, FairseqDecoder):
component_type = "decoder"
else:
raise ValueError(
"component to load must be either a FairseqEncoder or "
"FairseqDecoder. Loading other component types are not supported."
)
component_state_dict = OrderedDict()
for key in state["model"].keys():
if key.startswith(component_type):
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight
component_subkey = key[len(component_type) + 1 :]
component_state_dict[component_subkey] = state["model"][key]
component.load_state_dict(component_state_dict, strict=strict)
return component
def verify_checkpoint_directory(save_dir: str) -> None:
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
temp_file_path = os.path.join(save_dir, "dummy")
try:
with open(temp_file_path, "w"):
pass
except OSError as e:
logger.warning(
"Unable to access checkpoint save directory: {}".format(save_dir)
)
raise e
else:
os.remove(temp_file_path)
def save_ema_as_checkpoint(src_path, dst_path):
state = load_ema_from_checkpoint(src_path)
torch_persistent_save(state, dst_path)
def load_ema_from_checkpoint(fpath):
"""Loads exponential moving averaged (EMA) checkpoint from input and
returns a model with ema weights.
Args:
fpath: A string path of checkpoint to load from.
Returns:
A dict of string keys mapping to various values. The 'model' key
from the returned dict should correspond to an OrderedDict mapping
string parameter names to torch Tensors.
"""
params_dict = collections.OrderedDict()
new_state = None
with PathManager.open(fpath, "rb") as f:
new_state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
),
)
# EMA model is stored in a separate "extra state"
model_params = new_state["extra_state"]["ema"]
for key in list(model_params.keys()):
p = model_params[key]
if isinstance(p, torch.HalfTensor):
p = p.float()
if key not in params_dict:
params_dict[key] = p.clone()
# NOTE: clone() is needed in case of p is a shared parameter
else:
raise ValueError("Key {} is repeated in EMA model params.".format(key))
if len(params_dict) == 0:
raise ValueError(
f"Input checkpoint path '{fpath}' does not contain "
"ema model weights, is this model trained with EMA?"
)
new_state["model"] = params_dict
return new_state

View File

@@ -0,0 +1,130 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset, FairseqIterableDataset
from .base_wrapper_dataset import BaseWrapperDataset
from .add_target_dataset import AddTargetDataset
from .append_token_dataset import AppendTokenDataset
from .audio.raw_audio_dataset import BinarizedAudioDataset, FileAudioDataset
from .audio.hubert_dataset import HubertDataset
from .backtranslation_dataset import BacktranslationDataset
from .bucket_pad_length_dataset import BucketPadLengthDataset
from .colorize_dataset import ColorizeDataset
from .concat_dataset import ConcatDataset
from .concat_sentences_dataset import ConcatSentencesDataset
from .denoising_dataset import DenoisingDataset
from .id_dataset import IdDataset
from .indexed_dataset import (
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
MMapIndexedDataset,
)
from .language_pair_dataset import LanguagePairDataset
from .list_dataset import ListDataset
from .lm_context_window_dataset import LMContextWindowDataset
from .lru_cache_dataset import LRUCacheDataset
from .mask_tokens_dataset import MaskTokensDataset
from .monolingual_dataset import MonolingualDataset
from .multi_corpus_sampled_dataset import MultiCorpusSampledDataset
from .nested_dictionary_dataset import NestedDictionaryDataset
from .noising import NoisingDataset
from .numel_dataset import NumelDataset
from .num_samples_dataset import NumSamplesDataset
from .offset_tokens_dataset import OffsetTokensDataset
from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset
from .prepend_dataset import PrependDataset
from .prepend_token_dataset import PrependTokenDataset
from .raw_label_dataset import RawLabelDataset
from .replace_dataset import ReplaceDataset
from .resampling_dataset import ResamplingDataset
from .roll_dataset import RollDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .sort_dataset import SortDataset
from .strip_token_dataset import StripTokenDataset
from .subsample_dataset import SubsampleDataset
from .token_block_dataset import TokenBlockDataset
from .transform_eos_dataset import TransformEosDataset
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
from .shorten_dataset import TruncateDataset, RandomCropDataset
from .multilingual.sampled_multi_dataset import SampledMultiDataset
from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset
from .fasta_dataset import FastaDataset, EncodedFastaDataset
from .transform_eos_concat_langpair_dataset import TransformEosConcatLangPairDataset
from .iterators import (
CountingIterator,
EpochBatchIterator,
GroupedIterator,
ShardedIterator,
)
__all__ = [
"AddTargetDataset",
"AppendTokenDataset",
"BacktranslationDataset",
"BaseWrapperDataset",
"BinarizedAudioDataset",
"BucketPadLengthDataset",
"ColorizeDataset",
"ConcatDataset",
"ConcatSentencesDataset",
"CountingIterator",
"DenoisingDataset",
"Dictionary",
"EncodedFastaDataset",
"EpochBatchIterator",
"FairseqDataset",
"FairseqIterableDataset",
"FastaDataset",
"FileAudioDataset",
"GroupedIterator",
"HubertDataset",
"IdDataset",
"IndexedCachedDataset",
"IndexedDataset",
"IndexedRawTextDataset",
"LanguagePairDataset",
"LeftPadDataset",
"ListDataset",
"LMContextWindowDataset",
"LRUCacheDataset",
"MaskTokensDataset",
"MMapIndexedDataset",
"MonolingualDataset",
"MultiCorpusSampledDataset",
"NestedDictionaryDataset",
"NoisingDataset",
"NumelDataset",
"NumSamplesDataset",
"OffsetTokensDataset",
"PadDataset",
"PrependDataset",
"PrependTokenDataset",
"RandomCropDataset",
"RawLabelDataset",
"ResamplingDataset",
"ReplaceDataset",
"RightPadDataset",
"RollDataset",
"RoundRobinZipDatasets",
"SampledMultiDataset",
"SampledMultiEpochDataset",
"ShardedIterator",
"SortDataset",
"StripTokenDataset",
"SubsampleDataset",
"TokenBlockDataset",
"TransformEosDataset",
"TransformEosLangPairDataset",
"TransformEosConcatLangPairDataset",
"TruncateDataset",
"TruncatedDictionary",
]

View File

@@ -0,0 +1,83 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import BaseWrapperDataset, data_utils
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
class AddTargetDataset(BaseWrapperDataset):
def __init__(
self,
dataset,
labels,
pad,
eos,
batch_targets,
process_label=None,
label_len_fn=None,
add_to_input=False,
text_compression_level=TextCompressionLevel.none,
):
super().__init__(dataset)
self.labels = labels
self.batch_targets = batch_targets
self.pad = pad
self.eos = eos
self.process_label = process_label
self.label_len_fn = label_len_fn
self.add_to_input = add_to_input
self.text_compressor = TextCompressor(level=text_compression_level)
def get_label(self, index, process_fn=None):
lbl = self.labels[index]
lbl = self.text_compressor.decompress(lbl)
return lbl if process_fn is None else process_fn(lbl)
def __getitem__(self, index):
item = self.dataset[index]
item["label"] = self.get_label(index, process_fn=self.process_label)
return item
def size(self, index):
sz = self.dataset.size(index)
own_sz = self.label_len_fn(self.get_label(index))
return sz, own_sz
def collater(self, samples):
collated = self.dataset.collater(samples)
if len(collated) == 0:
return collated
indices = set(collated["id"].tolist())
target = [s["label"] for s in samples if s["id"] in indices]
if self.add_to_input:
eos = torch.LongTensor([self.eos])
prev_output_tokens = [torch.cat([eos, t], axis=-1) for t in target]
target = [torch.cat([t, eos], axis=-1) for t in target]
collated["net_input"]["prev_output_tokens"] = prev_output_tokens
if self.batch_targets:
collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
collated["ntokens"] = collated["target_lengths"].sum().item()
if getattr(collated["net_input"], "prev_output_tokens", None):
collated["net_input"]["prev_output_tokens"] = data_utils.collate_tokens(
collated["net_input"]["prev_output_tokens"],
pad_idx=self.pad,
left_pad=False,
)
else:
collated["ntokens"] = sum([len(t) for t in target])
collated["target"] = target
return collated
def filter_indices_by_size(self, indices, max_sizes):
indices, ignored = data_utils._filter_by_size_dynamic(
indices, self.size, max_sizes
)
return indices, ignored

View File

@@ -0,0 +1,41 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from . import BaseWrapperDataset
class AppendTokenDataset(BaseWrapperDataset):
def __init__(self, dataset, token=None):
super().__init__(dataset)
self.token = token
if token is not None:
self._sizes = np.array(dataset.sizes) + 1
else:
self._sizes = dataset.sizes
def __getitem__(self, idx):
item = self.dataset[idx]
if self.token is not None:
item = torch.cat([item, item.new([self.token])])
return item
@property
def sizes(self):
return self._sizes
def num_tokens(self, index):
n = self.dataset.num_tokens(index)
if self.token is not None:
n += 1
return n
def size(self, index):
n = self.dataset.size(index)
if self.token is not None:
n += 1
return n

View File

@@ -0,0 +1,93 @@
from abc import ABC, abstractmethod
from typing import Dict, Optional
import importlib
import os
import numpy as np
class AudioTransform(ABC):
@classmethod
@abstractmethod
def from_config_dict(cls, config: Optional[Dict] = None):
pass
class CompositeAudioTransform(AudioTransform):
def _from_config_dict(
cls,
transform_type,
get_audio_transform,
composite_cls,
config=None,
return_empty=False,
):
_config = {} if config is None else config
_transforms = _config.get(f"{transform_type}_transforms")
if _transforms is None:
if return_empty:
_transforms = []
else:
return None
transforms = [
get_audio_transform(_t).from_config_dict(_config.get(_t))
for _t in _transforms
]
return composite_cls(transforms)
def __init__(self, transforms):
self.transforms = [t for t in transforms if t is not None]
def __call__(self, x):
for t in self.transforms:
x = t(x)
return x
def __repr__(self):
format_string = (
[self.__class__.__name__ + "("]
+ [f" {t.__repr__()}" for t in self.transforms]
+ [")"]
)
return "\n".join(format_string)
def register_audio_transform(name, cls_type, registry, class_names):
def register_audio_transform_cls(cls):
if name in registry:
raise ValueError(f"Cannot register duplicate transform ({name})")
if not issubclass(cls, cls_type):
raise ValueError(
f"Transform ({name}: {cls.__name__}) must extend "
f"{cls_type.__name__}"
)
if cls.__name__ in class_names:
raise ValueError(
f"Cannot register audio transform with duplicate "
f"class name ({cls.__name__})"
)
registry[name] = cls
class_names.add(cls.__name__)
return cls
return register_audio_transform_cls
def import_transforms(transforms_dir, transform_type):
for file in os.listdir(transforms_dir):
path = os.path.join(transforms_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
name = file[: file.find(".py")] if file.endswith(".py") else file
importlib.import_module(
f"fairseq.data.audio.{transform_type}_transforms." + name
)
# Utility fn for uniform numbers in transforms
def rand_uniform(a, b):
return np.random.uniform() * (b - a) + a

View File

@@ -0,0 +1,389 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import mmap
from pathlib import Path
import io
from typing import BinaryIO, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"}
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
def convert_waveform(
waveform: Union[np.ndarray, torch.Tensor],
sample_rate: int,
normalize_volume: bool = False,
to_mono: bool = False,
to_sample_rate: Optional[int] = None,
) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
"""convert a waveform:
- to a target sample rate
- from multi-channel to mono channel
- volume normalization
Args:
waveform (numpy.ndarray or torch.Tensor): 2D original waveform
(channels x length)
sample_rate (int): original sample rate
normalize_volume (bool): perform volume normalization
to_mono (bool): convert to mono channel if having multiple channels
to_sample_rate (Optional[int]): target sample rate
Returns:
waveform (numpy.ndarray): converted 2D waveform (channels x length)
sample_rate (float): target sample rate
"""
try:
import torchaudio.sox_effects as ta_sox
except ImportError:
raise ImportError("Please install torchaudio: pip install torchaudio")
effects = []
if normalize_volume:
effects.append(["gain", "-n"])
if to_sample_rate is not None and to_sample_rate != sample_rate:
effects.append(["rate", f"{to_sample_rate}"])
if to_mono and waveform.shape[0] > 1:
effects.append(["channels", "1"])
if len(effects) > 0:
is_np_input = isinstance(waveform, np.ndarray)
_waveform = torch.from_numpy(waveform) if is_np_input else waveform
converted, converted_sample_rate = ta_sox.apply_effects_tensor(
_waveform, sample_rate, effects
)
if is_np_input:
converted = converted.numpy()
return converted, converted_sample_rate
return waveform, sample_rate
def get_waveform(
path_or_fp: Union[str, BinaryIO],
normalization: bool = True,
mono: bool = True,
frames: int = -1,
start: int = 0,
always_2d: bool = True,
output_sample_rate: Optional[int] = None,
normalize_volume: bool = False,
waveform_transforms: Optional[CompositeAudioWaveformTransform] = None,
) -> Tuple[np.ndarray, int]:
"""Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio.
Args:
path_or_fp (str or BinaryIO): the path or file-like object
normalization (bool): normalize values to [-1, 1] (Default: True)
mono (bool): convert multi-channel audio to mono-channel one
frames (int): the number of frames to read. (-1 for reading all)
start (int): Where to start reading. A negative value counts from the end.
always_2d (bool): always return 2D array even for mono-channel audios
output_sample_rate (Optional[int]): output sample rate
normalize_volume (bool): normalize volume
Returns:
waveform (numpy.ndarray): 1D or 2D waveform (channels x length)
sample_rate (float): sample rate
"""
if isinstance(path_or_fp, str):
ext = Path(path_or_fp).suffix
if ext not in SF_AUDIO_FILE_EXTENSIONS:
raise ValueError(f"Unsupported audio format: {ext}")
try:
import soundfile as sf
except ImportError:
raise ImportError("Please install soundfile: pip install soundfile")
waveform, sample_rate = sf.read(
path_or_fp, dtype="float32", always_2d=True, frames=frames, start=start
)
waveform = waveform.T # T x C -> C x T
waveform, sample_rate = convert_waveform(
waveform,
sample_rate,
normalize_volume=normalize_volume,
to_mono=mono,
to_sample_rate=output_sample_rate,
)
if not normalization:
waveform *= 2**15 # denormalized to 16-bit signed integers
if waveform_transforms is not None:
waveform, sample_rate = waveform_transforms(waveform, sample_rate)
if not always_2d:
waveform = waveform.squeeze(axis=0)
return waveform, sample_rate
def get_features_from_npy_or_audio(path, waveform_transforms=None):
ext = Path(path).suffix
if ext not in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
raise ValueError(f'Unsupported file format for "{path}"')
return (
np.load(path)
if ext == ".npy"
else get_fbank(path, waveform_transforms=waveform_transforms)
)
def get_features_or_waveform_from_stored_zip(
path,
byte_offset,
byte_size,
need_waveform=False,
use_sample_rate=None,
waveform_transforms=None,
):
assert path.endswith(".zip")
data = read_from_stored_zip(path, byte_offset, byte_size)
f = io.BytesIO(data)
if is_npy_data(data):
features_or_waveform = np.load(f)
elif is_sf_audio_data(data):
features_or_waveform = (
get_waveform(
f,
always_2d=False,
output_sample_rate=use_sample_rate,
waveform_transforms=waveform_transforms,
)[0]
if need_waveform
else get_fbank(f, waveform_transforms=waveform_transforms)
)
else:
raise ValueError(f'Unknown file format for "{path}"')
return features_or_waveform
def get_features_or_waveform(
path: str, need_waveform=False, use_sample_rate=None, waveform_transforms=None
):
"""Get speech features from .npy file or waveform from .wav/.flac file.
The file may be inside an uncompressed ZIP file and is accessed via byte
offset and length.
Args:
path (str): File path in the format of "<.npy/.wav/.flac path>" or
"<zip path>:<byte offset>:<byte length>".
need_waveform (bool): return waveform instead of features.
use_sample_rate (int): change sample rate for the input wave file
Returns:
features_or_waveform (numpy.ndarray): speech features or waveform.
"""
_path, slice_ptr = parse_path(path)
if len(slice_ptr) == 0:
if need_waveform:
return get_waveform(
_path,
always_2d=False,
output_sample_rate=use_sample_rate,
waveform_transforms=waveform_transforms,
)[0]
return get_features_from_npy_or_audio(
_path, waveform_transforms=waveform_transforms
)
elif len(slice_ptr) == 2:
features_or_waveform = get_features_or_waveform_from_stored_zip(
_path,
slice_ptr[0],
slice_ptr[1],
need_waveform=need_waveform,
use_sample_rate=use_sample_rate,
waveform_transforms=waveform_transforms,
)
else:
raise ValueError(f"Invalid path: {path}")
return features_or_waveform
def _get_kaldi_fbank(
waveform: np.ndarray, sample_rate: int, n_bins=80
) -> Optional[np.ndarray]:
"""Get mel-filter bank features via PyKaldi."""
try:
from kaldi.feat.fbank import Fbank, FbankOptions
from kaldi.feat.mel import MelBanksOptions
from kaldi.feat.window import FrameExtractionOptions
from kaldi.matrix import Vector
mel_opts = MelBanksOptions()
mel_opts.num_bins = n_bins
frame_opts = FrameExtractionOptions()
frame_opts.samp_freq = sample_rate
opts = FbankOptions()
opts.mel_opts = mel_opts
opts.frame_opts = frame_opts
fbank = Fbank(opts=opts)
features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy()
return features
except ImportError:
return None
def _get_torchaudio_fbank(
waveform: np.ndarray, sample_rate, n_bins=80
) -> Optional[np.ndarray]:
"""Get mel-filter bank features via TorchAudio."""
try:
import torchaudio.compliance.kaldi as ta_kaldi
waveform = torch.from_numpy(waveform)
features = ta_kaldi.fbank(
waveform, num_mel_bins=n_bins, sample_frequency=sample_rate
)
return features.numpy()
except ImportError:
return None
def get_fbank(
path_or_fp: Union[str, BinaryIO], n_bins=80, waveform_transforms=None
) -> np.ndarray:
"""Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
(faster CPP implementation) to TorchAudio (Python implementation). Note that
Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
waveform should not be normalized."""
waveform, sample_rate = get_waveform(
path_or_fp, normalization=False, waveform_transforms=waveform_transforms
)
features = _get_kaldi_fbank(waveform, sample_rate, n_bins)
if features is None:
features = _get_torchaudio_fbank(waveform, sample_rate, n_bins)
if features is None:
raise ImportError(
"Please install pyKaldi or torchaudio to enable "
"online filterbank feature extraction"
)
return features
def is_npy_data(data: bytes) -> bool:
return data[0] == 147 and data[1] == 78
def is_sf_audio_data(data: bytes) -> bool:
is_wav = data[0] == 82 and data[1] == 73 and data[2] == 70
is_flac = data[0] == 102 and data[1] == 76 and data[2] == 97
is_ogg = data[0] == 79 and data[1] == 103 and data[2] == 103
return is_wav or is_flac or is_ogg
def mmap_read(path: str, offset: int, length: int) -> bytes:
with open(path, "rb") as f:
with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_o:
data = mmap_o[offset : offset + length]
return data
def read_from_stored_zip(zip_path: str, offset: int, length: int) -> bytes:
return mmap_read(zip_path, offset, length)
def parse_path(path: str) -> Tuple[str, List[int]]:
"""Parse data path which is either a path to
1. a .npy/.wav/.flac/.ogg file
2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
Args:
path (str): the data path to parse
Returns:
file_path (str): the file path
slice_ptr (list of int): empty in case 1;
byte offset and length for the slice in case 2
"""
if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
_path, slice_ptr = path, []
else:
_path, *slice_ptr = path.split(":")
if not Path(_path).is_file():
raise FileNotFoundError(f"File not found: {_path}")
assert len(slice_ptr) in {0, 2}, f"Invalid path: {path}"
slice_ptr = [int(i) for i in slice_ptr]
return _path, slice_ptr
def get_window(window_fn: callable, n_fft: int, win_length: int) -> torch.Tensor:
padding = n_fft - win_length
assert padding >= 0
return F.pad(window_fn(win_length), (padding // 2, padding - padding // 2))
def get_fourier_basis(n_fft: int) -> torch.Tensor:
basis = np.fft.fft(np.eye(n_fft))
basis = np.vstack(
[np.real(basis[: n_fft // 2 + 1, :]), np.imag(basis[: n_fft // 2 + 1, :])]
)
return torch.from_numpy(basis).float()
def get_mel_filters(
sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float
) -> torch.Tensor:
try:
import librosa
except ImportError:
raise ImportError("Please install librosa: pip install librosa")
basis = librosa.filters.mel(sample_rate, n_fft, n_mels, f_min, f_max)
return torch.from_numpy(basis).float()
class TTSSpectrogram(torch.nn.Module):
def __init__(
self,
n_fft: int,
win_length: int,
hop_length: int,
window_fn: callable = torch.hann_window,
return_phase: bool = False,
) -> None:
super(TTSSpectrogram, self).__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.return_phase = return_phase
basis = get_fourier_basis(n_fft).unsqueeze(1)
basis *= get_window(window_fn, n_fft, win_length)
self.register_buffer("basis", basis)
def forward(
self, waveform: torch.Tensor
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
padding = (self.n_fft // 2, self.n_fft // 2)
x = F.pad(waveform.unsqueeze(1), padding, mode="reflect")
x = F.conv1d(x, self.basis, stride=self.hop_length)
real_part = x[:, : self.n_fft // 2 + 1, :]
imag_part = x[:, self.n_fft // 2 + 1 :, :]
magnitude = torch.sqrt(real_part**2 + imag_part**2)
if self.return_phase:
phase = torch.atan2(imag_part, real_part)
return magnitude, phase
return magnitude
class TTSMelScale(torch.nn.Module):
def __init__(
self, n_mels: int, sample_rate: int, f_min: float, f_max: float, n_stft: int
) -> None:
super(TTSMelScale, self).__init__()
basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max)
self.register_buffer("basis", basis)
def forward(self, specgram: torch.Tensor) -> torch.Tensor:
return torch.matmul(self.basis, specgram)

View File

@@ -0,0 +1,387 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from argparse import Namespace
from copy import deepcopy
from pathlib import Path
from typing import Dict, Optional
from fairseq.data import Dictionary
logger = logging.getLogger(__name__)
def get_config_from_yaml(yaml_path: Path):
try:
import yaml
except ImportError:
print("Please install PyYAML: pip install PyYAML")
config = {}
if yaml_path.is_file():
try:
with open(yaml_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
except Exception as e:
raise Exception(f"Failed to load config from {yaml_path.as_posix()}: {e}")
else:
raise FileNotFoundError(f"{yaml_path.as_posix()} not found")
return config
class S2TDataConfig(object):
"""Wrapper class for data config YAML"""
def __init__(self, yaml_path: Path):
self.config = get_config_from_yaml(yaml_path)
self.root = yaml_path.parent
def _auto_convert_to_abs_path(self, x):
if isinstance(x, str):
if not Path(x).exists() and (self.root / x).exists():
return (self.root / x).as_posix()
elif isinstance(x, dict):
return {k: self._auto_convert_to_abs_path(v) for k, v in x.items()}
return x
@property
def vocab_filename(self):
"""fairseq vocabulary file under data root"""
return self.config.get("vocab_filename", "dict.txt")
@property
def speaker_set_filename(self):
"""speaker set file under data root"""
return self.config.get("speaker_set_filename", None)
@property
def shuffle(self) -> bool:
"""Shuffle dataset samples before batching"""
return self.config.get("shuffle", False)
@property
def pre_tokenizer(self) -> Dict:
"""Pre-tokenizer to apply before subword tokenization. Returning
a dictionary with `tokenizer` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
tokenizer = self.config.get("pre_tokenizer", {"tokenizer": None})
return self._auto_convert_to_abs_path(tokenizer)
@property
def bpe_tokenizer(self) -> Dict:
"""Subword tokenizer to apply after pre-tokenization. Returning
a dictionary with `bpe` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
tokenizer = self.config.get("bpe_tokenizer", {"bpe": None})
return self._auto_convert_to_abs_path(tokenizer)
@property
def prepend_tgt_lang_tag(self) -> bool:
"""Prepend target lang ID token as the target BOS (e.g. for to-many
multilingual setting). During inference, this requires `--prefix-size 1`
to force BOS to be lang ID token."""
return self.config.get("prepend_tgt_lang_tag", False)
@property
def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
@property
def input_feat_per_channel(self):
"""The dimension of input features (per audio channel)"""
return self.config.get("input_feat_per_channel", 80)
@property
def input_channels(self):
"""The number of channels in the input audio"""
return self.config.get("input_channels", 1)
@property
def sample_rate(self):
return self.config.get("sample_rate", 16_000)
@property
def sampling_alpha(self):
"""Hyper-parameter alpha = 1/T for temperature-based resampling.
(alpha = 1 for no resampling)"""
return self.config.get("sampling_alpha", 1.0)
@property
def use_audio_input(self):
"""Needed by the dataset loader to see if the model requires
raw audio as inputs."""
return self.config.get("use_audio_input", False)
def standardize_audio(self) -> bool:
return self.use_audio_input and self.config.get("standardize_audio", False)
@property
def use_sample_rate(self):
"""Needed by the dataset loader to see if the model requires
raw audio with specific sample rate as inputs."""
return self.config.get("use_sample_rate", 16000)
@property
def audio_root(self):
"""Audio paths in the manifest TSV can be relative and this provides
the root path. Set this to empty string when using absolute paths."""
return self.config.get("audio_root", "")
def get_transforms(self, transform_type, split, is_train):
"""Split-specific feature transforms. Allowing train set
wildcard `_train`, evaluation set wildcard `_eval` and general
wildcard `*` for matching."""
from copy import deepcopy
cfg = deepcopy(self.config)
_cur = cfg.get(f"{transform_type}transforms", {})
cur = _cur.get(split)
cur = _cur.get("_train") if cur is None and is_train else cur
cur = _cur.get("_eval") if cur is None and not is_train else cur
cur = _cur.get("*") if cur is None else cur
return cur
def get_feature_transforms(self, split, is_train):
cfg = deepcopy(self.config)
# TODO: deprecate transforms
cur = self.get_transforms("", split, is_train)
if cur is not None:
logger.warning(
"Auto converting transforms into feature_transforms, "
"but transforms will be deprecated in the future. Please "
"update this in the config."
)
ft_transforms = self.get_transforms("feature_", split, is_train)
if ft_transforms:
cur.extend(ft_transforms)
else:
cur = self.get_transforms("feature_", split, is_train)
cfg["feature_transforms"] = cur
return cfg
def get_waveform_transforms(self, split, is_train):
cfg = deepcopy(self.config)
cfg["waveform_transforms"] = self.get_transforms("waveform_", split, is_train)
return cfg
def get_dataset_transforms(self, split, is_train):
cfg = deepcopy(self.config)
cfg["dataset_transforms"] = self.get_transforms("dataset_", split, is_train)
return cfg
@property
def global_cmvn_stats_npz(self) -> Optional[str]:
path = self.config.get("global_cmvn", {}).get("stats_npz_path", None)
return self._auto_convert_to_abs_path(path)
@property
def vocoder(self) -> Dict[str, str]:
vocoder = self.config.get("vocoder", {"type": "griffin_lim"})
return self._auto_convert_to_abs_path(vocoder)
@property
def hub(self) -> Dict[str, str]:
return self.config.get("hub", {})
class S2SDataConfig(S2TDataConfig):
"""Wrapper class for data config YAML"""
@property
def vocab_filename(self):
"""fairseq vocabulary file under data root"""
return self.config.get("vocab_filename", None)
@property
def pre_tokenizer(self) -> Dict:
return None
@property
def bpe_tokenizer(self) -> Dict:
return None
@property
def input_transformed_channels(self):
"""The number of channels in the audio after feature transforms"""
# TODO: move this into individual transforms
# TODO: deprecate transforms
_cur = self.config.get("transforms", {})
ft_transforms = self.config.get("feature_transforms", {})
if _cur and ft_transforms:
_cur.update(ft_transforms)
else:
_cur = self.config.get("feature_transforms", {})
cur = _cur.get("_train", [])
_channels = self.input_channels
if "delta_deltas" in cur:
_channels *= 3
return _channels
@property
def output_sample_rate(self):
"""The audio sample rate of output target speech"""
return self.config.get("output_sample_rate", 22050)
@property
def target_speaker_embed(self):
"""Target speaker embedding file (one line per target audio sample)"""
return self.config.get("target_speaker_embed", None)
@property
def prepend_tgt_lang_tag_as_bos(self) -> bool:
"""Prepend target lang ID token as the target BOS."""
return self.config.get("prepend_tgt_lang_tag_as_bos", False)
class MultitaskConfig(object):
"""Wrapper class for data config YAML"""
def __init__(self, yaml_path: Path):
config = get_config_from_yaml(yaml_path)
self.config = {}
for k, v in config.items():
self.config[k] = SingleTaskConfig(k, v)
def get_all_tasks(self):
return self.config
def get_single_task(self, name):
assert name in self.config, f"multitask '{name}' does not exist!"
return self.config[name]
@property
def first_pass_decoder_task_index(self):
"""Return the task index of the first-pass text decoder.
If there are multiple 'is_first_pass_decoder: True' in the config file,
the last task is used for the first-pass decoder.
If there is no 'is_first_pass_decoder: True' in the config file,
the last task whose task_name includes 'target' and decoder_type is not ctc.
"""
idx = -1
for i, (k, v) in enumerate(self.config.items()):
if v.is_first_pass_decoder:
idx = i
if idx < 0:
for i, (k, v) in enumerate(self.config.items()):
if k.startswith("target") and v.decoder_type == "transformer":
idx = i
return idx
class SingleTaskConfig(object):
def __init__(self, name, config):
self.task_name = name
self.config = config
dict_path = config.get("dict", "")
self.tgt_dict = Dictionary.load(dict_path) if Path(dict_path).exists() else None
@property
def data(self):
return self.config.get("data", "")
@property
def decoder_type(self):
return self.config.get("decoder_type", "transformer")
@property
def decoder_args(self):
"""Decoder arch related args"""
args = self.config.get("decoder_args", {})
return Namespace(**args)
@property
def criterion_cfg(self):
"""cfg for the multitask criterion"""
if self.decoder_type == "ctc":
from fairseq.criterions.ctc import CtcCriterionConfig
cfg = CtcCriterionConfig
cfg.zero_infinity = self.config.get("zero_infinity", True)
else:
from fairseq.criterions.label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterionConfig,
)
cfg = LabelSmoothedCrossEntropyCriterionConfig
cfg.label_smoothing = self.config.get("label_smoothing", 0.2)
return cfg
@property
def input_from(self):
"""Condition on encoder/decoder of the main model"""
return "decoder" if "decoder_layer" in self.config else "encoder"
@property
def input_layer(self):
if self.input_from == "decoder":
return self.config["decoder_layer"] - 1
else:
# default using the output from the last encoder layer (-1)
return self.config.get("encoder_layer", 0) - 1
@property
def loss_weight_schedule(self):
return (
"decay"
if "loss_weight_max" in self.config
and "loss_weight_decay_steps" in self.config
else "fixed"
)
def get_loss_weight(self, num_updates):
if self.loss_weight_schedule == "fixed":
weight = self.config.get("loss_weight", 1.0)
else: # "decay"
assert (
self.config.get("loss_weight_decay_steps", 0) > 0
), "loss_weight_decay_steps must be greater than 0 for a decay schedule"
loss_weight_min = self.config.get("loss_weight_min", 0.0001)
loss_weight_decay_stepsize = (
self.config["loss_weight_max"] - loss_weight_min
) / self.config["loss_weight_decay_steps"]
weight = max(
self.config["loss_weight_max"]
- loss_weight_decay_stepsize * num_updates,
loss_weight_min,
)
return weight
@property
def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
@property
def eos_token(self):
"""EOS token during generation"""
return self.config.get("eos_token", "<eos>")
@property
def rdrop_alpha(self):
return self.config.get("rdrop_alpha", 0.0)
@property
def is_first_pass_decoder(self):
flag = self.config.get("is_first_pass_decoder", False)
if flag:
if self.decoder_type == "ctc":
raise ValueError(
"First-pass decoder in the multi-decoder model must not be CTC."
)
if "target" not in self.task_name:
raise Warning(
'The name of the first-pass decoder does not include "target".'
)
return flag
@property
def get_lang_tag_mapping(self):
return self.config.get("lang_tag_mapping", {})

View File

@@ -0,0 +1,53 @@
import os
from fairseq.data.audio import (
AudioTransform,
CompositeAudioTransform,
import_transforms,
register_audio_transform,
)
class AudioDatasetTransform(AudioTransform):
pass
AUDIO_DATASET_TRANSFORM_REGISTRY = {}
AUDIO_DATASET_TRANSFORM_CLASS_NAMES = set()
def get_audio_dataset_transform(name):
return AUDIO_DATASET_TRANSFORM_REGISTRY[name]
def register_audio_dataset_transform(name):
return register_audio_transform(
name,
AudioDatasetTransform,
AUDIO_DATASET_TRANSFORM_REGISTRY,
AUDIO_DATASET_TRANSFORM_CLASS_NAMES,
)
import_transforms(os.path.dirname(__file__), "dataset")
class CompositeAudioDatasetTransform(CompositeAudioTransform):
@classmethod
def from_config_dict(cls, config=None):
return super()._from_config_dict(
cls,
"dataset",
get_audio_dataset_transform,
CompositeAudioDatasetTransform,
config,
return_empty=True,
)
def get_transform(self, cls):
for t in self.transforms:
if isinstance(t, cls):
return t
return None
def has_transform(self, cls):
return self.get_transform(cls) is not None

View File

@@ -0,0 +1,61 @@
from typing import List
import numpy as np
from fairseq.data.audio.dataset_transforms import (
AudioDatasetTransform,
register_audio_dataset_transform,
)
_DEFAULTS = {"rate": 0.25, "max_tokens": 3000, "attempts": 5}
@register_audio_dataset_transform("concataugment")
class ConcatAugment(AudioDatasetTransform):
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return ConcatAugment(
_config.get("rate", _DEFAULTS["rate"]),
_config.get("max_tokens", _DEFAULTS["max_tokens"]),
_config.get("attempts", _DEFAULTS["attempts"]),
)
def __init__(
self,
rate=_DEFAULTS["rate"],
max_tokens=_DEFAULTS["max_tokens"],
attempts=_DEFAULTS["attempts"],
):
self.rate, self.max_tokens, self.attempts = rate, max_tokens, attempts
def __repr__(self):
return (
self.__class__.__name__
+ "("
+ ", ".join(
[
f"rate={self.rate}",
f"max_tokens={self.max_tokens}",
f"attempts={self.attempts}",
]
)
+ ")"
)
def find_indices(self, index: int, n_frames: List[int], n_samples: int):
# skip conditions: application rate, max_tokens limit exceeded
if np.random.random() > self.rate:
return [index]
if self.max_tokens and n_frames[index] > self.max_tokens:
return [index]
# pick second sample to concatenate
for _ in range(self.attempts):
index2 = np.random.randint(0, n_samples)
if index2 != index and (
not self.max_tokens
or n_frames[index] + n_frames[index2] < self.max_tokens
):
return [index, index2]
return [index]

View File

@@ -0,0 +1,105 @@
import numpy as np
import torch
from fairseq.data.audio import rand_uniform
from fairseq.data.audio.dataset_transforms import (
AudioDatasetTransform,
register_audio_dataset_transform,
)
from fairseq.data.audio.waveform_transforms.noiseaugment import (
NoiseAugmentTransform,
)
_DEFAULTS = {
"rate": 0.25,
"mixing_noise_rate": 0.1,
"noise_path": "",
"noise_snr_min": -5,
"noise_snr_max": 5,
"utterance_snr_min": -5,
"utterance_snr_max": 5,
}
@register_audio_dataset_transform("noisyoverlapaugment")
class NoisyOverlapAugment(AudioDatasetTransform):
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return NoisyOverlapAugment(
_config.get("rate", _DEFAULTS["rate"]),
_config.get("mixing_noise_rate", _DEFAULTS["mixing_noise_rate"]),
_config.get("noise_path", _DEFAULTS["noise_path"]),
_config.get("noise_snr_min", _DEFAULTS["noise_snr_min"]),
_config.get("noise_snr_max", _DEFAULTS["noise_snr_max"]),
_config.get("utterance_snr_min", _DEFAULTS["utterance_snr_min"]),
_config.get("utterance_snr_max", _DEFAULTS["utterance_snr_max"]),
)
def __init__(
self,
rate=_DEFAULTS["rate"],
mixing_noise_rate=_DEFAULTS["mixing_noise_rate"],
noise_path=_DEFAULTS["noise_path"],
noise_snr_min=_DEFAULTS["noise_snr_min"],
noise_snr_max=_DEFAULTS["noise_snr_max"],
utterance_snr_min=_DEFAULTS["utterance_snr_min"],
utterance_snr_max=_DEFAULTS["utterance_snr_max"],
):
self.rate = rate
self.mixing_noise_rate = mixing_noise_rate
self.noise_shaper = NoiseAugmentTransform(noise_path)
self.noise_snr_min = noise_snr_min
self.noise_snr_max = noise_snr_max
self.utterance_snr_min = utterance_snr_min
self.utterance_snr_max = utterance_snr_max
def __repr__(self):
return (
self.__class__.__name__
+ "("
+ ", ".join(
[
f"rate={self.rate}",
f"mixing_noise_rate={self.mixing_noise_rate}",
f"noise_snr_min={self.noise_snr_min}",
f"noise_snr_max={self.noise_snr_max}",
f"utterance_snr_min={self.utterance_snr_min}",
f"utterance_snr_max={self.utterance_snr_max}",
]
)
+ ")"
)
def __call__(self, sources):
for i, source in enumerate(sources):
if np.random.random() > self.rate:
continue
pri = source.numpy()
if np.random.random() > self.mixing_noise_rate:
sec = sources[np.random.randint(0, len(sources))].numpy()
snr = rand_uniform(self.utterance_snr_min, self.utterance_snr_max)
else:
sec = self.noise_shaper.pick_sample(source.shape)
snr = rand_uniform(self.noise_snr_min, self.noise_snr_max)
L1 = pri.shape[-1]
L2 = sec.shape[-1]
l = np.random.randint(0, min(round(L1 / 2), L2)) # mix len
s_source = np.random.randint(0, L1 - l)
s_sec = np.random.randint(0, L2 - l)
get_power = lambda x: np.mean(x**2)
if get_power(sec) == 0:
continue
scl = np.sqrt(get_power(pri) / (np.power(10, snr / 10) * get_power(sec)))
pri[s_source : s_source + l] = np.add(
pri[s_source : s_source + l], np.multiply(scl, sec[s_sec : s_sec + l])
)
sources[i] = torch.from_numpy(pri).float()
return sources

View File

@@ -0,0 +1,43 @@
import os
from fairseq.data.audio import (
AudioTransform,
CompositeAudioTransform,
import_transforms,
register_audio_transform,
)
class AudioFeatureTransform(AudioTransform):
pass
AUDIO_FEATURE_TRANSFORM_REGISTRY = {}
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set()
def get_audio_feature_transform(name):
return AUDIO_FEATURE_TRANSFORM_REGISTRY[name]
def register_audio_feature_transform(name):
return register_audio_transform(
name,
AudioFeatureTransform,
AUDIO_FEATURE_TRANSFORM_REGISTRY,
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES,
)
import_transforms(os.path.dirname(__file__), "feature")
class CompositeAudioFeatureTransform(CompositeAudioTransform):
@classmethod
def from_config_dict(cls, config=None):
return super()._from_config_dict(
cls,
"feature",
get_audio_feature_transform,
CompositeAudioFeatureTransform,
config,
)

View File

@@ -0,0 +1,37 @@
import numpy as np
import torch
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform,
register_audio_feature_transform,
)
@register_audio_feature_transform("delta_deltas")
class DeltaDeltas(AudioFeatureTransform):
"""Expand delta-deltas features from spectrum."""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return DeltaDeltas(_config.get("win_length", 5))
def __init__(self, win_length=5):
self.win_length = win_length
def __repr__(self):
return self.__class__.__name__
def __call__(self, spectrogram):
from torchaudio.functional import compute_deltas
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
# spectrogram is T x F, while compute_deltas takes (…, F, T)
spectrogram = torch.from_numpy(spectrogram).transpose(0, 1)
delta = compute_deltas(spectrogram)
delta_delta = compute_deltas(delta)
out_feat = np.concatenate(
[spectrogram, delta.numpy(), delta_delta.numpy()], axis=0
)
out_feat = np.transpose(out_feat)
return out_feat

View File

@@ -0,0 +1,29 @@
import numpy as np
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform,
register_audio_feature_transform,
)
@register_audio_feature_transform("global_cmvn")
class GlobalCMVN(AudioFeatureTransform):
"""Global CMVN (cepstral mean and variance normalization). The global mean
and variance need to be pre-computed and stored in NumPy format (.npz)."""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return GlobalCMVN(_config.get("stats_npz_path"))
def __init__(self, stats_npz_path):
self.stats_npz_path = stats_npz_path
stats = np.load(stats_npz_path)
self.mean, self.std = stats["mean"], stats["std"]
def __repr__(self):
return self.__class__.__name__ + f'(stats_npz_path="{self.stats_npz_path}")'
def __call__(self, x):
x = np.subtract(x, self.mean)
x = np.divide(x, self.std)
return x

View File

@@ -0,0 +1,131 @@
import math
import numbers
from typing import Optional
import numpy as np
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform,
register_audio_feature_transform,
)
@register_audio_feature_transform("specaugment")
class SpecAugmentTransform(AudioFeatureTransform):
"""SpecAugment (https://arxiv.org/abs/1904.08779)"""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return SpecAugmentTransform(
_config.get("time_warp_W", 0),
_config.get("freq_mask_N", 0),
_config.get("freq_mask_F", 0),
_config.get("time_mask_N", 0),
_config.get("time_mask_T", 0),
_config.get("time_mask_p", 0.0),
_config.get("mask_value", None),
)
def __init__(
self,
time_warp_w: int = 0,
freq_mask_n: int = 0,
freq_mask_f: int = 0,
time_mask_n: int = 0,
time_mask_t: int = 0,
time_mask_p: float = 0.0,
mask_value: Optional[float] = 0.0,
):
# Sanity checks
assert mask_value is None or isinstance(
mask_value, numbers.Number
), f"mask_value (type: {type(mask_value)}) must be None or a number"
if freq_mask_n > 0:
assert freq_mask_f > 0, (
f"freq_mask_F ({freq_mask_f}) "
f"must be larger than 0 when doing freq masking."
)
if time_mask_n > 0:
assert time_mask_t > 0, (
f"time_mask_T ({time_mask_t}) must be larger than 0 when "
f"doing time masking."
)
self.time_warp_w = time_warp_w
self.freq_mask_n = freq_mask_n
self.freq_mask_f = freq_mask_f
self.time_mask_n = time_mask_n
self.time_mask_t = time_mask_t
self.time_mask_p = time_mask_p
self.mask_value = mask_value
def __repr__(self):
return (
self.__class__.__name__
+ "("
+ ", ".join(
[
f"time_warp_w={self.time_warp_w}",
f"freq_mask_n={self.freq_mask_n}",
f"freq_mask_f={self.freq_mask_f}",
f"time_mask_n={self.time_mask_n}",
f"time_mask_t={self.time_mask_t}",
f"time_mask_p={self.time_mask_p}",
]
)
+ ")"
)
def __call__(self, spectrogram):
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
distorted = spectrogram.copy() # make a copy of input spectrogram.
num_frames = spectrogram.shape[0] # or 'tau' in the paper.
num_freqs = spectrogram.shape[1] # or 'miu' in the paper.
mask_value = self.mask_value
if mask_value is None: # if no value was specified, use local mean.
mask_value = spectrogram.mean()
if num_frames == 0:
return spectrogram
if num_freqs < self.freq_mask_f:
return spectrogram
if self.time_warp_w > 0:
if 2 * self.time_warp_w < num_frames:
import cv2
w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w)
w = np.random.randint(-self.time_warp_w + 1, self.time_warp_w)
upper, lower = distorted[:w0, :], distorted[w0:, :]
upper = cv2.resize(
upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR
)
lower = cv2.resize(
lower,
dsize=(num_freqs, num_frames - w0 - w),
interpolation=cv2.INTER_LINEAR,
)
distorted = np.concatenate((upper, lower), axis=0)
for _i in range(self.freq_mask_n):
f = np.random.randint(0, self.freq_mask_f)
f0 = np.random.randint(0, num_freqs - f)
if f != 0:
distorted[:, f0 : f0 + f] = mask_value
max_time_mask_t = min(
self.time_mask_t, math.floor(num_frames * self.time_mask_p)
)
if max_time_mask_t < 1:
return distorted
for _i in range(self.time_mask_n):
t = np.random.randint(0, max_time_mask_t)
t0 = np.random.randint(0, num_frames - t)
if t != 0:
distorted[t0 : t0 + t, :] = mask_value
return distorted

View File

@@ -0,0 +1,41 @@
import numpy as np
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform,
register_audio_feature_transform,
)
@register_audio_feature_transform("utterance_cmvn")
class UtteranceCMVN(AudioFeatureTransform):
"""Utterance-level CMVN (cepstral mean and variance normalization)"""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return UtteranceCMVN(
_config.get("norm_means", True),
_config.get("norm_vars", True),
)
def __init__(self, norm_means=True, norm_vars=True):
self.norm_means, self.norm_vars = norm_means, norm_vars
def __repr__(self):
return (
self.__class__.__name__
+ f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})"
)
def __call__(self, x):
mean = x.mean(axis=0)
square_sums = (x**2).sum(axis=0)
if self.norm_means:
x = np.subtract(x, mean)
if self.norm_vars:
var = square_sums / x.shape[0] - mean**2
std = np.sqrt(np.maximum(var, 1e-10))
x = np.divide(x, std)
return x

View File

@@ -0,0 +1,205 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.abs
import csv
import logging
import os.path as op
from typing import List, Optional
import numpy as np
import torch
from fairseq.data import Dictionary
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
from fairseq.data.audio.text_to_speech_dataset import (
TextToSpeechDataset,
TextToSpeechDatasetCreator,
)
logger = logging.getLogger(__name__)
class FrmTextToSpeechDataset(TextToSpeechDataset):
def __init__(
self,
split: str,
is_train_split: bool,
data_cfg: S2TDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
n_frames_per_step=1,
speaker_to_id=None,
do_chunk=False,
chunk_bound=-1,
chunk_init=50,
chunk_incr=5,
add_eos=True,
dedup=True,
ref_fpu=-1,
):
# It assumes texts are encoded at a fixed frame-rate
super().__init__(
split=split,
is_train_split=is_train_split,
data_cfg=data_cfg,
audio_paths=audio_paths,
n_frames=n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
)
self.do_chunk = do_chunk
self.chunk_bound = chunk_bound
self.chunk_init = chunk_init
self.chunk_incr = chunk_incr
self.add_eos = add_eos
self.dedup = dedup
self.ref_fpu = ref_fpu
self.chunk_size = -1
if do_chunk:
assert self.chunk_incr >= 0
assert self.pre_tokenizer is None
def __getitem__(self, index):
index, source, target, speaker_id, _, _, _ = super().__getitem__(index)
if target[-1].item() == self.tgt_dict.eos_index:
target = target[:-1]
fpu = source.size(0) / target.size(0) # frame-per-unit
fps = self.n_frames_per_step
assert (
self.ref_fpu == -1 or abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1
), f"{fpu*fps} != {self.ref_fpu}"
# only chunk training split
if self.is_train_split and self.do_chunk and self.chunk_size > 0:
lang = target[: int(self.data_cfg.prepend_tgt_lang_tag)]
text = target[int(self.data_cfg.prepend_tgt_lang_tag) :]
size = len(text)
chunk_size = min(self.chunk_size, size)
chunk_start = np.random.randint(size - chunk_size + 1)
text = text[chunk_start : chunk_start + chunk_size]
target = torch.cat((lang, text), 0)
f_size = int(np.floor(chunk_size * fpu))
f_start = int(np.floor(chunk_start * fpu))
assert f_size > 0
source = source[f_start : f_start + f_size, :]
if self.dedup:
target = torch.unique_consecutive(target)
if self.add_eos:
eos_idx = self.tgt_dict.eos_index
target = torch.cat((target, torch.LongTensor([eos_idx])), 0)
return index, source, target, speaker_id
def set_epoch(self, epoch):
if self.is_train_split and self.do_chunk:
old = self.chunk_size
self.chunk_size = self.chunk_init + epoch * self.chunk_incr
if self.chunk_bound > 0:
self.chunk_size = min(self.chunk_size, self.chunk_bound)
logger.info(
(
f"{self.split}: setting chunk size "
f"from {old} to {self.chunk_size}"
)
)
class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
# inherit for key names
@classmethod
def from_tsv(
cls,
root: str,
data_cfg: S2TDataConfig,
split: str,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
is_train_split: bool,
n_frames_per_step: int,
speaker_to_id,
do_chunk: bool = False,
chunk_bound: int = -1,
chunk_init: int = 50,
chunk_incr: int = 5,
add_eos: bool = True,
dedup: bool = True,
ref_fpu: float = -1,
) -> FrmTextToSpeechDataset:
tsv_path = op.join(root, f"{split}.tsv")
if not op.isfile(tsv_path):
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
with open(tsv_path) as f:
reader = csv.DictReader(
f,
delimiter="\t",
quotechar=None,
doublequote=False,
lineterminator="\n",
quoting=csv.QUOTE_NONE,
)
s = [dict(e) for e in reader]
assert len(s) > 0
ids = [ss[cls.KEY_ID] for ss in s]
audio_paths = [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s]
n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s]
tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s]
src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s]
speakers = [ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s]
src_langs = [ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s]
tgt_langs = [ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s]
return FrmTextToSpeechDataset(
split=split,
is_train_split=is_train_split,
data_cfg=data_cfg,
audio_paths=audio_paths,
n_frames=n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
do_chunk=do_chunk,
chunk_bound=chunk_bound,
chunk_init=chunk_init,
chunk_incr=chunk_incr,
add_eos=add_eos,
dedup=dedup,
ref_fpu=ref_fpu,
)

View File

@@ -0,0 +1,356 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import logging
import os
import sys
from typing import Any, List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.data import data_utils
from fairseq.data.fairseq_dataset import FairseqDataset
from fairseq.data.audio.audio_utils import (
parse_path,
read_from_stored_zip,
)
import io
logger = logging.getLogger(__name__)
def load_audio(manifest_path, max_keep, min_keep):
n_long, n_short = 0, 0
names, inds, sizes = [], [], []
with open(manifest_path) as f:
root = f.readline().strip()
for ind, line in enumerate(f):
items = line.strip().split("\t")
assert len(items) == 2, line
sz = int(items[1])
if min_keep is not None and sz < min_keep:
n_short += 1
elif max_keep is not None and sz > max_keep:
n_long += 1
else:
names.append(items[0])
inds.append(ind)
sizes.append(sz)
tot = ind + 1
logger.info(
(
f"max_keep={max_keep}, min_keep={min_keep}, "
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
)
)
return root, names, inds, tot, sizes
def load_label(label_path, inds, tot):
with open(label_path) as f:
labels = [line.rstrip() for line in f]
assert (
len(labels) == tot
), f"number of labels does not match ({len(labels)} != {tot})"
labels = [labels[i] for i in inds]
return labels
def load_label_offset(label_path, inds, tot):
with open(label_path) as f:
code_lengths = [len(line.encode("utf-8")) for line in f]
assert (
len(code_lengths) == tot
), f"number of labels does not match ({len(code_lengths)} != {tot})"
offsets = list(itertools.accumulate([0] + code_lengths))
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
return offsets
def verify_label_lengths(
audio_sizes,
audio_rate,
label_path,
label_rate,
inds,
tot,
tol=0.1, # tolerance in seconds
):
if label_rate < 0:
logger.info(f"{label_path} is sequence label. skipped")
return
with open(label_path) as f:
lengths = [len(line.rstrip().split()) for line in f]
assert len(lengths) == tot
lengths = [lengths[i] for i in inds]
num_invalid = 0
for i, ind in enumerate(inds):
dur_from_audio = audio_sizes[i] / audio_rate
dur_from_label = lengths[i] / label_rate
if abs(dur_from_audio - dur_from_label) > tol:
logger.warning(
(
f"audio and label duration differ too much "
f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
f"in line {ind+1} of {label_path}. Check if `label_rate` "
f"is correctly set (currently {label_rate}). "
f"num. of samples = {audio_sizes[i]}; "
f"label length = {lengths[i]}"
)
)
num_invalid += 1
if num_invalid > 0:
logger.warning(
f"total {num_invalid} (audio, label) pairs with mismatched lengths"
)
class HubertDataset(FairseqDataset):
def __init__(
self,
manifest_path: str,
sample_rate: float,
label_paths: List[str],
label_rates: Union[List[float], float], # -1 for sequence labels
pad_list: List[str],
eos_list: List[str],
label_processors: Optional[List[Any]] = None,
max_keep_sample_size: Optional[int] = None,
min_keep_sample_size: Optional[int] = None,
max_sample_size: Optional[int] = None,
shuffle: bool = True,
pad_audio: bool = False,
normalize: bool = False,
store_labels: bool = True,
random_crop: bool = False,
single_target: bool = False,
):
self.audio_root, self.audio_names, inds, tot, self.sizes = load_audio(
manifest_path, max_keep_sample_size, min_keep_sample_size
)
self.sample_rate = sample_rate
self.shuffle = shuffle
self.random_crop = random_crop
self.num_labels = len(label_paths)
self.pad_list = pad_list
self.eos_list = eos_list
self.label_processors = label_processors
self.single_target = single_target
self.label_rates = (
[label_rates for _ in range(len(label_paths))]
if isinstance(label_rates, float)
else label_rates
)
self.store_labels = store_labels
if store_labels:
self.label_list = [load_label(p, inds, tot) for p in label_paths]
else:
self.label_paths = label_paths
self.label_offsets_list = [
load_label_offset(p, inds, tot) for p in label_paths
]
assert label_processors is None or len(label_processors) == self.num_labels
for label_path, label_rate in zip(label_paths, self.label_rates):
verify_label_lengths(
self.sizes, sample_rate, label_path, label_rate, inds, tot
)
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
self.pad_audio = pad_audio
self.normalize = normalize
logger.info(
f"pad_audio={pad_audio}, random_crop={random_crop}, "
f"normalize={normalize}, max_sample_size={self.max_sample_size}"
)
def get_audio(self, index):
import soundfile as sf
wav_path = os.path.join(self.audio_root, self.audio_names[index])
_path, slice_ptr = parse_path(wav_path)
if len(slice_ptr) == 0:
wav, cur_sample_rate = sf.read(_path)
else:
assert _path.endswith(".zip")
data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
f = io.BytesIO(data)
wav, cur_sample_rate = sf.read(f)
wav = torch.from_numpy(wav).float()
wav = self.postprocess(wav, cur_sample_rate)
return wav
def get_label(self, index, label_idx):
if self.store_labels:
label = self.label_list[label_idx][index]
else:
with open(self.label_paths[label_idx]) as f:
offset_s, offset_e = self.label_offsets_list[label_idx][index]
f.seek(offset_s)
label = f.read(offset_e - offset_s)
if self.label_processors is not None:
label = self.label_processors[label_idx](label)
return label
def get_labels(self, index):
return [self.get_label(index, i) for i in range(self.num_labels)]
def __getitem__(self, index):
wav = self.get_audio(index)
labels = self.get_labels(index)
return {"id": index, "source": wav, "label_list": labels}
def __len__(self):
return len(self.sizes)
def crop_to_max_size(self, wav, target_size):
size = len(wav)
diff = size - target_size
if diff <= 0:
return wav, 0
start, end = 0, target_size
if self.random_crop:
start = np.random.randint(0, diff + 1)
end = size - diff + start
return wav[start:end], start
def collater(self, samples):
# target = max(sizes) -> random_crop not used
# target = max_sample_size -> random_crop used for long
samples = [s for s in samples if s["source"] is not None]
if len(samples) == 0:
return {}
audios = [s["source"] for s in samples]
audio_sizes = [len(s) for s in audios]
if self.pad_audio:
audio_size = min(max(audio_sizes), self.max_sample_size)
else:
audio_size = min(min(audio_sizes), self.max_sample_size)
collated_audios, padding_mask, audio_starts = self.collater_audio(
audios, audio_size
)
targets_by_label = [
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
]
targets_list, lengths_list, ntokens_list = self.collater_label(
targets_by_label, audio_size, audio_starts
)
net_input = {"source": collated_audios, "padding_mask": padding_mask}
batch = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": net_input,
}
if self.single_target:
batch["target_lengths"] = lengths_list[0]
batch["ntokens"] = ntokens_list[0]
batch["target"] = targets_list[0]
else:
batch["target_lengths_list"] = lengths_list
batch["ntokens_list"] = ntokens_list
batch["target_list"] = targets_list
return batch
def collater_audio(self, audios, audio_size):
collated_audios = audios[0].new_zeros(len(audios), audio_size)
padding_mask = (
torch.BoolTensor(collated_audios.shape).fill_(False)
# if self.pad_audio else None
)
audio_starts = [0 for _ in audios]
for i, audio in enumerate(audios):
diff = len(audio) - audio_size
if diff == 0:
collated_audios[i] = audio
elif diff < 0:
assert self.pad_audio
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
padding_mask[i, diff:] = True
else:
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
audio, audio_size
)
return collated_audios, padding_mask, audio_starts
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
assert label_rate > 0
s2f = label_rate / self.sample_rate
frm_starts = [int(round(s * s2f)) for s in audio_starts]
frm_size = int(round(audio_size * s2f))
if not self.pad_audio:
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
frm_size = min(frm_size, *rem_size)
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
logger.debug(f"audio_starts={audio_starts}")
logger.debug(f"frame_starts={frm_starts}")
logger.debug(f"frame_size={frm_size}")
lengths = torch.LongTensor([len(t) for t in targets])
ntokens = lengths.sum().item()
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
return targets, lengths, ntokens
def collater_seq_label(self, targets, pad):
lengths = torch.LongTensor([len(t) for t in targets])
ntokens = lengths.sum().item()
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
return targets, lengths, ntokens
def collater_label(self, targets_by_label, audio_size, audio_starts):
targets_list, lengths_list, ntokens_list = [], [], []
itr = zip(targets_by_label, self.label_rates, self.pad_list)
for targets, label_rate, pad in itr:
if label_rate == -1.0:
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
else:
targets, lengths, ntokens = self.collater_frm_label(
targets, audio_size, audio_starts, label_rate, pad
)
targets_list.append(targets)
lengths_list.append(lengths)
ntokens_list.append(ntokens)
return targets_list, lengths_list, ntokens_list
def num_tokens(self, index):
return self.size(index)
def size(self, index):
if self.pad_audio:
return self.sizes[index]
return min(self.sizes[index], self.max_sample_size)
def ordered_indices(self):
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)[::-1]
def postprocess(self, wav, cur_sample_rate):
if wav.dim() == 2:
wav = wav.mean(-1)
assert wav.dim() == 1, wav.dim()
if cur_sample_rate != self.sample_rate:
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
if self.normalize:
with torch.no_grad():
wav = F.layer_norm(wav, wav.shape)
return wav

View File

@@ -0,0 +1,284 @@
# Copyright (c) 2021-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
import math
from typing import List, Optional, NamedTuple
import numpy as np
from fairseq.data.resampling_dataset import ResamplingDataset
import torch
from fairseq.data import (
ConcatDataset,
LanguagePairDataset,
FileAudioDataset,
data_utils,
)
from fairseq.data import FairseqDataset
logger = logging.getLogger(__name__)
class ModalityDatasetItem(NamedTuple):
datasetname: str
dataset: any
max_positions: List[int]
max_tokens: Optional[int] = None
max_sentences: Optional[int] = None
def resampling_dataset_present(ds):
if isinstance(ds, ResamplingDataset):
return True
if isinstance(ds, ConcatDataset):
return any(resampling_dataset_present(d) for d in ds.datasets)
if hasattr(ds, "dataset"):
return resampling_dataset_present(ds.dataset)
return False
# MultiModalityDataset: it concate multiple datasets with different modalities.
# Compared with ConcatDataset it can 1) sample data given the ratios for different datasets
# 2) it adds mode to indicate what type of the data samples come from.
# It will be used with GroupedEpochBatchIterator together to generate mini-batch with samples
# from the same type of dataset
# If only one dataset is used, it will perform like the original dataset with mode added
class MultiModalityDataset(ConcatDataset):
def __init__(self, datasets: List[ModalityDatasetItem]):
id_to_mode = []
dsets = []
max_tokens = []
max_sentences = []
max_positions = []
for dset in datasets:
id_to_mode.append(dset.datasetname)
dsets.append(dset.dataset)
max_tokens.append(dset.max_tokens)
max_positions.append(dset.max_positions)
max_sentences.append(dset.max_sentences)
weights = [1.0 for s in dsets]
super().__init__(dsets, weights)
self.max_tokens = max_tokens
self.max_positions = max_positions
self.max_sentences = max_sentences
self.id_to_mode = id_to_mode
self.raw_sub_batch_samplers = []
self._cur_epoch = 0
def set_epoch(self, epoch):
super().set_epoch(epoch)
self._cur_epoch = epoch
def __getitem__(self, idx):
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
sample = self.datasets[dataset_idx][sample_idx]
return (dataset_idx, sample)
def collater(self, samples):
if len(samples) == 0:
return {}
dataset_idx = samples[0][0]
# make sure all samples in samples are from same dataset
assert sum([0 if dataset_idx == s[0] else 1 for s in samples]) == 0
samples = self.datasets[dataset_idx].collater([x[1] for x in samples])
# add mode
samples["net_input"]["mode"] = self.id_to_mode[dataset_idx]
return samples
def size(self, index: int):
if len(self.datasets) == 1:
return self.datasets[0].size(index)
return super().size(index)
@property
def sizes(self):
if len(self.datasets) == 1:
return self.datasets[0].sizes
return super().sizes
def ordered_indices(self):
"""
Returns indices sorted by length. So less padding is needed.
"""
if len(self.datasets) == 1:
return self.datasets[0].ordered_indices()
indices_group = []
for d_idx, ds in enumerate(self.datasets):
sample_num = self.cumulative_sizes[d_idx]
if d_idx > 0:
sample_num = sample_num - self.cumulative_sizes[d_idx - 1]
assert sample_num == len(ds)
indices_group.append(ds.ordered_indices())
return indices_group
def get_raw_batch_samplers(self, required_batch_size_multiple, seed):
with data_utils.numpy_seed(seed):
indices = self.ordered_indices()
for i, ds in enumerate(self.datasets):
# If we have ResamplingDataset, the same id can correpond to a different
# sample in the next epoch, so we need to rebuild this at every epoch
if i < len(self.raw_sub_batch_samplers) and not resampling_dataset_present(
ds
):
logger.info(f"dataset {i} is valid and it is not re-sampled")
continue
indices[i] = ds.filter_indices_by_size(
indices[i],
self.max_positions[i],
)[0]
sub_batch_sampler = ds.batch_by_size(
indices[i],
max_tokens=self.max_tokens[i],
max_sentences=self.max_sentences[i],
required_batch_size_multiple=required_batch_size_multiple,
)
if i < len(self.raw_sub_batch_samplers):
self.raw_sub_batch_samplers[i] = sub_batch_sampler
else:
self.raw_sub_batch_samplers.append(sub_batch_sampler)
def get_batch_samplers(self, mult_ratios, required_batch_size_multiple, seed):
self.get_raw_batch_samplers(required_batch_size_multiple, seed)
batch_samplers = []
for i, _ in enumerate(self.datasets):
if i > 0:
sub_batch_sampler = [
[y + self.cumulative_sizes[i - 1] for y in x]
for x in self.raw_sub_batch_samplers[i]
]
else:
sub_batch_sampler = list(self.raw_sub_batch_samplers[i])
smp_r = mult_ratios[i]
if smp_r != 1:
is_increase = "increased" if smp_r > 1 else "decreased"
logger.info(
"number of batch for the dataset {} is {} from {} to {}".format(
self.id_to_mode[i],
is_increase,
len(sub_batch_sampler),
int(len(sub_batch_sampler) * smp_r),
)
)
mul_samplers = []
for _ in range(math.floor(smp_r)):
mul_samplers = mul_samplers + sub_batch_sampler
if math.floor(smp_r) != smp_r:
with data_utils.numpy_seed(seed + self._cur_epoch):
np.random.shuffle(sub_batch_sampler)
smp_num = int(
(smp_r - math.floor(smp_r)) * len(sub_batch_sampler)
)
mul_samplers = mul_samplers + sub_batch_sampler[:smp_num]
sub_batch_sampler = mul_samplers
else:
logger.info(
"dataset {} batch number is {} ".format(
self.id_to_mode[i], len(sub_batch_sampler)
)
)
batch_samplers.append(sub_batch_sampler)
return batch_samplers
class LangPairMaskDataset(FairseqDataset):
def __init__(
self,
dataset: LanguagePairDataset,
src_eos: int,
src_bos: Optional[int] = None,
noise_id: Optional[int] = -1,
mask_ratio: Optional[float] = 0,
mask_type: Optional[str] = "random",
):
self.dataset = dataset
self.src_eos = src_eos
self.src_bos = src_bos
self.noise_id = noise_id
self.mask_ratio = mask_ratio
self.mask_type = mask_type
assert mask_type in ("random", "tail")
@property
def src_sizes(self):
return self.dataset.src_sizes
@property
def tgt_sizes(self):
return self.dataset.tgt_sizes
@property
def sizes(self):
# dataset.sizes can be a dynamically computed sizes:
return self.dataset.sizes
def get_batch_shapes(self):
if hasattr(self.dataset, "get_batch_shapes"):
return self.dataset.get_batch_shapes()
return self.dataset.buckets
def num_tokens_vec(self, indices):
return self.dataset.num_tokens_vec(indices)
def __len__(self):
return len(self.dataset)
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
return self.dataset.size(index)
def ordered_indices(self):
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)
def mask_src_tokens(self, sample):
src_item = sample["source"]
mask = None
if self.mask_type == "random":
mask = torch.rand(len(src_item)).le(self.mask_ratio)
else:
mask = torch.ones(len(src_item))
mask[: int(len(src_item) * (1 - self.mask_ratio))] = 0
mask = mask.eq(1)
if src_item[0] == self.src_bos:
mask[0] = False
if src_item[-1] == self.src_eos:
mask[-1] = False
mask_src_item = src_item.masked_fill(mask, self.noise_id)
smp = {"id": sample["id"], "source": mask_src_item, "target": sample["target"]}
return smp
def __getitem__(self, index):
sample = self.dataset[index]
if self.mask_ratio > 0:
sample = self.mask_src_tokens(sample)
return sample
def collater(self, samples, pad_to_length=None):
return self.dataset.collater(samples, pad_to_length)
class FileAudioDatasetWrapper(FileAudioDataset):
def collater(self, samples):
samples = super().collater(samples)
if len(samples) == 0:
return {}
samples["net_input"]["src_tokens"] = samples["net_input"]["source"]
samples["net_input"]["prev_output_tokens"] = None
del samples["net_input"]["source"]
samples["net_input"]["src_lengths"] = None
samples["net_input"]["alignment"] = None
return samples

View File

@@ -0,0 +1,393 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import sys
import io
import numpy as np
import torch
import torch.nn.functional as F
from .. import FairseqDataset
from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes
from fairseq.data.audio.audio_utils import (
parse_path,
read_from_stored_zip,
is_sf_audio_data,
)
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
logger = logging.getLogger(__name__)
class RawAudioDataset(FairseqDataset):
def __init__(
self,
sample_rate,
max_sample_size=None,
min_sample_size=0,
shuffle=True,
pad=False,
normalize=False,
compute_mask_indices=False,
**mask_compute_kwargs,
):
super().__init__()
self.sample_rate = sample_rate
self.sizes = []
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
self.min_sample_size = min_sample_size
self.pad = pad
self.shuffle = shuffle
self.normalize = normalize
self.compute_mask_indices = compute_mask_indices
if self.compute_mask_indices:
self.mask_compute_kwargs = mask_compute_kwargs
self._features_size_map = {}
self._C = mask_compute_kwargs["encoder_embed_dim"]
self._conv_feature_layers = eval(mask_compute_kwargs["conv_feature_layers"])
def __getitem__(self, index):
raise NotImplementedError()
def __len__(self):
return len(self.sizes)
def postprocess(self, feats, curr_sample_rate):
if feats.dim() == 2:
feats = feats.mean(-1)
if curr_sample_rate != self.sample_rate:
raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}")
assert feats.dim() == 1, feats.dim()
if self.normalize:
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
return feats
def crop_to_max_size(self, wav, target_size):
size = len(wav)
diff = size - target_size
if diff <= 0:
return wav
start = np.random.randint(0, diff + 1)
end = size - diff + start
return wav[start:end]
def _compute_mask_indices(self, dims, padding_mask):
B, T, C = dims
mask_indices, mask_channel_indices = None, None
if self.mask_compute_kwargs["mask_prob"] > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_compute_kwargs["mask_prob"],
self.mask_compute_kwargs["mask_length"],
self.mask_compute_kwargs["mask_selection"],
self.mask_compute_kwargs["mask_other"],
min_masks=2,
no_overlap=self.mask_compute_kwargs["no_mask_overlap"],
min_space=self.mask_compute_kwargs["mask_min_space"],
)
mask_indices = torch.from_numpy(mask_indices)
if self.mask_compute_kwargs["mask_channel_prob"] > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_compute_kwargs["mask_channel_prob"],
self.mask_compute_kwargs["mask_channel_length"],
self.mask_compute_kwargs["mask_channel_selection"],
self.mask_compute_kwargs["mask_channel_other"],
no_overlap=self.mask_compute_kwargs["no_mask_channel_overlap"],
min_space=self.mask_compute_kwargs["mask_channel_min_space"],
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices).unsqueeze(1).expand(-1, T, -1)
)
return mask_indices, mask_channel_indices
@staticmethod
def _bucket_tensor(tensor, num_pad, value):
return F.pad(tensor, (0, num_pad), value=value)
def collater(self, samples):
samples = [s for s in samples if s["source"] is not None]
if len(samples) == 0:
return {}
sources = [s["source"] for s in samples]
sizes = [len(s) for s in sources]
if self.pad:
target_size = min(max(sizes), self.max_sample_size)
else:
target_size = min(min(sizes), self.max_sample_size)
collated_sources = sources[0].new_zeros(len(sources), target_size)
padding_mask = (
torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None
)
for i, (source, size) in enumerate(zip(sources, sizes)):
diff = size - target_size
if diff == 0:
collated_sources[i] = source
elif diff < 0:
assert self.pad
collated_sources[i] = torch.cat(
[source, source.new_full((-diff,), 0.0)]
)
padding_mask[i, diff:] = True
else:
collated_sources[i] = self.crop_to_max_size(source, target_size)
input = {"source": collated_sources}
out = {"id": torch.LongTensor([s["id"] for s in samples])}
if self.pad:
input["padding_mask"] = padding_mask
if hasattr(self, "num_buckets") and self.num_buckets > 0:
assert self.pad, "Cannot bucket without padding first."
bucket = max(self._bucketed_sizes[s["id"]] for s in samples)
num_pad = bucket - collated_sources.size(-1)
if num_pad:
input["source"] = self._bucket_tensor(collated_sources, num_pad, 0)
input["padding_mask"] = self._bucket_tensor(padding_mask, num_pad, True)
if self.compute_mask_indices:
B = input["source"].size(0)
T = self._get_mask_indices_dims(input["source"].size(-1))
padding_mask_reshaped = input["padding_mask"].clone()
extra = padding_mask_reshaped.size(1) % T
if extra > 0:
padding_mask_reshaped = padding_mask_reshaped[:, :-extra]
padding_mask_reshaped = padding_mask_reshaped.view(
padding_mask_reshaped.size(0), T, -1
)
padding_mask_reshaped = padding_mask_reshaped.all(-1)
input["padding_count"] = padding_mask_reshaped.sum(-1).max().item()
mask_indices, mask_channel_indices = self._compute_mask_indices(
(B, T, self._C),
padding_mask_reshaped,
)
input["mask_indices"] = mask_indices
input["mask_channel_indices"] = mask_channel_indices
out["sample_size"] = mask_indices.sum().item()
out["net_input"] = input
return out
def _get_mask_indices_dims(self, size, padding=0, dilation=1):
if size not in self._features_size_map:
L_in = size
for (_, kernel_size, stride) in self._conv_feature_layers:
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
L_out = 1 + L_out // stride
L_in = L_out
self._features_size_map[size] = L_out
return self._features_size_map[size]
def num_tokens(self, index):
return self.size(index)
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
if self.pad:
return self.sizes[index]
return min(self.sizes[index], self.max_sample_size)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
order.append(
np.minimum(
np.array(self.sizes),
self.max_sample_size,
)
)
return np.lexsort(order)[::-1]
else:
return np.arange(len(self))
def set_bucket_info(self, num_buckets):
self.num_buckets = num_buckets
if self.num_buckets > 0:
self._collated_sizes = np.minimum(
np.array(self.sizes),
self.max_sample_size,
)
self.buckets = get_buckets(
self._collated_sizes,
self.num_buckets,
)
self._bucketed_sizes = get_bucketed_sizes(
self._collated_sizes, self.buckets
)
logger.info(
f"{len(self.buckets)} bucket(s) for the audio dataset: "
f"{self.buckets}"
)
class FileAudioDataset(RawAudioDataset):
def __init__(
self,
manifest_path,
sample_rate,
max_sample_size=None,
min_sample_size=0,
shuffle=True,
pad=False,
normalize=False,
num_buckets=0,
compute_mask_indices=False,
text_compression_level=TextCompressionLevel.none,
**mask_compute_kwargs,
):
super().__init__(
sample_rate=sample_rate,
max_sample_size=max_sample_size,
min_sample_size=min_sample_size,
shuffle=shuffle,
pad=pad,
normalize=normalize,
compute_mask_indices=compute_mask_indices,
**mask_compute_kwargs,
)
self.text_compressor = TextCompressor(level=text_compression_level)
skipped = 0
self.fnames = []
sizes = []
self.skipped_indices = set()
with open(manifest_path, "r") as f:
self.root_dir = f.readline().strip()
for i, line in enumerate(f):
items = line.strip().split("\t")
assert len(items) == 2, line
sz = int(items[1])
if min_sample_size is not None and sz < min_sample_size:
skipped += 1
self.skipped_indices.add(i)
continue
self.fnames.append(self.text_compressor.compress(items[0]))
sizes.append(sz)
logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples")
self.sizes = np.array(sizes, dtype=np.int64)
try:
import pyarrow
self.fnames = pyarrow.array(self.fnames)
except:
logger.debug(
"Could not create a pyarrow array. Please install pyarrow for better performance"
)
pass
self.set_bucket_info(num_buckets)
def __getitem__(self, index):
import soundfile as sf
fn = self.fnames[index]
fn = fn if isinstance(self.fnames, list) else fn.as_py()
fn = self.text_compressor.decompress(fn)
path_or_fp = os.path.join(self.root_dir, fn)
_path, slice_ptr = parse_path(path_or_fp)
if len(slice_ptr) == 2:
byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
assert is_sf_audio_data(byte_data)
path_or_fp = io.BytesIO(byte_data)
wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32")
feats = torch.from_numpy(wav).float()
feats = self.postprocess(feats, curr_sample_rate)
return {"id": index, "source": feats}
class BinarizedAudioDataset(RawAudioDataset):
def __init__(
self,
data_dir,
split,
sample_rate,
max_sample_size=None,
min_sample_size=0,
shuffle=True,
pad=False,
normalize=False,
num_buckets=0,
compute_mask_indices=False,
**mask_compute_kwargs,
):
super().__init__(
sample_rate=sample_rate,
max_sample_size=max_sample_size,
min_sample_size=min_sample_size,
shuffle=shuffle,
pad=pad,
normalize=normalize,
compute_mask_indices=compute_mask_indices,
**mask_compute_kwargs,
)
from fairseq.data import data_utils, Dictionary
self.fnames_dict = Dictionary.load(os.path.join(data_dir, "dict.txt"))
root_path = os.path.join(data_dir, f"{split}.root")
if os.path.exists(root_path):
with open(root_path, "r") as f:
self.root_dir = next(f).strip()
else:
self.root_dir = None
fnames_path = os.path.join(data_dir, split)
self.fnames = data_utils.load_indexed_dataset(fnames_path, self.fnames_dict)
lengths_path = os.path.join(data_dir, f"{split}.lengths")
with open(lengths_path, "r") as f:
for line in f:
sz = int(line.rstrip())
assert (
sz >= min_sample_size
), f"Min sample size is not supported for binarized dataset, but found a sample with size {sz}"
self.sizes.append(sz)
self.sizes = np.array(self.sizes, dtype=np.int64)
self.set_bucket_info(num_buckets)
logger.info(f"loaded {len(self.fnames)} samples")
def __getitem__(self, index):
import soundfile as sf
fname = self.fnames_dict.string(self.fnames[index], separator="")
if self.root_dir:
fname = os.path.join(self.root_dir, fname)
wav, curr_sample_rate = sf.read(fname)
feats = torch.from_numpy(wav).float()
feats = self.postprocess(feats, curr_sample_rate)
return {"id": index, "source": feats}

View File

@@ -0,0 +1,379 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
from fairseq.data import ConcatDataset, Dictionary
from fairseq.data import data_utils as fairseq_data_utils
from fairseq.data.audio.audio_utils import get_features_or_waveform
from fairseq.data.audio.data_cfg import S2SDataConfig
from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDataset,
SpeechToTextDatasetCreator,
TextTargetMultitaskData,
_collate_frames,
)
logger = logging.getLogger(__name__)
@dataclass
class SpeechToSpeechDatasetItem(object):
index: int
source: torch.Tensor
target: Optional[torch.Tensor] = None
target_speaker: Optional[torch.Tensor] = None
tgt_lang_tag: Optional[int] = None
class SpeechToSpeechDataset(SpeechToTextDataset):
def __init__(
self,
split: str,
is_train_split: bool,
data_cfg: S2SDataConfig,
src_audio_paths: List[str],
src_n_frames: List[int],
tgt_audio_paths: List[str],
tgt_n_frames: List[int],
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
target_is_code: bool = False,
tgt_dict: Dictionary = None,
n_frames_per_step: int = 1,
):
tgt_texts = tgt_audio_paths if target_is_code else None
super().__init__(
split=split,
is_train_split=is_train_split,
cfg=data_cfg,
audio_paths=src_audio_paths,
n_frames=src_n_frames,
ids=ids,
tgt_dict=tgt_dict,
tgt_texts=tgt_texts,
src_langs=src_langs,
tgt_langs=tgt_langs,
n_frames_per_step=n_frames_per_step,
)
self.tgt_audio_paths = tgt_audio_paths
self.tgt_lens = [t // self.n_frames_per_step for t in tgt_n_frames]
assert not target_is_code or tgt_dict is not None
self.target_is_code = target_is_code
assert len(tgt_audio_paths) == self.n_samples
assert len(tgt_n_frames) == self.n_samples
self.tgt_speakers = None
if self.cfg.target_speaker_embed:
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(
self.cfg.target_speaker_embed, split
)
spk_emb_dict = {s["id"]: s["speaker_embed"] for s in samples}
self.tgt_speakers = [spk_emb_dict[id] for id in self.ids]
assert len(self.tgt_speakers) == self.n_samples
logger.info(self.__repr__())
def pack_units(self, input: torch.Tensor) -> torch.Tensor:
if self.n_frames_per_step <= 1:
return input
offset = 4
vocab_size = (
len(self.tgt_dict) - offset
) # remove offset from <bos>, <pad>, <eos>, <unk>, which is specific to fairseq dictionary
assert input.dim() == 1
stacked_input = (
input[:-1].view(-1, self.n_frames_per_step) - offset
) # remove <eos>
scale = [
pow(vocab_size, self.n_frames_per_step - 1 - i)
for i in range(self.n_frames_per_step)
]
scale = torch.LongTensor(scale).squeeze(0)
res = input.new((len(input) - 1) // self.n_frames_per_step + 1).fill_(input[-1])
res[:-1] = (stacked_input * scale).sum(dim=1) + offset
return res
def __getitem__(self, index: int) -> SpeechToSpeechDatasetItem:
source = self._get_source_audio(index)
tgt_lang_tag = None
if self.cfg.prepend_tgt_lang_tag_as_bos:
# prepend_tgt_lang_tag_as_bos: put tgt_lang_tag as bos of target
tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
if not self.target_is_code:
target = get_features_or_waveform(self.tgt_audio_paths[index])
target = torch.from_numpy(target).float()
target = self.pack_frames(target)
else:
target = self.tgt_dict.encode_line(
self.tgt_audio_paths[index],
add_if_not_exist=False,
append_eos=True,
).long()
if self.n_frames_per_step > 1:
n_tgt_frame = target.size(0) - 1 # exclude <eos>
keep_n_tgt_frame = n_tgt_frame - n_tgt_frame % self.n_frames_per_step
target = torch.cat(
(
target[:keep_n_tgt_frame],
target.new_full((1,), self.tgt_dict.eos()),
),
dim=0,
)
if self.tgt_speakers:
tgt_spk = get_features_or_waveform(self.tgt_speakers[index])
tgt_spk = torch.from_numpy(tgt_spk).float()
else:
tgt_spk = torch.FloatTensor([])
return SpeechToSpeechDatasetItem(
index=index,
source=source,
target=target,
target_speaker=tgt_spk,
tgt_lang_tag=tgt_lang_tag,
)
def _collate_target(self, samples: List[SpeechToSpeechDatasetItem]) -> torch.Tensor:
if self.target_is_code:
target = fairseq_data_utils.collate_tokens(
[x.target for x in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
)
# convert stacked units to a single id
pack_targets = [self.pack_units(x.target) for x in samples]
prev_output_tokens = fairseq_data_utils.collate_tokens(
pack_targets,
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=True,
)
target_lengths = torch.tensor(
[x.size(0) for x in pack_targets], dtype=torch.long
)
else:
target = _collate_frames([x.target for x in samples], is_audio_input=False)
bsz, _, d = target.size()
prev_output_tokens = torch.cat(
(target.new_full((bsz, 1, d), 0.0), target[:, :-1, :]), dim=1
)
target_lengths = torch.tensor(
[x.target.size(0) for x in samples], dtype=torch.long
)
return target, prev_output_tokens, target_lengths
def collater(
self, samples: List[SpeechToSpeechDatasetItem], return_order: bool = False
) -> Dict:
if len(samples) == 0:
return {}
indices = torch.tensor([x.index for x in samples], dtype=torch.long)
frames = _collate_frames([x.source for x in samples], self.cfg.use_audio_input)
# sort samples by descending number of frames
n_frames = torch.tensor([x.source.size(0) for x in samples], dtype=torch.long)
n_frames, order = n_frames.sort(descending=True)
indices = indices.index_select(0, order)
frames = frames.index_select(0, order)
target, prev_output_tokens, target_lengths = self._collate_target(samples)
target = target.index_select(0, order)
target_lengths = target_lengths.index_select(0, order)
prev_output_tokens = prev_output_tokens.index_select(0, order)
ntokens = sum(x.target.size(0) for x in samples)
tgt_speakers = None
if self.cfg.target_speaker_embed:
tgt_speakers = _collate_frames(
[x.target_speaker for x in samples], is_audio_input=True
).index_select(0, order)
net_input = {
"src_tokens": frames,
"src_lengths": n_frames,
"prev_output_tokens": prev_output_tokens,
"tgt_speaker": tgt_speakers, # TODO: unify "speaker" and "tgt_speaker"
}
if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
for i in range(len(samples)):
net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
out = {
"id": indices,
"net_input": net_input,
"speaker": tgt_speakers, # to support Tacotron2 loss for speech-to-spectrogram model
"target": target,
"target_lengths": target_lengths,
"ntokens": ntokens,
"nsentences": len(samples),
}
if return_order:
out["order"] = order
return out
class SpeechToSpeechMultitaskDataset(SpeechToSpeechDataset):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.multitask_data = {}
def add_multitask_dataset(self, task_name, task_data):
self.multitask_data[task_name] = task_data
def __getitem__(
self, index: int
) -> Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]:
s2s_data = super().__getitem__(index)
multitask_target = {}
sample_id = self.ids[index]
tgt_lang = self.tgt_langs[index]
for task_name, task_dataset in self.multitask_data.items():
multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
return s2s_data, multitask_target
def collater(
self, samples: List[Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]]
) -> Dict:
if len(samples) == 0:
return {}
out = super().collater([s for s, _ in samples], return_order=True)
order = out["order"]
del out["order"]
for task_name, task_dataset in self.multitask_data.items():
if "multitask" not in out:
out["multitask"] = {}
d = [s[task_name] for _, s in samples]
task_target = task_dataset.collater(d)
out["multitask"][task_name] = {
"target": task_target["target"].index_select(0, order),
"target_lengths": task_target["target_lengths"].index_select(0, order),
"ntokens": task_target["ntokens"],
}
out["multitask"][task_name]["net_input"] = {
"prev_output_tokens": task_target["prev_output_tokens"].index_select(
0, order
),
}
return out
class SpeechToSpeechDatasetCreator(object):
# mandatory columns
KEY_ID, KEY_SRC_AUDIO, KEY_SRC_N_FRAMES = "id", "src_audio", "src_n_frames"
KEY_TGT_AUDIO, KEY_TGT_N_FRAMES = "tgt_audio", "tgt_n_frames"
# optional columns
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
# default values
DEFAULT_LANG = ""
@classmethod
def _from_list(
cls,
split_name: str,
is_train_split,
samples: List[Dict],
data_cfg: S2SDataConfig,
target_is_code: bool = False,
tgt_dict: Dictionary = None,
n_frames_per_step: int = 1,
multitask: Optional[Dict] = None,
) -> SpeechToSpeechDataset:
audio_root = Path(data_cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples]
src_audio_paths = [
(audio_root / s[cls.KEY_SRC_AUDIO]).as_posix() for s in samples
]
tgt_audio_paths = [
s[cls.KEY_TGT_AUDIO]
if target_is_code
else (audio_root / s[cls.KEY_TGT_AUDIO]).as_posix()
for s in samples
]
src_n_frames = [int(s[cls.KEY_SRC_N_FRAMES]) for s in samples]
tgt_n_frames = [int(s[cls.KEY_TGT_N_FRAMES]) for s in samples]
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
has_multitask = multitask is not None and len(multitask.keys()) > 0
dataset_cls = (
SpeechToSpeechMultitaskDataset if has_multitask else SpeechToSpeechDataset
)
ds = dataset_cls(
split=split_name,
is_train_split=is_train_split,
data_cfg=data_cfg,
src_audio_paths=src_audio_paths,
src_n_frames=src_n_frames,
tgt_audio_paths=tgt_audio_paths,
tgt_n_frames=tgt_n_frames,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
target_is_code=target_is_code,
tgt_dict=tgt_dict,
n_frames_per_step=n_frames_per_step,
)
if has_multitask:
for task_name, task_obj in multitask.items():
task_data = TextTargetMultitaskData(
task_obj.args, split_name, task_obj.target_dictionary
)
ds.add_multitask_dataset(task_name, task_data)
return ds
@classmethod
def from_tsv(
cls,
root: str,
data_cfg: S2SDataConfig,
splits: str,
is_train_split: bool,
epoch: int,
seed: int,
target_is_code: bool = False,
tgt_dict: Dictionary = None,
n_frames_per_step: int = 1,
multitask: Optional[Dict] = None,
) -> SpeechToSpeechDataset:
datasets = []
for split in splits.split(","):
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(root, split)
ds = cls._from_list(
split_name=split,
is_train_split=is_train_split,
samples=samples,
data_cfg=data_cfg,
target_is_code=target_is_code,
tgt_dict=tgt_dict,
n_frames_per_step=n_frames_per_step,
multitask=multitask,
)
datasets.append(ds)
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]

View File

@@ -0,0 +1,733 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import csv
import logging
import re
from argparse import Namespace
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset
from fairseq.data import data_utils as fairseq_data_utils
from fairseq.data import encoders
from fairseq.data.audio.audio_utils import get_features_or_waveform
from fairseq.data.audio.data_cfg import S2TDataConfig
from fairseq.data.audio.dataset_transforms import CompositeAudioDatasetTransform
from fairseq.data.audio.dataset_transforms.concataugment import ConcatAugment
from fairseq.data.audio.dataset_transforms.noisyoverlapaugment import (
NoisyOverlapAugment,
)
from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform
from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
logger = logging.getLogger(__name__)
def _collate_frames(
frames: List[torch.Tensor], is_audio_input: bool = False
) -> torch.Tensor:
"""
Convert a list of 2D frames into a padded 3D tensor
Args:
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
length of i-th frame and f_dim is static dimension of features
Returns:
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
"""
max_len = max(frame.size(0) for frame in frames)
if is_audio_input:
out = frames[0].new_zeros((len(frames), max_len))
else:
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
for i, v in enumerate(frames):
out[i, : v.size(0)] = v
return out
def _is_int_or_np_int(n):
return isinstance(n, int) or (
isinstance(n, np.generic) and isinstance(n.item(), int)
)
@dataclass
class SpeechToTextDatasetItem(object):
index: int
source: torch.Tensor
target: Optional[torch.Tensor] = None
speaker_id: Optional[int] = None
class SpeechToTextDataset(FairseqDataset):
LANG_TAG_TEMPLATE = "<lang:{}>"
def __init__(
self,
split: str,
is_train_split: bool,
cfg: S2TDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
n_frames_per_step=1,
speaker_to_id=None,
append_eos=True,
):
self.split, self.is_train_split = split, is_train_split
self.cfg = cfg
self.audio_paths, self.n_frames = audio_paths, n_frames
self.n_samples = len(audio_paths)
assert len(n_frames) == self.n_samples > 0
assert src_texts is None or len(src_texts) == self.n_samples
assert tgt_texts is None or len(tgt_texts) == self.n_samples
assert speakers is None or len(speakers) == self.n_samples
assert src_langs is None or len(src_langs) == self.n_samples
assert tgt_langs is None or len(tgt_langs) == self.n_samples
assert ids is None or len(ids) == self.n_samples
assert (tgt_dict is None and tgt_texts is None) or (
tgt_dict is not None and tgt_texts is not None
)
self.src_texts, self.tgt_texts = src_texts, tgt_texts
self.src_langs, self.tgt_langs = src_langs, tgt_langs
self.speakers = speakers
self.tgt_dict = tgt_dict
self.check_tgt_lang_tag()
self.ids = ids
self.shuffle = cfg.shuffle if is_train_split else False
self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
self.cfg.get_feature_transforms(split, is_train_split)
)
self.waveform_transforms = CompositeAudioWaveformTransform.from_config_dict(
self.cfg.get_waveform_transforms(split, is_train_split)
)
# TODO: add these to data_cfg.py
self.dataset_transforms = CompositeAudioDatasetTransform.from_config_dict(
self.cfg.get_dataset_transforms(split, is_train_split)
)
# check proper usage of transforms
if self.feature_transforms and self.cfg.use_audio_input:
logger.warning(
"Feature transforms will not be applied. To use feature transforms, "
"set use_audio_input as False in config."
)
self.pre_tokenizer = pre_tokenizer
self.bpe_tokenizer = bpe_tokenizer
self.n_frames_per_step = n_frames_per_step
self.speaker_to_id = speaker_to_id
self.tgt_lens = self.get_tgt_lens_and_check_oov()
self.append_eos = append_eos
logger.info(self.__repr__())
def get_tgt_lens_and_check_oov(self):
if self.tgt_texts is None:
return [0 for _ in range(self.n_samples)]
tgt_lens = []
n_tokens, n_oov_tokens = 0, 0
for i in range(self.n_samples):
tokenized = self.get_tokenized_tgt_text(i).split(" ")
oov_tokens = [
t
for t in tokenized
if self.tgt_dict.index(t) == self.tgt_dict.unk_index
]
n_tokens += len(tokenized)
n_oov_tokens += len(oov_tokens)
tgt_lens.append(len(tokenized))
logger.info(f"'{self.split}' has {n_oov_tokens / n_tokens * 100:.2f}% OOV")
return tgt_lens
def __repr__(self):
return (
self.__class__.__name__
+ f'(split="{self.split}", n_samples={self.n_samples:_}, '
f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, "
f"n_frames_per_step={self.n_frames_per_step}, "
f"shuffle={self.shuffle}, "
f"feature_transforms={self.feature_transforms}, "
f"waveform_transforms={self.waveform_transforms}, "
f"dataset_transforms={self.dataset_transforms})"
)
@classmethod
def is_lang_tag(cls, token):
pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
return re.match(pattern, token)
def check_tgt_lang_tag(self):
if self.cfg.prepend_tgt_lang_tag:
assert self.tgt_langs is not None and self.tgt_dict is not None
tgt_lang_tags = [
self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
]
assert all(t in self.tgt_dict for t in tgt_lang_tags)
@classmethod
def tokenize(cls, tokenizer, text: str):
return text if tokenizer is None else tokenizer.encode(text)
def get_tokenized_tgt_text(self, index: Union[int, List[int]]):
if _is_int_or_np_int(index):
text = self.tgt_texts[index]
else:
text = " ".join([self.tgt_texts[i] for i in index])
text = self.tokenize(self.pre_tokenizer, text)
text = self.tokenize(self.bpe_tokenizer, text)
return text
def pack_frames(self, feature: torch.Tensor):
if self.n_frames_per_step == 1:
return feature
n_packed_frames = feature.shape[0] // self.n_frames_per_step
feature = feature[: self.n_frames_per_step * n_packed_frames]
return feature.reshape(n_packed_frames, -1)
@classmethod
def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary):
lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang))
assert lang_tag_idx != dictionary.unk()
return lang_tag_idx
def _get_source_audio(self, index: Union[int, List[int]]) -> torch.Tensor:
"""
Gives source audio for given index with any relevant transforms
applied. For ConcatAug, source audios for given indices are
concatenated in given order.
Args:
index (int or List[int]): index—or in the case of ConcatAug,
indices—to pull the source audio for
Returns:
source audios concatenated for given indices with
relevant transforms appplied
"""
if _is_int_or_np_int(index):
source = get_features_or_waveform(
self.audio_paths[index],
need_waveform=self.cfg.use_audio_input,
use_sample_rate=self.cfg.use_sample_rate,
waveform_transforms=self.waveform_transforms,
)
else:
source = np.concatenate(
[
get_features_or_waveform(
self.audio_paths[i],
need_waveform=self.cfg.use_audio_input,
use_sample_rate=self.cfg.use_sample_rate,
waveform_transforms=self.waveform_transforms,
)
for i in index
]
)
if self.cfg.use_audio_input:
source = torch.from_numpy(source).float()
if self.cfg.standardize_audio:
with torch.no_grad():
source = F.layer_norm(source, source.shape)
else:
if self.feature_transforms is not None:
source = self.feature_transforms(source)
source = torch.from_numpy(source).float()
return source
def __getitem__(self, index: int) -> SpeechToTextDatasetItem:
has_concat = self.dataset_transforms.has_transform(ConcatAugment)
if has_concat:
concat = self.dataset_transforms.get_transform(ConcatAugment)
indices = concat.find_indices(index, self.n_frames, self.n_samples)
source = self._get_source_audio(indices if has_concat else index)
source = self.pack_frames(source)
target = None
if self.tgt_texts is not None:
tokenized = self.get_tokenized_tgt_text(indices if has_concat else index)
target = self.tgt_dict.encode_line(
tokenized, add_if_not_exist=False, append_eos=self.append_eos
).long()
if self.cfg.prepend_tgt_lang_tag:
lang_tag_idx = self.get_lang_tag_idx(
self.tgt_langs[index], self.tgt_dict
)
target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
if self.cfg.prepend_bos_and_append_tgt_lang_tag:
bos = torch.LongTensor([self.tgt_dict.bos()])
lang_tag_idx = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
assert lang_tag_idx != self.tgt_dict.unk()
lang_tag_idx = torch.LongTensor([lang_tag_idx])
target = torch.cat((bos, target, lang_tag_idx), 0)
speaker_id = None
if self.speaker_to_id is not None:
speaker_id = self.speaker_to_id[self.speakers[index]]
return SpeechToTextDatasetItem(
index=index, source=source, target=target, speaker_id=speaker_id
)
def __len__(self):
return self.n_samples
def collater(
self, samples: List[SpeechToTextDatasetItem], return_order: bool = False
) -> Dict:
if len(samples) == 0:
return {}
indices = torch.tensor([x.index for x in samples], dtype=torch.long)
sources = [x.source for x in samples]
has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment)
if has_NOAug and self.cfg.use_audio_input:
NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment)
sources = NOAug(sources)
frames = _collate_frames(sources, self.cfg.use_audio_input)
# sort samples by descending number of frames
n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long)
n_frames, order = n_frames.sort(descending=True)
indices = indices.index_select(0, order)
frames = frames.index_select(0, order)
target, target_lengths = None, None
prev_output_tokens = None
ntokens = None
if self.tgt_texts is not None:
target = fairseq_data_utils.collate_tokens(
[x.target for x in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
)
target = target.index_select(0, order)
target_lengths = torch.tensor(
[x.target.size(0) for x in samples], dtype=torch.long
).index_select(0, order)
prev_output_tokens = fairseq_data_utils.collate_tokens(
[x.target for x in samples],
self.tgt_dict.pad(),
eos_idx=None,
left_pad=False,
move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, order)
ntokens = sum(x.target.size(0) for x in samples)
speaker = None
if self.speaker_to_id is not None:
speaker = (
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
.index_select(0, order)
.view(-1, 1)
)
net_input = {
"src_tokens": frames,
"src_lengths": n_frames,
"prev_output_tokens": prev_output_tokens,
}
out = {
"id": indices,
"net_input": net_input,
"speaker": speaker,
"target": target,
"target_lengths": target_lengths,
"ntokens": ntokens,
"nsentences": len(samples),
}
if return_order:
out["order"] = order
return out
def num_tokens(self, index):
return self.n_frames[index]
def size(self, index):
return self.n_frames[index], self.tgt_lens[index]
@property
def sizes(self):
return np.array(self.n_frames)
@property
def can_reuse_epoch_itr_across_epochs(self):
return True
def ordered_indices(self):
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
# first by descending order of # of frames then by original/random order
order.append([-n for n in self.n_frames])
return np.lexsort(order)
def prefetch(self, indices):
raise False
class TextTargetMultitaskData(object):
# mandatory columns
KEY_ID, KEY_TEXT = "id", "tgt_text"
LANG_TAG_TEMPLATE = "<lang:{}>"
def __init__(self, args, split, tgt_dict):
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split)
self.data = {s[self.KEY_ID]: s[self.KEY_TEXT] for s in samples}
self.dict = tgt_dict
self.append_eos = args.decoder_type != "ctc"
self.pre_tokenizer = self.build_tokenizer(args)
self.bpe_tokenizer = self.build_bpe(args)
self.prepend_bos_and_append_tgt_lang_tag = (
args.prepend_bos_and_append_tgt_lang_tag
)
self.eos_token = args.eos_token
self.lang_tag_mapping = args.get_lang_tag_mapping
@classmethod
def is_lang_tag(cls, token):
pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
return re.match(pattern, token)
@classmethod
def tokenize(cls, tokenizer, text: str):
return text if tokenizer is None else tokenizer.encode(text)
def get_tokenized_tgt_text(self, index: int):
text = self.tokenize(self.pre_tokenizer, self.data[index])
text = self.tokenize(self.bpe_tokenizer, text)
return text
def get_lang_tag_idx(self, lang: str, dictionary: Dictionary):
lang_tag = self.LANG_TAG_TEMPLATE.format(lang)
lang_tag = self.lang_tag_mapping.get(lang_tag, lang_tag)
lang_tag_idx = dictionary.index(lang_tag)
assert lang_tag_idx != dictionary.unk(), (lang, lang_tag)
return lang_tag_idx
def build_tokenizer(self, args):
pre_tokenizer = args.config.get("pre_tokenizer")
if pre_tokenizer is not None:
logger.info(f"pre-tokenizer: {pre_tokenizer}")
return encoders.build_tokenizer(Namespace(**pre_tokenizer))
else:
return None
def build_bpe(self, args):
bpe_tokenizer = args.config.get("bpe_tokenizer")
if bpe_tokenizer is not None:
logger.info(f"tokenizer: {bpe_tokenizer}")
return encoders.build_bpe(Namespace(**bpe_tokenizer))
else:
return None
def get(self, sample_id, tgt_lang=None):
if sample_id in self.data:
tokenized = self.get_tokenized_tgt_text(sample_id)
target = self.dict.encode_line(
tokenized,
add_if_not_exist=False,
append_eos=self.append_eos,
)
if self.prepend_bos_and_append_tgt_lang_tag:
bos = torch.LongTensor([self.dict.bos()])
lang_tag_idx = self.get_lang_tag_idx(tgt_lang, self.dict)
assert lang_tag_idx != self.dict.unk()
lang_tag_idx = torch.LongTensor([lang_tag_idx])
target = torch.cat((bos, target, lang_tag_idx), 0)
return target
else:
logger.warning(f"no target for {sample_id}")
return torch.IntTensor([])
def collater(self, samples: List[torch.Tensor]) -> torch.Tensor:
out = fairseq_data_utils.collate_tokens(
samples,
self.dict.pad(),
eos_idx=None,
left_pad=False,
move_eos_to_beginning=False,
).long()
prev_out = fairseq_data_utils.collate_tokens(
samples,
self.dict.pad(),
eos_idx=None,
left_pad=False,
move_eos_to_beginning=True,
).long()
target_lengths = torch.tensor([t.size(0) for t in samples], dtype=torch.long)
ntokens = sum(t.size(0) for t in samples)
output = {
"prev_output_tokens": prev_out,
"target": out,
"target_lengths": target_lengths,
"ntokens": ntokens,
}
return output
class SpeechToTextMultitaskDataset(SpeechToTextDataset):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.multitask_data = {}
def add_multitask_dataset(self, task_name, task_data):
self.multitask_data[task_name] = task_data
def __getitem__(
self, index: int
) -> Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]:
s2t_data = super().__getitem__(index)
multitask_target = {}
sample_id = self.ids[index]
tgt_lang = self.tgt_langs[index]
for task_name, task_dataset in self.multitask_data.items():
multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
return s2t_data, multitask_target
def collater(
self, samples: List[Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]]
) -> Dict:
if len(samples) == 0:
return {}
out = super().collater([s for s, _ in samples], return_order=True)
order = out["order"]
del out["order"]
for task_name, task_dataset in self.multitask_data.items():
if "multitask" not in out:
out["multitask"] = {}
d = [s[task_name] for _, s in samples]
task_target = task_dataset.collater(d)
out["multitask"][task_name] = {
"target": task_target["target"].index_select(0, order),
"target_lengths": task_target["target_lengths"].index_select(0, order),
"ntokens": task_target["ntokens"],
}
out["multitask"][task_name]["net_input"] = {
"prev_output_tokens": task_target["prev_output_tokens"].index_select(
0, order
),
}
return out
class SpeechToTextDatasetCreator(object):
# mandatory columns
KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames"
KEY_TGT_TEXT = "tgt_text"
# optional columns
KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text"
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
# default values
DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = ""
@classmethod
def _from_list(
cls,
split_name: str,
is_train_split,
samples: List[Dict],
cfg: S2TDataConfig,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
multitask: Optional[Dict] = None,
) -> SpeechToTextDataset:
audio_root = Path(cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples]
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
has_multitask = multitask is not None and len(multitask.keys()) > 0
dataset_cls = (
SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset
)
ds = dataset_cls(
split=split_name,
is_train_split=is_train_split,
cfg=cfg,
audio_paths=audio_paths,
n_frames=n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
)
if has_multitask:
for task_name, task_obj in multitask.items():
task_data = TextTargetMultitaskData(
task_obj.args, split_name, task_obj.target_dictionary
)
ds.add_multitask_dataset(task_name, task_data)
return ds
@classmethod
def get_size_ratios(
cls, datasets: List[SpeechToTextDataset], alpha: float = 1.0
) -> List[float]:
"""Size ratios for temperature-based sampling
(https://arxiv.org/abs/1907.05019)"""
id_to_lp, lp_to_sz = {}, defaultdict(int)
for ds in datasets:
lang_pairs = {f"{s}->{t}" for s, t in zip(ds.src_langs, ds.tgt_langs)}
assert len(lang_pairs) == 1
lang_pair = list(lang_pairs)[0]
id_to_lp[ds.split] = lang_pair
lp_to_sz[lang_pair] += sum(ds.n_frames)
sz_sum = sum(v for v in lp_to_sz.values())
lp_to_prob = {k: v / sz_sum for k, v in lp_to_sz.items()}
lp_to_tgt_prob = {k: v**alpha for k, v in lp_to_prob.items()}
prob_sum = sum(v for v in lp_to_tgt_prob.values())
lp_to_tgt_prob = {k: v / prob_sum for k, v in lp_to_tgt_prob.items()}
lp_to_sz_ratio = {
k: (lp_to_tgt_prob[k] * sz_sum) / v for k, v in lp_to_sz.items()
}
size_ratio = [lp_to_sz_ratio[id_to_lp[ds.split]] for ds in datasets]
p_formatted = {
k: f"{lp_to_prob[k]:.3f}->{lp_to_tgt_prob[k]:.3f}" for k in lp_to_sz
}
logger.info(f"sampling probability balancing: {p_formatted}")
sr_formatted = {ds.split: f"{r:.3f}" for ds, r in zip(datasets, size_ratio)}
logger.info(f"balanced sampling size ratio: {sr_formatted}")
return size_ratio
@classmethod
def _load_samples_from_tsv(cls, root: str, split: str):
tsv_path = Path(root) / f"{split}.tsv"
if not tsv_path.is_file():
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
with open(tsv_path) as f:
reader = csv.DictReader(
f,
delimiter="\t",
quotechar=None,
doublequote=False,
lineterminator="\n",
quoting=csv.QUOTE_NONE,
)
samples = [dict(e) for e in reader]
if len(samples) == 0:
raise ValueError(f"Empty manifest: {tsv_path}")
return samples
@classmethod
def _from_tsv(
cls,
root: str,
cfg: S2TDataConfig,
split: str,
tgt_dict,
is_train_split: bool,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
multitask: Optional[Dict] = None,
) -> SpeechToTextDataset:
samples = cls._load_samples_from_tsv(root, split)
return cls._from_list(
split,
is_train_split,
samples,
cfg,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
multitask,
)
@classmethod
def from_tsv(
cls,
root: str,
cfg: S2TDataConfig,
splits: str,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
is_train_split: bool,
epoch: int,
seed: int,
n_frames_per_step: int = 1,
speaker_to_id=None,
multitask: Optional[Dict] = None,
) -> SpeechToTextDataset:
datasets = [
cls._from_tsv(
root=root,
cfg=cfg,
split=split,
tgt_dict=tgt_dict,
is_train_split=is_train_split,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
multitask=multitask,
)
for split in splits.split(",")
]
if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
# temperature-based sampling
size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
datasets = [
ResamplingDataset(
d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
)
for r, d in zip(size_ratios, datasets)
]
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]

View File

@@ -0,0 +1,359 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from pathlib import Path
from typing import Dict, List, NamedTuple, Optional
import torch
from fairseq.data import ConcatDataset, Dictionary, ResamplingDataset
from fairseq.data import data_utils as fairseq_data_utils
from fairseq.data.audio.speech_to_text_dataset import (
S2TDataConfig,
SpeechToTextDataset,
SpeechToTextDatasetCreator,
)
logger = logging.getLogger(__name__)
class S2TJointDataConfig(S2TDataConfig):
"""Wrapper class for data config YAML"""
@property
def src_vocab_filename(self):
"""fairseq vocabulary file under data root"""
return self.config.get("src_vocab_filename", "src_dict.txt")
@property
def src_pre_tokenizer(self) -> Dict:
"""Pre-tokenizer to apply before subword tokenization. Returning
a dictionary with `tokenizer` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
return self.config.get("src_pre_tokenizer", {"tokenizer": None})
@property
def src_bpe_tokenizer(self) -> Dict:
"""Subword tokenizer to apply on source text after pre-tokenization.
Returning a dictionary with `bpe` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
return self.config.get("src_bpe_tokenizer", {"bpe": None})
@property
def prepend_tgt_lang_tag_no_change(self) -> bool:
"""Prepend target lang ID token as the prev_output_tokens BOS (e.g. for
to-many multilingual setting). No change needed during inference.
This option is deprecated and replaced by prepend_tgt_lang_tag_as_bos.
"""
value = self.config.get("prepend_tgt_lang_tag_no_change", None)
if value is None:
return self.config.get("prepend_tgt_lang_tag_as_bos", False)
return value
@property
def sampling_text_alpha(self):
"""Hyper-parameter alpha = 1/T for temperature-based resampling. (text
input only) (alpha = 1 for no resampling)"""
return self.config.get("sampling_text_alpha", 1.0)
class SpeechToTextJointDatasetItem(NamedTuple):
index: int
source: torch.Tensor
target: Optional[torch.Tensor] = None
src_txt_tokens: Optional[torch.Tensor] = None
tgt_lang_tag: Optional[int] = None
src_lang_tag: Optional[int] = None
tgt_alignment: Optional[torch.Tensor] = None
# use_src_lang_id:
# 0: don't use src_lang_id
# 1: attach src_lang_id to the src_txt_tokens as eos
class SpeechToTextJointDataset(SpeechToTextDataset):
def __init__(
self,
split: str,
is_train_split: bool,
cfg: S2TJointDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
src_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
src_pre_tokenizer=None,
src_bpe_tokenizer=None,
append_eos: Optional[bool] = True,
alignment: Optional[List[str]] = None,
use_src_lang_id: Optional[int] = 0,
):
super().__init__(
split,
is_train_split,
cfg,
audio_paths,
n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
append_eos=append_eos,
)
self.src_dict = src_dict
self.src_pre_tokenizer = src_pre_tokenizer
self.src_bpe_tokenizer = src_bpe_tokenizer
self.alignment = None
self.use_src_lang_id = use_src_lang_id
if alignment is not None:
self.alignment = [
[float(s) for s in sample.split()] for sample in alignment
]
def get_tokenized_src_text(self, index: int):
text = self.tokenize(self.src_pre_tokenizer, self.src_texts[index])
text = self.tokenize(self.src_bpe_tokenizer, text)
return text
def __getitem__(self, index: int) -> SpeechToTextJointDatasetItem:
s2t_dataset_item = super().__getitem__(index)
src_tokens = None
src_lang_tag = None
if self.src_texts is not None and self.src_dict is not None:
src_tokens = self.get_tokenized_src_text(index)
src_tokens = self.src_dict.encode_line(
src_tokens, add_if_not_exist=False, append_eos=True
).long()
if self.use_src_lang_id > 0:
src_lang_tag = self.get_lang_tag_idx(
self.src_langs[index], self.src_dict
)
tgt_lang_tag = None
if self.cfg.prepend_tgt_lang_tag_no_change:
# prepend_tgt_lang_tag_no_change: modify prev_output_tokens instead
tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
ali = None
if self.alignment is not None:
ali = torch.Tensor(self.alignment[index]).float()
return SpeechToTextJointDatasetItem(
index=index,
source=s2t_dataset_item.source,
target=s2t_dataset_item.target,
src_txt_tokens=src_tokens,
tgt_lang_tag=tgt_lang_tag,
src_lang_tag=src_lang_tag,
tgt_alignment=ali,
)
def __len__(self):
return self.n_samples
def collater(self, samples: List[SpeechToTextJointDatasetItem]) -> Dict:
s2t_out = super().collater(samples, return_order=True)
if s2t_out == {}:
return s2t_out
net_input, order = s2t_out["net_input"], s2t_out["order"]
if self.src_texts is not None and self.src_dict is not None:
src_txt_tokens = fairseq_data_utils.collate_tokens(
[x.src_txt_tokens for x in samples],
self.src_dict.pad(),
self.src_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
)
src_txt_lengths = torch.tensor(
[x.src_txt_tokens.size()[0] for x in samples], dtype=torch.long
)
if self.use_src_lang_id > 0:
src_lang_idxs = torch.tensor(
[s.src_lang_tag for s in samples], dtype=src_txt_tokens.dtype
)
if self.use_src_lang_id == 1: # replace eos with lang_id
eos_idx = src_txt_lengths - 1
src_txt_tokens.scatter_(
1, eos_idx.view(-1, 1), src_lang_idxs.view(-1, 1)
)
else:
raise NotImplementedError("Implementation is required")
src_txt_tokens = src_txt_tokens.index_select(0, order)
src_txt_lengths = src_txt_lengths.index_select(0, order)
net_input["src_txt_tokens"] = src_txt_tokens
net_input["src_txt_lengths"] = src_txt_lengths
net_input["alignment"] = None
if self.alignment is not None:
max_len = max([s.tgt_alignment.size(0) for s in samples])
alignment = torch.ones(len(samples), max_len).float()
for i, s in enumerate(samples):
cur_len = s.tgt_alignment.size(0)
alignment[i][:cur_len].copy_(s.tgt_alignment)
net_input["alignment"] = alignment.index_select(0, order)
if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
for i in range(len(samples)):
net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
out = {
"id": s2t_out["id"],
"net_input": net_input,
"target": s2t_out["target"],
"target_lengths": s2t_out["target_lengths"],
"ntokens": s2t_out["ntokens"],
"nsentences": len(samples),
}
return out
class SpeechToTextJointDatasetCreator(SpeechToTextDatasetCreator):
KEY_ALIGN = "align"
@classmethod
def _from_list(
cls,
split_name: str,
is_train_split,
samples: List[Dict],
cfg: S2TJointDataConfig,
tgt_dict,
src_dict,
pre_tokenizer,
bpe_tokenizer,
src_pre_tokenizer,
src_bpe_tokenizer,
append_eos,
use_src_lang_id,
) -> SpeechToTextJointDataset:
audio_root = Path(cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples]
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_alignment = None
if cls.KEY_ALIGN in samples[0].keys():
tgt_alignment = [s[cls.KEY_ALIGN] for s in samples]
return SpeechToTextJointDataset(
split_name,
is_train_split,
cfg,
audio_paths,
n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
src_dict=src_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
src_pre_tokenizer=src_pre_tokenizer,
src_bpe_tokenizer=src_bpe_tokenizer,
append_eos=append_eos,
alignment=tgt_alignment,
use_src_lang_id=use_src_lang_id,
)
@classmethod
def _from_tsv(
cls,
root: str,
cfg: S2TJointDataConfig,
split: str,
tgt_dict,
src_dict,
is_train_split: bool,
pre_tokenizer,
bpe_tokenizer,
src_pre_tokenizer,
src_bpe_tokenizer,
append_eos: bool,
use_src_lang_id: int,
) -> SpeechToTextJointDataset:
samples = cls._load_samples_from_tsv(root, split)
return cls._from_list(
split,
is_train_split,
samples,
cfg,
tgt_dict,
src_dict,
pre_tokenizer,
bpe_tokenizer,
src_pre_tokenizer,
src_bpe_tokenizer,
append_eos,
use_src_lang_id,
)
@classmethod
def from_tsv(
cls,
root: str,
cfg: S2TJointDataConfig,
splits: str,
tgt_dict,
src_dict,
pre_tokenizer,
bpe_tokenizer,
src_pre_tokenizer,
src_bpe_tokenizer,
is_train_split: bool,
epoch: int,
seed: int,
append_eos: Optional[bool] = True,
use_src_lang_id: Optional[int] = 0,
) -> SpeechToTextJointDataset:
datasets = [
cls._from_tsv(
root,
cfg,
split,
tgt_dict,
src_dict,
is_train_split,
pre_tokenizer,
bpe_tokenizer,
src_pre_tokenizer,
src_bpe_tokenizer,
append_eos=append_eos,
use_src_lang_id=use_src_lang_id,
)
for split in splits.split(",")
]
if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
# temperature-based sampling
size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
datasets = [
ResamplingDataset(
d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
)
for r, d in zip(size_ratios, datasets)
]
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]

View File

@@ -0,0 +1,250 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.abs
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional
import numpy as np
import torch
from fairseq.data import Dictionary
from fairseq.data import data_utils as fairseq_data_utils
from fairseq.data.audio.audio_utils import get_features_or_waveform
from fairseq.data.audio.speech_to_text_dataset import (
S2TDataConfig,
SpeechToTextDataset,
SpeechToTextDatasetCreator,
_collate_frames,
)
@dataclass
class TextToSpeechDatasetItem(object):
index: int
source: torch.Tensor
target: Optional[torch.Tensor] = None
speaker_id: Optional[int] = None
duration: Optional[torch.Tensor] = None
pitch: Optional[torch.Tensor] = None
energy: Optional[torch.Tensor] = None
class TextToSpeechDataset(SpeechToTextDataset):
def __init__(
self,
split: str,
is_train_split: bool,
cfg: S2TDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
n_frames_per_step=1,
speaker_to_id=None,
durations: Optional[List[List[int]]] = None,
pitches: Optional[List[str]] = None,
energies: Optional[List[str]] = None,
):
super(TextToSpeechDataset, self).__init__(
split,
is_train_split,
cfg,
audio_paths,
n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
)
self.durations = durations
self.pitches = pitches
self.energies = energies
def __getitem__(self, index: int) -> TextToSpeechDatasetItem:
s2t_item = super().__getitem__(index)
duration, pitch, energy = None, None, None
if self.durations is not None:
duration = torch.tensor(
self.durations[index] + [0], dtype=torch.long # pad 0 for EOS
)
if self.pitches is not None:
pitch = get_features_or_waveform(self.pitches[index])
pitch = torch.from_numpy(
np.concatenate((pitch, [0])) # pad 0 for EOS
).float()
if self.energies is not None:
energy = get_features_or_waveform(self.energies[index])
energy = torch.from_numpy(
np.concatenate((energy, [0])) # pad 0 for EOS
).float()
return TextToSpeechDatasetItem(
index=index,
source=s2t_item.source,
target=s2t_item.target,
speaker_id=s2t_item.speaker_id,
duration=duration,
pitch=pitch,
energy=energy,
)
def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]:
if len(samples) == 0:
return {}
src_lengths, order = torch.tensor(
[s.target.shape[0] for s in samples], dtype=torch.long
).sort(descending=True)
id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select(
0, order
)
feat = _collate_frames(
[s.source for s in samples], self.cfg.use_audio_input
).index_select(0, order)
target_lengths = torch.tensor(
[s.source.shape[0] for s in samples], dtype=torch.long
).index_select(0, order)
src_tokens = fairseq_data_utils.collate_tokens(
[s.target for s in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
).index_select(0, order)
speaker = None
if self.speaker_to_id is not None:
speaker = (
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
.index_select(0, order)
.view(-1, 1)
)
bsz, _, d = feat.size()
prev_output_tokens = torch.cat(
(feat.new_zeros((bsz, 1, d)), feat[:, :-1, :]), dim=1
)
durations, pitches, energies = None, None, None
if self.durations is not None:
durations = fairseq_data_utils.collate_tokens(
[s.duration for s in samples], 0
).index_select(0, order)
assert src_tokens.shape[1] == durations.shape[1]
if self.pitches is not None:
pitches = _collate_frames([s.pitch for s in samples], True)
pitches = pitches.index_select(0, order)
assert src_tokens.shape[1] == pitches.shape[1]
if self.energies is not None:
energies = _collate_frames([s.energy for s in samples], True)
energies = energies.index_select(0, order)
assert src_tokens.shape[1] == energies.shape[1]
src_texts = [self.tgt_dict.string(samples[i].target) for i in order]
return {
"id": id_,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"prev_output_tokens": prev_output_tokens,
},
"speaker": speaker,
"target": feat,
"durations": durations,
"pitches": pitches,
"energies": energies,
"target_lengths": target_lengths,
"ntokens": sum(target_lengths).item(),
"nsentences": len(samples),
"src_texts": src_texts,
}
class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator):
KEY_DURATION = "duration"
KEY_PITCH = "pitch"
KEY_ENERGY = "energy"
@classmethod
def _from_list(
cls,
split_name: str,
is_train_split,
samples: List[Dict],
cfg: S2TDataConfig,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
multitask=None,
) -> TextToSpeechDataset:
audio_root = Path(cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples]
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
durations = [s.get(cls.KEY_DURATION, None) for s in samples]
durations = [
None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations
]
durations = None if any(dd is None for dd in durations) else durations
pitches = [s.get(cls.KEY_PITCH, None) for s in samples]
pitches = [
None if pp is None else (audio_root / pp).as_posix() for pp in pitches
]
pitches = None if any(pp is None for pp in pitches) else pitches
energies = [s.get(cls.KEY_ENERGY, None) for s in samples]
energies = [
None if ee is None else (audio_root / ee).as_posix() for ee in energies
]
energies = None if any(ee is None for ee in energies) else energies
return TextToSpeechDataset(
split_name,
is_train_split,
cfg,
audio_paths,
n_frames,
src_texts,
tgt_texts,
speakers,
src_langs,
tgt_langs,
ids,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
durations,
pitches,
energies,
)

View File

@@ -0,0 +1,48 @@
import os
from fairseq.data.audio import (
AudioTransform,
CompositeAudioTransform,
import_transforms,
register_audio_transform,
)
class AudioWaveformTransform(AudioTransform):
pass
AUDIO_WAVEFORM_TRANSFORM_REGISTRY = {}
AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES = set()
def get_audio_waveform_transform(name):
return AUDIO_WAVEFORM_TRANSFORM_REGISTRY[name]
def register_audio_waveform_transform(name):
return register_audio_transform(
name,
AudioWaveformTransform,
AUDIO_WAVEFORM_TRANSFORM_REGISTRY,
AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES,
)
import_transforms(os.path.dirname(__file__), "waveform")
class CompositeAudioWaveformTransform(CompositeAudioTransform):
@classmethod
def from_config_dict(cls, config=None):
return super()._from_config_dict(
cls,
"waveform",
get_audio_waveform_transform,
CompositeAudioWaveformTransform,
config,
)
def __call__(self, x, sample_rate):
for t in self.transforms:
x, sample_rate = t(x, sample_rate)
return x, sample_rate

View File

@@ -0,0 +1,201 @@
from pathlib import Path
import numpy as np
from math import ceil
from fairseq.data.audio import rand_uniform
from fairseq.data.audio.waveform_transforms import (
AudioWaveformTransform,
register_audio_waveform_transform,
)
SNR_MIN = 5.0
SNR_MAX = 15.0
RATE = 0.25
NOISE_RATE = 1.0
NOISE_LEN_MEAN = 0.2
NOISE_LEN_STD = 0.05
class NoiseAugmentTransform(AudioWaveformTransform):
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return cls(
_config.get("samples_path", None),
_config.get("snr_min", SNR_MIN),
_config.get("snr_max", SNR_MAX),
_config.get("rate", RATE),
)
def __init__(
self,
samples_path: str,
snr_min: float = SNR_MIN,
snr_max: float = SNR_MAX,
rate: float = RATE,
):
# Sanity checks
assert (
samples_path
), "need to provide path to audio samples for noise augmentation"
assert snr_max >= snr_min, f"empty signal-to-noise range ({snr_min}, {snr_max})"
assert rate >= 0 and rate <= 1, "rate should be a float between 0 to 1"
self.paths = list(Path(samples_path).glob("**/*.wav")) # load music
self.n_samples = len(self.paths)
assert self.n_samples > 0, f"no audio files found in {samples_path}"
self.snr_min = snr_min
self.snr_max = snr_max
self.rate = rate
def __repr__(self):
return (
self.__class__.__name__
+ "("
+ ", ".join(
[
f"n_samples={self.n_samples}",
f"snr={self.snr_min}-{self.snr_max}dB",
f"rate={self.rate}",
]
)
+ ")"
)
def pick_sample(self, goal_shape, always_2d=False, use_sample_rate=None):
from fairseq.data.audio.audio_utils import get_waveform
path = self.paths[np.random.randint(0, self.n_samples)]
sample = get_waveform(
path, always_2d=always_2d, output_sample_rate=use_sample_rate
)[0]
# Check dimensions match, else silently skip adding noise to sample
# NOTE: SHOULD THIS QUIT WITH AN ERROR?
is_2d = len(goal_shape) == 2
if len(goal_shape) != sample.ndim or (
is_2d and goal_shape[0] != sample.shape[0]
):
return np.zeros(goal_shape)
# Cut/repeat sample to size
len_dim = len(goal_shape) - 1
n_repeat = ceil(goal_shape[len_dim] / sample.shape[len_dim])
repeated = np.tile(sample, [1, n_repeat] if is_2d else n_repeat)
start = np.random.randint(0, repeated.shape[len_dim] - goal_shape[len_dim] + 1)
return (
repeated[:, start : start + goal_shape[len_dim]]
if is_2d
else repeated[start : start + goal_shape[len_dim]]
)
def _mix(self, source, noise, snr):
get_power = lambda x: np.mean(x**2)
if get_power(noise):
scl = np.sqrt(
get_power(source) / (np.power(10, snr / 10) * get_power(noise))
)
else:
scl = 0
return 1 * source + scl * noise
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
return self.pick_sample(goal_shape, always_2d, use_sample_rate)
def __call__(self, source, sample_rate):
if np.random.random() > self.rate:
return source, sample_rate
noise = self._get_noise(
source.shape, always_2d=True, use_sample_rate=sample_rate
)
return (
self._mix(source, noise, rand_uniform(self.snr_min, self.snr_max)),
sample_rate,
)
@register_audio_waveform_transform("musicaugment")
class MusicAugmentTransform(NoiseAugmentTransform):
pass
@register_audio_waveform_transform("backgroundnoiseaugment")
class BackgroundNoiseAugmentTransform(NoiseAugmentTransform):
pass
@register_audio_waveform_transform("babbleaugment")
class BabbleAugmentTransform(NoiseAugmentTransform):
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
for i in range(np.random.randint(3, 8)):
speech = self.pick_sample(goal_shape, always_2d, use_sample_rate)
if i == 0:
agg_noise = speech
else: # SNR scaled by i (how many noise signals already in agg_noise)
agg_noise = self._mix(agg_noise, speech, i)
return agg_noise
@register_audio_waveform_transform("sporadicnoiseaugment")
class SporadicNoiseAugmentTransform(NoiseAugmentTransform):
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return cls(
_config.get("samples_path", None),
_config.get("snr_min", SNR_MIN),
_config.get("snr_max", SNR_MAX),
_config.get("rate", RATE),
_config.get("noise_rate", NOISE_RATE),
_config.get("noise_len_mean", NOISE_LEN_MEAN),
_config.get("noise_len_std", NOISE_LEN_STD),
)
def __init__(
self,
samples_path: str,
snr_min: float = SNR_MIN,
snr_max: float = SNR_MAX,
rate: float = RATE,
noise_rate: float = NOISE_RATE, # noises per second
noise_len_mean: float = NOISE_LEN_MEAN, # length of noises in seconds
noise_len_std: float = NOISE_LEN_STD,
):
super().__init__(samples_path, snr_min, snr_max, rate)
self.noise_rate = noise_rate
self.noise_len_mean = noise_len_mean
self.noise_len_std = noise_len_std
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
agg_noise = np.zeros(goal_shape)
len_dim = len(goal_shape) - 1
is_2d = len(goal_shape) == 2
n_noises = round(self.noise_rate * goal_shape[len_dim] / use_sample_rate)
start_pointers = [
round(rand_uniform(0, goal_shape[len_dim])) for _ in range(n_noises)
]
for start_pointer in start_pointers:
noise_shape = list(goal_shape)
len_seconds = np.random.normal(self.noise_len_mean, self.noise_len_std)
noise_shape[len_dim] = round(max(0, len_seconds) * use_sample_rate)
end_pointer = start_pointer + noise_shape[len_dim]
if end_pointer >= goal_shape[len_dim]:
continue
noise = self.pick_sample(noise_shape, always_2d, use_sample_rate)
if is_2d:
agg_noise[:, start_pointer:end_pointer] = (
agg_noise[:, start_pointer:end_pointer] + noise
)
else:
agg_noise[start_pointer:end_pointer] = (
agg_noise[start_pointer:end_pointer] + noise
)
return agg_noise

View File

@@ -0,0 +1,165 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq import utils
from . import FairseqDataset
def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
"""Backtranslate a list of samples.
Given an input (*samples*) of the form:
[{'id': 1, 'source': 'hallo welt'}]
this will return:
[{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}]
Args:
samples (List[dict]): samples to backtranslate. Individual samples are
expected to have a 'source' key, which will become the 'target'
after backtranslation.
collate_fn (callable): function to collate samples into a mini-batch
generate_fn (callable): function to generate backtranslations
cuda (bool): use GPU for generation (default: ``True``)
Returns:
List[dict]: an updated list of samples with a backtranslated source
"""
collated_samples = collate_fn(samples)
s = utils.move_to_cuda(collated_samples) if cuda else collated_samples
generated_sources = generate_fn(s)
id_to_src = {sample["id"]: sample["source"] for sample in samples}
# Go through each tgt sentence in batch and its corresponding best
# generated hypothesis and create a backtranslation data pair
# {id: id, source: generated backtranslation, target: original tgt}
return [
{
"id": id.item(),
"target": id_to_src[id.item()],
"source": hypos[0]["tokens"].cpu(),
}
for id, hypos in zip(collated_samples["id"], generated_sources)
]
class BacktranslationDataset(FairseqDataset):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation function (*backtranslation_fn*),
and returns the corresponding `{generated src, input tgt}` batch.
Args:
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
backtranslated. Only the source side of this dataset will be used.
After backtranslation, the source sentences in this dataset will be
returned as the targets.
src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated
sentences.
tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of
sentences to be backtranslated.
backtranslation_fn (callable, optional): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
Pass in None when it is not available at initialization time, and
use set_backtranslation_fn function to set it when available.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch
(default: ``tgt_dataset.collater``).
cuda: use GPU for generation
"""
def __init__(
self,
tgt_dataset,
src_dict,
tgt_dict=None,
backtranslation_fn=None,
output_collater=None,
cuda=True,
**kwargs
):
self.tgt_dataset = tgt_dataset
self.backtranslation_fn = backtranslation_fn
self.output_collater = (
output_collater if output_collater is not None else tgt_dataset.collater
)
self.cuda = cuda if torch.cuda.is_available() else False
self.src_dict = src_dict
self.tgt_dict = tgt_dict
def __getitem__(self, index):
"""
Returns a single sample from *tgt_dataset*. Note that backtranslation is
not applied in this step; use :func:`collater` instead to backtranslate
a batch of samples.
"""
return self.tgt_dataset[index]
def __len__(self):
return len(self.tgt_dataset)
def set_backtranslation_fn(self, backtranslation_fn):
self.backtranslation_fn = backtranslation_fn
def collater(self, samples):
"""Merge and backtranslate a list of samples to form a mini-batch.
Using the samples from *tgt_dataset*, load a collated target sample to
feed to the backtranslation model. Then take the backtranslation with
the best score as the source and the original input as the target.
Note: we expect *tgt_dataset* to provide a function `collater()` that
will collate samples into the format expected by *backtranslation_fn*.
After backtranslation, we will feed the new list of samples (i.e., the
`(backtranslated source, original source)` pairs) to *output_collater*
and return the result.
Args:
samples (List[dict]): samples to backtranslate and collate
Returns:
dict: a mini-batch with keys coming from *output_collater*
"""
if samples[0].get("is_dummy", False):
return samples
samples = backtranslate_samples(
samples=samples,
collate_fn=self.tgt_dataset.collater,
generate_fn=(lambda net_input: self.backtranslation_fn(net_input)),
cuda=self.cuda,
)
return self.output_collater(samples)
def num_tokens(self, index):
"""Just use the tgt dataset num_tokens"""
return self.tgt_dataset.num_tokens(index)
def ordered_indices(self):
"""Just use the tgt dataset ordered_indices"""
return self.tgt_dataset.ordered_indices()
def size(self, index):
"""Return an example's size as a float or tuple. This value is used
when filtering a dataset with ``--max-positions``.
Note: we use *tgt_dataset* to approximate the length of the source
sentence, since we do not know the actual length until after
backtranslation.
"""
tgt_size = self.tgt_dataset.size(index)[0]
return (tgt_size, tgt_size)
@property
def supports_prefetch(self):
return getattr(self.tgt_dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.tgt_dataset.prefetch(indices)

View File

@@ -0,0 +1,78 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from torch.utils.data.dataloader import default_collate
from . import FairseqDataset
class BaseWrapperDataset(FairseqDataset):
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples):
if hasattr(self.dataset, "collater"):
return self.dataset.collater(samples)
else:
return default_collate(samples)
@property
def sizes(self):
return self.dataset.sizes
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
return self.dataset.size(index)
def ordered_indices(self):
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def attr(self, attr: str, index: int):
return self.dataset.attr(attr, index)
def prefetch(self, indices):
self.dataset.prefetch(indices)
def get_batch_shapes(self):
return self.dataset.get_batch_shapes()
def batch_by_size(
self,
indices,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
):
return self.dataset.batch_by_size(
indices,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
def filter_indices_by_size(self, indices, max_sizes):
return self.dataset.filter_indices_by_size(indices, max_sizes)
@property
def can_reuse_epoch_itr_across_epochs(self):
return self.dataset.can_reuse_epoch_itr_across_epochs
def set_epoch(self, epoch):
super().set_epoch(epoch)
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)

View File

@@ -0,0 +1,78 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch.nn.functional as F
from fairseq.data import BaseWrapperDataset
from fairseq.data.data_utils import get_buckets, get_bucketed_sizes
class BucketPadLengthDataset(BaseWrapperDataset):
"""
Bucket and pad item lengths to the nearest bucket size. This can be used to
reduce the number of unique batch shapes, which is important on TPUs since
each new batch shape requires a recompilation.
Args:
dataset (FairseqDatset): dataset to bucket
sizes (List[int]): all item sizes
num_buckets (int): number of buckets to create
pad_idx (int): padding symbol
left_pad (bool): if True, pad on the left; otherwise right pad
"""
def __init__(
self,
dataset,
sizes,
num_buckets,
pad_idx,
left_pad,
tensor_key=None,
):
super().__init__(dataset)
self.pad_idx = pad_idx
self.left_pad = left_pad
assert num_buckets > 0
self.buckets = get_buckets(sizes, num_buckets)
self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
self._tensor_key = tensor_key
def _set_tensor(self, item, val):
if self._tensor_key is None:
return val
item[self._tensor_key] = val
return item
def _get_tensor(self, item):
if self._tensor_key is None:
return item
return item[self._tensor_key]
def _pad(self, tensor, bucket_size, dim=-1):
num_pad = bucket_size - tensor.size(dim)
return F.pad(
tensor,
(num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad),
value=self.pad_idx,
)
def __getitem__(self, index):
item = self.dataset[index]
bucket_size = self._bucketed_sizes[index]
tensor = self._get_tensor(item)
padded = self._pad(tensor, bucket_size)
return self._set_tensor(item, padded)
@property
def sizes(self):
return self._bucketed_sizes
def num_tokens(self, index):
return self._bucketed_sizes[index]
def size(self, index):
return self._bucketed_sizes[index]

View File

@@ -0,0 +1,576 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import logging
import os
import random
from pathlib import Path
import numpy as np
import torch
import torch.utils.data
from . import data_utils
from fairseq.data.fairseq_dataset import FairseqDataset
F0_FRAME_SPACE = 0.005 # sec
logger = logging.getLogger(__name__)
class ExpressiveCodeDataConfig(object):
def __init__(self, json_path):
with open(json_path, "r") as f:
self.config = json.load(f)
self._manifests = self.config["manifests"]
@property
def manifests(self):
return self._manifests
@property
def n_units(self):
return self.config["n_units"]
@property
def sampling_rate(self):
return self.config["sampling_rate"]
@property
def code_hop_size(self):
return self.config["code_hop_size"]
@property
def f0_stats(self):
"""pre-computed f0 statistics path"""
return self.config.get("f0_stats", None)
@property
def f0_vq_type(self):
"""naive or precomp"""
return self.config["f0_vq_type"]
@property
def f0_vq_name(self):
return self.config["f0_vq_name"]
def get_f0_vq_naive_quantizer(self, log, norm_mean, norm_std):
key = "log" if log else "linear"
if norm_mean and norm_std:
key += "_mean_std_norm"
elif norm_mean:
key += "_mean_norm"
else:
key += "_none_norm"
return self.config["f0_vq_naive_quantizer"][key]
@property
def f0_vq_n_units(self):
return self.config["f0_vq_n_units"]
@property
def multispkr(self):
"""how to parse speaker label from audio path"""
return self.config.get("multispkr", None)
def get_f0(audio, rate=16000):
try:
import amfm_decompy.basic_tools as basic
import amfm_decompy.pYAAPT as pYAAPT
from librosa.util import normalize
except ImportError:
raise "Please install amfm_decompy (`pip install AMFM-decompy`) and librosa (`pip install librosa`)."
assert audio.ndim == 1
frame_length = 20.0 # ms
to_pad = int(frame_length / 1000 * rate) // 2
audio = normalize(audio) * 0.95
audio = np.pad(audio, (to_pad, to_pad), "constant", constant_values=0)
audio = basic.SignalObj(audio, rate)
pitch = pYAAPT.yaapt(
audio,
frame_length=frame_length,
frame_space=F0_FRAME_SPACE * 1000,
nccf_thresh1=0.25,
tda_frame_length=25.0,
)
f0 = pitch.samp_values
return f0
def interpolate_f0(f0):
try:
from scipy.interpolate import interp1d
except ImportError:
raise "Please install scipy (`pip install scipy`)"
orig_t = np.arange(f0.shape[0])
f0_interp = f0[:]
ii = f0_interp != 0
if ii.sum() > 1:
f0_interp = interp1d(
orig_t[ii], f0_interp[ii], bounds_error=False, kind="linear", fill_value=0
)(orig_t)
f0_interp = torch.Tensor(f0_interp).type_as(f0).to(f0.device)
return f0_interp
def naive_quantize(x, edges):
bin_idx = (x.view(-1, 1) > edges.view(1, -1)).long().sum(dim=1)
return bin_idx
def load_wav(full_path):
try:
import soundfile as sf
except ImportError:
raise "Please install soundfile (`pip install SoundFile`)"
data, sampling_rate = sf.read(full_path)
return data, sampling_rate
def parse_code(code_str, dictionary, append_eos):
code, duration = torch.unique_consecutive(
torch.ShortTensor(list(map(int, code_str.split()))), return_counts=True
)
code = " ".join(map(str, code.tolist()))
code = dictionary.encode_line(code, append_eos).short()
if append_eos:
duration = torch.cat((duration, duration.new_zeros((1,))), dim=0) # eos
duration = duration.short()
return code, duration
def parse_manifest(manifest, dictionary):
audio_files = []
codes = []
durations = []
speakers = []
with open(manifest) as info:
for line in info.readlines():
sample = eval(line.strip())
if "cpc_km100" in sample:
k = "cpc_km100"
elif "hubert_km100" in sample:
k = "hubert_km100"
elif "phone" in sample:
k = "phone"
else:
assert False, "unknown format"
code = sample[k]
code, duration = parse_code(code, dictionary, append_eos=True)
codes.append(code)
durations.append(duration)
audio_files.append(sample["audio"])
speakers.append(sample.get("speaker", None))
return audio_files, codes, durations, speakers
def parse_speaker(path, method):
if type(path) == str:
path = Path(path)
if method == "parent_name":
return path.parent.name
elif method == "parent_parent_name":
return path.parent.parent.name
elif method == "_":
return path.name.split("_")[0]
elif method == "single":
return "A"
elif callable(method):
return method(path)
else:
raise NotImplementedError()
def get_f0_by_filename(filename, tgt_sampling_rate):
audio, sampling_rate = load_wav(filename)
if sampling_rate != tgt_sampling_rate:
raise ValueError(
"{} SR doesn't match target {} SR".format(sampling_rate, tgt_sampling_rate)
)
# compute un-interpolated f0, and use Ann's interp in __getitem__ if set
f0 = get_f0(audio, rate=tgt_sampling_rate)
f0 = torch.from_numpy(f0.astype(np.float32))
return f0
def align_f0_to_durations(f0, durations, f0_code_ratio, tol=1):
code_len = durations.sum()
targ_len = int(f0_code_ratio * code_len)
diff = f0.size(0) - targ_len
assert abs(diff) <= tol, (
f"Cannot subsample F0: |{f0.size(0)} - {f0_code_ratio}*{code_len}|"
f" > {tol} (dur=\n{durations})"
)
if diff > 0:
f0 = f0[:targ_len]
elif diff < 0:
f0 = torch.cat((f0, f0.new_full((-diff,), f0[-1])), 0)
f0_offset = 0.0
seg_f0s = []
for dur in durations:
f0_dur = dur.item() * f0_code_ratio
seg_f0 = f0[int(f0_offset) : int(f0_offset + f0_dur)]
seg_f0 = seg_f0[seg_f0 != 0]
if len(seg_f0) == 0:
seg_f0 = torch.tensor(0).type(seg_f0.type())
else:
seg_f0 = seg_f0.mean()
seg_f0s.append(seg_f0)
f0_offset += f0_dur
assert int(f0_offset) == f0.size(0), f"{f0_offset} {f0.size()} {durations.sum()}"
return torch.tensor(seg_f0s)
class Paddings(object):
def __init__(self, code_val, dur_val=0, f0_val=-2.0):
self.code = code_val
self.dur = dur_val
self.f0 = f0_val
class Shifts(object):
def __init__(self, shifts_str, pads):
self._shifts = list(map(int, shifts_str.split(",")))
assert len(self._shifts) == 2, self._shifts
assert all(s >= 0 for s in self._shifts)
self.extra_length = max(s for s in self._shifts)
self.pads = pads
@property
def dur(self):
return self._shifts[0]
@property
def f0(self):
return self._shifts[1]
@staticmethod
def shift_one(seq, left_pad_num, right_pad_num, pad):
assert seq.ndim == 1
bos = seq.new_full((left_pad_num,), pad)
eos = seq.new_full((right_pad_num,), pad)
seq = torch.cat([bos, seq, eos])
mask = torch.ones_like(seq).bool()
mask[left_pad_num : len(seq) - right_pad_num] = 0
return seq, mask
def __call__(self, code, dur, f0):
if self.extra_length == 0:
code_mask = torch.zeros_like(code).bool()
dur_mask = torch.zeros_like(dur).bool()
f0_mask = torch.zeros_like(f0).bool()
return code, code_mask, dur, dur_mask, f0, f0_mask
code, code_mask = self.shift_one(code, 0, self.extra_length, self.pads.code)
dur, dur_mask = self.shift_one(
dur, self.dur, self.extra_length - self.dur, self.pads.dur
)
f0, f0_mask = self.shift_one(
f0, self.f0, self.extra_length - self.f0, self.pads.f0
)
return code, code_mask, dur, dur_mask, f0, f0_mask
class CodeDataset(FairseqDataset):
def __init__(
self,
manifest,
dictionary,
dur_dictionary,
f0_dictionary,
config,
discrete_dur,
discrete_f0,
log_f0,
normalize_f0_mean,
normalize_f0_std,
interpolate_f0,
return_filename=False,
strip_filename=True,
shifts="0,0",
return_continuous_f0=False,
):
random.seed(1234)
self.dictionary = dictionary
self.dur_dictionary = dur_dictionary
self.f0_dictionary = f0_dictionary
self.config = config
# duration config
self.discrete_dur = discrete_dur
# pitch config
self.discrete_f0 = discrete_f0
self.log_f0 = log_f0
self.normalize_f0_mean = normalize_f0_mean
self.normalize_f0_std = normalize_f0_std
self.interpolate_f0 = interpolate_f0
self.return_filename = return_filename
self.strip_filename = strip_filename
self.f0_code_ratio = config.code_hop_size / (
config.sampling_rate * F0_FRAME_SPACE
)
# use lazy loading to avoid sharing file handlers across workers
self.manifest = manifest
self._codes = None
self._durs = None
self._f0s = None
with open(f"{manifest}.leng.txt", "r") as f:
lengs = [int(line.rstrip()) for line in f]
edges = np.cumsum([0] + lengs)
self.starts, self.ends = edges[:-1], edges[1:]
with open(f"{manifest}.path.txt", "r") as f:
self.file_names = [line.rstrip() for line in f]
logger.info(f"num entries: {len(self.starts)}")
if os.path.exists(f"{manifest}.f0_stat.pt"):
self.f0_stats = torch.load(f"{manifest}.f0_stat.pt")
elif config.f0_stats:
self.f0_stats = torch.load(config.f0_stats)
self.multispkr = config.multispkr
if config.multispkr:
with open(f"{manifest}.speaker.txt", "r") as f:
self.spkrs = [line.rstrip() for line in f]
self.id_to_spkr = sorted(self.spkrs)
self.spkr_to_id = {k: v for v, k in enumerate(self.id_to_spkr)}
self.pads = Paddings(
dictionary.pad(),
0, # use 0 for duration padding
f0_dictionary.pad() if discrete_f0 else -5.0,
)
self.shifts = Shifts(shifts, pads=self.pads)
self.return_continuous_f0 = return_continuous_f0
def get_data_handlers(self):
logging.info(f"loading data for {self.manifest}")
self._codes = np.load(f"{self.manifest}.code.npy", mmap_mode="r")
self._durs = np.load(f"{self.manifest}.dur.npy", mmap_mode="r")
if self.discrete_f0:
if self.config.f0_vq_type == "precomp":
self._f0s = np.load(
f"{self.manifest}.{self.config.f0_vq_name}.npy", mmap_mode="r"
)
elif self.config.f0_vq_type == "naive":
self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
quantizers_path = self.config.get_f0_vq_naive_quantizer(
self.log_f0, self.normalize_f0_mean, self.normalize_f0_std
)
quantizers = torch.load(quantizers_path)
n_units = self.config.f0_vq_n_units
self._f0_quantizer = torch.from_numpy(quantizers[n_units])
else:
raise ValueError(f"f0_vq_type {self.config.f0_vq_type} not supported")
else:
self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
def preprocess_f0(self, f0, stats):
"""
1. interpolate
2. log transform (keep unvoiced frame 0)
"""
# TODO: change this to be dependent on config for naive quantizer
f0 = f0.clone()
if self.interpolate_f0:
f0 = interpolate_f0(f0)
mask = f0 != 0 # only process voiced frames
if self.log_f0:
f0[mask] = f0[mask].log()
if self.normalize_f0_mean:
mean = stats["logf0_mean"] if self.log_f0 else stats["f0_mean"]
f0[mask] = f0[mask] - mean
if self.normalize_f0_std:
std = stats["logf0_std"] if self.log_f0 else stats["f0_std"]
f0[mask] = f0[mask] / std
return f0
def _get_raw_item(self, index):
start, end = self.starts[index], self.ends[index]
if self._codes is None:
self.get_data_handlers()
code = torch.from_numpy(np.array(self._codes[start:end])).long()
dur = torch.from_numpy(np.array(self._durs[start:end]))
f0 = torch.from_numpy(np.array(self._f0s[start:end]))
return code, dur, f0
def __getitem__(self, index):
code, dur, f0 = self._get_raw_item(index)
code = torch.cat([code.new([self.dictionary.bos()]), code])
# use 0 for eos and bos
dur = torch.cat([dur.new([0]), dur])
if self.discrete_dur:
dur = self.dur_dictionary.encode_line(
" ".join(map(str, dur.tolist())), append_eos=False
).long()
else:
dur = dur.float()
# TODO: find a more elegant approach
raw_f0 = None
if self.discrete_f0:
if self.config.f0_vq_type == "precomp":
f0 = self.f0_dictionary.encode_line(
" ".join(map(str, f0.tolist())), append_eos=False
).long()
else:
f0 = f0.float()
f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
if self.return_continuous_f0:
raw_f0 = f0
raw_f0 = torch.cat([raw_f0.new([self.f0_dictionary.bos()]), raw_f0])
f0 = naive_quantize(f0, self._f0_quantizer)
f0 = torch.cat([f0.new([self.f0_dictionary.bos()]), f0])
else:
f0 = f0.float()
if self.multispkr:
f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
else:
f0 = self.preprocess_f0(f0, self.f0_stats)
f0 = torch.cat([f0.new([0]), f0])
if raw_f0 is not None:
*_, raw_f0, raw_f0_mask = self.shifts(code, dur, raw_f0)
else:
raw_f0_mask = None
code, code_mask, dur, dur_mask, f0, f0_mask = self.shifts(code, dur, f0)
if raw_f0_mask is not None:
assert (raw_f0_mask == f0_mask).all()
# is a padded frame if either input or output is padded
feats = {
"source": code[:-1],
"target": code[1:],
"mask": code_mask[1:].logical_or(code_mask[:-1]),
"dur_source": dur[:-1],
"dur_target": dur[1:],
"dur_mask": dur_mask[1:].logical_or(dur_mask[:-1]),
"f0_source": f0[:-1],
"f0_target": f0[1:],
"f0_mask": f0_mask[1:].logical_or(f0_mask[:-1]),
}
if raw_f0 is not None:
feats["raw_f0"] = raw_f0[1:]
if self.return_filename:
fname = self.file_names[index]
feats["filename"] = (
fname if not self.strip_filename else Path(fname).with_suffix("").name
)
return feats
def __len__(self):
return len(self.starts)
def size(self, index):
return self.ends[index] - self.starts[index] + self.shifts.extra_length
def num_tokens(self, index):
return self.size(index)
def collater(self, samples):
pad_idx, eos_idx = self.dictionary.pad(), self.dictionary.eos()
if len(samples) == 0:
return {}
src_tokens = data_utils.collate_tokens(
[s["source"] for s in samples], pad_idx, eos_idx, left_pad=False
)
tgt_tokens = data_utils.collate_tokens(
[s["target"] for s in samples],
pad_idx=pad_idx,
eos_idx=pad_idx, # appending padding, eos is there already
left_pad=False,
)
src_durs, tgt_durs = [
data_utils.collate_tokens(
[s[k] for s in samples],
pad_idx=self.pads.dur,
eos_idx=self.pads.dur,
left_pad=False,
)
for k in ["dur_source", "dur_target"]
]
src_f0s, tgt_f0s = [
data_utils.collate_tokens(
[s[k] for s in samples],
pad_idx=self.pads.f0,
eos_idx=self.pads.f0,
left_pad=False,
)
for k in ["f0_source", "f0_target"]
]
mask, dur_mask, f0_mask = [
data_utils.collate_tokens(
[s[k] for s in samples],
pad_idx=1,
eos_idx=1,
left_pad=False,
)
for k in ["mask", "dur_mask", "f0_mask"]
]
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
n_tokens = sum(len(s["source"]) for s in samples)
result = {
"nsentences": len(samples),
"ntokens": n_tokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"dur_src": src_durs,
"f0_src": src_f0s,
},
"target": tgt_tokens,
"dur_target": tgt_durs,
"f0_target": tgt_f0s,
"mask": mask,
"dur_mask": dur_mask,
"f0_mask": f0_mask,
}
if "filename" in samples[0]:
result["filename"] = [s["filename"] for s in samples]
# TODO: remove this hack into the inference dataset
if "prefix" in samples[0]:
result["prefix"] = [s["prefix"] for s in samples]
if "raw_f0" in samples[0]:
raw_f0s = data_utils.collate_tokens(
[s["raw_f0"] for s in samples],
pad_idx=self.pads.f0,
eos_idx=self.pads.f0,
left_pad=False,
)
result["raw_f0"] = raw_f0s
return result

View File

@@ -0,0 +1,25 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import BaseWrapperDataset
class ColorizeDataset(BaseWrapperDataset):
"""Adds 'colors' property to net input that is obtained from the provided color getter for use by models"""
def __init__(self, dataset, color_getter):
super().__init__(dataset)
self.color_getter = color_getter
def collater(self, samples):
base_collate = super().collater(samples)
if len(base_collate) > 0:
base_collate["net_input"]["colors"] = torch.tensor(
list(self.color_getter(self.dataset, s["id"]) for s in samples),
dtype=torch.long,
)
return base_collate

View File

@@ -0,0 +1,124 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import bisect
import numpy as np
from torch.utils.data.dataloader import default_collate
from . import FairseqDataset
class ConcatDataset(FairseqDataset):
@staticmethod
def cumsum(sequence, sample_ratios):
r, s = [], 0
for e, ratio in zip(sequence, sample_ratios):
curr_len = int(ratio * len(e))
r.append(curr_len + s)
s += curr_len
return r
def __init__(self, datasets, sample_ratios=1):
super(ConcatDataset, self).__init__()
assert len(datasets) > 0, "datasets should not be an empty iterable"
self.datasets = list(datasets)
if isinstance(sample_ratios, int):
sample_ratios = [sample_ratios] * len(self.datasets)
self.sample_ratios = sample_ratios
self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
self.real_sizes = [len(d) for d in self.datasets]
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
return self.datasets[dataset_idx][sample_idx]
def _get_dataset_and_sample_index(self, idx: int):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
sample_idx = sample_idx % self.real_sizes[dataset_idx]
return dataset_idx, sample_idx
def collater(self, samples, **extra_args):
# For now only supports datasets with same underlying collater implementations
if hasattr(self.datasets[0], "collater"):
return self.datasets[0].collater(samples, **extra_args)
else:
return default_collate(samples, **extra_args)
def size(self, idx: int):
"""
Return an example's size as a float or tuple.
"""
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
return self.datasets[dataset_idx].size(sample_idx)
def num_tokens(self, index: int):
return np.max(self.size(index))
def attr(self, attr: str, index: int):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
return getattr(self.datasets[dataset_idx], attr, None)
@property
def sizes(self):
_dataset_sizes = []
for ds, sr in zip(self.datasets, self.sample_ratios):
if isinstance(ds.sizes, np.ndarray):
_dataset_sizes.append(np.tile(ds.sizes, sr))
else:
# Only support underlying dataset with single size array.
assert isinstance(ds.sizes, list)
_dataset_sizes.append(np.tile(ds.sizes[0], sr))
return np.concatenate(_dataset_sizes)
@property
def supports_prefetch(self):
return all(d.supports_prefetch for d in self.datasets)
def ordered_indices(self):
"""
Returns indices sorted by length. So less padding is needed.
"""
if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1:
# special handling for concatenating lang_pair_datasets
indices = np.arange(len(self))
sizes = self.sizes
tgt_sizes = (
sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
)
src_sizes = (
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
)
# sort by target length, then source length
if tgt_sizes is not None:
indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
return indices[np.argsort(src_sizes[indices], kind="mergesort")]
else:
return np.argsort(self.sizes)
def prefetch(self, indices):
frm = 0
for to, ds in zip(self.cumulative_sizes, self.datasets):
real_size = len(ds)
if getattr(ds, "supports_prefetch", False):
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
frm = to
@property
def can_reuse_epoch_itr_across_epochs(self):
return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.datasets:
if hasattr(ds, "set_epoch"):
ds.set_epoch(epoch)

View File

@@ -0,0 +1,54 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import FairseqDataset
class ConcatSentencesDataset(FairseqDataset):
def __init__(self, *datasets):
super().__init__()
self.datasets = datasets
assert all(
len(ds) == len(datasets[0]) for ds in datasets
), "datasets must have the same length"
def __getitem__(self, index):
return torch.cat([ds[index] for ds in self.datasets])
def __len__(self):
return len(self.datasets[0])
def collater(self, samples):
return self.datasets[0].collater(samples)
@property
def sizes(self):
return sum(ds.sizes for ds in self.datasets)
def num_tokens(self, index):
return sum(ds.num_tokens(index) for ds in self.datasets)
def size(self, index):
return sum(ds.size(index) for ds in self.datasets)
def ordered_indices(self):
return self.datasets[0].ordered_indices()
@property
def supports_prefetch(self):
return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets)
def prefetch(self, indices):
for ds in self.datasets:
if getattr(ds, "supports_prefetch", False):
ds.prefetch(indices)
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.datasets:
if hasattr(ds, "set_epoch"):
ds.set_epoch(epoch)

View File

@@ -0,0 +1,604 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
try:
from collections.abc import Iterable
except ImportError:
from collections import Iterable
import contextlib
import itertools
import logging
import re
import warnings
from typing import Optional, Tuple
import numpy as np
import torch
from fairseq.file_io import PathManager
from fairseq import utils
import os
logger = logging.getLogger(__name__)
def infer_language_pair(path):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src, dst = None, None
for filename in PathManager.ls(path):
parts = filename.split(".")
if len(parts) >= 3 and len(parts[1].split("-")) == 2:
return parts[1].split("-")
return src, dst
def collate_tokens(
values,
pad_idx,
eos_idx=None,
left_pad=False,
move_eos_to_beginning=False,
pad_to_length=None,
pad_to_multiple=1,
pad_to_bsz=None,
):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
size = size if pad_to_length is None else max(size, pad_to_length)
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz)
res = values[0].new(batch_size, size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
if eos_idx is None:
# if no eos_idx is specified, then use the last token in src
dst[0] = src[-1]
else:
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)
for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
return res
def load_indexed_dataset(
path, dictionary=None, dataset_impl=None, combine=False, default="cached"
):
"""A helper function for loading indexed datasets.
Args:
path (str): path to indexed dataset (e.g., 'data-bin/train')
dictionary (~fairseq.data.Dictionary): data dictionary
dataset_impl (str, optional): which dataset implementation to use. If
not provided, it will be inferred automatically. For legacy indexed
data we use the 'cached' implementation by default.
combine (bool, optional): automatically load and combine multiple
datasets. For example, if *path* is 'data-bin/train', then we will
combine 'data-bin/train', 'data-bin/train1', ... and return a
single ConcatDataset instance.
"""
import fairseq.data.indexed_dataset as indexed_dataset
from fairseq.data.concat_dataset import ConcatDataset
datasets = []
for k in itertools.count():
path_k = path + (str(k) if k > 0 else "")
try:
path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
except Exception as e:
if "StorageException: [404] Path not found" in str(e):
logger.warning(f"path_k: {e} not found")
else:
raise e
dataset_impl_k = dataset_impl
if dataset_impl_k is None:
dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
dataset = indexed_dataset.make_dataset(
path_k,
impl=dataset_impl_k or default,
fix_lua_indexing=True,
dictionary=dictionary,
)
if dataset is None:
break
logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k))
datasets.append(dataset)
if not combine:
break
if len(datasets) == 0:
return None
elif len(datasets) == 1:
return datasets[0]
else:
return ConcatDataset(datasets)
@contextlib.contextmanager
def numpy_seed(seed, *addl_seeds):
"""Context manager which seeds the NumPy PRNG with the specified seed and
restores the state afterward"""
if seed is None:
yield
return
if len(addl_seeds) > 0:
seed = int(hash((seed, *addl_seeds)) % 1e6)
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)
def collect_filtered(function, iterable, filtered):
"""
Similar to :func:`filter` but collects filtered elements in ``filtered``.
Args:
function (callable): function that returns ``False`` for elements that
should be filtered
iterable (iterable): iterable to filter
filtered (list): list to store filtered elements
"""
for el in iterable:
if function(el):
yield el
else:
filtered.append(el)
def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
def compare_leq(a, b):
return a <= b if not isinstance(a, tuple) else max(a) <= b
def check_size(idx):
if isinstance(max_positions, float) or isinstance(max_positions, int):
return size_fn(idx) <= max_positions
elif isinstance(max_positions, dict):
idx_size = size_fn(idx)
assert isinstance(idx_size, dict)
intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
return all(
all(
a is None or b is None or a <= b
for a, b in zip(idx_size[key], max_positions[key])
)
for key in intersect_keys
)
else:
# For MultiCorpusSampledDataset, will generalize it later
if not isinstance(size_fn(idx), Iterable):
return all(size_fn(idx) <= b for b in max_positions)
return all(
a is None or b is None or a <= b
for a, b in zip(size_fn(idx), max_positions)
)
ignored = []
itr = collect_filtered(check_size, indices, ignored)
indices = np.fromiter(itr, dtype=np.int64, count=-1)
return indices, ignored
def filter_by_size(indices, dataset, max_positions, raise_exception=False):
"""
[deprecated] Filter indices based on their size.
Use `FairseqDataset::filter_indices_by_size` instead.
Args:
indices (List[int]): ordered list of dataset indices
dataset (FairseqDataset): fairseq dataset instance
max_positions (tuple): filter elements larger than this size.
Comparisons are done component-wise.
raise_exception (bool, optional): if ``True``, raise an exception if
any elements are filtered (default: False).
"""
warnings.warn(
"data_utils.filter_by_size is deprecated. "
"Use `FairseqDataset::filter_indices_by_size` instead.",
stacklevel=2,
)
if isinstance(max_positions, float) or isinstance(max_positions, int):
if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
ignored = indices[dataset.sizes[indices] > max_positions].tolist()
indices = indices[dataset.sizes[indices] <= max_positions]
elif (
hasattr(dataset, "sizes")
and isinstance(dataset.sizes, list)
and len(dataset.sizes) == 1
):
ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
indices = indices[dataset.sizes[0][indices] <= max_positions]
else:
indices, ignored = _filter_by_size_dynamic(
indices, dataset.size, max_positions
)
else:
indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
if len(ignored) > 0 and raise_exception:
raise Exception(
(
"Size of sample #{} is invalid (={}) since max_positions={}, "
"skip this example with --skip-invalid-size-inputs-valid-test"
).format(ignored[0], dataset.size(ignored[0]), max_positions)
)
if len(ignored) > 0:
logger.warning(
(
"{} samples have invalid sizes and will be skipped, "
"max_positions={}, first few sample ids={}"
).format(len(ignored), max_positions, ignored[:10])
)
return indices
def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
"""Filter a list of sample indices. Remove those that are longer
than specified in max_sizes.
Args:
indices (np.array): original array of sample indices
max_sizes (int or list[int] or tuple[int]): max sample size,
can be defined separately for src and tgt (then list or tuple)
Returns:
np.array: filtered sample array
list: list of removed indices
"""
if max_sizes is None:
return indices, []
if type(max_sizes) in (int, float):
max_src_size, max_tgt_size = max_sizes, max_sizes
else:
max_src_size, max_tgt_size = max_sizes
if tgt_sizes is None:
ignored = indices[src_sizes[indices] > max_src_size]
else:
ignored = indices[
(src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
]
if len(ignored) > 0:
if tgt_sizes is None:
indices = indices[src_sizes[indices] <= max_src_size]
else:
indices = indices[
(src_sizes[indices] <= max_src_size)
& (tgt_sizes[indices] <= max_tgt_size)
]
return indices, ignored.tolist()
def batch_by_size(
indices,
num_tokens_fn,
num_tokens_vec=None,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
fixed_shapes=None,
):
"""
Yield mini-batches of indices bucketed by size. Batches may contain
sequences of different lengths.
Args:
indices (List[int]): ordered list of dataset indices
num_tokens_fn (callable): function that returns the number of tokens at
a given index
num_tokens_vec (List[int], optional): precomputed vector of the number
of tokens for each index in indices (to enable faster batch generation)
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch (default: None).
required_batch_size_multiple (int, optional): require batch size to
be less than N or a multiple of N (default: 1).
fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
only be created with the given shapes. *max_sentences* and
*required_batch_size_multiple* will be ignored (default: None).
"""
try:
from fairseq.data.data_utils_fast import (
batch_by_size_fn,
batch_by_size_vec,
batch_fixed_shapes_fast,
)
except ImportError:
raise ImportError(
"Please build Cython components with: "
"`python setup.py build_ext --inplace`"
)
except ValueError:
raise ValueError(
"Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`."
)
# added int() to avoid TypeError: an integer is required
max_tokens = int(max_tokens) if max_tokens is not None else -1
max_sentences = max_sentences if max_sentences is not None else -1
bsz_mult = required_batch_size_multiple
if not isinstance(indices, np.ndarray):
indices = np.fromiter(indices, dtype=np.int64, count=-1)
if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray):
num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1)
if fixed_shapes is None:
if num_tokens_vec is None:
return batch_by_size_fn(
indices,
num_tokens_fn,
max_tokens,
max_sentences,
bsz_mult,
)
else:
return batch_by_size_vec(
indices,
num_tokens_vec,
max_tokens,
max_sentences,
bsz_mult,
)
else:
fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
sort_order = np.lexsort(
[
fixed_shapes[:, 1].argsort(), # length
fixed_shapes[:, 0].argsort(), # bsz
]
)
fixed_shapes_sorted = fixed_shapes[sort_order]
return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
def post_process(sentence: str, symbol: str):
if symbol == "sentencepiece":
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
elif symbol == "wordpiece":
sentence = sentence.replace(" ", "").replace("_", " ").strip()
elif symbol == "letter":
sentence = sentence.replace(" ", "").replace("|", " ").strip()
elif symbol == "silence":
import re
sentence = sentence.replace("<SIL>", "")
sentence = re.sub(" +", " ", sentence).strip()
elif symbol == "_EOW":
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
elif symbol in {"subword_nmt", "@@ ", "@@"}:
if symbol == "subword_nmt":
symbol = "@@ "
sentence = (sentence + " ").replace(symbol, "").rstrip()
elif symbol == "none":
pass
elif symbol is not None:
raise NotImplementedError(f"Unknown post_process option: {symbol}")
return sentence
def compute_mask_indices(
shape: Tuple[int, int],
padding_mask: Optional[torch.Tensor],
mask_prob: float,
mask_length: int,
mask_type: str = "static",
mask_other: float = 0.0,
min_masks: int = 0,
no_overlap: bool = False,
min_space: int = 0,
require_same_masks: bool = True,
mask_dropout: float = 0.0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape
Args:
shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_type: how to compute mask lengths
static = fixed size
uniform = sample from uniform distribution [mask_other, mask_length*2]
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
poisson = sample from possion distribution with lambda = mask length
min_masks: minimum number of masked spans
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
mask_dropout: randomly dropout this percentage of masks in each example
"""
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
all_num_mask = int(
# add a random number for probabilistic rounding
mask_prob * all_sz / float(mask_length)
+ np.random.rand()
)
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
for i in range(bsz):
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ np.random.rand()
)
num_mask = max(min_masks, num_mask)
else:
sz = all_sz
num_mask = all_num_mask
if mask_type == "static":
lengths = np.full(num_mask, mask_length)
elif mask_type == "uniform":
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
elif mask_type == "normal":
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
lengths = [max(1, int(round(x))) for x in lengths]
elif mask_type == "poisson":
lengths = np.random.poisson(mask_length, size=num_mask)
lengths = [int(round(x)) for x in lengths]
else:
raise Exception("unknown mask selection " + mask_type)
if sum(lengths) == 0:
lengths[0] = min(mask_length, sz - 1)
if no_overlap:
mask_idc = []
def arrange(s, e, length, keep_length):
span_start = np.random.randint(s, e - length)
mask_idc.extend(span_start + i for i in range(length))
new_parts = []
if span_start - s - min_space >= keep_length:
new_parts.append((s, span_start - min_space + 1))
if e - span_start - length - min_space > keep_length:
new_parts.append((span_start + length + min_space, e))
return new_parts
parts = [(0, sz)]
min_length = min(lengths)
for length in sorted(lengths, reverse=True):
lens = np.fromiter(
(e - s if e - s >= length + min_space else 0 for s, e in parts),
np.int,
)
l_sum = np.sum(lens)
if l_sum == 0:
break
probs = lens / np.sum(lens)
c = np.random.choice(len(parts), p=probs)
s, e = parts.pop(c)
parts.extend(arrange(s, e, length, min_length))
mask_idc = np.asarray(mask_idc)
else:
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
mask_idc = np.asarray(
[
mask_idc[j] + offset
for j in range(len(mask_idc))
for offset in range(lengths[j])
]
)
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len and require_same_masks:
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
if mask_dropout > 0:
num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int)
mask_idc = np.random.choice(
mask_idc, len(mask_idc) - num_holes, replace=False
)
mask[i, mask_idc] = True
return mask
def get_mem_usage():
try:
import psutil
mb = 1024 * 1024
return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
except ImportError:
return "N/A"
# lens: torch.LongTensor
# returns: torch.BoolTensor
def lengths_to_padding_mask(lens):
bsz, max_lens = lens.size(0), torch.max(lens).item()
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
return mask
# lens: torch.LongTensor
# returns: torch.BoolTensor
def lengths_to_mask(lens):
return ~lengths_to_padding_mask(lens)
def get_buckets(sizes, num_buckets):
buckets = np.unique(
np.percentile(
sizes,
np.linspace(0, 100, num_buckets + 1),
interpolation="lower",
)[1:]
)
return buckets
def get_bucketed_sizes(orig_sizes, buckets):
sizes = np.copy(orig_sizes)
assert np.min(sizes) >= 0
start_val = -1
for end_val in buckets:
mask = (sizes > start_val) & (sizes <= end_val)
sizes[mask] = end_val
start_val = end_val
return sizes
def _find_extra_valid_paths(dataset_path: str) -> set:
paths = utils.split_paths(dataset_path)
all_valid_paths = set()
for sub_dir in paths:
contents = PathManager.ls(sub_dir)
valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
all_valid_paths |= {os.path.basename(p) for p in valid_paths}
# Remove .bin, .idx etc
roots = {os.path.splitext(p)[0] for p in all_valid_paths}
return roots
def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
"""Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored."""
if (
train_cfg.dataset.ignore_unused_valid_subsets
or train_cfg.dataset.combine_valid_subsets
or train_cfg.dataset.disable_validation
or not hasattr(train_cfg.task, "data")
):
return
other_paths = _find_extra_valid_paths(train_cfg.task.data)
specified_subsets = train_cfg.dataset.valid_subset.split(",")
ignored_paths = [p for p in other_paths if p not in specified_subsets]
if ignored_paths:
advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
raise ValueError(msg)

View File

@@ -0,0 +1,178 @@
# cython: language_level=3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
cimport cython
cimport numpy as np
from libc.stdint cimport int32_t, int64_t
from libcpp cimport bool as bool_t
ctypedef int64_t DTYPE_t
@cython.cdivision(True)
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef list batch_by_size_vec(
np.ndarray[int64_t, ndim=1] indices,
np.ndarray[int64_t, ndim=1] num_tokens_vec,
int64_t max_tokens,
int64_t max_sentences,
int32_t bsz_mult,
):
if indices.shape[0] == 0:
return []
assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, (
f"Sentences lengths should not exceed max_tokens={max_tokens}"
)
cdef int32_t indices_len = indices.shape[0]
cdef np.ndarray[int32_t, ndim=1] batches_ends = \
np.zeros(indices_len, dtype=np.int32)
cdef int32_t[:] batches_ends_view = batches_ends
cdef int64_t[:] num_tokens_view = num_tokens_vec
cdef int32_t pos = 0
cdef int32_t new_batch_end = 0
cdef int64_t new_batch_max_tokens = 0
cdef int32_t new_batch_sentences = 0
cdef int64_t new_batch_num_tokens = 0
cdef bool_t overflow = False
cdef bool_t size_matches_with_bsz_mult = False
cdef int32_t batches_count = 0
cdef int32_t batch_start = 0
cdef int64_t tail_max_tokens = 0
cdef int64_t batch_max_tokens = 0
for pos in range(indices_len):
# At every pos we keep stats about the last complete batch [batch_start:batch_end),
# and tail [batch_end:pos].
# 1) Every time when (batch + tail) forms a valid batch
# (according to max_tokens, max_sentences and bsz_mult) we append tail to batch.
# 2) When (batch+tail) violates max_tokens or max_sentences constraints
# we finalize running batch, and tail becomes a new batch.
# 3) There is a corner case when tail also violates constraints.
# In that situation [batch_end:pos-1] (tail without the current pos)
# gets added to the finalized batches, while [pos:pos] becomes a new tail.
#
# Important: For the sake of performance try to avoid using function calls within this loop.
tail_max_tokens = tail_max_tokens \
if tail_max_tokens > num_tokens_view[pos] \
else num_tokens_view[pos]
new_batch_end = pos + 1
new_batch_max_tokens = batch_max_tokens \
if batch_max_tokens > tail_max_tokens \
else tail_max_tokens
new_batch_sentences = new_batch_end - batch_start
new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens
overflow = (new_batch_sentences > max_sentences > 0 or
new_batch_num_tokens > max_tokens > 0)
size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or
new_batch_sentences % bsz_mult == 0)
if overflow:
tail_num_tokens = tail_max_tokens * \
(new_batch_end - batches_ends_view[batches_count])
tail_overflow = tail_num_tokens > max_tokens > 0
# In case of a tail overflow finalize two batches
if tail_overflow:
batches_count += 1
batches_ends_view[batches_count] = pos
tail_max_tokens = num_tokens_view[pos]
batch_start = batches_ends_view[batches_count]
batches_count += 1
new_batch_max_tokens = tail_max_tokens
if overflow or size_matches_with_bsz_mult:
batches_ends_view[batches_count] = new_batch_end
batch_max_tokens = new_batch_max_tokens
tail_max_tokens = 0
if batches_ends_view[batches_count] != indices_len:
batches_count += 1
# Memory and time-efficient split
return np.split(indices, batches_ends[:batches_count])
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef list batch_by_size_fn(
np.ndarray[DTYPE_t, ndim=1] indices,
num_tokens_fn,
int64_t max_tokens,
int64_t max_sentences,
int32_t bsz_mult,
):
cdef int32_t indices_len = indices.shape[0]
cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len,
dtype=np.int64)
cdef DTYPE_t[:] indices_view = indices
cdef DTYPE_t[:] num_tokens_vec_view = num_tokens_vec
cdef int64_t pos
for pos in range(indices_len):
num_tokens_vec[pos] = num_tokens_fn(indices_view[pos])
return batch_by_size_vec(indices, num_tokens_vec, max_tokens,
max_sentences, bsz_mult,)
cdef _find_valid_shape(
DTYPE_t[:, :] shapes_view,
int64_t num_sentences,
int64_t num_tokens,
):
"""Return index of first valid shape of -1 if none is found."""
for i in range(shapes_view.shape[0]):
if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]:
return i
return -1
@cython.cdivision(True)
cpdef list batch_fixed_shapes_fast(
np.ndarray[DTYPE_t, ndim=1] indices,
num_tokens_fn,
np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted,
):
cdef int64_t sample_len = 0
cdef list sample_lens = []
cdef list batch = []
cdef list batches = []
cdef int64_t mod_len
cdef int64_t i
cdef int64_t idx
cdef int64_t num_tokens
cdef DTYPE_t[:] indices_view = indices
cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted
for i in range(len(indices_view)):
idx = indices_view[i]
num_tokens = num_tokens_fn(idx)
sample_lens.append(num_tokens)
sample_len = max(sample_len, num_tokens)
shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len)
if shape_idx == -1:
batches.append(batch)
batch = []
sample_lens = []
sample_len = 0
shapes_view = fixed_shapes_sorted
elif shape_idx > 0:
# small optimization for the next call to _find_valid_shape
shapes_view = shapes_view[shape_idx:]
batch.append(idx)
if len(batch) > 0:
batches.append(batch)
return batches

View File

@@ -0,0 +1,443 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import numpy as np
import torch
from . import FairseqDataset, data_utils
def collate(
samples,
pad_idx,
eos_idx,
vocab,
left_pad_source=False,
left_pad_target=False,
input_feeding=True,
pad_to_length=None,
):
assert input_feeding
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
left_pad=left_pad,
move_eos_to_beginning=move_eos_to_beginning,
pad_to_length=pad_to_length,
)
id = torch.LongTensor([s["id"] for s in samples])
src_tokens = merge(
"source",
left_pad=left_pad_source,
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
)
# sort by descending source length
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge(
"target",
left_pad=left_pad_target,
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
target = target.index_select(0, sort_order)
ntokens = sum(len(s["target"]) for s in samples)
if input_feeding:
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
"target",
left_pad=left_pad_target,
move_eos_to_beginning=True,
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
else:
ntokens = sum(len(s["source"]) for s in samples)
batch = {
"id": id,
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
},
"target": target,
"nsentences": samples[0]["source"].size(0),
"sort_order": sort_order,
}
if prev_output_tokens is not None:
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
return batch
class DenoisingDataset(FairseqDataset):
"""
A wrapper around TokenBlockDataset for BART dataset.
Args:
dataset (TokenBlockDataset): dataset to wrap
sizes (List[int]): sentence lengths
vocab (~fairseq.data.Dictionary): vocabulary
mask_idx (int): dictionary index used for masked token
mask_whole_words: only mask whole words. This should be a byte mask
over vocab indices, indicating whether it is the beginning of a
word. We will extend any mask to encompass the whole word.
shuffle (bool, optional): shuffle the elements before batching.
Default: ``True``
seed: Seed for random number generator for reproducibility.
"""
def __init__(
self,
dataset,
sizes,
vocab,
mask_idx,
mask_whole_words,
shuffle,
seed,
mask,
mask_random,
insert,
rotate,
permute_sentences,
bpe,
replace_length,
mask_length,
poisson_lambda,
eos=None,
item_transform_func=None,
):
self.dataset = dataset
self.sizes = sizes
self.vocab = vocab
self.shuffle = shuffle
self.seed = seed
self.mask_idx = mask_idx
self.mask_whole_word = mask_whole_words
self.mask_ratio = mask
self.random_ratio = mask_random
self.insert_ratio = insert
self.rotate_ratio = rotate
self.permute_sentence_ratio = permute_sentences
self.eos = eos if eos is not None else vocab.eos()
self.item_transform_func = item_transform_func
if bpe != "gpt2":
self.full_stop_index = self.vocab.eos()
else:
assert bpe == "gpt2"
self.full_stop_index = self.vocab.index("13")
self.replace_length = replace_length
if self.replace_length not in [-1, 0, 1]:
raise ValueError(f"invalid arg: replace_length={self.replace_length}")
if mask_length not in ["subword", "word", "span-poisson"]:
raise ValueError(f"invalid arg: mask-length={mask_length}")
if mask_length == "subword" and replace_length not in [0, 1]:
raise ValueError(f"if using subwords, use replace-length=1 or 0")
self.mask_span_distribution = None
if mask_length == "span-poisson":
_lambda = poisson_lambda
lambda_to_the_k = 1
e_to_the_minus_lambda = math.exp(-_lambda)
k_factorial = 1
ps = []
for k in range(0, 128):
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
lambda_to_the_k *= _lambda
k_factorial *= k + 1
if ps[-1] < 0.0000001:
break
ps = torch.FloatTensor(ps)
self.mask_span_distribution = torch.distributions.Categorical(ps)
self.epoch = 0
@property
def can_reuse_epoch_itr_across_epochs(self):
return True # only the noise changes, not item sizes
def set_epoch(self, epoch, **unused):
self.epoch = epoch
def __getitem__(self, index):
with data_utils.numpy_seed(self.seed, self.epoch, index):
tokens = self.dataset[index]
assert tokens[-1] == self.eos
source, target = tokens, tokens.clone()
if self.permute_sentence_ratio > 0.0:
source = self.permute_sentences(source, self.permute_sentence_ratio)
if self.mask_ratio > 0:
source = self.add_whole_word_mask(source, self.mask_ratio)
if self.insert_ratio > 0:
source = self.add_insertion_noise(source, self.insert_ratio)
if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio:
source = self.add_rolling_noise(source)
# there can additional changes to make:
if self.item_transform_func is not None:
source, target = self.item_transform_func(source, target)
assert (source >= 0).all()
assert (source[1:-1] >= 1).all()
assert (source <= len(self.vocab)).all()
assert source[0] == self.vocab.bos()
assert source[-1] == self.eos
return {
"id": index,
"source": source,
"target": target,
}
def __len__(self):
return len(self.dataset)
def permute_sentences(self, source, p=1.0):
full_stops = source == self.full_stop_index
# Pretend it ends with a full stop so last span is a sentence
full_stops[-2] = 1
# Tokens that are full stops, where the previous token is not
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2
result = source.clone()
num_sentences = sentence_ends.size(0)
num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0)
substitutions = torch.randperm(num_sentences)[:num_to_permute]
ordering = torch.arange(0, num_sentences)
ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]
# Ignore <bos> at start
index = 1
for i in ordering:
sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]]
result[index : index + sentence.size(0)] = sentence
index += sentence.size(0)
return result
def word_starts(self, source):
if self.mask_whole_word is not None:
is_word_start = self.mask_whole_word.gather(0, source)
else:
is_word_start = torch.ones(source.size())
is_word_start[0] = 0
is_word_start[-1] = 0
return is_word_start
def add_whole_word_mask(self, source, p):
is_word_start = self.word_starts(source)
num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
num_inserts = 0
if num_to_mask == 0:
return source
if self.mask_span_distribution is not None:
lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
# Make sure we have enough to mask
cum_length = torch.cumsum(lengths, 0)
while cum_length[-1] < num_to_mask:
lengths = torch.cat(
[
lengths,
self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
],
dim=0,
)
cum_length = torch.cumsum(lengths, 0)
# Trim to masking budget
i = 0
while cum_length[i] < num_to_mask:
i += 1
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
num_to_mask = i + 1
lengths = lengths[:num_to_mask]
# Handle 0-length mask (inserts) separately
lengths = lengths[lengths > 0]
num_inserts = num_to_mask - lengths.size(0)
num_to_mask -= num_inserts
if num_to_mask == 0:
return self.add_insertion_noise(source, num_inserts / source.size(0))
assert (lengths > 0).all()
else:
lengths = torch.ones((num_to_mask,)).long()
assert is_word_start[-1] == 0
word_starts = is_word_start.nonzero(as_tuple=False)
indices = word_starts[
torch.randperm(word_starts.size(0))[:num_to_mask]
].squeeze(1)
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
source_length = source.size(0)
assert source_length - 1 not in indices
to_keep = torch.ones(source_length, dtype=torch.bool)
is_word_start[
-1
] = 255 # acts as a long length, so spans don't go over the end of doc
if self.replace_length == 0:
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
source[indices[mask_random]] = torch.randint(
1, len(self.vocab), size=(mask_random.sum(),)
)
if self.mask_span_distribution is not None:
assert len(lengths.size()) == 1
assert lengths.size() == indices.size()
lengths -= 1
while indices.size(0) > 0:
assert lengths.size() == indices.size()
lengths -= is_word_start[indices + 1].long()
uncompleted = lengths >= 0
indices = indices[uncompleted] + 1
mask_random = mask_random[uncompleted]
lengths = lengths[uncompleted]
if self.replace_length != -1:
# delete token
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
source[indices[mask_random]] = torch.randint(
1, len(self.vocab), size=(mask_random.sum(),)
)
else:
# A bit faster when all lengths are 1
while indices.size(0) > 0:
uncompleted = is_word_start[indices + 1] == 0
indices = indices[uncompleted] + 1
mask_random = mask_random[uncompleted]
if self.replace_length != -1:
# delete token
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
source[indices[mask_random]] = torch.randint(
1, len(self.vocab), size=(mask_random.sum(),)
)
assert source_length - 1 not in indices
source = source[to_keep]
if num_inserts > 0:
source = self.add_insertion_noise(source, num_inserts / source.size(0))
return source
def add_permuted_noise(self, tokens, p):
num_words = len(tokens)
num_to_permute = math.ceil(((num_words * 2) * p) / 2.0)
substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1
tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]]
return tokens
def add_rolling_noise(self, tokens):
offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1)
tokens = torch.cat(
(tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]),
dim=0,
)
return tokens
def add_insertion_noise(self, tokens, p):
if p == 0.0:
return tokens
num_tokens = len(tokens)
n = int(math.ceil(num_tokens * p))
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
noise_mask[noise_indices] = 1
result = torch.LongTensor(n + len(tokens)).fill_(-1)
num_random = int(math.ceil(n * self.random_ratio))
result[noise_indices[num_random:]] = self.mask_idx
result[noise_indices[:num_random]] = torch.randint(
low=1, high=len(self.vocab), size=(num_random,)
)
result[~noise_mask] = tokens
assert (result >= 0).all()
return result
def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch of data
"""
return collate(
samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length
)
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return self.sizes[index]
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return self.sizes[index]
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
indices = np.random.permutation(len(self))
else:
indices = np.arange(len(self))
return indices[np.argsort(self.sizes[indices], kind="mergesort")]
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
@property
def supports_prefetch(self):
return (
hasattr(self.src, "supports_prefetch")
and self.src.supports_prefetch
and hasattr(self.tgt, "supports_prefetch")
and self.tgt.supports_prefetch
)

View File

@@ -0,0 +1,401 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from collections import Counter
from multiprocessing import Pool
import torch
from fairseq import utils
from fairseq.data import data_utils
from fairseq.file_chunker_utils import Chunker, find_offsets
from fairseq.file_io import PathManager
from fairseq.tokenizer import tokenize_line
class Dictionary:
"""A mapping from symbols to consecutive integers"""
def __init__(
self,
*, # begin keyword-only arguments
bos="<s>",
pad="<pad>",
eos="</s>",
unk="<unk>",
extra_special_symbols=None,
):
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
self.bos_index = self.add_symbol(bos)
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
if extra_special_symbols:
for s in extra_special_symbols:
self.add_symbol(s)
self.nspecial = len(self.symbols)
def __eq__(self, other):
return self.indices == other.indices
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
return self.unk_word
def get_count(self, idx):
return self.count[idx]
def __len__(self):
"""Returns the number of symbols in the dictionary"""
return len(self.symbols)
def __contains__(self, sym):
return sym in self.indices
def index(self, sym):
"""Returns the index of the specified symbol"""
assert isinstance(sym, str)
if sym in self.indices:
return self.indices[sym]
return self.unk_index
def string(
self,
tensor,
bpe_symbol=None,
escape_unk=False,
extra_symbols_to_ignore=None,
unk_string=None,
include_eos=False,
separator=" ",
):
"""Helper for converting a tensor of token indices to a string.
Can optionally remove BPE symbols or escape <unk> words.
"""
if torch.is_tensor(tensor) and tensor.dim() == 2:
return "\n".join(
self.string(
t,
bpe_symbol,
escape_unk,
extra_symbols_to_ignore,
include_eos=include_eos,
)
for t in tensor
)
extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
if not include_eos:
extra_symbols_to_ignore.add(self.eos())
def token_string(i):
if i == self.unk():
if unk_string is not None:
return unk_string
else:
return self.unk_string(escape_unk)
else:
return self[i]
if hasattr(self, "bos_index"):
extra_symbols_to_ignore.add(self.bos())
sent = separator.join(
token_string(i)
for i in tensor
if utils.item(i) not in extra_symbols_to_ignore
)
return data_utils.post_process(sent, bpe_symbol)
def unk_string(self, escape=False):
"""Return unknown string, optionally escaped as: <<unk>>"""
if escape:
return "<{}>".format(self.unk_word)
else:
return self.unk_word
def add_symbol(self, word, n=1, overwrite=False):
"""Adds a word to the dictionary"""
if word in self.indices and not overwrite:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
def update(self, new_dict):
"""Updates counts from new dictionary."""
for word in new_dict.symbols:
idx2 = new_dict.indices[word]
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + new_dict.count[idx2]
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(new_dict.count[idx2])
def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
"""Sort symbols by frequency in descending order, ignoring special ones.
Args:
- threshold defines the minimum word count
- nwords defines the total number of words in the final dictionary,
including special symbols
- padding_factor can be used to pad the dictionary size to be a
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
"""
if nwords <= 0:
nwords = len(self)
new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial)))
new_symbols = self.symbols[: self.nspecial]
new_count = self.count[: self.nspecial]
c = Counter(
dict(
sorted(zip(self.symbols[self.nspecial :], self.count[self.nspecial :]))
)
)
for symbol, count in c.most_common(nwords - self.nspecial):
if count >= threshold:
new_indices[symbol] = len(new_symbols)
new_symbols.append(symbol)
new_count.append(count)
else:
break
assert len(new_symbols) == len(new_indices)
self.count = list(new_count)
self.symbols = list(new_symbols)
self.indices = new_indices
self.pad_to_multiple_(padding_factor)
def pad_to_multiple_(self, padding_factor):
"""Pad Dictionary size to be a multiple of *padding_factor*."""
if padding_factor > 1:
i = 0
while len(self) % padding_factor != 0:
symbol = "madeupword{:04d}".format(i)
self.add_symbol(symbol, n=0)
i += 1
def bos(self):
"""Helper to get index of beginning-of-sentence symbol"""
return self.bos_index
def pad(self):
"""Helper to get index of pad symbol"""
return self.pad_index
def eos(self):
"""Helper to get index of end-of-sentence symbol"""
return self.eos_index
def unk(self):
"""Helper to get index of unk symbol"""
return self.unk_index
@classmethod
def load(cls, f):
"""Loads the dictionary from a text file with the format:
```
<symbol0> <count0>
<symbol1> <count1>
...
```
"""
d = cls()
d.add_from_file(f)
return d
def add_from_file(self, f):
"""
Loads a pre-existing dictionary from a text file and adds its symbols
to this instance.
"""
if isinstance(f, str):
try:
with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd:
self.add_from_file(fd)
except FileNotFoundError as fnfe:
raise fnfe
except UnicodeError:
raise Exception(
"Incorrect encoding detected in {}, please "
"rebuild the dataset".format(f)
)
return
lines = f.readlines()
indices_start_line = self._load_meta(lines)
for line in lines[indices_start_line:]:
try:
line, field = line.rstrip().rsplit(" ", 1)
if field == "#fairseq:overwrite":
overwrite = True
line, field = line.rsplit(" ", 1)
else:
overwrite = False
count = int(field)
word = line
if word in self and not overwrite:
raise RuntimeError(
"Duplicate word found when loading Dictionary: '{}'. "
"Duplicate words can overwrite earlier ones by adding the "
"#fairseq:overwrite flag at the end of the corresponding row "
"in the dictionary file. If using the Camembert model, please "
"download an updated copy of the model file.".format(word)
)
self.add_symbol(word, n=count, overwrite=overwrite)
except ValueError:
raise ValueError(
f"Incorrect dictionary format, expected '<token> <cnt> [flags]': \"{line}\""
)
def _save(self, f, kv_iterator):
if isinstance(f, str):
PathManager.mkdirs(os.path.dirname(f))
with PathManager.open(f, "w", encoding="utf-8") as fd:
return self.save(fd)
for k, v in kv_iterator:
print("{} {}".format(k, v), file=f)
def _get_meta(self):
return [], []
def _load_meta(self, lines):
return 0
def save(self, f):
"""Stores dictionary into a text file"""
ex_keys, ex_vals = self._get_meta()
self._save(
f,
zip(
ex_keys + self.symbols[self.nspecial :],
ex_vals + self.count[self.nspecial :],
),
)
def dummy_sentence(self, length):
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
t[-1] = self.eos()
return t
def encode_line(
self,
line,
line_tokenizer=tokenize_line,
add_if_not_exist=True,
consumer=None,
append_eos=True,
reverse_order=False,
) -> torch.IntTensor:
words = line_tokenizer(line)
if reverse_order:
words = list(reversed(words))
nwords = len(words)
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
for i, word in enumerate(words):
if add_if_not_exist:
idx = self.add_symbol(word)
else:
idx = self.index(word)
if consumer is not None:
consumer(word, idx)
ids[i] = idx
if append_eos:
ids[nwords] = self.eos_index
return ids
@staticmethod
def _add_file_to_dictionary_single_worker(
filename,
tokenize,
eos_word,
start_offset,
end_offset,
):
counter = Counter()
with Chunker(filename, start_offset, end_offset) as line_iterator:
for line in line_iterator:
for word in tokenize(line):
counter.update([word])
counter.update([eos_word])
return counter
@staticmethod
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
def merge_result(counter):
for w, c in sorted(counter.items()):
dict.add_symbol(w, c)
local_file = PathManager.get_local_path(filename)
offsets = find_offsets(local_file, num_workers)
if num_workers > 1:
chunks = zip(offsets, offsets[1:])
pool = Pool(processes=num_workers)
results = []
for (start_offset, end_offset) in chunks:
results.append(
pool.apply_async(
Dictionary._add_file_to_dictionary_single_worker,
(
local_file,
tokenize,
dict.eos_word,
start_offset,
end_offset,
),
)
)
pool.close()
pool.join()
for r in results:
merge_result(r.get())
else:
merge_result(
Dictionary._add_file_to_dictionary_single_worker(
local_file, tokenize, dict.eos_word, offsets[0], offsets[1]
)
)
class TruncatedDictionary(object):
def __init__(self, wrapped_dict, length):
self.__class__ = type(
wrapped_dict.__class__.__name__,
(self.__class__, wrapped_dict.__class__),
{},
)
self.__dict__ = wrapped_dict.__dict__
self.wrapped_dict = wrapped_dict
self.length = min(len(self.wrapped_dict), length)
def __len__(self):
return self.length
def __getitem__(self, i):
if i < self.length:
return self.wrapped_dict[i]
return self.wrapped_dict.unk()

View File

@@ -0,0 +1,29 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
from fairseq import registry
build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry(
"--tokenizer",
default=None,
)
build_bpe, register_bpe, BPE_REGISTRY, _ = registry.setup_registry(
"--bpe",
default=None,
)
# automatically import any Python files in the encoders/ directory
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
module = file[: file.find(".py")]
importlib.import_module("fairseq.data.encoders." + module)

View File

@@ -0,0 +1,48 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.data.encoders.byte_utils import (
SPACE,
SPACE_ESCAPE,
byte_encode,
smart_byte_decode,
)
from fairseq.dataclass import FairseqDataclass
@dataclass
class ByteBpeConfig(FairseqDataclass):
sentencepiece_model_path: str = field(
default="???", metadata={"help": "path to sentencepiece model"}
)
@register_bpe("byte_bpe", dataclass=ByteBpeConfig)
class ByteBPE(object):
def __init__(self, cfg):
vocab = file_utils.cached_path(cfg.sentencepiece_model_path)
try:
import sentencepiece as spm
self.sp = spm.SentencePieceProcessor()
self.sp.Load(vocab)
except ImportError:
raise ImportError(
"Please install sentencepiece with: pip install sentencepiece"
)
def encode(self, x: str) -> str:
byte_encoded = byte_encode(x)
return SPACE.join(self.sp.EncodeAsPieces(byte_encoded))
@staticmethod
def decode(x: str) -> str:
unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
return smart_byte_decode(unescaped)

View File

@@ -0,0 +1,51 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import re
WHITESPACE_NORMALIZER = re.compile(r"\s+")
SPACE = chr(32)
SPACE_ESCAPE = chr(9601)
# excluding non-breaking space (160) here
PRINTABLE_LATIN = set(
list(range(32, 126 + 1)) + list(range(161, 172 + 1)) + list(range(174, 255 + 1))
)
BYTE_TO_BCHAR = {
b: chr(b) if b in PRINTABLE_LATIN else chr(256 + b) for b in range(256)
}
BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
def byte_encode(x: str) -> str:
normalized = WHITESPACE_NORMALIZER.sub(SPACE, x)
return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")])
def byte_decode(x: str) -> str:
try:
return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8")
except ValueError:
return ""
def smart_byte_decode(x: str) -> str:
output = byte_decode(x)
if output == "":
# DP the best recovery (max valid chars) if it's broken
n_bytes = len(x)
f = [0 for _ in range(n_bytes + 1)]
pt = [0 for _ in range(n_bytes + 1)]
for i in range(1, n_bytes + 1):
f[i], pt[i] = f[i - 1], i - 1
for j in range(1, min(4, i) + 1):
if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0:
f[i], pt[i] = f[i - j] + 1, i - j
cur_pt = n_bytes
while cur_pt > 0:
if f[cur_pt] == f[pt[cur_pt]] + 1:
output = byte_decode(x[pt[cur_pt] : cur_pt]) + output
cur_pt = pt[cur_pt]
return output

View File

@@ -0,0 +1,34 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.data.encoders import register_bpe
from fairseq.data.encoders.byte_utils import (
SPACE,
SPACE_ESCAPE,
byte_encode,
smart_byte_decode,
)
@register_bpe("bytes")
class Bytes(object):
def __init__(self, *unused):
pass
@staticmethod
def add_args(parser):
pass
@staticmethod
def encode(x: str) -> str:
encoded = byte_encode(x)
escaped = encoded.replace(SPACE, SPACE_ESCAPE)
return SPACE.join(list(escaped))
@staticmethod
def decode(x: str) -> str:
unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
return smart_byte_decode(unescaped)

View File

@@ -0,0 +1,30 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.data.encoders import register_bpe
SPACE = chr(32)
SPACE_ESCAPE = chr(9601)
@register_bpe("characters")
class Characters(object):
def __init__(self, *unused):
pass
@staticmethod
def add_args(parser):
pass
@staticmethod
def encode(x: str) -> str:
escaped = x.replace(SPACE, SPACE_ESCAPE)
return SPACE.join(list(escaped))
@staticmethod
def decode(x: str) -> str:
return x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)

View File

@@ -0,0 +1,36 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@dataclass
class fastBPEConfig(FairseqDataclass):
bpe_codes: str = field(default="???", metadata={"help": "path to fastBPE BPE"})
@register_bpe("fastbpe", dataclass=fastBPEConfig)
class fastBPE(object):
def __init__(self, cfg):
if cfg.bpe_codes is None:
raise ValueError("--bpe-codes is required for --bpe=fastbpe")
codes = file_utils.cached_path(cfg.bpe_codes)
try:
import fastBPE
self.bpe = fastBPE.fastBPE(codes)
self.bpe_symbol = "@@ "
except ImportError:
raise ImportError("Please install fastBPE with: pip install fastBPE")
def encode(self, x: str) -> str:
return self.bpe.apply([x])[0]
def decode(self, x: str) -> str:
return (x + " ").replace(self.bpe_symbol, "").rstrip()

View File

@@ -0,0 +1,45 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
from .gpt2_bpe_utils import get_encoder
DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
@dataclass
class GPT2BPEConfig(FairseqDataclass):
gpt2_encoder_json: str = field(
default=DEFAULT_ENCODER_JSON, metadata={"help": "path to encoder.json"}
)
gpt2_vocab_bpe: str = field(
default=DEFAULT_VOCAB_BPE, metadata={"help": "path to vocab.bpe"}
)
@register_bpe("gpt2", dataclass=GPT2BPEConfig)
class GPT2BPE(object):
def __init__(self, cfg):
encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json)
vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe)
self.bpe = get_encoder(encoder_json, vocab_bpe)
def encode(self, x: str) -> str:
return " ".join(map(str, self.bpe.encode(x)))
def decode(self, x: str) -> str:
return self.bpe.decode(
[int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]
)
def is_beginning_of_word(self, x: str) -> bool:
return self.decode(x).startswith(" ")

View File

@@ -0,0 +1,140 @@
"""
Byte pair encoding utilities from GPT-2.
Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py
Original license: MIT
"""
import json
from functools import lru_cache
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class Encoder:
def __init__(self, encoder, bpe_merges, errors="replace"):
self.encoder = encoder
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
try:
import regex as re
self.re = re
except ImportError:
raise ImportError("Please install regex with: pip install regex")
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = self.re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
for token in self.re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
)
return bpe_tokens
def decode(self, tokens):
text = "".join([self.decoder.get(token, token) for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode(
"utf-8", errors=self.errors
)
return text
def get_encoder(encoder_json_path, vocab_bpe_path):
with open(encoder_json_path, "r") as f:
encoder = json.load(f)
with open(vocab_bpe_path, "r", encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
return Encoder(
encoder=encoder,
bpe_merges=bpe_merges,
)

View File

@@ -0,0 +1,50 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from typing import Optional
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@dataclass
class BertBPEConfig(FairseqDataclass):
bpe_cased: bool = field(default=False, metadata={"help": "set for cased BPE"})
bpe_vocab_file: Optional[str] = field(
default=None, metadata={"help": "bpe vocab file"}
)
@register_bpe("bert", dataclass=BertBPEConfig)
class BertBPE(object):
def __init__(self, cfg):
try:
from transformers import BertTokenizer
except ImportError:
raise ImportError(
"Please install transformers with: pip install transformers"
)
if cfg.bpe_vocab_file:
self.bert_tokenizer = BertTokenizer(
cfg.bpe_vocab_file, do_lower_case=not cfg.bpe_cased
)
else:
vocab_file_name = (
"bert-base-cased" if cfg.bpe_cased else "bert-base-uncased"
)
self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name)
def encode(self, x: str) -> str:
return " ".join(self.bert_tokenizer.tokenize(x))
def decode(self, x: str) -> str:
return self.bert_tokenizer.clean_up_tokenization(
self.bert_tokenizer.convert_tokens_to_string(x.split(" "))
)
def is_beginning_of_word(self, x: str) -> bool:
return not x.startswith("##")

View File

@@ -0,0 +1,50 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
from fairseq import file_utils
@dataclass
class HuggingFaceByteLevelBPEConfig(FairseqDataclass):
bpe_merges: str = field(default="???", metadata={"help": "path to merges.txt"})
bpe_vocab: str = field(default="???", metadata={"help": "path to vocab.json"})
bpe_add_prefix_space: bool = field(
default=False, metadata={"help": "add prefix space before encoding"}
)
@register_bpe("hf_byte_bpe", dataclass=HuggingFaceByteLevelBPEConfig)
class HuggingFaceByteLevelBPE(object):
def __init__(self, cfg):
try:
from tokenizers import ByteLevelBPETokenizer
except ImportError:
raise ImportError(
"Please install huggingface/tokenizers with: " "pip install tokenizers"
)
bpe_vocab = file_utils.cached_path(cfg.bpe_vocab)
bpe_merges = file_utils.cached_path(cfg.bpe_merges)
self.bpe = ByteLevelBPETokenizer(
bpe_vocab,
bpe_merges,
add_prefix_space=cfg.bpe_add_prefix_space,
)
def encode(self, x: str) -> str:
return " ".join(map(str, self.bpe.encode(x).ids))
def decode(self, x: str) -> str:
return self.bpe.decode(
[int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]
)
def is_beginning_of_word(self, x: str) -> bool:
return self.decode(x).startswith(" ")

View File

@@ -0,0 +1,49 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq.data.encoders import register_tokenizer
from fairseq.dataclass import FairseqDataclass
@dataclass
class MosesTokenizerConfig(FairseqDataclass):
source_lang: str = field(default="en", metadata={"help": "source language"})
target_lang: str = field(default="en", metadata={"help": "target language"})
moses_no_dash_splits: bool = field(
default=False, metadata={"help": "don't apply dash split rules"}
)
moses_no_escape: bool = field(
default=False,
metadata={"help": "don't perform HTML escaping on apostrophe, quotes, etc."},
)
@register_tokenizer("moses", dataclass=MosesTokenizerConfig)
class MosesTokenizer(object):
def __init__(self, cfg: MosesTokenizerConfig):
self.cfg = cfg
try:
from sacremoses import MosesTokenizer, MosesDetokenizer
self.tok = MosesTokenizer(cfg.source_lang)
self.detok = MosesDetokenizer(cfg.target_lang)
except ImportError:
raise ImportError(
"Please install Moses tokenizer with: pip install sacremoses"
)
def encode(self, x: str) -> str:
return self.tok.tokenize(
x,
aggressive_dash_splits=(not self.cfg.moses_no_dash_splits),
return_str=True,
escape=(not self.cfg.moses_no_escape),
)
def decode(self, x: str) -> str:
return self.detok.detokenize(x.split())

View File

@@ -0,0 +1,24 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.data.encoders import register_tokenizer
from fairseq.dataclass import FairseqDataclass
@register_tokenizer("nltk", dataclass=FairseqDataclass)
class NLTKTokenizer(object):
def __init__(self, *unused):
try:
from nltk.tokenize import word_tokenize
self.word_tokenize = word_tokenize
except ImportError:
raise ImportError("Please install nltk with: pip install nltk")
def encode(self, x: str) -> str:
return " ".join(self.word_tokenize(x))
def decode(self, x: str) -> str:
return x

View File

@@ -0,0 +1,65 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from typing import Optional
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@dataclass
class SentencepieceConfig(FairseqDataclass):
sentencepiece_model: str = field(
default="???", metadata={"help": "path to sentencepiece model"}
)
sentencepiece_enable_sampling: bool = field(
default=False, metadata={"help": "enable sampling"}
)
sentencepiece_alpha: Optional[float] = field(
default=None,
metadata={
"help": "soothing parameter for unigram sampling, "
"and merge probability for BPE-dropout"
},
)
@register_bpe("sentencepiece", dataclass=SentencepieceConfig)
class SentencepieceBPE(object):
def __init__(self, cfg):
self.enable_sampling = cfg.sentencepiece_enable_sampling
self.alpha = cfg.sentencepiece_alpha
sentencepiece_model = file_utils.cached_path(cfg.sentencepiece_model)
try:
import sentencepiece as spm
self.sp = spm.SentencePieceProcessor()
self.sp.Load(sentencepiece_model)
except ImportError:
raise ImportError(
"Please install sentencepiece with: pip install sentencepiece"
)
def encode(self, x: str) -> str:
return " ".join(
self.sp.Encode(
x, out_type=str, enable_sampling=self.enable_sampling, alpha=self.alpha
)
)
def decode(self, x: str) -> str:
return x.replace(" ", "").replace("\u2581", " ").strip()
def is_beginning_of_word(self, x: str) -> bool:
if x in ["<unk>", "<s>", "</s>", "<pad>"]:
# special elements are always considered beginnings
# HACK: this logic is already present in fairseq/tasks/masked_lm.py
# but these special tokens are also contained in the sentencepiece
# vocabulary which causes duplicate special tokens. This hack makes
# sure that they are all taken into account.
return True
return x.startswith("\u2581")

View File

@@ -0,0 +1,21 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import re
from fairseq.data.encoders import register_tokenizer
from fairseq.dataclass import FairseqDataclass
@register_tokenizer("space", dataclass=FairseqDataclass)
class SpaceTokenizer(object):
def __init__(self, *unused):
self.space_tok = re.compile(r"\s+")
def encode(self, x: str) -> str:
return self.space_tok.sub(" ", x)
def decode(self, x: str) -> str:
return x

View File

@@ -0,0 +1,54 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@dataclass
class SubwordNMTBPEConfig(FairseqDataclass):
bpe_codes: str = field(default="???", metadata={"help": "path to subword NMT BPE"})
bpe_separator: str = field(default="@@", metadata={"help": "BPE separator"})
@register_bpe("subword_nmt", dataclass=SubwordNMTBPEConfig)
class SubwordNMTBPE(object):
def __init__(self, cfg):
if cfg.bpe_codes is None:
raise ValueError("--bpe-codes is required for --bpe=subword_nmt")
codes = file_utils.cached_path(cfg.bpe_codes)
try:
from subword_nmt import apply_bpe
bpe_parser = apply_bpe.create_parser()
bpe_args = bpe_parser.parse_args(
[
"--codes",
codes,
"--separator",
cfg.bpe_separator,
]
)
self.bpe = apply_bpe.BPE(
bpe_args.codes,
bpe_args.merges,
bpe_args.separator,
None,
bpe_args.glossaries,
)
self.bpe_symbol = bpe_args.separator + " "
except ImportError:
raise ImportError(
"Please install subword_nmt with: pip install subword-nmt"
)
def encode(self, x: str) -> str:
return self.bpe.process_line(x)
def decode(self, x: str) -> str:
return (x + " ").replace(self.bpe_symbol, "").rstrip()

View File

@@ -0,0 +1,30 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq.data import encoders
def get_whole_word_mask(args, dictionary):
bpe = encoders.build_bpe(args)
if bpe is not None:
def is_beginning_of_word(i):
if i < dictionary.nspecial:
# special elements are always considered beginnings
return True
tok = dictionary[i]
if tok.startswith("madeupword"):
return True
try:
return bpe.is_beginning_of_word(tok)
except ValueError:
return True
mask_whole_words = torch.ByteTensor(
list(map(is_beginning_of_word, range(len(dictionary))))
)
return mask_whole_words
return None

View File

@@ -0,0 +1,205 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
import torch.utils.data
from fairseq.data import data_utils
logger = logging.getLogger(__name__)
class EpochListening:
"""Mixin for receiving updates whenever the epoch increments."""
@property
def can_reuse_epoch_itr_across_epochs(self):
"""
Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for
this dataset across epochs.
This needs to return ``False`` if the sample sizes can change across
epochs, in which case we may need to regenerate batches at each epoch.
If your dataset relies in ``set_epoch`` then you should consider setting
this to ``False``.
"""
return True
def set_epoch(self, epoch):
"""Will receive the updated epoch number at the beginning of the epoch."""
pass
class FairseqDataset(torch.utils.data.Dataset, EpochListening):
"""A dataset that provides helpers for batching."""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def collater(self, samples):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
raise NotImplementedError
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
raise NotImplementedError
def num_tokens_vec(self, indices):
"""Return the number of tokens for a set of positions defined by indices.
This value is used to enforce ``--max-tokens`` during batching."""
raise NotImplementedError
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
raise NotImplementedError
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
return np.arange(len(self), dtype=np.int64)
@property
def supports_prefetch(self):
"""Whether this dataset supports prefetching."""
return False
def attr(self, attr: str, index: int):
return getattr(self, attr, None)
def prefetch(self, indices):
"""Prefetch the data required for this epoch."""
raise NotImplementedError
def get_batch_shapes(self):
"""
Return a list of valid batch shapes, for example::
[(8, 512), (16, 256), (32, 128)]
The first dimension of each tuple is the batch size and can be ``None``
to automatically infer the max batch size based on ``--max-tokens``.
The second dimension of each tuple is the max supported length as given
by :func:`fairseq.data.FairseqDataset.num_tokens`.
This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size`
to restrict batch shapes. This is useful on TPUs to avoid too many
dynamic shapes (and recompilations).
"""
return None
def batch_by_size(
self,
indices,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
):
"""
Given an ordered set of indices, return batches according to
*max_tokens*, *max_sentences* and *required_batch_size_multiple*.
"""
from fairseq.data import data_utils
fixed_shapes = self.get_batch_shapes()
if fixed_shapes is not None:
def adjust_bsz(bsz, num_tokens):
if bsz is None:
assert max_tokens is not None, "Must specify --max-tokens"
bsz = max_tokens // num_tokens
if max_sentences is not None:
bsz = min(bsz, max_sentences)
elif (
bsz >= required_batch_size_multiple
and bsz % required_batch_size_multiple != 0
):
bsz -= bsz % required_batch_size_multiple
return bsz
fixed_shapes = np.array(
[
[adjust_bsz(bsz, num_tokens), num_tokens]
for (bsz, num_tokens) in fixed_shapes
]
)
try:
num_tokens_vec = self.num_tokens_vec(indices).astype("int64")
except NotImplementedError:
num_tokens_vec = None
return data_utils.batch_by_size(
indices,
num_tokens_fn=self.num_tokens,
num_tokens_vec=num_tokens_vec,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
fixed_shapes=fixed_shapes,
)
def filter_indices_by_size(self, indices, max_sizes):
"""
Filter a list of sample indices. Remove those that are longer than
specified in *max_sizes*.
WARNING: don't update, override method in child classes
Args:
indices (np.array): original array of sample indices
max_sizes (int or list[int] or tuple[int]): max sample size,
can be defined separately for src and tgt (then list or tuple)
Returns:
np.array: filtered sample array
list: list of removed indices
"""
if isinstance(max_sizes, float) or isinstance(max_sizes, int):
if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray):
ignored = indices[self.sizes[indices] > max_sizes].tolist()
indices = indices[self.sizes[indices] <= max_sizes]
elif (
hasattr(self, "sizes")
and isinstance(self.sizes, list)
and len(self.sizes) == 1
):
ignored = indices[self.sizes[0][indices] > max_sizes].tolist()
indices = indices[self.sizes[0][indices] <= max_sizes]
else:
indices, ignored = data_utils._filter_by_size_dynamic(
indices, self.size, max_sizes
)
else:
indices, ignored = data_utils._filter_by_size_dynamic(
indices, self.size, max_sizes
)
return indices, ignored
@property
def supports_fetch_outside_dataloader(self):
"""Whether this dataset supports fetching outside the workers of the dataloader."""
return True
class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening):
"""
For datasets that need to be read sequentially, usually because the data is
being streamed or otherwise can't be manipulated on a single machine.
"""
def __iter__(self):
raise NotImplementedError

View File

@@ -0,0 +1,107 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import subprocess
import threading
from pathlib import Path
import numpy as np
import torch
def fasta_file_path(prefix_path):
return prefix_path + ".fasta"
class FastaDataset(torch.utils.data.Dataset):
"""
For loading protein sequence datasets in the common FASTA data format
"""
def __init__(self, path: str, cache_indices=False):
self.fn = fasta_file_path(path)
self.threadlocal = threading.local()
self.cache = Path(f"{path}.fasta.idx.npy")
if cache_indices:
if self.cache.exists():
self.offsets, self.sizes = np.load(self.cache)
else:
self.offsets, self.sizes = self._build_index(path)
np.save(self.cache, np.stack([self.offsets, self.sizes]))
else:
self.offsets, self.sizes = self._build_index(path)
def _get_file(self):
if not hasattr(self.threadlocal, "f"):
self.threadlocal.f = open(self.fn, "r")
return self.threadlocal.f
def __getitem__(self, idx):
f = self._get_file()
f.seek(self.offsets[idx])
desc = f.readline().strip()
line = f.readline()
seq = ""
while line != "" and line[0] != ">":
seq += line.strip()
line = f.readline()
return desc, seq
def __len__(self):
return self.offsets.size
def _build_index(self, path: str):
# Use grep and awk to get 100M/s on local SSD.
# Should process your enormous 100G fasta in ~10 min single core...
path = fasta_file_path(path)
bytes_offsets = subprocess.check_output(
f"cat {path} | tqdm --bytes --total $(wc -c < {path})"
"| grep --byte-offset '^>' -o | cut -d: -f1",
shell=True,
)
fasta_lengths = subprocess.check_output(
f"cat {path} | tqdm --bytes --total $(wc -c < {path})"
"| awk '/^>/ {print \"\";next;} { printf(\"%s\",$0);}' | tail -n+2 | awk '{print length($1)}'",
shell=True,
)
bytes_np = np.fromstring(bytes_offsets, dtype=np.int64, sep=" ")
sizes_np = np.fromstring(fasta_lengths, dtype=np.int64, sep=" ")
return bytes_np, sizes_np
def __setstate__(self, state):
self.__dict__ = state
self.threadlocal = threading.local()
def __getstate__(self):
d = {}
for i, v in self.__dict__.items():
if i != "threadlocal":
d[i] = v
return d
def __del__(self):
if hasattr(self.threadlocal, "f"):
self.threadlocal.f.close()
del self.threadlocal.f
@staticmethod
def exists(path):
return os.path.exists(fasta_file_path(path))
class EncodedFastaDataset(FastaDataset):
"""
The FastaDataset returns raw sequences - this allows us to return
indices with a dictionary instead.
"""
def __init__(self, path, dictionary):
super().__init__(path, cache_indices=True)
self.dictionary = dictionary
def __getitem__(self, idx):
desc, seq = super().__getitem__(idx)
return self.dictionary.encode_line(seq, line_tokenizer=list).long()

View File

@@ -0,0 +1,21 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .huffman_coder import HuffmanCodeBuilder, HuffmanCoder
from .huffman_mmap_indexed_dataset import (
HuffmanMMapIndex,
HuffmanMMapIndexedDataset,
HuffmanMMapIndexedDatasetBuilder,
vocab_file_path,
)
__all__ = [
"HuffmanCoder",
"HuffmanCodeBuilder",
"HuffmanMMapIndexedDatasetBuilder",
"HuffmanMMapIndexedDataset",
"HuffmanMMapIndex",
"vocab_file_path",
]

View File

@@ -0,0 +1,267 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import re
import typing as tp
from collections import Counter, deque
from dataclasses import dataclass
from bitarray import bitarray, util
from fairseq.data import Dictionary
# basically we have to write to addressable bytes for the memory mapped
# dataset loader. Sentences that get encoded to a length that is not a
# multiple of BLOCKSIZE (a byte) will be padded to fit. (see _pad in the coder)
BLOCKSIZE = 8
class HuffmanCoder:
def __init__(
self, root: "HuffmanNode", bos="<s>", pad="<pad>", eos="</s>", unk="<unk>"
):
self.root = root
self.table = root.code_table()
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
def _pad(self, a: bitarray) -> bitarray:
"""
bitpadding, 1 then 0.
If the array is already a multiple of blocksize, we add a full block.
"""
pad_len = BLOCKSIZE - (len(a) % BLOCKSIZE) - 1
padding = bitarray("1" + "0" * pad_len)
return a + padding
def _unpad(self, a: bitarray) -> bitarray:
"""
remove the bitpadding.
There will be a set of 0s preceded by a 1 at the end of the bitarray, we remove that
"""
# count the 0 padding at the end until we find the first 1
# we want to remove the one too
remove_cnt = util.rindex(a, 1)
return a[:remove_cnt]
def encode(self, iter: tp.List[str]) -> bytes:
"""
encode a list of tokens a return bytes. We use bitpadding to make sure the encoded bits fit in bytes.
"""
a = bitarray()
for token in iter:
code = self.get_code(token)
if code is None:
if self.unk_word is None:
raise Exception(f"unknown token {token} cannot be encoded.")
else:
token = self.unk_word
a = a + self.get_code(token)
return self._pad(a).tobytes()
def decode(self, bits: bytes) -> tp.Iterator["HuffmanNode"]:
"""
take bitpadded bytes and decode it to a set of leaves. You can then use each node to find the symbol/id
"""
a = bitarray()
a.frombytes(bits)
return self.root.decode(self._unpad(a))
def get_code(self, symbol: str) -> tp.Optional[bitarray]:
node = self.get_node(symbol)
return None if node is None else node.code
def get_node(self, symbol: str) -> "HuffmanNode":
return self.table.get(symbol)
@classmethod
def from_file(
cls,
filename: str,
bos="<s>",
pad="<pad>",
eos="</s>",
unk="<unk>",
) -> "HuffmanCoder":
builder = HuffmanCodeBuilder.from_file(filename)
return builder.build_code(bos=bos, pad=pad, eos=eos, unk=unk)
def to_file(self, filename, sep="\t"):
nodes = list(self.table.values())
nodes.sort(key=lambda n: n.id)
with open(filename, "w", encoding="utf-8") as output:
for n in nodes:
output.write(f"{n.symbol}{sep}{n.count}\n")
def __iter__(self):
for n in self.table.values():
yield n
def merge(self, other_coder: "HuffmanCoder") -> "HuffmanCoder":
builder = HuffmanCodeBuilder()
for n in self:
builder.increment(n.symbol, n.count)
for n in other_coder:
builder.increment(n.symbol, n.count)
return builder.build_code()
def __eq__(self, other: "HuffmanCoder") -> bool:
return self.table == other.table
def __len__(self) -> int:
return len(self.table)
def __contains__(self, sym: str) -> bool:
return sym in self.table
def to_dictionary(self) -> Dictionary:
dictionary = Dictionary(bos=self.bos, unk=self.unk, pad=self.pad, eos=self.eos)
for n in self:
dictionary.add_symbol(n.symbol, n=n.count)
dictionary.finalize()
return dictionary
@dataclass
class HuffmanNode:
"""
a node in a Huffman tree
"""
id: int
count: int
symbol: tp.Optional[str] = None
left: tp.Optional["HuffmanNode"] = None
right: tp.Optional["HuffmanNode"] = None
code: tp.Optional[bitarray] = None
def is_leaf(self) -> bool:
return self.left is None and self.right is None
def code_table(
self, prefix: tp.Optional[bitarray] = None
) -> tp.Dict[str, "HuffmanNode"]:
defaulted_prefix = prefix if prefix is not None else bitarray()
if self.is_leaf():
self.code = (
defaulted_prefix if len(defaulted_prefix) > 0 else bitarray("0")
) # leaf could be the root if there is only one symbol
return {self.symbol: self}
codes_right = self.right.code_table(defaulted_prefix + bitarray([0]))
codes_left = self.left.code_table(defaulted_prefix + bitarray([1]))
return {**codes_left, **codes_right}
def decode(self, bits: bitarray) -> tp.Iterator["HuffmanNode"]:
current_node = self
for bit in bits:
if bit == 0: # go right
current_node = current_node.right
else: # go left
current_node = current_node.left
if current_node is None:
# we shouldn't be on a leaf here
raise Exception("fell off a leaf")
if current_node.is_leaf():
yield current_node
current_node = self
if current_node != self:
raise Exception("couldn't decode all the bits")
class HuffmanCodeBuilder:
"""
build a dictionary with occurence count and then build the Huffman code for it.
"""
def __init__(self):
self.symbols = Counter()
def add_symbols(self, *syms) -> None:
self.symbols.update(syms)
def increment(self, symbol: str, cnt: int) -> None:
self.symbols[symbol] += cnt
@classmethod
def from_file(cls, filename):
c = cls()
with open(filename, "r", encoding="utf-8") as input:
for line in input:
split = re.split(r"[\s]+", line)
c.increment(split[0], int(split[1]))
return c
def to_file(self, filename, sep="\t"):
with open(filename, "w", encoding="utf-8") as output:
for (tok, cnt) in self.symbols.most_common():
output.write(f"{tok}{sep}{cnt}\n")
def _smallest(self, q1: deque, q2: deque) -> HuffmanNode:
if len(q1) == 0:
return q2.pop()
if len(q2) == 0:
return q1.pop()
if q1[-1].count < q2[-1].count:
return q1.pop()
return q2.pop()
def __add__(self, c: "HuffmanCodeBuilder") -> "HuffmanCodeBuilder":
new_c = self.symbols + c.symbols
new_b = HuffmanCodeBuilder()
new_b.symbols = new_c
return new_b
def build_code(
self,
bos="<s>",
pad="<pad>",
eos="</s>",
unk="<unk>",
) -> HuffmanCoder:
assert len(self.symbols) > 0, "cannot build code from empty list of symbols"
if self.symbols[bos] == 0:
self.add_symbols(bos)
if self.symbols[pad] == 0:
self.add_symbols(pad)
if self.symbols[eos] == 0:
self.add_symbols(eos)
if self.symbols[unk] == 0:
self.add_symbols(unk)
node_id = 0
leaves_queue = deque(
[
HuffmanNode(symbol=symbol, count=count, id=idx)
for idx, (symbol, count) in enumerate(self.symbols.most_common())
]
) # left are the most common, right are the least common
if len(leaves_queue) == 1:
root = leaves_queue.pop()
root.id = 0
return HuffmanCoder(root)
nodes_queue = deque()
while len(leaves_queue) > 0 or len(nodes_queue) != 1:
# get the lowest two nodes at the head of each queue
node1 = self._smallest(leaves_queue, nodes_queue)
node2 = self._smallest(leaves_queue, nodes_queue)
# add new node
nodes_queue.appendleft(
HuffmanNode(
count=node1.count + node2.count, left=node1, right=node2, id=node_id
)
)
node_id += 1
# we are left with the root
return HuffmanCoder(nodes_queue.pop(), bos=bos, pad=pad, eos=eos, unk=unk)

View File

@@ -0,0 +1,287 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import mmap
import os
import shutil
import struct
import typing as tp
from functools import lru_cache
import numpy as np
import torch
from fairseq.data import indexed_dataset
from fairseq.data.huffman import HuffmanCoder
from fairseq.file_io import PathManager
class HuffmanMMapIndex:
"""
keep an index of the offsets in the huffman binary file.
First a header, then the list of sizes (num tokens) for each instance and finally
the addresses of each instance.
"""
_HDR_MAGIC = b"HUFFIDX\x00\x00"
_VERSION = 1
@classmethod
def writer(cls, path: str, data_len: int):
class _Writer:
def __enter__(self):
self._file = open(path, "wb")
# write header (magic + version)
self._file.write(cls._HDR_MAGIC)
self._file.write(struct.pack("<Q", cls._VERSION))
self._file.write(struct.pack("<Q", data_len))
return self
def write(self, sizes, pointers):
# add number of items in the index to the header
self._file.write(struct.pack("<Q", len(sizes)))
# write sizes
sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order="C"))
del sizes
# write address pointers
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order="C"))
del pointers
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
return _Writer()
def __init__(self, path):
with open(path, "rb") as stream:
# read headers
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
(version,) = struct.unpack("<Q", stream.read(8))
assert (
self._VERSION == version
), f"Unexpected file version{version} != code version {self._VERSION}"
# read length of data file
(self._data_len,) = struct.unpack("<Q", stream.read(8))
# read number of items in data file/index
(self._len,) = struct.unpack("<Q", stream.read(8))
offset = stream.tell()
indexed_dataset._warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
self._sizes = np.frombuffer(
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
)
self._pointers = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._len,
offset=offset + self._sizes.nbytes,
)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
def __iter__(self):
for i in range(self._len):
yield self[i]
@property
def data_len(self):
return self._data_len
@property
def sizes(self):
return self._sizes
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def vocab_file_path(prefix_path):
return prefix_path + ".vocab"
class HuffmanMMapIndexedDataset(torch.utils.data.Dataset):
"""
an indexed dataset that use mmap and memoryview to access data from disk
that was compressed with a HuffmanCoder.
"""
def __init__(self, prefix_path):
super().__init__()
self._prefix_path = None
self._index = None
self._bin_buffer = None
self._coder = None
self._file = None
self._bin_buffer_mmap = None
self._do_init(prefix_path)
def __getstate__(self):
return self._prefix_path
def __setstate__(self, state):
self._do_init(state)
def _do_init(self, prefix_path):
self._prefix_path = prefix_path
self._index = HuffmanMMapIndex(
indexed_dataset.index_file_path(self._prefix_path)
)
self._coder = HuffmanCoder.from_file(vocab_file_path(self._prefix_path))
indexed_dataset._warmup_mmap_file(
indexed_dataset.data_file_path(self._prefix_path)
)
self._file = os.open(
indexed_dataset.data_file_path(self._prefix_path), os.O_RDONLY
)
self._bin_buffer_mmap = mmap.mmap(
self._file,
self._index.data_len,
access=mmap.ACCESS_READ,
)
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def __del__(self):
del self._bin_buffer
if self._file:
os.close(self._file)
del self._index
def __len__(self):
return len(self._index)
def _decode(self, i):
ptr, _ = self._index[i]
if i == 0:
raw_bytes = self._bin_buffer[:ptr]
else:
(prev_ptr, _) = self._index[i - 1]
raw_bytes = self._bin_buffer[prev_ptr:ptr]
return self._coder.decode(raw_bytes.tobytes())
@lru_cache(maxsize=8)
def __getitem__(self, i):
nodes = self._decode(i)
return torch.tensor([n.id for n in nodes], dtype=torch.int64)
def __iter__(self):
for idx in range(len(self)):
yield self[idx]
def get_symbols(self, i):
nodes = self._decode(i)
for n in nodes:
yield n.symbol
@property
def sizes(self):
return self._index.sizes
@property
def supports_prefetch(self):
return False
@property
def coder(self):
return self._coder
@staticmethod
def exists(prefix_path):
return (
PathManager.exists(indexed_dataset.index_file_path(prefix_path))
and PathManager.exists(indexed_dataset.data_file_path(prefix_path))
and PathManager.exists(vocab_file_path(prefix_path))
)
class HuffmanMMapIndexedDatasetBuilder:
"""
Helper to build a memory mapped datasets with a huffman encoder.
You can either open/close this manually or use it as a ContextManager.
Provide your own coder, it will then be stored alongside the dataset.
The builder will first write the vocab file, then open the binary file so you can stream
into it, finally the index will be written when the builder is closed (your index should fit in memory).
"""
def __init__(self, path_prefix: str, coder: HuffmanCoder) -> None:
self._path_prefix = path_prefix
self._coder = coder
self._sizes = []
self._ptrs = []
self._data_len = 0
def open(self):
self._coder.to_file(vocab_file_path(self._path_prefix))
self._data_file = open(indexed_dataset.data_file_path(self._path_prefix), "wb")
def __enter__(self) -> "HuffmanMMapIndexedDatasetBuilder":
self.open()
return self
def add_item(self, tokens: tp.List[str]) -> None:
"""
add a list of tokens to the dataset, they will compressed with the
provided coder before being written to file.
"""
encoded = self._coder.encode(tokens)
code_len = len(encoded)
last_ptr = 0
if len(self._ptrs) > 0:
last_ptr = self._ptrs[-1]
self._sizes.append(len(tokens))
self._ptrs.append(last_ptr + code_len)
self._data_len += code_len
self._data_file.write(encoded)
def append(self, other_dataset_path_prefix: str) -> None:
"""
append an existing dataset.
Beware, if it wasn't built with the same coder, you are in trouble.
"""
other_index = HuffmanMMapIndex(
indexed_dataset.index_file_path(other_dataset_path_prefix)
)
for (ptr, size) in other_index:
self._ptrs.append(ptr + self._data_len)
self._sizes.append(size)
# Concatenate data
with open(indexed_dataset.data_file_path(other_dataset_path_prefix), "rb") as f:
shutil.copyfileobj(f, self._data_file)
self._data_len += other_index.data_len
def close(self):
self._data_file.close()
with HuffmanMMapIndex.writer(
indexed_dataset.index_file_path(self._path_prefix), self._data_len
) as index:
index.write(self._sizes, self._ptrs)
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()

View File

@@ -0,0 +1,19 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import FairseqDataset
class IdDataset(FairseqDataset):
def __getitem__(self, index):
return index
def __len__(self):
return 0
def collater(self, samples):
return torch.tensor(samples)

View File

@@ -0,0 +1,587 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import shutil
import struct
from functools import lru_cache
import numpy as np
import torch
from fairseq.dataclass.constants import DATASET_IMPL_CHOICES
from fairseq.data.fasta_dataset import FastaDataset
from fairseq.file_io import PathManager
from fairseq.data.huffman import HuffmanMMapIndexedDataset, HuffmanMMapIndex
from . import FairseqDataset
from typing import Union
def best_fitting_int_dtype(
max_int_to_represent,
) -> Union[np.uint16, np.uint32, np.int64]:
if max_int_to_represent is None:
return np.uint32 # Safe guess
elif max_int_to_represent < 65500:
return np.uint16
elif max_int_to_represent < 4294967295:
return np.uint32
else:
return np.int64
# we avoid np.uint64 because it doesn't save space and its type promotion behaves unexpectedly
# https://github.com/numpy/numpy/issues/5745
def get_available_dataset_impl():
return list(map(str, DATASET_IMPL_CHOICES))
def infer_dataset_impl(path):
if IndexedRawTextDataset.exists(path):
return "raw"
elif IndexedDataset.exists(path):
with open(index_file_path(path), "rb") as f:
magic = f.read(8)
if magic == IndexedDataset._HDR_MAGIC:
return "cached"
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
return "mmap"
elif magic == HuffmanMMapIndex._HDR_MAGIC[:8]:
return "huffman"
else:
return None
elif FastaDataset.exists(path):
return "fasta"
else:
return None
def make_builder(out_file, impl, vocab_size=None):
if impl == "mmap":
return MMapIndexedDatasetBuilder(
out_file, dtype=best_fitting_int_dtype(vocab_size)
)
elif impl == "fasta":
raise NotImplementedError
elif impl == "huffman":
raise ValueError(
"Use HuffmanCodeBuilder directly as it has a different interface."
)
else:
return IndexedDatasetBuilder(out_file)
def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None):
if impl == "raw" and IndexedRawTextDataset.exists(path):
assert dictionary is not None
return IndexedRawTextDataset(path, dictionary)
elif impl == "lazy" and IndexedDataset.exists(path):
return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing)
elif impl == "cached" and IndexedDataset.exists(path):
return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing)
elif impl == "mmap" and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path)
elif impl == "fasta" and FastaDataset.exists(path):
from fairseq.data.fasta_dataset import EncodedFastaDataset
return EncodedFastaDataset(path, dictionary)
elif impl == "huffman" and HuffmanMMapIndexedDataset.exists(path):
return HuffmanMMapIndexedDataset(path)
return None
def dataset_exists(path, impl):
if impl == "raw":
return IndexedRawTextDataset.exists(path)
elif impl == "mmap":
return MMapIndexedDataset.exists(path)
elif impl == "huffman":
return HuffmanMMapIndexedDataset.exists(path)
else:
return IndexedDataset.exists(path)
def read_longs(f, n):
a = np.empty(n, dtype=np.int64)
f.readinto(a)
return a
def write_longs(f, a):
f.write(np.array(a, dtype=np.int64))
_code_to_dtype = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float64,
7: np.double,
8: np.uint16,
9: np.uint32,
10: np.uint64,
}
def _dtype_header_code(dtype) -> int:
for k in _code_to_dtype.keys():
if _code_to_dtype[k] == dtype:
return k
raise ValueError(dtype)
def index_file_path(prefix_path):
return prefix_path + ".idx"
def data_file_path(prefix_path):
return prefix_path + ".bin"
class IndexedDataset(FairseqDataset):
"""Loader for TorchNet IndexedDataset"""
_HDR_MAGIC = b"TNTIDX\x00\x00"
def __init__(self, path, fix_lua_indexing=False):
super().__init__()
self.path = path
self.fix_lua_indexing = fix_lua_indexing
self.data_file = None
self.read_index(path)
def read_index(self, path):
with open(index_file_path(path), "rb") as f:
magic = f.read(8)
assert magic == self._HDR_MAGIC, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
version = f.read(8)
assert struct.unpack("<Q", version) == (1,)
code, self.element_size = struct.unpack("<QQ", f.read(16))
self.dtype = _code_to_dtype[code]
self._len, self.s = struct.unpack("<QQ", f.read(16))
self.dim_offsets = read_longs(f, self._len + 1)
self.data_offsets = read_longs(f, self._len + 1)
self.sizes = read_longs(f, self.s)
def read_data(self, path):
self.data_file = open(data_file_path(path), "rb", buffering=0)
def check_index(self, i):
if i < 0 or i >= self._len:
raise IndexError("index out of range")
def __del__(self):
if self.data_file:
self.data_file.close()
@lru_cache(maxsize=8)
def __getitem__(self, i) -> torch.Tensor:
if not self.data_file:
self.read_data(self.path)
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
item = torch.from_numpy(a).long()
if self.fix_lua_indexing:
item -= 1 # subtract 1 for 0-based indexing
return item
def __len__(self):
return self._len
def num_tokens(self, index):
return self.sizes[index]
def size(self, index):
return self.sizes[index]
@staticmethod
def exists(path):
return PathManager.exists(index_file_path(path)) and PathManager.exists(
data_file_path(path)
)
@property
def supports_prefetch(self):
return False # avoid prefetching to save memory
class IndexedCachedDataset(IndexedDataset):
def __init__(self, path, fix_lua_indexing=False):
super().__init__(path, fix_lua_indexing=fix_lua_indexing)
self.cache = None
self.cache_index = {}
@property
def supports_prefetch(self):
return True
def prefetch(self, indices):
if all(i in self.cache_index for i in indices):
return
if not self.data_file:
self.read_data(self.path)
indices = sorted(set(indices))
total_size = 0
for i in indices:
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
self.cache = np.empty(total_size, dtype=self.dtype)
ptx = 0
self.cache_index.clear()
for i in indices:
self.cache_index[i] = ptx
size = self.data_offsets[i + 1] - self.data_offsets[i]
a = self.cache[ptx : ptx + size]
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
ptx += size
if self.data_file:
# close and delete data file after prefetch so we can pickle
self.data_file.close()
self.data_file = None
@lru_cache(maxsize=8)
def __getitem__(self, i):
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
ptx = self.cache_index[i]
np.copyto(a, self.cache[ptx : ptx + a.size])
item = torch.from_numpy(a).long()
if self.fix_lua_indexing:
item -= 1 # subtract 1 for 0-based indexing
return item
class IndexedRawTextDataset(FairseqDataset):
"""Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory"""
def __init__(self, path, dictionary, append_eos=True, reverse_order=False):
self.tokens_list = []
self.lines = []
self.sizes = []
self.append_eos = append_eos
self.reverse_order = reverse_order
self.read_data(path, dictionary)
self.size = len(self.tokens_list)
def read_data(self, path, dictionary):
with open(path, "r", encoding="utf-8") as f:
for line in f:
self.lines.append(line.strip("\n"))
tokens = dictionary.encode_line(
line,
add_if_not_exist=False,
append_eos=self.append_eos,
reverse_order=self.reverse_order,
).long()
self.tokens_list.append(tokens)
self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes)
def check_index(self, i):
if i < 0 or i >= self.size:
raise IndexError("index out of range")
@lru_cache(maxsize=8)
def __getitem__(self, i):
self.check_index(i)
return self.tokens_list[i]
def get_original_text(self, i):
self.check_index(i)
return self.lines[i]
def __del__(self):
pass
def __len__(self):
return self.size
def num_tokens(self, index):
return self.sizes[index]
def size(self, index):
return self.sizes[index]
@staticmethod
def exists(path):
return PathManager.exists(path)
class IndexedDatasetBuilder:
element_sizes = {
np.uint8: 1,
np.int8: 1,
np.int16: 2,
np.int32: 4,
np.int64: 8,
np.float64: 4,
np.double: 8,
}
def __init__(self, out_file, dtype=np.int32):
self.out_file = open(out_file, "wb")
self.dtype = dtype
self.data_offsets = [0]
self.dim_offsets = [0]
self.sizes = []
self.element_size = self.element_sizes[self.dtype]
def add_item(self, tensor):
# +1 for Lua compatibility
bytes = self.out_file.write(np.array(tensor.numpy() + 1, dtype=self.dtype))
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
for s in tensor.size():
self.sizes.append(s)
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def merge_file_(self, another_file):
index = IndexedDataset(another_file)
assert index.dtype == self.dtype
begin = self.data_offsets[-1]
for offset in index.data_offsets[1:]:
self.data_offsets.append(begin + offset)
self.sizes.extend(index.sizes)
begin = self.dim_offsets[-1]
for dim_offset in index.dim_offsets[1:]:
self.dim_offsets.append(begin + dim_offset)
with open(data_file_path(another_file), "rb") as f:
while True:
data = f.read(1024)
if data:
self.out_file.write(data)
else:
break
def finalize(self, index_file):
self.out_file.close()
index = open(index_file, "wb")
index.write(b"TNTIDX\x00\x00")
index.write(struct.pack("<Q", 1))
index.write(
struct.pack("<QQ", _dtype_header_code(self.dtype), self.element_size)
)
index.write(struct.pack("<QQ", len(self.data_offsets) - 1, len(self.sizes)))
write_longs(index, self.dim_offsets)
write_longs(index, self.data_offsets)
write_longs(index, self.sizes)
index.close()
def _warmup_mmap_file(path):
with open(path, "rb") as stream:
while stream.read(100 * 1024 * 1024):
pass
class MMapIndexedDataset(torch.utils.data.Dataset):
class Index:
_HDR_MAGIC = b"MMIDIDX\x00\x00"
@classmethod
def writer(cls, path, dtype):
class _Writer:
def __enter__(self):
self._file = open(path, "wb")
self._file.write(cls._HDR_MAGIC)
self._file.write(struct.pack("<Q", 1))
self._file.write(struct.pack("<B", _dtype_header_code(dtype)))
return self
@staticmethod
def _get_pointers(sizes):
dtype_size = dtype().itemsize
address = 0
pointers = []
for size in sizes:
pointers.append(address)
address += size * dtype_size
return pointers
def write(self, sizes):
pointers = self._get_pointers(sizes)
self._file.write(struct.pack("<Q", len(sizes)))
sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order="C"))
del sizes
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order="C"))
del pointers
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
return _Writer()
def __init__(self, path):
with open(path, "rb") as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
version = struct.unpack("<Q", stream.read(8))
assert (1,) == version
(dtype_code,) = struct.unpack("<B", stream.read(1))
self._dtype = _code_to_dtype[dtype_code]
self._dtype_size = self._dtype().itemsize
self._len = struct.unpack("<Q", stream.read(8))[0]
offset = stream.tell()
_warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
self._sizes = np.frombuffer(
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
)
self._pointers = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._len,
offset=offset + self._sizes.nbytes,
)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
@property
def dtype(self):
return self._dtype
@property
def sizes(self):
return self._sizes
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def __init__(self, path):
super().__init__()
self._path = None
self._index = None
self._bin_buffer = None
self._do_init(path)
def __getstate__(self):
return self._path
def __setstate__(self, state):
self._do_init(state)
def _do_init(self, path):
self._path = path
self._index = self.Index(index_file_path(self._path))
_warmup_mmap_file(data_file_path(self._path))
self._bin_buffer_mmap = np.memmap(
data_file_path(self._path), mode="r", order="C"
)
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
del self._index
def __len__(self):
return len(self._index)
@lru_cache(maxsize=8)
def __getitem__(self, i):
ptr, size = self._index[i]
np_array = np.frombuffer(
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
)
if self._index.dtype != np.int64:
np_array = np_array.astype(np.int64)
return torch.from_numpy(np_array)
@property
def sizes(self):
return self._index.sizes
@property
def supports_prefetch(self):
return False
@staticmethod
def exists(path):
return PathManager.exists(index_file_path(path)) and PathManager.exists(
data_file_path(path)
)
def get_indexed_dataset_to_local(path) -> str:
local_index_path = PathManager.get_local_path(index_file_path(path))
local_data_path = PathManager.get_local_path(data_file_path(path))
assert local_index_path.endswith(".idx") and local_data_path.endswith(".bin"), (
"PathManager.get_local_path does not return files with expected patterns: "
f"{local_index_path} and {local_data_path}"
)
local_path = local_data_path[:-4] # stripping surfix ".bin"
assert local_path == local_index_path[:-4] # stripping surfix ".idx"
return local_path
class MMapIndexedDatasetBuilder:
def __init__(self, out_file, dtype=np.int64):
self._data_file = open(out_file, "wb")
self._dtype = dtype
self._sizes = []
def add_item(self, tensor):
np_array = np.array(tensor.numpy(), dtype=self._dtype)
self._data_file.write(np_array.tobytes(order="C"))
self._sizes.append(np_array.size)
def merge_file_(self, another_file):
# Concatenate index
index = MMapIndexedDataset.Index(index_file_path(another_file))
assert index.dtype == self._dtype
for size in index.sizes:
self._sizes.append(size)
# Concatenate data
with open(data_file_path(another_file), "rb") as f:
shutil.copyfileobj(f, self._data_file)
def finalize(self, index_file):
self._data_file.close()
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
index.write(self._sizes)

View File

@@ -0,0 +1,883 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import logging
import math
import operator
import os
import queue
import time
from threading import Thread
from typing import Iterator, List
import numpy as np
import torch
from fairseq.data import data_utils
logger = logging.getLogger(__name__)
# Object used by _background_consumer to signal the source is exhausted
# to the main thread.
_sentinel = object()
class CountingIterator(object):
"""Wrapper around an iterable that maintains the iteration count.
Args:
iterable (iterable): iterable to wrap
start (int): starting iteration count. Note that this doesn't
actually advance the iterator.
total (int): override the iterator length returned by ``__len``.
This can be used to truncate *iterator*.
Attributes:
n (int): number of elements consumed from this iterator
"""
def __init__(self, iterable, start=None, total=None):
self._itr = iter(iterable)
self.n = start or getattr(iterable, "n", 0)
self.total = total if total is not None else self.n + len(iterable)
def __len__(self):
return self.total
def __iter__(self):
return self
def __next__(self):
if not self.has_next():
raise StopIteration
try:
x = next(self._itr)
except StopIteration:
raise IndexError(
f"Iterator expected to have length {self.total}, "
f"but exhausted at position {self.n}."
)
self.n += 1
return x
def has_next(self):
"""Whether the iterator has been exhausted."""
return self.n < self.total
def skip(self, n):
"""Fast-forward the iterator by skipping n elements."""
for _ in range(n):
next(self)
return self
def take(self, n):
"""Truncate the iterator to n elements at most."""
self.total = min(self.total, n)
# Propagate this change to the underlying iterator
if hasattr(self._itr, "take"):
self._itr.take(max(n - self.n, 0))
return self
class EpochBatchIterating(object):
def __len__(self) -> int:
raise NotImplementedError
@property
def next_epoch_idx(self):
raise NotImplementedError
def next_epoch_itr(
self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True
):
"""Return a new iterator over the dataset.
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator (default: True).
fix_batches_to_gpus (bool, optional): ensure that batches are always
allocated to the same shards across epochs. Requires
that :attr:`dataset` supports prefetching (default: False).
set_dataset_epoch (bool, optional): update the wrapped Dataset with
the new epoch number (default: True).
"""
raise NotImplementedError
def end_of_epoch(self) -> bool:
"""Returns whether the most recent epoch iterator has been exhausted"""
raise NotImplementedError
@property
def iterations_in_epoch(self) -> int:
"""The number of consumed batches in the current epoch."""
raise NotImplementedError
def state_dict(self):
"""Returns a dictionary containing a whole state of the iterator."""
raise NotImplementedError
def load_state_dict(self, state_dict):
"""Copies the state of the iterator from the given *state_dict*."""
raise NotImplementedError
@property
def first_batch(self):
return "DUMMY"
class StreamingEpochBatchIterator(EpochBatchIterating):
"""A steaming-style iterator over a :class:`torch.utils.data.IterableDataset`.
Args:
dataset (~torch.utils.data.Dataset): dataset from which to load the data
max_sentences: batch size
collate_fn (callable): merges a list of samples to form a mini-batch
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 1).
buffer_size (int, optional): the number of batches to keep ready in the
queue. Helps speeding up dataloading. When buffer_size is zero, the
default torch.utils.data.DataLoader preloading is used.
timeout (int, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative (default: ``0``).
"""
def __init__(
self,
dataset,
max_sentences=1,
collate_fn=None,
epoch=1,
num_workers=0,
buffer_size=0,
timeout=0,
persistent_workers=False,
):
assert isinstance(dataset, torch.utils.data.IterableDataset)
self.dataset = dataset
self.max_sentences = max_sentences
self.collate_fn = collate_fn
self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
self.num_workers = num_workers
# This upper limit here is to prevent people from abusing this feature
# in a shared computing environment.
self.buffer_size = min(buffer_size, 20)
self.timeout = timeout
self.persistent_workers = persistent_workers
self._current_epoch_iterator = None
@property
def next_epoch_idx(self):
"""Return the epoch index after *next_epoch_itr* is called."""
if self._current_epoch_iterator is not None and self.end_of_epoch():
return self.epoch + 1
else:
return self.epoch
def next_epoch_itr(
self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True
):
self.epoch = self.next_epoch_idx
if set_dataset_epoch and hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(self.epoch)
self._current_epoch_iterator = self._get_iterator_for_epoch(self.epoch, shuffle)
return self._current_epoch_iterator
def end_of_epoch(self) -> bool:
return not self._current_epoch_iterator.has_next()
@property
def iterations_in_epoch(self) -> int:
if self._current_epoch_iterator is not None:
return self._current_epoch_iterator.n
return 0
def state_dict(self):
return {
"epoch": self.epoch,
}
def load_state_dict(self, state_dict):
self.epoch = state_dict["epoch"]
def _get_iterator_for_epoch(self, epoch, shuffle, offset=0):
if self.num_workers > 0:
os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"
# Create data loader
worker_init_fn = getattr(self.dataset, "worker_init_fn", None)
itr = torch.utils.data.DataLoader(
self.dataset,
batch_size=self.max_sentences,
collate_fn=self.collate_fn,
num_workers=self.num_workers,
timeout=self.timeout,
worker_init_fn=worker_init_fn,
pin_memory=True,
persistent_workers=self.persistent_workers,
)
# Wrap with a BufferedIterator if needed
if self.buffer_size > 0:
itr = BufferedIterator(self.buffer_size, itr)
# Wrap with CountingIterator
itr = CountingIterator(itr, start=offset)
return itr
class FrozenBatchSampler:
def __init__(
self,
ordered_batches,
epoch,
fix_batches_to_gpus,
shuffle,
initial_offset,
):
self.ordered_batches = ordered_batches
self.fix_batches_to_gpus = fix_batches_to_gpus
self.shuffle = shuffle
self.make_batches_for_epoch(epoch, initial_offset)
def make_batches_for_epoch(self, epoch, offset=0):
self.batches = self.ordered_batches(
epoch, self.fix_batches_to_gpus, self.shuffle
)
if offset > 0:
self.batches = self.batches[offset:]
def __iter__(self) -> Iterator[List[int]]:
return iter(self.batches)
def __len__(self) -> int:
return len(self.batches)
class EpochBatchIterator(EpochBatchIterating):
"""A multi-epoch iterator over a :class:`torch.utils.data.Dataset`.
Compared to :class:`torch.utils.data.DataLoader`, this iterator:
- can be reused across multiple epochs with the :func:`next_epoch_itr`
method (optionally shuffled between epochs)
- can be serialized/deserialized with the :func:`state_dict` and
:func:`load_state_dict` methods
- supports sharding with the *num_shards* and *shard_id* arguments
Args:
dataset (~torch.utils.data.Dataset): dataset from which to load the data
collate_fn (callable): merges a list of samples to form a mini-batch
batch_sampler (~torch.utils.data.Sampler or a callable): an iterator over batches of
indices, or a callable to create such an iterator (~torch.utils.data.Sampler).
A callable batch_sampler will be called for each epoch to enable per epoch dynamic
batch iterators defined by this callable batch_sampler.
seed (int, optional): seed for random number generator for
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
shards (default: 1).
shard_id (int, optional): which shard of the data iterator to
return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 1).
buffer_size (int, optional): the number of batches to keep ready in the
queue. Helps speeding up dataloading. When buffer_size is zero, the
default torch.utils.data.DataLoader preloading is used.
timeout (int, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative (default: ``0``).
disable_shuffling (bool, optional): force disable shuffling
(default: ``False``).
skip_remainder_batch (bool, optional): if set, discard the last batch in an epoch
for the sake of training stability, as the last batch is usually smaller than
local_batch_size * distributed_word_size (default: ``False``).
grouped_shuffling (bool, optional): enable shuffling batches in groups
of num_shards. Ensures that each GPU receives similar length sequences when
batches are sorted by length.
"""
def __init__(
self,
dataset,
collate_fn,
batch_sampler,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
buffer_size=0,
timeout=0,
disable_shuffling=False,
skip_remainder_batch=False,
grouped_shuffling=False,
reuse_dataloader=False,
persistent_workers=False,
):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
self.collate_fn = collate_fn
self.batch_sampler = batch_sampler
self._frozen_batches = (
tuple(batch_sampler) if not callable(batch_sampler) else None
)
self.seed = seed
self.num_shards = num_shards
self.shard_id = shard_id
self.num_workers = num_workers
# This upper limit here is to prevent people from abusing this feature
# in a shared computing environment.
self.buffer_size = min(buffer_size, 20)
self.timeout = timeout
self.disable_shuffling = disable_shuffling
self.skip_remainder_batch = skip_remainder_batch
self.grouped_shuffling = grouped_shuffling
self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
self.shuffle = not disable_shuffling
self._cur_epoch_itr = None
self._next_epoch_itr = None
self._supports_prefetch = getattr(dataset, "supports_prefetch", False)
self.dataloader = None
self.reuse_dataloader = reuse_dataloader
self.persistent_workers = persistent_workers
@property
def frozen_batches(self):
if self._frozen_batches is None:
self._frozen_batches = tuple(self.batch_sampler(self.dataset, self.epoch))
return self._frozen_batches
@property
def first_batch(self):
if len(self.frozen_batches) == 0:
raise Exception(
"The dataset is empty. This could indicate "
"that all elements in the dataset have been skipped. "
"Try increasing the max number of allowed tokens or using "
"a larger dataset."
)
if getattr(self.dataset, "supports_fetch_outside_dataloader", True):
return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0]])
else:
return "DUMMY"
def __len__(self):
return int(math.ceil(len(self.frozen_batches) / float(self.num_shards)))
@property
def n(self):
return self.iterations_in_epoch
@property
def next_epoch_idx(self):
"""Return the epoch index after *next_epoch_itr* is called."""
if self._next_epoch_itr is not None:
return self.epoch
elif self._cur_epoch_itr is not None and self.end_of_epoch():
return self.epoch + 1
else:
return self.epoch
def next_epoch_itr(
self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True
):
"""Return a new iterator over the dataset.
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator (default: True).
fix_batches_to_gpus (bool, optional): ensure that batches are always
allocated to the same shards across epochs. Requires
that :attr:`dataset` supports prefetching (default: False).
set_dataset_epoch (bool, optional): update the wrapped Dataset with
the new epoch number (default: True).
"""
if self.disable_shuffling:
shuffle = False
prev_epoch = self.epoch
self.epoch = self.next_epoch_idx
if set_dataset_epoch and hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(self.epoch)
if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr
self._next_epoch_itr = None
else:
if callable(self.batch_sampler) and prev_epoch != self.epoch:
# reset _frozen_batches to refresh the next epoch
self._frozen_batches = None
self._cur_epoch_itr = self._get_iterator_for_epoch(
self.epoch,
shuffle,
fix_batches_to_gpus=fix_batches_to_gpus,
)
self.shuffle = shuffle
return self._cur_epoch_itr
def end_of_epoch(self) -> bool:
"""Returns whether the most recent epoch iterator has been exhausted"""
return not self._cur_epoch_itr.has_next()
@property
def iterations_in_epoch(self):
"""The number of consumed batches in the current epoch."""
if self._cur_epoch_itr is not None:
return self._cur_epoch_itr.n
elif self._next_epoch_itr is not None:
return self._next_epoch_itr.n
return 0
def state_dict(self):
"""Returns a dictionary containing a whole state of the iterator."""
if self.end_of_epoch():
epoch = self.epoch + 1
iter_in_epoch = 0
else:
epoch = self.epoch
iter_in_epoch = self.iterations_in_epoch
return {
"version": 2,
"epoch": epoch,
"iterations_in_epoch": iter_in_epoch,
"shuffle": self.shuffle,
}
def load_state_dict(self, state_dict):
"""Copies the state of the iterator from the given *state_dict*."""
self.epoch = state_dict["epoch"]
itr_pos = state_dict.get("iterations_in_epoch", 0)
version = state_dict.get("version", 1)
if itr_pos > 0:
# fast-forward epoch iterator
self._next_epoch_itr = self._get_iterator_for_epoch(
self.epoch,
shuffle=state_dict.get("shuffle", True),
offset=itr_pos,
)
if self._next_epoch_itr is None:
if version == 1:
# legacy behavior: we finished the epoch, increment epoch counter
self.epoch += 1
else:
raise RuntimeError(
"Cannot resume training due to dataloader mismatch, please "
"report this to the fairseq developers. You can relaunch "
"training with `--reset-dataloader` and it should work."
)
else:
self._next_epoch_itr = None
def _get_iterator_for_epoch(
self, epoch, shuffle, fix_batches_to_gpus=False, offset=0
):
if self.reuse_dataloader and self.dataloader is not None:
self.epoch_batch_sampler.make_batches_for_epoch(epoch, offset)
itr = self.dataloader
else:
self.epoch_batch_sampler = FrozenBatchSampler(
self.ordered_batches,
epoch,
fix_batches_to_gpus,
shuffle,
initial_offset=offset,
)
if offset > 0 and len(self.epoch_batch_sampler) == 0:
return None
if self.num_workers > 0:
os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"
# Create data loader
itr = torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=self.epoch_batch_sampler,
num_workers=self.num_workers,
timeout=self.timeout,
pin_memory=True,
persistent_workers=self.persistent_workers,
)
if self.reuse_dataloader:
self.dataloader = itr
# Wrap with a BufferedIterator if needed
if self.buffer_size > 0:
itr = BufferedIterator(self.buffer_size, itr)
# Wrap with CountingIterator
itr = CountingIterator(itr, start=offset)
if self.skip_remainder_batch:
# TODO: Below is a lazy implementation which discard the final batch regardless
# of whether it is a full batch or not.
total_num_itrs = len(self.epoch_batch_sampler) - 1
itr.take(total_num_itrs)
logger.info(f"skip final residual batch, total_num_itrs = {total_num_itrs}")
return itr
def ordered_batches(self, epoch, fix_batches_to_gpus, shuffle):
def shuffle_batches(batches, seed):
with data_utils.numpy_seed(seed):
if self.grouped_shuffling:
grouped_batches = [
batches[(i * self.num_shards) : ((i + 1) * self.num_shards)]
for i in range((len(batches) // self.num_shards))
]
np.random.shuffle(grouped_batches)
batches = list(itertools.chain(*grouped_batches))
else:
np.random.shuffle(batches)
return batches
if self._supports_prefetch:
batches = self.frozen_batches
if shuffle and not fix_batches_to_gpus:
batches = shuffle_batches(list(batches), self.seed + epoch)
batches = list(
ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
)
self.dataset.prefetch([i for s in batches for i in s])
if shuffle and fix_batches_to_gpus:
batches = shuffle_batches(batches, self.seed + epoch + self.shard_id)
else:
if shuffle:
batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch)
else:
batches = self.frozen_batches
batches = list(
ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
)
return batches
class GroupedIterator(CountingIterator):
"""Wrapper around an iterable that returns groups (chunks) of items.
Args:
iterable (iterable): iterable to wrap
chunk_size (int): size of each chunk
skip_remainder_batch (bool, optional): if set, discard the last grouped batch in
each training epoch, as the last grouped batch is usually smaller than
local_batch_size * distributed_word_size * chunk_size (default: ``False``).
Attributes:
n (int): number of elements consumed from this iterator
"""
def __init__(self, iterable, chunk_size, skip_remainder_batch=False):
if skip_remainder_batch:
total_num_itrs = int(math.floor(len(iterable) / float(chunk_size)))
logger.info(
f"skip final residual batch, grouped total_num_itrs = {total_num_itrs}"
)
else:
total_num_itrs = int(math.ceil(len(iterable) / float(chunk_size)))
logger.info(f"grouped total_num_itrs = {total_num_itrs}")
itr = _chunk_iterator(iterable, chunk_size, skip_remainder_batch)
super().__init__(
itr,
start=int(math.ceil(getattr(iterable, "n", 0) / float(chunk_size))),
total=total_num_itrs,
)
self.chunk_size = chunk_size
if skip_remainder_batch:
self.take(total_num_itrs)
# TODO: [Hack] Here the grouped iterator modifies the base iterator size so that
# training can move into the next epoch once the grouped iterator is exhausted.
# Double-check this implementation in case unexpected behavior occurs.
iterable.take(total_num_itrs * chunk_size)
def _chunk_iterator(itr, chunk_size, skip_remainder_batch=False):
chunk = []
for x in itr:
chunk.append(x)
if len(chunk) == chunk_size:
yield chunk
chunk = []
if not skip_remainder_batch and len(chunk) > 0:
yield chunk
class ShardedIterator(CountingIterator):
"""A sharded wrapper around an iterable, padded to length.
Args:
iterable (iterable): iterable to wrap
num_shards (int): number of shards to split the iterable into
shard_id (int): which shard to iterator over
fill_value (Any, optional): padding value when the iterable doesn't
evenly divide *num_shards* (default: None).
Attributes:
n (int): number of elements consumed from this iterator
"""
def __init__(
self, iterable, num_shards, shard_id, fill_value=None, skip_remainder_batch=None
):
"""
Args:
skip_remainder_batch: ignored"""
if shard_id < 0 or shard_id >= num_shards:
raise ValueError("shard_id must be between 0 and num_shards")
sharded_len = int(math.ceil(len(iterable) / float(num_shards)))
itr = map(
operator.itemgetter(1),
itertools.zip_longest(
range(sharded_len),
itertools.islice(iterable, shard_id, len(iterable), num_shards),
fillvalue=fill_value,
),
)
super().__init__(
itr,
start=int(math.ceil(getattr(iterable, "n", 0) / float(num_shards))),
total=sharded_len,
)
class BackgroundConsumer(Thread):
def __init__(self, queue, source, max_len, cuda_device):
Thread.__init__(self)
self._queue = queue
self._source = source
self._max_len = max_len
self.count = 0
self.cuda_device = cuda_device
def run(self):
# set_device to avoid creation of GPU0 context when using pin_memory
if self.cuda_device is not None:
torch.cuda.set_device(self.cuda_device)
try:
for item in self._source:
self._queue.put(item)
# Stop if we reached the maximum length
self.count += 1
if self._max_len is not None and self.count >= self._max_len:
break
# Signal the consumer we are done.
self._queue.put(_sentinel)
except Exception as e:
self._queue.put(e)
class BufferedIterator(object):
def __init__(self, size, iterable):
self._queue = queue.Queue(size)
self._iterable = iterable
self._consumer = None
self.start_time = time.time()
self.warning_time = None
self.total = len(iterable)
def _create_consumer(self):
self._consumer = BackgroundConsumer(
self._queue,
self._iterable,
self.total,
torch.cuda.current_device() if torch.cuda.is_available() else None,
)
self._consumer.daemon = True
self._consumer.start()
def __iter__(self):
return self
def __len__(self):
return self.total
def take(self, n):
self.total = min(self.total, n)
# Propagate this change to the underlying iterator
if hasattr(self._iterable, "take"):
self._iterable.take(n)
return self
def __next__(self):
# Create consumer if not created yet
if self._consumer is None:
self._create_consumer()
# Notify the user if there is a data loading bottleneck
if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)):
if time.time() - self.start_time > 5 * 60:
if (
self.warning_time is None
or time.time() - self.warning_time > 15 * 60
):
logger.debug(
"Data loading buffer is empty or nearly empty. This may "
"indicate a data loading bottleneck, and increasing the "
"number of workers (--num-workers) may help."
)
self.warning_time = time.time()
# Get next example
item = self._queue.get(True)
if isinstance(item, Exception):
raise item
if item is _sentinel:
raise StopIteration()
return item
class GroupedEpochBatchIterator(EpochBatchIterator):
"""Grouped version of EpochBatchIterator
It takes several samplers from different datasets.
Each epoch shuffle the dataset wise sampler individually with different
random seed. The those sub samplers are combined with into
one big samplers with deterministic permutation to mix batches from
different datasets. It will act like EpochBatchIterator but make sure
1) data from one data set each time
2) for different workers, they use the same order to fetch the data
so they will use data from the same dataset everytime
mult_rate is used for update_freq > 1 case where we want to make sure update_freq
mini-batches come from same source
"""
def __init__(
self,
dataset,
collate_fn,
batch_samplers,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=0,
mult_rate=1,
buffer_size=0,
skip_remainder_batch=False,
reuse_dataloader=False,
persistent_workers=False,
):
super().__init__(
dataset,
collate_fn,
batch_samplers,
seed,
num_shards,
shard_id,
num_workers,
epoch,
buffer_size,
skip_remainder_batch=skip_remainder_batch,
reuse_dataloader=reuse_dataloader,
persistent_workers=persistent_workers,
)
# level 0: sub-samplers 1: batch_idx 2: batches
self._frozen_batches = tuple([tuple(sub_batch) for sub_batch in batch_samplers])
self.step_size = mult_rate * num_shards
self.lengths = [
(len(x) // self.step_size) * self.step_size for x in self.frozen_batches
]
def __len__(self):
return sum(self.lengths)
@property
def first_batch(self):
if len(self.frozen_batches) == 0:
raise Exception(
"The dataset is empty. This could indicate "
"that all elements in the dataset have been skipped. "
"Try increasing the max number of allowed tokens or using "
"a larger dataset."
)
if self.dataset.supports_fetch_outside_dataloader:
return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0][0]])
else:
return "DUMMY"
def _get_iterator_for_epoch(
self, epoch, shuffle, fix_batches_to_gpus=False, offset=0
):
def shuffle_batches(batches, seed):
with data_utils.numpy_seed(seed):
np.random.shuffle(batches)
return batches
def return_full_batches(batch_sets, seed, shuffle):
if shuffle:
batch_sets = [shuffle_batches(list(x), seed) for x in batch_sets]
batch_sets = [
batch_sets[i][: self.lengths[i]] for i in range(len(batch_sets))
]
batches = list(itertools.chain.from_iterable(batch_sets))
if shuffle:
with data_utils.numpy_seed(seed):
idx = np.random.permutation(len(batches) // self.step_size)
if len(idx) * self.step_size != len(batches):
raise ValueError(
"ERROR: %d %d %d %d"
% (len(idx), self.step_size, len(batches), self.shard_id),
":".join(["%d" % x for x in self.lengths]),
)
mini_shards = [
batches[i * self.step_size : (i + 1) * self.step_size]
for i in idx
]
batches = list(itertools.chain.from_iterable(mini_shards))
return batches
if self._supports_prefetch:
raise NotImplementedError("To be implemented")
else:
batches = return_full_batches(
self.frozen_batches, self.seed + epoch, shuffle
)
batches = list(
ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
)
if offset > 0 and offset >= len(batches):
return None
if self.num_workers > 0:
os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"
itr = torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=batches[offset:],
num_workers=self.num_workers,
persistent_workers=self.persistent_workers,
)
if self.buffer_size > 0:
itr = BufferedIterator(self.buffer_size, itr)
return CountingIterator(itr, start=offset)

View File

@@ -0,0 +1,477 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
import torch
from fairseq.data import FairseqDataset, data_utils
logger = logging.getLogger(__name__)
def collate(
samples,
pad_idx,
eos_idx,
left_pad_source=True,
left_pad_target=False,
input_feeding=True,
pad_to_length=None,
pad_to_multiple=1,
):
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx,
left_pad,
move_eos_to_beginning,
pad_to_length=pad_to_length,
pad_to_multiple=pad_to_multiple,
)
def check_alignment(alignment, src_len, tgt_len):
if alignment is None or len(alignment) == 0:
return False
if (
alignment[:, 0].max().item() >= src_len - 1
or alignment[:, 1].max().item() >= tgt_len - 1
):
logger.warning("alignment size mismatch found, skipping alignment!")
return False
return True
def compute_alignment_weights(alignments):
"""
Given a tensor of shape [:, 2] containing the source-target indices
corresponding to the alignments, a weight vector containing the
inverse frequency of each target index is computed.
For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
index 3 is repeated twice)
"""
align_tgt = alignments[:, 1]
_, align_tgt_i, align_tgt_c = torch.unique(
align_tgt, return_inverse=True, return_counts=True
)
align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
return 1.0 / align_weights.float()
id = torch.LongTensor([s["id"] for s in samples])
src_tokens = merge(
"source",
left_pad=left_pad_source,
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
)
# sort by descending source length
src_lengths = torch.LongTensor(
[s["source"].ne(pad_idx).long().sum() for s in samples]
)
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge(
"target",
left_pad=left_pad_target,
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
target = target.index_select(0, sort_order)
tgt_lengths = torch.LongTensor(
[s["target"].ne(pad_idx).long().sum() for s in samples]
).index_select(0, sort_order)
ntokens = tgt_lengths.sum().item()
if samples[0].get("prev_output_tokens", None) is not None:
prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target)
elif input_feeding:
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
"target",
left_pad=left_pad_target,
move_eos_to_beginning=True,
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
else:
ntokens = src_lengths.sum().item()
batch = {
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
},
"target": target,
}
if prev_output_tokens is not None:
batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(
0, sort_order
)
if samples[0].get("alignment", None) is not None:
bsz, tgt_sz = batch["target"].shape
src_sz = batch["net_input"]["src_tokens"].shape[1]
offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz
if left_pad_source:
offsets[:, 0] += src_sz - src_lengths
if left_pad_target:
offsets[:, 1] += tgt_sz - tgt_lengths
alignments = [
alignment + offset
for align_idx, offset, src_len, tgt_len in zip(
sort_order, offsets, src_lengths, tgt_lengths
)
for alignment in [samples[align_idx]["alignment"].view(-1, 2)]
if check_alignment(alignment, src_len, tgt_len)
]
if len(alignments) > 0:
alignments = torch.cat(alignments, dim=0)
align_weights = compute_alignment_weights(alignments)
batch["alignments"] = alignments
batch["align_weights"] = align_weights
if samples[0].get("constraints", None) is not None:
# Collate the packed constraints across the samples, padding to
# the length of the longest sample.
lens = [sample.get("constraints").size(0) for sample in samples]
max_len = max(lens)
constraints = torch.zeros((len(samples), max(lens))).long()
for i, sample in enumerate(samples):
constraints[i, 0 : lens[i]] = samples[i].get("constraints")
batch["constraints"] = constraints.index_select(0, sort_order)
return batch
class LanguagePairDataset(FairseqDataset):
"""
A pair of torch.utils.data.Datasets.
Args:
src (torch.utils.data.Dataset): source dataset to wrap
src_sizes (List[int]): source sentence lengths
src_dict (~fairseq.data.Dictionary): source vocabulary
tgt (torch.utils.data.Dataset, optional): target dataset to wrap
tgt_sizes (List[int], optional): target sentence lengths
tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
left_pad_source (bool, optional): pad source tensors on the left side
(default: True).
left_pad_target (bool, optional): pad target tensors on the left side
(default: False).
shuffle (bool, optional): shuffle dataset elements before batching
(default: True).
input_feeding (bool, optional): create a shifted version of the targets
to be passed into the model for teacher forcing (default: True).
remove_eos_from_source (bool, optional): if set, removes eos from end
of source if it's present (default: False).
append_eos_to_target (bool, optional): if set, appends eos to end of
target if it's absent (default: False).
align_dataset (torch.utils.data.Dataset, optional): dataset
containing alignments.
constraints (Tensor, optional): 2d tensor with a concatenated, zero-
delimited list of constraints for each sentence.
append_bos (bool, optional): if set, appends bos to the beginning of
source/target sentence.
num_buckets (int, optional): if set to a value greater than 0, then
batches will be bucketed into the given number of batch shapes.
src_lang_id (int, optional): source language ID, if set, the collated batch
will contain a field 'src_lang_id' in 'net_input' which indicates the
source language of the samples.
tgt_lang_id (int, optional): target language ID, if set, the collated batch
will contain a field 'tgt_lang_id' which indicates the target language
of the samples.
"""
def __init__(
self,
src,
src_sizes,
src_dict,
tgt=None,
tgt_sizes=None,
tgt_dict=None,
left_pad_source=True,
left_pad_target=False,
shuffle=True,
input_feeding=True,
remove_eos_from_source=False,
append_eos_to_target=False,
align_dataset=None,
constraints=None,
append_bos=False,
eos=None,
num_buckets=0,
src_lang_id=None,
tgt_lang_id=None,
pad_to_multiple=1,
):
if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
if tgt is not None:
assert len(src) == len(
tgt
), "Source and target must contain the same number of examples"
self.src = src
self.tgt = tgt
self.src_sizes = np.array(src_sizes)
self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
self.sizes = (
np.vstack((self.src_sizes, self.tgt_sizes)).T
if self.tgt_sizes is not None
else self.src_sizes
)
self.src_dict = src_dict
self.tgt_dict = tgt_dict
self.left_pad_source = left_pad_source
self.left_pad_target = left_pad_target
self.shuffle = shuffle
self.input_feeding = input_feeding
self.remove_eos_from_source = remove_eos_from_source
self.append_eos_to_target = append_eos_to_target
self.align_dataset = align_dataset
if self.align_dataset is not None:
assert (
self.tgt_sizes is not None
), "Both source and target needed when alignments are provided"
self.constraints = constraints
self.append_bos = append_bos
self.eos = eos if eos is not None else src_dict.eos()
self.src_lang_id = src_lang_id
self.tgt_lang_id = tgt_lang_id
if num_buckets > 0:
from fairseq.data import BucketPadLengthDataset
self.src = BucketPadLengthDataset(
self.src,
sizes=self.src_sizes,
num_buckets=num_buckets,
pad_idx=self.src_dict.pad(),
left_pad=self.left_pad_source,
)
self.src_sizes = self.src.sizes
logger.info("bucketing source lengths: {}".format(list(self.src.buckets)))
if self.tgt is not None:
self.tgt = BucketPadLengthDataset(
self.tgt,
sizes=self.tgt_sizes,
num_buckets=num_buckets,
pad_idx=self.tgt_dict.pad(),
left_pad=self.left_pad_target,
)
self.tgt_sizes = self.tgt.sizes
logger.info(
"bucketing target lengths: {}".format(list(self.tgt.buckets))
)
# determine bucket sizes using self.num_tokens, which will return
# the padded lengths (thanks to BucketPadLengthDataset)
num_tokens = np.vectorize(self.num_tokens, otypes=[np.compat.long])
self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
self.buckets = [
(None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens)
]
else:
self.buckets = None
self.pad_to_multiple = pad_to_multiple
def get_batch_shapes(self):
return self.buckets
def __getitem__(self, index):
tgt_item = self.tgt[index] if self.tgt is not None else None
src_item = self.src[index]
# Append EOS to end of tgt sentence if it does not have an EOS and remove
# EOS from end of src sentence if it exists. This is useful when we use
# use existing datasets for opposite directions i.e., when we want to
# use tgt_dataset as src_dataset and vice versa
if self.append_eos_to_target:
eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
if self.tgt and self.tgt[index][-1] != eos:
tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
if self.append_bos:
bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
if self.tgt and self.tgt[index][0] != bos:
tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])
bos = self.src_dict.bos()
if self.src[index][0] != bos:
src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])
if self.remove_eos_from_source:
eos = self.src_dict.eos()
if self.src[index][-1] == eos:
src_item = self.src[index][:-1]
example = {
"id": index,
"source": src_item,
"target": tgt_item,
}
if self.align_dataset is not None:
example["alignment"] = self.align_dataset[index]
if self.constraints is not None:
example["constraints"] = self.constraints[index]
return example
def __len__(self):
return len(self.src)
def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
pad_to_length (dict, optional): a dictionary of
{'source': source_pad_to_length, 'target': target_pad_to_length}
to indicate the max length to pad to in source and target respectively.
Returns:
dict: a mini-batch with the following keys:
- `id` (LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape `(bsz, src_len)`. Padding will
appear on the left if *left_pad_source* is ``True``.
- `src_lengths` (LongTensor): 1D Tensor of the unpadded
lengths of each source sentence of shape `(bsz)`
- `prev_output_tokens` (LongTensor): a padded 2D Tensor of
tokens in the target sentence, shifted right by one
position for teacher forcing, of shape `(bsz, tgt_len)`.
This key will not be present if *input_feeding* is
``False``. Padding will appear on the left if
*left_pad_target* is ``True``.
- `src_lang_id` (LongTensor): a long Tensor which contains source
language IDs of each sample in the batch
- `target` (LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear
on the left if *left_pad_target* is ``True``.
- `tgt_lang_id` (LongTensor): a long Tensor which contains target language
IDs of each sample in the batch
"""
res = collate(
samples,
pad_idx=self.src_dict.pad(),
eos_idx=self.eos,
left_pad_source=self.left_pad_source,
left_pad_target=self.left_pad_target,
input_feeding=self.input_feeding,
pad_to_length=pad_to_length,
pad_to_multiple=self.pad_to_multiple,
)
if self.src_lang_id is not None or self.tgt_lang_id is not None:
src_tokens = res["net_input"]["src_tokens"]
bsz = src_tokens.size(0)
if self.src_lang_id is not None:
res["net_input"]["src_lang_id"] = (
torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens)
)
if self.tgt_lang_id is not None:
res["tgt_lang_id"] = (
torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens)
)
return res
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return max(
self.src_sizes[index],
self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
)
def num_tokens_vec(self, indices):
"""Return the number of tokens for a set of positions defined by indices.
This value is used to enforce ``--max-tokens`` during batching."""
sizes = self.src_sizes[indices]
if self.tgt_sizes is not None:
sizes = np.maximum(sizes, self.tgt_sizes[indices])
return sizes
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return (
self.src_sizes[index],
self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
indices = np.random.permutation(len(self)).astype(np.int64)
else:
indices = np.arange(len(self), dtype=np.int64)
if self.buckets is None:
# sort by target length, then source length
if self.tgt_sizes is not None:
indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]
else:
# sort by bucketed_num_tokens, which is:
# max(padded_src_len, padded_tgt_len)
return indices[
np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")
]
@property
def supports_prefetch(self):
return getattr(self.src, "supports_prefetch", False) and (
getattr(self.tgt, "supports_prefetch", False) or self.tgt is None
)
def prefetch(self, indices):
self.src.prefetch(indices)
if self.tgt is not None:
self.tgt.prefetch(indices)
if self.align_dataset is not None:
self.align_dataset.prefetch(indices)
def filter_indices_by_size(self, indices, max_sizes):
"""Filter a list of sample indices. Remove those that are longer
than specified in max_sizes.
Args:
indices (np.array): original array of sample indices
max_sizes (int or list[int] or tuple[int]): max sample size,
can be defined separately for src and tgt (then list or tuple)
Returns:
np.array: filtered sample array
list: list of removed indices
"""
return data_utils.filter_paired_dataset_indices_by_size(
self.src_sizes,
self.tgt_sizes,
indices,
max_sizes,
)

View File

@@ -0,0 +1,16 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .block_pair_dataset import BlockPairDataset
from .masked_lm_dataset import MaskedLMDataset
from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary
__all__ = [
"BertDictionary",
"BlockPairDataset",
"MaskedLMDataset",
"MaskedLMDictionary",
]

View File

@@ -0,0 +1,311 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import numpy as np
import torch
from fairseq.data import FairseqDataset
class BlockPairDataset(FairseqDataset):
"""Break a Dataset of tokens into sentence pair blocks for next sentence
prediction as well as masked language model.
High-level logics are:
1. break input tensor to tensor blocks
2. pair the blocks with 50% next sentence and 50% random sentence
3. return paired blocks as well as related segment labels
Args:
dataset (~torch.utils.data.Dataset): dataset to break into blocks
sizes: array of sentence lengths
dictionary: dictionary for the task
block_size: maximum block size
break_mode: mode for breaking copurs into block pairs. currently we support
2 modes
doc: respect document boundaries and each part of the pair should belong to on document
none: don't respect any boundary and cut tokens evenly
short_seq_prob: probability for generating shorter block pairs
doc_break_size: Size for empty line separating documents. Typically 1 if
the sentences have eos, 0 otherwise.
"""
def __init__(
self,
dataset,
dictionary,
sizes,
block_size,
break_mode="doc",
short_seq_prob=0.1,
doc_break_size=1,
):
super().__init__()
self.dataset = dataset
self.pad = dictionary.pad()
self.eos = dictionary.eos()
self.cls = dictionary.cls()
self.mask = dictionary.mask()
self.sep = dictionary.sep()
self.break_mode = break_mode
self.dictionary = dictionary
self.short_seq_prob = short_seq_prob
self.block_indices = []
assert len(dataset) == len(sizes)
if break_mode == "doc":
cur_doc = []
for sent_id, sz in enumerate(sizes):
assert doc_break_size == 0 or sz != 0, (
"when doc_break_size is non-zero, we expect documents to be"
"separated by a blank line with a single eos."
)
# empty line as document separator
if sz == doc_break_size:
if len(cur_doc) == 0:
continue
self.block_indices.append(cur_doc)
cur_doc = []
else:
cur_doc.append(sent_id)
max_num_tokens = block_size - 3 # Account for [CLS], [SEP], [SEP]
self.sent_pairs = []
self.sizes = []
for doc_id, doc in enumerate(self.block_indices):
self._generate_sentence_pair(doc, doc_id, max_num_tokens, sizes)
elif break_mode is None or break_mode == "none":
# each block should have half of the block size since we are constructing block pair
sent_length = (block_size - 3) // 2
total_len = sum(dataset.sizes)
length = math.ceil(total_len / sent_length)
def block_at(i):
start = i * sent_length
end = min(start + sent_length, total_len)
return (start, end)
sent_indices = np.array([block_at(i) for i in range(length)])
sent_sizes = np.array([e - s for s, e in sent_indices])
dataset_index = self._sent_to_dataset_index(sent_sizes)
# pair sentences
self._pair_sentences(dataset_index)
else:
raise ValueError("Invalid break_mode: " + break_mode)
def _pair_sentences(self, dataset_index):
"""
Give a list of evenly cut blocks/sentences, pair these sentences with 50%
consecutive sentences and 50% random sentences.
This is used for none break mode
"""
# pair sentences
for sent_id, sent in enumerate(dataset_index):
next_sent_label = (
1 if np.random.rand() > 0.5 and sent_id != len(dataset_index) - 1 else 0
)
if next_sent_label:
next_sent = dataset_index[sent_id + 1]
else:
next_sent = dataset_index[
self._skip_sampling(len(dataset_index), [sent_id, sent_id + 1])
]
self.sent_pairs.append((sent, next_sent, next_sent_label))
# The current blocks don't include the special tokens but the
# sizes already account for this
self.sizes.append(3 + sent[3] + next_sent[3])
def _sent_to_dataset_index(self, sent_sizes):
"""
Build index mapping block indices to the underlying dataset indices
"""
dataset_index = []
ds_idx, ds_remaining = -1, 0
for to_consume in sent_sizes:
sent_size = to_consume
if ds_remaining == 0:
ds_idx += 1
ds_remaining = sent_sizes[ds_idx]
start_ds_idx = ds_idx
start_offset = sent_sizes[ds_idx] - ds_remaining
while to_consume > ds_remaining:
to_consume -= ds_remaining
ds_idx += 1
ds_remaining = sent_sizes[ds_idx]
ds_remaining -= to_consume
dataset_index.append(
(
start_ds_idx, # starting index in dataset
start_offset, # starting offset within starting index
ds_idx, # ending index in dataset
sent_size, # sentence length
)
)
assert ds_remaining == 0
assert ds_idx == len(self.dataset) - 1
return dataset_index
def _generate_sentence_pair(self, doc, doc_id, max_num_tokens, sizes):
"""
Go through a single document and genrate sentence paris from it
"""
current_chunk = []
current_length = 0
curr = 0
# To provide more randomness, we decrease target seq length for parts of
# samples (10% by default). Note that max_num_tokens is the hard threshold
# for batching and will never be changed.
target_seq_length = max_num_tokens
if np.random.random() < self.short_seq_prob:
target_seq_length = np.random.randint(2, max_num_tokens)
# loop through all sentences in document
while curr < len(doc):
sent_id = doc[curr]
current_chunk.append(sent_id)
current_length = sum(sizes[current_chunk])
# split chunk and generate pair when exceed target_seq_length or
# finish the loop
if curr == len(doc) - 1 or current_length >= target_seq_length:
# split the chunk into 2 parts
a_end = 1
if len(current_chunk) > 2:
a_end = np.random.randint(1, len(current_chunk) - 1)
sent_a = current_chunk[:a_end]
len_a = sum(sizes[sent_a])
# generate next sentence label, note that if there is only 1 sentence
# in current chunk, label is always 0
next_sent_label = (
1 if np.random.rand() > 0.5 and len(current_chunk) != 1 else 0
)
if not next_sent_label:
# if next sentence label is 0, sample sent_b from a random doc
target_b_length = target_seq_length - len_a
rand_doc_id = self._skip_sampling(len(self.block_indices), [doc_id])
random_doc = self.block_indices[rand_doc_id]
random_start = np.random.randint(0, len(random_doc))
sent_b = []
len_b = 0
for j in range(random_start, len(random_doc)):
sent_b.append(random_doc[j])
len_b = sum(sizes[sent_b])
if len_b >= target_b_length:
break
# return the second part of the chunk since it's not used
num_unused_segments = len(current_chunk) - a_end
curr -= num_unused_segments
else:
# if next sentence label is 1, use the second part of chunk as sent_B
sent_b = current_chunk[a_end:]
len_b = sum(sizes[sent_b])
# currently sent_a and sent_B may be longer than max_num_tokens,
# truncate them and return block idx and offsets for them
sent_a, sent_b = self._truncate_sentences(
sent_a, sent_b, max_num_tokens
)
self.sent_pairs.append((sent_a, sent_b, next_sent_label))
self.sizes.append(3 + sent_a[3] + sent_b[3])
current_chunk = []
curr += 1
def _skip_sampling(self, total, skip_ids):
"""
Generate a random integer which is not in skip_ids. Sample range is [0, total)
TODO: ids in skip_ids should be consecutive, we can extend it to more generic version later
"""
rand_id = np.random.randint(total - len(skip_ids))
return rand_id if rand_id < min(skip_ids) else rand_id + len(skip_ids)
def _truncate_sentences(self, sent_a, sent_b, max_num_tokens):
"""
Trancate a pair of sentence to limit total length under max_num_tokens
Logics:
1. Truncate longer sentence
2. Tokens to be truncated could be at the beginning or the end of the sentnce
Returns:
Truncated sentences represented by dataset idx
"""
len_a, len_b = sum(self.dataset.sizes[sent_a]), sum(self.dataset.sizes[sent_b])
front_cut_a = front_cut_b = end_cut_a = end_cut_b = 0
while True:
total_length = (
len_a + len_b - front_cut_a - front_cut_b - end_cut_a - end_cut_b
)
if total_length <= max_num_tokens:
break
if len_a - front_cut_a - end_cut_a > len_b - front_cut_b - end_cut_b:
if np.random.rand() < 0.5:
front_cut_a += 1
else:
end_cut_a += 1
else:
if np.random.rand() < 0.5:
front_cut_b += 1
else:
end_cut_b += 1
# calculate ds indices as well as offsets and return
truncated_sent_a = self._cut_sentence(sent_a, front_cut_a, end_cut_a)
truncated_sent_b = self._cut_sentence(sent_b, front_cut_b, end_cut_b)
return truncated_sent_a, truncated_sent_b
def _cut_sentence(self, sent, front_cut, end_cut):
"""
Cut a sentence based on the numbers of tokens to be cut from beginning and end
Represent the sentence as dataset idx and return
"""
start_ds_idx, end_ds_idx, offset = sent[0], sent[-1], 0
target_len = sum(self.dataset.sizes[sent]) - front_cut - end_cut
while front_cut > 0:
if self.dataset.sizes[start_ds_idx] > front_cut:
offset += front_cut
break
else:
front_cut -= self.dataset.sizes[start_ds_idx]
start_ds_idx += 1
while end_cut > 0:
if self.dataset.sizes[end_ds_idx] > end_cut:
break
else:
end_cut -= self.dataset.sizes[end_ds_idx]
end_ds_idx -= 1
return start_ds_idx, offset, end_ds_idx, target_len
def _fetch_block(self, start_ds_idx, offset, end_ds_idx, length):
"""
Fetch a block of tokens based on its dataset idx
"""
buffer = torch.cat(
[self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)]
)
s, e = offset, offset + length
return buffer[s:e]
def __getitem__(self, index):
block1, block2, next_sent_label = self.sent_pairs[index]
block1 = self._fetch_block(*block1)
block2 = self._fetch_block(*block2)
return block1, block2, next_sent_label
def __len__(self):
return len(self.sizes)
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
prefetch_idx = set()
for index in indices:
for block1, block2, _ in [self.sent_pairs[index]]:
for ds_idx in range(block1[0], block1[2] + 1):
prefetch_idx.add(ds_idx)
for ds_idx in range(block2[0], block2[2] + 1):
prefetch_idx.add(ds_idx)
self.dataset.prefetch(prefetch_idx)

View File

@@ -0,0 +1,303 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Dict, List, Tuple
import numpy as np
import torch
from fairseq.data import Dictionary, FairseqDataset, data_utils
from fairseq.data.concat_dataset import ConcatDataset
from fairseq.data.legacy.block_pair_dataset import BlockPairDataset
from fairseq.data.token_block_dataset import TokenBlockDataset
class MaskedLMDataset(FairseqDataset):
"""
A wrapper Dataset for masked language modelling. The dataset
wraps around TokenBlockDataset or BlockedPairDataset and creates a batch
where the input blocks are masked according to the specified masking
probability. Additionally the batch can also contain sentence level targets
if this is specified.
Args:
dataset: Dataset which generates blocks of data. Only BlockPairDataset
and TokenBlockDataset are supported.
sizes: Sentence lengths
vocab: Dictionary with the vocabulary and special tokens.
pad_idx: Id of padding token in dictionary
mask_idx: Id of mask token in dictionary
classif_token_idx: Id of classification token in dictionary. This is the
token associated with the sentence embedding (Eg: CLS for BERT)
sep_token_idx: Id of separator token in dictionary
(Eg: SEP in BERT)
seed: Seed for random number generator for reproducibility.
shuffle: Shuffle the elements before batching.
has_pairs: Specifies whether the underlying dataset
generates a pair of blocks along with a sentence_target or not.
Setting it to True assumes that the underlying dataset generates a
label for the pair of sentences which is surfaced as
sentence_target. The default value assumes a single block with no
sentence target.
segment_id: An optional segment id for filling in the segment labels
when we are in the single block setting (Eg: XLM). Default is 0.
masking_ratio: specifies what percentage of the blocks should be masked.
masking_prob: specifies the probability of a given token being
replaced with the "MASK" token.
random_token_prob: specifies the probability of a given token being
replaced by a random token from the vocabulary.
"""
def __init__(
self,
dataset: FairseqDataset,
sizes: np.ndarray,
vocab: Dictionary,
pad_idx: int,
mask_idx: int,
classif_token_idx: int,
sep_token_idx: int,
seed: int = 1,
shuffle: bool = True,
has_pairs: bool = True,
segment_id: int = 0,
masking_ratio: float = 0.15,
masking_prob: float = 0.8,
random_token_prob: float = 0.1,
):
# Make sure the input datasets are the ones supported
assert (
isinstance(dataset, TokenBlockDataset)
or isinstance(dataset, BlockPairDataset)
or isinstance(dataset, ConcatDataset)
), (
"MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset or "
"ConcatDataset"
)
self.dataset = dataset
self.sizes = np.array(sizes)
self.vocab = vocab
self.pad_idx = pad_idx
self.mask_idx = mask_idx
self.classif_token_idx = classif_token_idx
self.sep_token_idx = sep_token_idx
self.shuffle = shuffle
self.seed = seed
self.has_pairs = has_pairs
self.segment_id = segment_id
self.masking_ratio = masking_ratio
self.masking_prob = masking_prob
self.random_token_prob = random_token_prob
# If we have only one block then sizes needs to be updated to include
# the classification token
if not has_pairs:
self.sizes = self.sizes + 1
def __getitem__(self, index: int):
# if has_pairs, then expect 2 blocks and a sentence target
if self.has_pairs:
(block_one, block_two, sentence_target) = self.dataset[index]
else:
block_one = self.dataset[index]
return {
"id": index,
"block_one": block_one,
"block_two": block_two if self.has_pairs else None,
"sentence_target": sentence_target if self.has_pairs else None,
}
def __len__(self):
return len(self.dataset)
def _mask_block(
self,
sentence: np.ndarray,
mask_idx: int,
pad_idx: int,
dictionary_token_range: Tuple,
):
"""
Mask tokens for Masked Language Model training
Samples mask_ratio tokens that will be predicted by LM.
Note:This function may not be efficient enough since we had multiple
conversions between np and torch, we can replace them with torch
operators later.
Args:
sentence: 1d tensor to be masked
mask_idx: index to use for masking the sentence
pad_idx: index to use for masking the target for tokens we aren't
predicting
dictionary_token_range: range of indices in dictionary which can
be used for random word replacement
(e.g. without special characters)
Return:
masked_sent: masked sentence
target: target with words which we are not predicting replaced
by pad_idx
"""
masked_sent = np.copy(sentence)
sent_length = len(sentence)
mask_num = math.ceil(sent_length * self.masking_ratio)
mask = np.random.choice(sent_length, mask_num, replace=False)
target = np.copy(sentence)
for i in range(sent_length):
if i in mask:
rand = np.random.random()
# replace with mask if probability is less than masking_prob
# (Eg: 0.8)
if rand < self.masking_prob:
masked_sent[i] = mask_idx
# replace with random token if probability is less than
# masking_prob + random_token_prob (Eg: 0.9)
elif rand < (self.masking_prob + self.random_token_prob):
# sample random token from dictionary
masked_sent[i] = np.random.randint(
dictionary_token_range[0], dictionary_token_range[1]
)
else:
target[i] = pad_idx
return masked_sent, target
def _collate(self, samples: List[Dict], pad_idx: int, eos_idx: int):
"""
Does the heavy lifting for creating a batch from the input list of
examples. The logic is as follows:
1. Mask the input blocks. In case has_pair is True then we have 2
blocks to mask.
2. Prepend the first masked block tensor with the special token
used as sentence embedding. Eg: CLS in BERT. This happens
irrespective of the value of has_pair.
3. If has_pair is True, then append the first masked block with the
special separator token (eg: SEP for BERT) and compute segment
label accordingly. In this case, also append the second masked
block with this special separator token and compute its segment
label.
4. For the targets tensor, prepend and append with padding index
accordingly.
5. Concatenate all tensors.
"""
if len(samples) == 0:
return {}
# To ensure determinism, we reset the state of the PRNG after every
# batch based on the seed and the first id of the batch. This ensures
# that across epochs we get the same mask for the same example. This
# is needed for reproducibility and is how BERT does masking
# TODO: Can we add deteminism without this constraint?
with data_utils.numpy_seed(self.seed + samples[0]["id"]):
for s in samples:
# token range is needed for replacing with random token during
# masking
token_range = (self.vocab.nspecial, len(self.vocab))
# mask according to specified probabilities.
masked_blk_one, masked_tgt_one = self._mask_block(
s["block_one"],
self.mask_idx,
self.pad_idx,
token_range,
)
tokens = np.concatenate([[self.classif_token_idx], masked_blk_one])
targets = np.concatenate([[self.pad_idx], masked_tgt_one])
segments = np.ones(len(tokens)) * self.segment_id
# if has_pairs is True then we need to add the SEP token to both
# the blocks after masking and re-compute segments based on the new
# lengths.
if self.has_pairs:
tokens_one = np.concatenate([tokens, [self.sep_token_idx]])
targets_one = np.concatenate([targets, [self.pad_idx]])
masked_blk_two, masked_tgt_two = self._mask_block(
s["block_two"], self.mask_idx, self.pad_idx, token_range
)
tokens_two = np.concatenate([masked_blk_two, [self.sep_token_idx]])
targets_two = np.concatenate([masked_tgt_two, [self.pad_idx]])
# block + 1 sep + 1 special (CLS)
segments_one = np.zeros(len(tokens_one))
# block + 1 sep
segments_two = np.ones(len(tokens_two))
tokens = np.concatenate([tokens_one, tokens_two])
targets = np.concatenate([targets_one, targets_two])
segments = np.concatenate([segments_one, segments_two])
s["source"] = torch.LongTensor(tokens)
s["segment_labels"] = torch.LongTensor(segments)
s["lm_target"] = torch.LongTensor(targets)
def merge(key):
return data_utils.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad=False
)
return {
"id": torch.LongTensor([s["id"] for s in samples]),
"ntokens": sum(len(s["source"]) for s in samples),
"net_input": {
"src_tokens": merge("source"),
"segment_labels": merge("segment_labels"),
},
"lm_target": merge("lm_target"),
"sentence_target": torch.LongTensor([s["sentence_target"] for s in samples])
if self.has_pairs
else None,
"nsentences": len(samples),
}
def collater(self, samples: List[Dict]):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch of data
"""
return self._collate(samples, self.vocab.pad(), self.vocab.eos())
def num_tokens(self, index: int):
"""
Return the number of tokens in a sample. This value is used to
enforce max-tokens during batching.
"""
return self.sizes[index]
def size(self, index: int):
"""
Return an example's size as a float or tuple. This value is used when
filtering a dataset with max-positions.
"""
return self.sizes[index]
def ordered_indices(self):
"""
Return an ordered list of indices. Batches will be constructed based
on this order.
"""
if self.shuffle:
return np.random.permutation(len(self))
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
self.dataset.prefetch(indices)

View File

@@ -0,0 +1,60 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.data import Dictionary
class MaskedLMDictionary(Dictionary):
"""
Dictionary for Masked Language Modelling tasks. This extends Dictionary by
adding the mask symbol.
"""
def __init__(
self,
pad="<pad>",
eos="</s>",
unk="<unk>",
mask="<mask>",
):
super().__init__(pad=pad, eos=eos, unk=unk)
self.mask_word = mask
self.mask_index = self.add_symbol(mask)
self.nspecial = len(self.symbols)
def mask(self):
"""Helper to get index of mask symbol"""
return self.mask_index
class BertDictionary(MaskedLMDictionary):
"""
Dictionary for BERT task. This extends MaskedLMDictionary by adding support
for cls and sep symbols.
"""
def __init__(
self,
pad="<pad>",
eos="</s>",
unk="<unk>",
mask="<mask>",
cls="<cls>",
sep="<sep>",
):
super().__init__(pad=pad, eos=eos, unk=unk, mask=mask)
self.cls_word = cls
self.sep_word = sep
self.cls_index = self.add_symbol(cls)
self.sep_index = self.add_symbol(sep)
self.nspecial = len(self.symbols)
def cls(self):
"""Helper to get index of cls symbol"""
return self.cls_index
def sep(self):
"""Helper to get index of sep symbol"""
return self.sep_index

View File

@@ -0,0 +1,32 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import BaseWrapperDataset
class ListDataset(BaseWrapperDataset):
def __init__(self, dataset, sizes=None):
super().__init__(dataset)
self._sizes = sizes
def __iter__(self):
for x in self.dataset:
yield x
def collater(self, samples):
return samples
@property
def sizes(self):
return self._sizes
def num_tokens(self, index):
return self.sizes[index]
def size(self, index):
return self.sizes[index]
def set_epoch(self, epoch):
pass

View File

@@ -0,0 +1,97 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from typing import Dict
from fairseq.data.monolingual_dataset import MonolingualDataset
from . import FairseqDataset
class LMContextWindowDataset(FairseqDataset):
"""
Wraps a MonolingualDataset and provides more context for evaluation.
Each item in the new dataset will have a maximum size of
``tokens_per_sample + context_window``.
Args:
dataset: dataset to wrap
tokens_per_sample (int): the max number of tokens in each dataset item
context_window (int): the number of accumulated tokens to add to each
dataset item
pad_idx (int): padding symbol
"""
def __init__(
self,
dataset: MonolingualDataset,
tokens_per_sample: int,
context_window: int,
pad_idx: int,
):
assert context_window > 0
self.dataset = dataset
self.tokens_per_sample = tokens_per_sample
self.context_window = context_window
self.pad_idx = pad_idx
self.prev_tokens = np.empty([0])
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples) -> Dict:
sample = self.dataset.collater(samples)
pad = self.pad_idx
max_sample_len = self.tokens_per_sample + self.context_window
bsz, tsz = sample["net_input"]["src_tokens"].shape
start_idxs = [0] * bsz
toks = sample["net_input"]["src_tokens"]
lengths = sample["net_input"]["src_lengths"]
tgt = sample["target"]
new_toks = np.empty([bsz, tsz + self.context_window], dtype=np.int64)
new_tgt = np.full([bsz, tsz + self.context_window], pad, dtype=np.int64)
sample_lens = toks.ne(pad).long().sum(dim=1).cpu()
for i in range(bsz):
sample_len = sample_lens[i]
extra = len(self.prev_tokens) + sample_len - max_sample_len
if extra > 0:
self.prev_tokens = self.prev_tokens[extra:]
pads = np.full(self.context_window - len(self.prev_tokens), pad)
new_toks[i] = np.concatenate([self.prev_tokens, toks[i].numpy(), pads])
new_tgt[
i, len(self.prev_tokens) : len(self.prev_tokens) + len(tgt[i])
] = tgt[i]
start_idxs[i] = len(self.prev_tokens)
lengths[i] += len(self.prev_tokens)
self.prev_tokens = new_toks[i][new_toks[i] != pad][-self.context_window :]
sample["net_input"]["src_tokens"] = torch.from_numpy(new_toks)
sample["target"] = torch.from_numpy(new_tgt)
sample["start_indices"] = start_idxs
return sample
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
return self.dataset.size(index)
def ordered_indices(self):
# NOTE we don't shuffle the data to retain access to the previous dataset elements
return np.arange(len(self.dataset))
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)

View File

@@ -0,0 +1,21 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import lru_cache
from . import BaseWrapperDataset
class LRUCacheDataset(BaseWrapperDataset):
def __init__(self, dataset, token=None):
super().__init__(dataset)
@lru_cache(maxsize=8)
def __getitem__(self, index):
return self.dataset[index]
@lru_cache(maxsize=8)
def collater(self, samples):
return self.dataset.collater(samples)

View File

@@ -0,0 +1,220 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import lru_cache
import numpy as np
import torch
from fairseq.data import Dictionary, data_utils
from . import BaseWrapperDataset, LRUCacheDataset
class MaskTokensDataset(BaseWrapperDataset):
"""
A wrapper Dataset for masked language modeling.
Input items are masked according to the specified masking probability.
Args:
dataset: Dataset to wrap.
sizes: Sentence lengths
vocab: Dictionary with the vocabulary and special tokens.
pad_idx: Id of pad token in vocab
mask_idx: Id of mask token in vocab
return_masked_tokens: controls whether to return the non-masked tokens
(the default) or to return a tensor with the original masked token
IDs (and *pad_idx* elsewhere). The latter is useful as targets for
masked LM training.
seed: Seed for random number generator for reproducibility.
mask_prob: probability of replacing a token with *mask_idx*.
leave_unmasked_prob: probability that a masked token is unmasked.
random_token_prob: probability of replacing a masked token with a
random token from the vocabulary.
freq_weighted_replacement: sample random replacement words based on
word frequencies in the vocab.
mask_whole_words: only mask whole words. This should be a byte mask
over vocab indices, indicating whether it is the beginning of a
word. We will extend any mask to encompass the whole word.
bpe: BPE to use for whole-word masking.
mask_multiple_length : repeat each mask index multiple times. Default
value is 1.
mask_stdev : standard deviation of masks distribution in case of
multiple masking. Default value is 0.
"""
@classmethod
def apply_mask(cls, dataset: torch.utils.data.Dataset, *args, **kwargs):
"""Return the source and target datasets for masked LM training."""
dataset = LRUCacheDataset(dataset)
return (
LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=False)),
LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=True)),
)
def __init__(
self,
dataset: torch.utils.data.Dataset,
vocab: Dictionary,
pad_idx: int,
mask_idx: int,
return_masked_tokens: bool = False,
seed: int = 1,
mask_prob: float = 0.15,
leave_unmasked_prob: float = 0.1,
random_token_prob: float = 0.1,
freq_weighted_replacement: bool = False,
mask_whole_words: torch.Tensor = None,
mask_multiple_length: int = 1,
mask_stdev: float = 0.0,
):
assert 0.0 < mask_prob < 1.0
assert 0.0 <= random_token_prob <= 1.0
assert 0.0 <= leave_unmasked_prob <= 1.0
assert random_token_prob + leave_unmasked_prob <= 1.0
assert mask_multiple_length >= 1
assert mask_stdev >= 0.0
self.dataset = dataset
self.vocab = vocab
self.pad_idx = pad_idx
self.mask_idx = mask_idx
self.return_masked_tokens = return_masked_tokens
self.seed = seed
self.mask_prob = mask_prob
self.leave_unmasked_prob = leave_unmasked_prob
self.random_token_prob = random_token_prob
self.mask_whole_words = mask_whole_words
self.mask_multiple_length = mask_multiple_length
self.mask_stdev = mask_stdev
if random_token_prob > 0.0:
if freq_weighted_replacement:
weights = np.array(self.vocab.count)
else:
weights = np.ones(len(self.vocab))
weights[: self.vocab.nspecial] = 0
self.weights = weights / weights.sum()
self.epoch = 0
@property
def can_reuse_epoch_itr_across_epochs(self):
return True # only the noise changes, not item sizes
def set_epoch(self, epoch, **unused):
super().set_epoch(epoch)
self.epoch = epoch
def __getitem__(self, index: int):
return self.__getitem_cached__(self.seed, self.epoch, index)
@lru_cache(maxsize=8)
def __getitem_cached__(self, seed: int, epoch: int, index: int):
with data_utils.numpy_seed(self.seed, self.epoch, index):
item = self.dataset[index]
sz = len(item)
assert (
self.mask_idx not in item
), "Dataset contains mask_idx (={}), this is not expected!".format(
self.mask_idx,
)
if self.mask_whole_words is not None:
word_begins_mask = self.mask_whole_words.gather(0, item)
word_begins_idx = word_begins_mask.nonzero().view(-1)
sz = len(word_begins_idx)
words = np.split(word_begins_mask, word_begins_idx)[1:]
assert len(words) == sz
word_lens = list(map(len, words))
# decide elements to mask
mask = np.full(sz, False)
num_mask = int(
# add a random number for probabilistic rounding
self.mask_prob * sz / float(self.mask_multiple_length)
+ np.random.rand()
)
# multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453)
mask_idc = np.random.choice(sz, num_mask, replace=False)
if self.mask_stdev > 0.0:
lengths = np.random.normal(
self.mask_multiple_length, self.mask_stdev, size=num_mask
)
lengths = [max(0, int(round(x))) for x in lengths]
mask_idc = np.asarray(
[
mask_idc[j] + offset
for j in range(len(mask_idc))
for offset in range(lengths[j])
],
dtype=np.int64,
)
else:
mask_idc = np.concatenate(
[mask_idc + i for i in range(self.mask_multiple_length)]
)
mask_idc = mask_idc[mask_idc < len(mask)]
try:
mask[mask_idc] = True
except: # something wrong
print(
"Assigning mask indexes {} to mask {} failed!".format(
mask_idc, mask
)
)
raise
if self.return_masked_tokens:
# exit early if we're just returning the masked tokens
# (i.e., the targets for masked LM training)
if self.mask_whole_words is not None:
mask = np.repeat(mask, word_lens)
new_item = np.full(len(mask), self.pad_idx)
new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1]
return torch.from_numpy(new_item)
# decide unmasking and random replacement
rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
if rand_or_unmask_prob > 0.0:
rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
if self.random_token_prob == 0.0:
unmask = rand_or_unmask
rand_mask = None
elif self.leave_unmasked_prob == 0.0:
unmask = None
rand_mask = rand_or_unmask
else:
unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
decision = np.random.rand(sz) < unmask_prob
unmask = rand_or_unmask & decision
rand_mask = rand_or_unmask & (~decision)
else:
unmask = rand_mask = None
if unmask is not None:
mask = mask ^ unmask
if self.mask_whole_words is not None:
mask = np.repeat(mask, word_lens)
new_item = np.copy(item)
new_item[mask] = self.mask_idx
if rand_mask is not None:
num_rand = rand_mask.sum()
if num_rand > 0:
if self.mask_whole_words is not None:
rand_mask = np.repeat(rand_mask, word_lens)
num_rand = rand_mask.sum()
new_item[rand_mask] = np.random.choice(
len(self.vocab),
num_rand,
p=self.weights,
)
return torch.from_numpy(new_item)

View File

@@ -0,0 +1,253 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from . import FairseqDataset, data_utils
def collate(samples, pad_idx, eos_idx, fixed_pad_length=None, pad_to_bsz=None):
if len(samples) == 0:
return {}
def merge(key, is_list=False):
if is_list:
res = []
for i in range(len(samples[0][key])):
res.append(
data_utils.collate_tokens(
[s[key][i] for s in samples],
pad_idx,
eos_idx,
left_pad=False,
pad_to_length=fixed_pad_length,
pad_to_bsz=pad_to_bsz,
)
)
return res
else:
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx,
left_pad=False,
pad_to_length=fixed_pad_length,
pad_to_bsz=pad_to_bsz,
)
src_tokens = merge("source")
if samples[0]["target"] is not None:
is_target_list = isinstance(samples[0]["target"], list)
target = merge("target", is_target_list)
else:
target = src_tokens
return {
"id": torch.LongTensor([s["id"] for s in samples]),
"nsentences": len(samples),
"ntokens": sum(len(s["source"]) for s in samples),
"net_input": {
"src_tokens": src_tokens,
"src_lengths": torch.LongTensor([s["source"].numel() for s in samples]),
},
"target": target,
}
class MonolingualDataset(FairseqDataset):
"""
A wrapper around torch.utils.data.Dataset for monolingual data.
Args:
dataset (torch.utils.data.Dataset): dataset to wrap
sizes (List[int]): sentence lengths
vocab (~fairseq.data.Dictionary): vocabulary
shuffle (bool, optional): shuffle the elements before batching
(default: True).
"""
def __init__(
self,
dataset,
sizes,
src_vocab,
tgt_vocab=None,
add_eos_for_other_targets=False,
shuffle=False,
targets=None,
add_bos_token=False,
fixed_pad_length=None,
pad_to_bsz=None,
src_lang_idx=None,
tgt_lang_idx=None,
):
self.dataset = dataset
self.sizes = np.array(sizes)
self.vocab = src_vocab
self.tgt_vocab = tgt_vocab or src_vocab
self.add_eos_for_other_targets = add_eos_for_other_targets
self.shuffle = shuffle
self.add_bos_token = add_bos_token
self.fixed_pad_length = fixed_pad_length
self.pad_to_bsz = pad_to_bsz
self.src_lang_idx = src_lang_idx
self.tgt_lang_idx = tgt_lang_idx
assert targets is None or all(
t in {"self", "future", "past"} for t in targets
), "targets must be none or one of 'self', 'future', 'past'"
if targets is not None and len(targets) == 0:
targets = None
self.targets = targets
def __getitem__(self, index):
if self.targets is not None:
# *future_target* is the original sentence
# *source* is shifted right by 1 (maybe left-padded with eos)
# *past_target* is shifted right by 2 (left-padded as needed)
#
# Left-to-right language models should condition on *source* and
# predict *future_target*.
# Right-to-left language models should condition on *source* and
# predict *past_target*.
source, future_target, past_target = self.dataset[index]
source, target = self._make_source_target(
source, future_target, past_target
)
else:
source = self.dataset[index]
target = None
source, target = self._maybe_add_bos(source, target)
return {"id": index, "source": source, "target": target}
def __len__(self):
return len(self.dataset)
def _make_source_target(self, source, future_target, past_target):
if self.targets is not None:
target = []
if (
self.add_eos_for_other_targets
and (("self" in self.targets) or ("past" in self.targets))
and source[-1] != self.vocab.eos()
):
# append eos at the end of source
source = torch.cat([source, source.new([self.vocab.eos()])])
if "future" in self.targets:
future_target = torch.cat(
[future_target, future_target.new([self.vocab.pad()])]
)
if "past" in self.targets:
# first token is before the start of sentence which is only used in "none" break mode when
# add_eos_for_other_targets is False
past_target = torch.cat(
[
past_target.new([self.vocab.pad()]),
past_target[1:],
source[-2, None],
]
)
for t in self.targets:
if t == "self":
target.append(source)
elif t == "future":
target.append(future_target)
elif t == "past":
target.append(past_target)
else:
raise Exception("invalid target " + t)
if len(target) == 1:
target = target[0]
else:
target = future_target
return source, self._filter_vocab(target)
def _maybe_add_bos(self, source, target):
if self.add_bos_token:
source = torch.cat([source.new([self.vocab.bos()]), source])
if target is not None:
target = torch.cat([target.new([self.tgt_vocab.bos()]), target])
return source, target
def num_tokens_vec(self, indices):
"""Return the number of tokens for a set of positions defined by indices.
This value is used to enforce ``--max-tokens`` during batching."""
return self.sizes[indices]
def _filter_vocab(self, target):
if len(self.tgt_vocab) != len(self.vocab):
def _filter(target):
mask = target.ge(len(self.tgt_vocab))
if mask.any():
target[mask] = self.tgt_vocab.unk()
return target
if isinstance(target, list):
return [_filter(t) for t in target]
return _filter(target)
return target
def collater(self, samples):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch with the following keys:
- `id` (LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape `(bsz, src_len)`. Padding will
appear on the right.
- `target` (LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear
on the right.
"""
return collate(
samples,
self.vocab.pad(),
self.vocab.eos(),
self.fixed_pad_length,
self.pad_to_bsz,
)
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return self.sizes[index]
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return self.sizes[index]
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
self.dataset.prefetch(indices)

View File

@@ -0,0 +1,285 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import asyncio
import logging
import time
from collections import OrderedDict
from typing import Dict, List, Optional
import numpy as np
from fairseq.data import data_utils
from . import FairseqDataset
logger = logging.getLogger(__name__)
class MultiCorpusDataset(FairseqDataset):
"""
Stores multiple instances of FairseqDataset together.
Unless batch_sample=True, requires each instance
to be the same dataset, as the collate method needs to work on batches with
samples from each dataset.
Allows specifying a distribution over the datasets to use. Note that unlike
MultiCorpusSampledDataset, this distribution allows sampling for each item,
rather than on a batch level. Note that datasets with sampling probabilty
of 0 will be skipped.
Each time ordered_indices() is called, a new sample is generated with
the specified distribution.
Args:
datasets: a OrderedDict of FairseqDataset instances.
distribution: a List containing the probability of getting an utterance from
corresponding dataset
seed: random seed for sampling the datsets
sort_indices: if true, will sort the ordered indices by size
batch_sample: if true, will ensure each batch is from a single dataset
"""
def __init__(
self,
datasets: Dict[str, FairseqDataset],
distribution: List[float],
seed: int,
sort_indices: bool = False,
batch_sample: bool = False,
distributed_rank: Optional[int] = None,
):
super().__init__()
assert isinstance(datasets, OrderedDict)
assert len(datasets) == len(distribution)
assert sum(distribution) == 1
self.datasets = datasets
self.distribution = distribution
self.seed = seed
self.sort_indices = sort_indices
self.batch_sample = batch_sample
self.distributed_rank = distributed_rank
# Avoid repeated conversions to list later
self.dataset_list = list(datasets.values())
self.total_num_instances = 0
first_dataset = self.dataset_list[0]
self.num_instances_per_dataset = []
self.dataset_offsets = []
for i, dataset in enumerate(self.dataset_list):
assert isinstance(dataset, FairseqDataset)
assert type(dataset) is type(first_dataset)
self.num_instances_per_dataset.append(
0 if self.distribution[i] == 0 else len(dataset)
)
self.dataset_offsets.append(self.total_num_instances)
self.total_num_instances += self.num_instances_per_dataset[i]
def ordered_indices(self):
start = time.time()
with data_utils.numpy_seed(self.seed, self.epoch):
logger.info(
f"sampling new dataset with seed {self.seed} epoch {self.epoch}"
)
sampled_indices = []
num_selected_instances = 0
# For each dataset i, sample self.distribution[i] * self.total_num_instances
for i, key in enumerate(self.datasets):
if self.distribution[i] == 0:
# skip dataset if sampling probability is 0
continue
if i < len(self.datasets) - 1:
num_instances = int(self.distribution[i] * self.total_num_instances)
high = self.dataset_offsets[i + 1]
else:
num_instances = self.total_num_instances - num_selected_instances
high = self.total_num_instances
logger.info(f"sampling {num_instances} from {key} dataset")
num_selected_instances += num_instances
# First, add k copies of the dataset where k = num_instances // len(dataset).
# This ensures an equal distribution of the data points as much as possible.
# For the remaining entries randomly sample them
dataset_size = len(self.datasets[key])
num_copies = num_instances // dataset_size
dataset_indices = (
np.random.permutation(high - self.dataset_offsets[i])
+ self.dataset_offsets[i]
)[: num_instances - num_copies * dataset_size]
if num_copies > 0:
sampled_indices += list(
np.concatenate(
(
np.repeat(
np.arange(self.dataset_offsets[i], high), num_copies
),
dataset_indices,
)
)
)
else:
sampled_indices += list(dataset_indices)
assert (
len(sampled_indices) == self.total_num_instances
), f"{len(sampled_indices)} vs {self.total_num_instances}"
np.random.shuffle(sampled_indices)
if self.sort_indices:
sampled_indices.sort(key=lambda i: self.num_tokens(i))
logger.info(
"multi_corpus_dataset ordered_indices took {}s".format(
time.time() - start
)
)
return np.array(sampled_indices, dtype=np.int64)
def _map_index(self, index: int):
"""
If dataset A has length N and dataset B has length M
then index 1 maps to index 1 of dataset A, and index N + 1
maps to index 1 of B.
"""
counter = 0
for num_instances, key in zip(self.num_instances_per_dataset, self.datasets):
if index < counter + num_instances:
return index - counter, key
counter += num_instances
raise ValueError(
"Invalid index: {}, max: {}".format(index, self.total_num_instances)
)
def __len__(self):
"""
Length of this dataset is the sum of individual datasets
"""
return self.total_num_instances
async def getitem(self, index):
new_index, key = self._map_index(index)
try:
if hasattr(self.datasets[key], "getitem"):
item = await self.datasets[key].getitem(new_index)
else:
item = self.datasets[key][new_index]
item["full_id"] = index
return item
except Exception as e:
e.args = (f"Error from {key} dataset", *e.args)
raise
def __getitem__(self, index):
return asyncio.run(self.getitem(index))
async def getitems(self, indices):
# initialize a bunch of everstore read operations
# wait in the end to reduce overhead
# very helpful if io is latency bounded
max_concurrency = 32
sem = asyncio.Semaphore(max_concurrency)
async def controlled_getitem(index):
async with sem:
return await self.getitem(index)
coroutines = []
for index in indices:
coroutines.append(controlled_getitem(index))
results = await asyncio.gather(*coroutines)
return results
def __getitems__(self, indices):
return asyncio.run(self.getitems(indices))
def collater(self, samples):
"""
If we are doing batch sampling, then pick the right collater to use.
Otherwise we assume all collaters are the same.
"""
if len(samples) == 0:
return None
if "full_id" in samples[0]:
_, key = self._map_index(samples[0]["full_id"])
try:
batch = self.datasets[key].collater(samples)
except Exception:
print(f"Collating failed for key {key}", flush=True)
raise
return batch
else:
# Subclasses may override __getitem__ to not specify full_id
return list(self.datasets.values())[0].collater(samples)
def num_tokens(self, index: int):
index, key = self._map_index(index)
return self.datasets[key].num_tokens(index)
def size(self, index: int):
index, key = self._map_index(index)
return self.datasets[key].size(index)
@property
def can_reuse_epoch_itr_across_epochs(self):
return False
def set_epoch(self, epoch, **unused):
super().set_epoch(epoch)
logger.info(f"setting epoch of multi_corpus_dataset to {epoch}")
self.epoch = epoch
@property
def supports_prefetch(self):
return False
@property
def supports_fetch_outside_dataloader(self):
return all(
self.datasets[key].supports_fetch_outside_dataloader
for key in self.datasets
)
def batch_by_size(
self,
indices,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
):
if not self.batch_sample:
return super().batch_by_size(
indices, max_tokens, max_sentences, required_batch_size_multiple
)
dataset_indices = {key: [] for key in self.datasets}
for i in indices:
_, key = self._map_index(i)
dataset_indices[key].append(i)
batches = []
for key in dataset_indices:
cur_batches = super().batch_by_size(
np.array(dataset_indices[key], dtype=np.int64),
max_tokens,
max_sentences,
required_batch_size_multiple,
)
logger.info(f"Created {len(cur_batches)} batches for dataset {key}")
batches += cur_batches
# If this dataset is used in a distributed training setup,
# then shuffle such that the order is seeded by the distributed rank
# as well
if self.distributed_rank is not None:
with data_utils.numpy_seed(self.seed, self.epoch, self.distributed_rank):
np.random.shuffle(batches)
return batches

View File

@@ -0,0 +1,152 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from typing import Callable, Dict, List
import numpy as np
from . import FairseqDataset
def uniform_sampler(x):
# Sample from uniform distribution
return np.random.choice(x, 1).item()
class MultiCorpusSampledDataset(FairseqDataset):
"""
Stores multiple instances of FairseqDataset together and in every iteration
creates a batch by first sampling a dataset according to a specified
probability distribution and then getting instances from that dataset.
Args:
datasets: an OrderedDict of FairseqDataset instances.
sampling_func: A function for sampling over list of dataset keys.
The default strategy is to sample uniformly.
"""
def __init__(
self,
datasets: Dict[str, FairseqDataset],
sampling_func: Callable[[List], int] = None,
):
super().__init__()
assert isinstance(datasets, OrderedDict)
self.datasets = datasets
if sampling_func is None:
sampling_func = uniform_sampler
self.sampling_func = sampling_func
self.total_num_instances = 0
for _, dataset in datasets.items():
assert isinstance(dataset, FairseqDataset)
self.total_num_instances += len(dataset)
self._ordered_indices = None
def __len__(self):
"""
Length of this dataset is the sum of individual datasets
"""
return self.total_num_instances
def ordered_indices(self):
"""
Ordered indices for batching. Here we call the underlying
dataset's ordered_indices() so that we get the same random ordering
as we would have from using the underlying dataset directly.
"""
if self._ordered_indices is None:
self._ordered_indices = OrderedDict(
[
(key, dataset.ordered_indices())
for key, dataset in self.datasets.items()
]
)
return np.arange(len(self))
def _map_index_to_dataset(self, key: int, index: int):
"""
Different underlying datasets have different lengths. In order to ensure
we are not accessing an index outside the range of the current dataset
size, we wrap around. This function should be called after we have
created an ordering for this and all underlying datasets.
"""
assert (
self._ordered_indices is not None
), "Must call MultiCorpusSampledDataset.ordered_indices() first"
mapped_index = index % len(self.datasets[key])
return self._ordered_indices[key][mapped_index]
def __getitem__(self, index: int):
"""
Get the item associated with index from each underlying dataset.
Since index is in the range of [0, TotalNumInstances], we need to
map the index to the dataset before retrieving the item.
"""
return OrderedDict(
[
(key, dataset[self._map_index_to_dataset(key, index)])
for key, dataset in self.datasets.items()
]
)
def collater(self, samples: List[Dict]):
"""
Generate a mini-batch for this dataset.
To convert this into a regular mini-batch we use the following
logic:
1. Select a dataset using the specified probability distribution.
2. Call the collater function of the selected dataset.
"""
if len(samples) == 0:
return None
selected_key = self.sampling_func(list(self.datasets.keys()))
selected_samples = [sample[selected_key] for sample in samples]
return self.datasets[selected_key].collater(selected_samples)
def num_tokens(self, index: int):
"""
Return an example's length (number of tokens), used for batching. Here
we return the max across all examples at index across all underlying
datasets.
"""
return max(
dataset.num_tokens(self._map_index_to_dataset(key, index))
for key, dataset in self.datasets.items()
)
def size(self, index: int):
"""
Return an example's size as a float or tuple. Here we return the max
across all underlying datasets. This value is used when filtering a
dataset with max-positions.
"""
return max(
dataset.size(self._map_index_to_dataset(key, index))
for key, dataset in self.datasets.items()
)
@property
def supports_prefetch(self):
return all(
getattr(dataset, "supports_prefetch", False)
for dataset in self.datasets.values()
)
def prefetch(self, indices):
for key, dataset in self.datasets.items():
dataset.prefetch(
[self._map_index_to_dataset(key, index) for index in indices]
)
@property
def supports_fetch_outside_dataloader(self):
return all(
self.datasets[key].supports_fetch_outside_dataloader
for key in self.datasets
)

View File

@@ -0,0 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,63 @@
from enum import Enum
from typing import Dict, List, Optional, Sequence
import torch
from fairseq.data import Dictionary
class EncoderLangtok(Enum):
"""
Prepend to the beginning of source sentence either the
source or target language token. (src/tgt).
"""
src = "src"
tgt = "tgt"
class LangTokSpec(Enum):
main = "main"
mono_dae = "mono_dae"
class LangTokStyle(Enum):
multilingual = "multilingual"
mbart = "mbart"
@torch.jit.export
def get_lang_tok(
lang: str, lang_tok_style: str, spec: str = LangTokSpec.main.value
) -> str:
# TOKEN_STYLES can't be defined outside this fn since it needs to be
# TorchScriptable.
TOKEN_STYLES: Dict[str, str] = {
LangTokStyle.mbart.value: "[{}]",
LangTokStyle.multilingual.value: "__{}__",
}
if spec.endswith("dae"):
lang = f"{lang}_dae"
elif spec.endswith("mined"):
lang = f"{lang}_mined"
style = TOKEN_STYLES[lang_tok_style]
return style.format(lang)
def augment_dictionary(
dictionary: Dictionary,
language_list: List[str],
lang_tok_style: str,
langtoks_specs: Sequence[str] = (LangTokSpec.main.value,),
extra_data: Optional[Dict[str, str]] = None,
) -> None:
for spec in langtoks_specs:
for language in language_list:
dictionary.add_symbol(
get_lang_tok(lang=language, lang_tok_style=lang_tok_style, spec=spec)
)
if lang_tok_style == LangTokStyle.mbart.value or (
extra_data is not None and LangTokSpec.mono_dae.value in extra_data
):
dictionary.add_symbol("<mask>")

View File

@@ -0,0 +1,468 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import hashlib
import logging
import time
from bisect import bisect_right
from collections import OrderedDict, defaultdict
from enum import Enum
from typing import List
import numpy as np
import torch
from fairseq.data import FairseqDataset, data_utils
from fairseq.distributed import utils as distributed_utils
def get_time_gap(s, e):
return (
datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)
).__str__()
logger = logging.getLogger(__name__)
def default_virtual_size_func(datasets, ratios, max_scale_up=1.5):
sizes = [len(d) for d in datasets]
if ratios is None:
return sum(sizes)
largest_idx = np.argmax(sizes)
largest_r = ratios[largest_idx]
largest_s = sizes[largest_idx]
# set virtual sizes relative to the largest dataset
virtual_sizes = [(r / largest_r) * largest_s for r in ratios]
vsize = sum(virtual_sizes)
max_size = sum(sizes) * max_scale_up
return int(vsize if vsize < max_size else max_size)
class CollateFormat(Enum):
single = 1
ordered_dict = 2
class SampledMultiDataset(FairseqDataset):
"""Samples from multiple sub-datasets according to given sampling ratios.
Args:
datasets (
List[~torch.utils.data.Dataset]
or OrderedDict[str, ~torch.utils.data.Dataset]
): datasets
sampling_ratios (List[float]): list of probability of each dataset to be sampled
(default: None, which corresponds to concatenating all dataset together).
seed (int): RNG seed to use (default: 2).
epoch (int): starting epoch number (default: 1).
eval_key (str, optional): a key used at evaluation time that causes
this instance to pass-through batches from *datasets[eval_key]*.
collate_format (CollateFormat): collater output format, either CollateFormat.ordered_dict or
CollateFormat.single (default: CollateFormat.single) where CollateFormat.single configures
the collater to output batches of data mixed from all sub-datasets,
and CollateFormat.ordered_dict configures the collater to output a dictionary of batches indexed by keys
of sub-datasets.
Note that not all sub-datasets will present in a single batch in both formats.
virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func).
split (str): the split of the data, e.g. 'train', 'valid' or 'test'.
shared_collater (bool): whether or not to all sub-datasets have the same collater.
shuffle (bool): whether or not to shuffle data (default: True).
"""
def __init__(
self,
datasets,
sampling_ratios=None,
seed=2,
epoch=1,
eval_key=None,
collate_format=CollateFormat.single,
virtual_size=default_virtual_size_func,
split="",
shared_collater=False,
shuffle=True,
):
super().__init__()
self.shared_collater = shared_collater
self.shuffle = shuffle
if isinstance(datasets, OrderedDict):
self.keys = list(datasets.keys())
datasets = list(datasets.values())
elif isinstance(datasets, List):
self.keys = list(range(len(datasets)))
else:
raise AssertionError()
self.datasets = datasets
self.split = split
self.eval_key = eval_key
if self.eval_key is not None:
self.collate_format = CollateFormat.single
else:
self.collate_format = collate_format
self.seed = seed
self._cur_epoch = None
self.cumulated_sizes = None
# self.datasets[k][self._cur_indices[i]] is the data item i in this sampled dataset
# namely, data item i is sampled from the kth sub-dataset self.datasets[k]
# where self.cumulated_sizes[k-1] <= i < self.cumulated_sizes[k]
self._cur_indices = None
self._sizes = None
self.virtual_size_per_dataset = None
# caching properties
self._reset_cached_properties()
self.setup_sampling(sampling_ratios, virtual_size)
self.set_epoch(epoch)
def _clean_if_not_none(self, var_list):
for v in var_list:
if v is not None:
del v
def _reset_cached_properties(self):
self._clean_if_not_none([self._sizes, self._cur_indices])
self._sizes = None
self._cur_indices = None
def setup_sampling(self, sample_ratios, virtual_size):
sizes = [len(d) for d in self.datasets]
if sample_ratios is None:
# default back to concating datasets
self.sample_ratios = None
self.virtual_size = sum(sizes)
else:
if not isinstance(sample_ratios, np.ndarray):
sample_ratios = np.array(sample_ratios)
self.sample_ratios = sample_ratios
virtual_size = (
default_virtual_size_func if virtual_size is None else virtual_size
)
self.virtual_size = (
virtual_size(self.datasets, self.sample_ratios)
if callable(virtual_size)
else virtual_size
)
def adjust_sampling(self, epoch, sampling_ratios, virtual_size):
if sampling_ratios is not None:
sampling_ratios = self._sync_sample_ratios(sampling_ratios)
self.setup_sampling(sampling_ratios, virtual_size)
def _sync_sample_ratios(self, ratios):
# in case the ratios are not precisely the same across processes
# also to ensure every procresses update the ratios in the same pace
ratios = torch.DoubleTensor(ratios)
if torch.distributed.is_initialized():
if torch.cuda.is_available():
distributed_utils.all_reduce(
ratios.cuda(), group=distributed_utils.get_data_parallel_group()
)
else:
distributed_utils.all_reduce(
ratios, group=distributed_utils.get_data_parallel_group()
)
ret = ratios.cpu()
ret = ret.numpy()
return ret
def random_choice_in_dataset(self, rng, dataset, choice_size):
if hasattr(dataset, "random_choice_in_dataset"):
return dataset.random_choice_in_dataset(rng, choice_size)
dataset_size = len(dataset)
return rng.choice(
dataset_size, choice_size, replace=(choice_size > dataset_size)
)
def get_virtual_indices(self, rng, datasets, sample_ratios, virtual_size):
def get_counts(sample_ratios):
counts = np.array([virtual_size * r for r in sample_ratios], dtype=np.int64)
diff = virtual_size - counts.sum()
assert diff >= 0
# due to round-offs, the size might not match the desired sizes
if diff > 0:
dataset_indices = rng.choice(
len(sample_ratios), size=diff, p=sample_ratios
)
for i in dataset_indices:
counts[i] += 1
return counts
def get_in_dataset_indices(datasets, sizes, sample_ratios):
counts = get_counts(sample_ratios)
# uniformally sample desired counts for each dataset
# if the desired counts are large, sample with replacement:
indices = [
self.random_choice_in_dataset(rng, d, c)
for c, d in zip(counts, datasets)
]
return indices
sizes = [len(d) for d in datasets]
if sample_ratios is None:
# default back to concating datasets
in_dataset_indices = [list(range(s)) for s in sizes]
virtual_sizes_per_dataset = sizes
else:
ratios = sample_ratios / sample_ratios.sum()
in_dataset_indices = get_in_dataset_indices(datasets, sizes, ratios)
virtual_sizes_per_dataset = [len(d) for d in in_dataset_indices]
virtual_sizes_per_dataset = np.array(virtual_sizes_per_dataset, np.int64)
cumulative_sizes = np.cumsum(virtual_sizes_per_dataset)
assert sum(virtual_sizes_per_dataset) == virtual_size
assert cumulative_sizes[-1] == virtual_size
if virtual_size < sum(sizes):
logger.warning(
f"virtual data size ({virtual_size}) is less than real data size ({sum(sizes)})."
" If virtual size << real data size, there could be data coverage issue."
)
in_dataset_indices = np.hstack(in_dataset_indices)
return in_dataset_indices, cumulative_sizes, virtual_sizes_per_dataset
def _get_dataset_and_index(self, index):
i = bisect_right(self.cumulated_sizes, index)
return i, self._cur_indices[index]
def __getitem__(self, index):
# self.__getitem__(index) returns self.datasets[k][self._cur_indices[index]]
# where k satisfies self.cumulated_sizes[k - 1] <= k < self.cumulated_sizes[k]
ds_idx, ds_sample_idx = self._get_dataset_and_index(index)
ret = (ds_idx, self.datasets[ds_idx][ds_sample_idx])
return ret
def num_tokens(self, index):
return self.sizes[index].max()
def num_tokens_vec(self, indices):
sizes_vec = self.sizes[np.array(indices)]
# max across all dimensions but first one
return np.amax(sizes_vec, axis=tuple(range(1, len(sizes_vec.shape))))
def size(self, index):
return self.sizes[index]
def __len__(self):
return self.virtual_size
def collater(self, samples, **extra_args):
"""Merge a list of samples to form a mini-batch."""
if len(samples) == 0:
return None
if self.collate_format == "ordered_dict":
collect_samples = [[] for _ in range(len(self.datasets))]
for (i, sample) in samples:
collect_samples[i].append(sample)
batch = OrderedDict(
[
(self.keys[i], dataset.collater(collect_samples[i]))
for i, (key, dataset) in enumerate(zip(self.keys, self.datasets))
if len(collect_samples[i]) > 0
]
)
elif self.shared_collater:
batch = self.datasets[0].collater([s for _, s in samples])
else:
samples_dict = defaultdict(list)
pad_to_length = (
defaultdict(int)
if "pad_to_length" not in extra_args
else extra_args["pad_to_length"]
)
for ds_idx, s in samples:
pad_to_length["source"] = max(
pad_to_length["source"], s["source"].size(0)
)
if s["target"] is not None:
pad_to_length["target"] = max(
pad_to_length["target"], s["target"].size(0)
)
samples_dict[ds_idx].append(s)
batches = [
self.datasets[i].collater(samples_dict[i], pad_to_length=pad_to_length)
for i in range(len(self.datasets))
if len(samples_dict[i]) > 0
]
def straight_data(tensors):
batch = torch.cat(tensors, dim=0)
return batch
src_lengths = straight_data(
[b["net_input"]["src_lengths"] for b in batches]
)
src_lengths, sort_order = src_lengths.sort(descending=True)
def straight_order(tensors):
batch = straight_data(tensors)
return batch.index_select(0, sort_order)
batch = {
"id": straight_order([b["id"] for b in batches]),
"nsentences": sum(b["nsentences"] for b in batches),
"ntokens": sum(b["ntokens"] for b in batches),
"net_input": {
"src_tokens": straight_order(
[b["net_input"]["src_tokens"] for b in batches]
),
"src_lengths": src_lengths,
},
"target": straight_order([b["target"] for b in batches])
if batches[0]["target"] is not None
else None,
}
if "prev_output_tokens" in batches[0]["net_input"]:
batch["net_input"]["prev_output_tokens"] = straight_order(
[b["net_input"]["prev_output_tokens"] for b in batches]
)
if "src_lang_id" in batches[0]["net_input"]:
batch["net_input"]["src_lang_id"] = straight_order(
[b["net_input"]["src_lang_id"] for b in batches]
)
if "tgt_lang_id" in batches[0]:
batch["tgt_lang_id"] = straight_order(
[b["tgt_lang_id"] for b in batches]
)
return batch
@property
def sizes(self):
if self._sizes is not None:
return self._sizes
start_time = time.time()
in_sub_dataset_indices = [
self._cur_indices[
0 if i == 0 else self.cumulated_sizes[i - 1] : self.cumulated_sizes[i]
]
for i in range(len(self.datasets))
]
sub_dataset_sizes = [
d.sizes[indices]
for d, indices in zip(self.datasets, in_sub_dataset_indices)
]
self._sizes = np.vstack(sub_dataset_sizes)
logger.info(f"sizes() calling time: {get_time_gap(start_time, time.time())}")
return self._sizes
def ordered_indices(self):
if self.shuffle:
indices = np.random.permutation(len(self))
else:
indices = np.arange(len(self))
sizes = self.sizes
tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
src_sizes = (
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
)
# sort by target length, then source length
if tgt_sizes is not None:
indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
sort_indices = indices[np.argsort(src_sizes[indices], kind="mergesort")]
return sort_indices
def prefetch(self, indices):
prefetch_indices = [[] for _ in range(len(self.datasets))]
for i in indices:
ds_idx, ds_sample_idx = self._get_dataset_and_index(i)
prefetch_indices[ds_idx].append(ds_sample_idx)
for i in range(len(prefetch_indices)):
self.datasets[i].prefetch(prefetch_indices[i])
@property
def can_reuse_epoch_itr_across_epochs(self):
return False
def set_epoch(self, epoch):
super().set_epoch(epoch)
if epoch == self._cur_epoch:
# re-enter so return
return
for d in self.datasets:
if hasattr(d, "set_epoch"):
d.set_epoch(epoch)
self._cur_epoch = epoch
self._establish_virtual_datasets()
def _establish_virtual_datasets(self):
if self.sample_ratios is None and self._cur_indices is not None:
# not a samping dataset, no need to resample if indices are already established
return
self._reset_cached_properties()
start_time = time.time()
# Generate a weighted sample of indices as a function of the
# random seed and the current epoch.
rng = np.random.RandomState(
[
int(
hashlib.sha1(
str(self.__class__.__name__).encode("utf-8")
).hexdigest(),
16,
)
% (2**32),
self.seed % (2**32), # global seed
self._cur_epoch, # epoch index,
]
)
self._clean_if_not_none(
[self.cumulated_sizes, self.virtual_size_per_dataset, self._sizes]
)
self._sizes = None
indices, cumulated_sizes, virtual_size_per_dataset = self.get_virtual_indices(
rng, self.datasets, self.sample_ratios, self.virtual_size
)
self._cur_indices = indices
self.cumulated_sizes = cumulated_sizes
self.virtual_size_per_dataset = virtual_size_per_dataset
raw_sizes = [len(d) for d in self.datasets]
sampled_sizes = self.virtual_size_per_dataset
logger.info(
f"[{self.split}] Raw sizes: {str(dict(zip(self.keys, raw_sizes)))}; "
f"raw total size: {sum(raw_sizes)}"
)
logger.info(
f"[{self.split}] Resampled sizes: {str(dict(zip(self.keys, sampled_sizes)))}; "
f"resampled total size: {sum(sampled_sizes)}"
)
if self.sample_ratios is not None:
logger.info(
f"[{self.split}] Upsampling ratios: {str(dict(zip(self.keys, self.sample_ratios)))}"
)
else:
logger.info(f"[{self.split}] A concat dataset")
logger.info(
f"[{self.split}] virtual dataset established time: {get_time_gap(start_time, time.time())}"
)
def filter_indices_by_size(self, indices, max_sizes):
"""Filter a list of sample indices. Remove those that are longer
than specified in max_sizes.
Args:
indices (np.array): original array of sample indices
max_sizes (int or list[int] or tuple[int]): max sample size,
can be defined separately for src and tgt (then list or tuple)
Returns:
np.array: filtered sample array
list: list of removed indices
"""
sizes = self.sizes
tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
src_sizes = (
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
)
return data_utils.filter_paired_dataset_indices_by_size(
src_sizes, tgt_sizes, indices, max_sizes
)

View File

@@ -0,0 +1,199 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import hashlib
import logging
import math
import numpy as np
from fairseq.data import SampledMultiDataset
from .sampled_multi_dataset import CollateFormat, default_virtual_size_func
logger = logging.getLogger(__name__)
class SampledMultiEpochDataset(SampledMultiDataset):
"""Samples from multiple sub-datasets according to sampling ratios
using virtual epoch sizes to speed up dataloading.
Args:
datasets (
List[~torch.utils.data.Dataset]
or OrderedDict[str, ~torch.utils.data.Dataset]
): datasets
sampling_ratios (List[float]): list of probability of each dataset to be sampled
(default: None, which corresponds to concating all dataset together).
seed (int): RNG seed to use (default: 2).
epoch (int): starting epoch number (default: 1).
eval_key (str, optional): a key used at evaluation time that causes
this instance to pass-through batches from *datasets[eval_key]*.
collate_format (CollateFormat): collater output format, either CollateFormat.ordered_dict or
CollateFormat.single (default: CollateFormat.single) where CollateFormat.single configures
the collater to output batches of data mixed from all sub-datasets,
and CollateFormat.ordered_dict configures the collater to output a dictionary of batches indexed by keys
of sub-datasets.
Note that not all sub-datasets will present in a single batch in both formats.
virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func).
split (str): the split of the data, e.g. 'train', 'valid' or 'test'.
virtual_epoch_size (int): virtual epoch size, the dataset will go through the data by
this virtual epoch size one by one to speed up data loading, e.g. indicing and filtering
can be performed whenever a virtual epoch is loaded without waiting for the whole dataset to be loaded.
shared_collater (bool): whether or not to all sub-datasets have the same collater.
shard_epoch (int): the real epoch number for shard selection.
shuffle (bool): whether or not to shuffle data (default: True).
"""
def __init__(
self,
datasets,
sampling_ratios=None,
seed=2,
epoch=1,
eval_key=None,
collate_format=CollateFormat.single,
virtual_size=default_virtual_size_func,
split="",
virtual_epoch_size=None,
shared_collater=False,
shard_epoch=1,
shuffle=True,
):
self.virtual_epoch_size = virtual_epoch_size
self._current_epoch_start_index = None
self._random_global_indices = None
self.shard_epoch = shard_epoch if shard_epoch is not None else 1
self.load_next_shard = None
self._epoch_sizes = None
super().__init__(
datasets=datasets,
sampling_ratios=sampling_ratios,
seed=seed,
epoch=epoch,
eval_key=eval_key,
collate_format=collate_format,
virtual_size=virtual_size,
split=split,
shared_collater=shared_collater,
shuffle=shuffle,
)
def _setup(self, epoch):
self.virtual_epoch_size = (
self.virtual_epoch_size
if self.virtual_epoch_size is not None
else self.virtual_size
)
if self.virtual_epoch_size > self.virtual_size:
logger.warning(
f"virtual epoch size {self.virtual_epoch_size} "
f"is greater than virtual dataset size {self.virtual_size}"
)
self.virtual_epoch_size = self.virtual_size
self.num_virtual_epochs = math.ceil(self.virtual_size / self.virtual_epoch_size)
self._current_epoch_start_index = self._get_epoch_start_index(epoch)
logger.info(
f"virtual epoch size {self.virtual_epoch_size}; virtual dataset size {self.virtual_size}"
)
def _map_epoch_index_to_global(self, index):
index = self._current_epoch_start_index + index
# add randomness
return self._random_global_indices[index]
@property
def sizes(self):
if self._epoch_sizes is not None:
return self._epoch_sizes
_sizes = super().sizes
indices = self._random_global_indices[
self._current_epoch_start_index : self._current_epoch_start_index
+ len(self)
]
self._epoch_sizes = _sizes[indices]
# del super()._sizes to save memory
del self._sizes
self._sizes = None
return self._epoch_sizes
def _get_dataset_and_index(self, index):
i = self._map_epoch_index_to_global(index)
return super()._get_dataset_and_index(i)
def __len__(self):
return (
self.virtual_epoch_size
if self._current_epoch_start_index + self.virtual_epoch_size
< self.virtual_size
else self.virtual_size - self._current_epoch_start_index
)
def set_epoch(self, epoch):
if self._current_epoch_start_index is None:
# initializing epoch idnices of a virtual dataset
self._setup(epoch)
self._next_virtual_epoch(epoch)
else:
# working on already intialized epoch indices
if epoch == self._cur_epoch:
# re-enter so return
return
self._next_virtual_epoch(epoch)
def _get_epoch_start_index(self, epoch):
assert epoch >= 1 # fairseq is using 1-based epoch everywhere
return ((epoch - 1) % self.num_virtual_epochs) * self.virtual_epoch_size
def _next_global_indices(self, epoch):
rng = np.random.RandomState(
[
int(
hashlib.sha1(
str(self.__class__.__name__).encode("utf-8")
).hexdigest(),
16,
)
% (2**32),
self.seed % (2**32), # global seed
epoch, # epoch index,
]
)
del self._random_global_indices
self._random_global_indices = rng.choice(
self.virtual_size, self.virtual_size, replace=False
)
if self.load_next_shard is None:
self.load_next_shard = False
else:
# increase shard epoch for next loading
self.shard_epoch += 1
self.load_next_shard = True
logger.info(
"to load next epoch/shard in next load_dataset: "
f"epoch={epoch}/shard_epoch={self.shard_epoch}"
)
def _next_virtual_epoch(self, epoch):
index = self._get_epoch_start_index(epoch)
if index == 0 or self._random_global_indices is None:
# need to start from the beginning,
# so call super().set_epoch(epoch) to establish the global virtual indices
logger.info(
"establishing a new set of global virtual indices for "
f"epoch={epoch}/shard_epoch={self.shard_epoch}"
)
super().set_epoch(epoch)
self._next_global_indices(epoch)
else:
self._cur_epoch = epoch
# reset cache sizes and ordered_indices for the epoch after moving to a new epoch
self._clean_if_not_none(
[
self._epoch_sizes,
]
)
self._epoch_sizes = None
self._current_epoch_start_index = index

View File

@@ -0,0 +1,78 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import List
logger = logging.getLogger(__name__)
def uniform(dataset_sizes: List[int]):
return [1.0] * len(dataset_sizes)
def temperature_sampling(dataset_sizes, temp):
total_size = sum(dataset_sizes)
return [(size / total_size) ** (1.0 / temp) for size in dataset_sizes]
def make_temperature_sampling(temp=1.0):
def sampling_func(dataset_sizes):
return temperature_sampling(dataset_sizes, temp)
return sampling_func
def make_ratio_sampling(ratios):
def sampling_func(dataset_sizes):
return ratios
return sampling_func
class SamplingMethod:
@staticmethod
def add_arguments(parser):
parser.add_argument(
"--sampling-method",
choices=[
"uniform",
"temperature",
"concat",
"RoundRobin",
],
type=str,
default="concat",
help="The method to sample data per language pairs",
)
parser.add_argument(
"--sampling-temperature",
default=1.5,
type=float,
help="only work with --sampling-method temperature",
)
@staticmethod
def build_sampler(args, task):
return SamplingMethod(args, task)
def __init__(self, args, task):
self.args = args
self.task = task
def is_adaptive(self):
return False
def sampling_method_selector(self):
args = self.args
logger.info(f"selected sampler: {args.sampling_method}")
if args.sampling_method == "uniform":
return uniform
elif args.sampling_method == "temperature" or self.is_adaptive():
return make_temperature_sampling(float(args.sampling_temperature))
else:
# default to concating all data set together
return None

View File

@@ -0,0 +1,125 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import torch
from torch.utils.data.dataloader import default_collate
from . import FairseqDataset
def _flatten(dico, prefix=None):
"""Flatten a nested dictionary."""
new_dico = OrderedDict()
if isinstance(dico, dict):
prefix = prefix + "." if prefix is not None else ""
for k, v in dico.items():
if v is None:
continue
new_dico.update(_flatten(v, prefix + k))
elif isinstance(dico, list):
for i, v in enumerate(dico):
new_dico.update(_flatten(v, prefix + ".[" + str(i) + "]"))
else:
new_dico = OrderedDict({prefix: dico})
return new_dico
def _unflatten(dico):
"""Unflatten a flattened dictionary into a nested dictionary."""
new_dico = OrderedDict()
for full_k, v in dico.items():
full_k = full_k.split(".")
node = new_dico
for k in full_k[:-1]:
if k.startswith("[") and k.endswith("]"):
k = int(k[1:-1])
if k not in node:
node[k] = OrderedDict()
node = node[k]
node[full_k[-1]] = v
return new_dico
class NestedDictionaryDataset(FairseqDataset):
def __init__(self, defn, sizes=None):
super().__init__()
self.defn = _flatten(defn)
self.sizes = [sizes] if not isinstance(sizes, (list, tuple)) else sizes
first = None
for v in self.defn.values():
if not isinstance(
v,
(
FairseqDataset,
torch.utils.data.Dataset,
),
):
raise ValueError("Expected Dataset but found: {}".format(v.__class__))
first = first or v
if len(v) > 0:
assert len(v) == len(first), "dataset lengths must match"
self._len = len(first)
def __getitem__(self, index):
return OrderedDict((k, ds[index]) for k, ds in self.defn.items())
def __len__(self):
return self._len
def collater(self, samples):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
if len(samples) == 0:
return {}
sample = OrderedDict()
for k, ds in self.defn.items():
try:
sample[k] = ds.collater([s[k] for s in samples])
except NotImplementedError:
sample[k] = default_collate([s[k] for s in samples])
return _unflatten(sample)
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return max(s[index] for s in self.sizes)
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
if len(self.sizes) == 1:
return self.sizes[0][index]
else:
return (s[index] for s in self.sizes)
@property
def supports_prefetch(self):
"""Whether this dataset supports prefetching."""
return any(ds.supports_prefetch for ds in self.defn.values())
def prefetch(self, indices):
"""Prefetch the data required for this epoch."""
for ds in self.defn.values():
if getattr(ds, "supports_prefetch", False):
ds.prefetch(indices)
@property
def can_reuse_epoch_itr_across_epochs(self):
return all(ds.can_reuse_epoch_itr_across_epochs for ds in self.defn.values())
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.defn.values():
ds.set_epoch(epoch)

View File

@@ -0,0 +1,334 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from fairseq.data import data_utils
class WordNoising(object):
"""Generate a noisy version of a sentence, without changing words themselves."""
def __init__(self, dictionary, bpe_cont_marker="@@", bpe_end_marker=None):
self.dictionary = dictionary
self.bpe_end = None
if bpe_cont_marker:
self.bpe_end = np.array(
[
not self.dictionary[i].endswith(bpe_cont_marker)
for i in range(len(self.dictionary))
]
)
elif bpe_end_marker:
self.bpe_end = np.array(
[
self.dictionary[i].endswith(bpe_end_marker)
for i in range(len(self.dictionary))
]
)
self.get_word_idx = (
self._get_bpe_word_idx if self.bpe_end is not None else self._get_token_idx
)
def noising(self, x, lengths, noising_prob=0.0):
raise NotImplementedError()
def _get_bpe_word_idx(self, x):
"""
Given a list of BPE tokens, for every index in the tokens list,
return the index of the word grouping that it belongs to.
For example, for input x corresponding to ["how", "are", "y@@", "ou"],
return [[0], [1], [2], [2]].
"""
# x: (T x B)
bpe_end = self.bpe_end[x]
if x.size(0) == 1 and x.size(1) == 1:
# Special case when we only have one word in x. If x = [[N]],
# bpe_end is a scalar (bool) instead of a 2-dim array of bools,
# which makes the sum operation below fail.
return np.array([[0]])
# do a reduce front sum to generate word ids
word_idx = bpe_end[::-1].cumsum(0)[::-1]
word_idx = word_idx.max(0)[None, :] - word_idx
return word_idx
def _get_token_idx(self, x):
"""
This is to extend noising functions to be able to apply to non-bpe
tokens, e.g. word or characters.
"""
x = torch.t(x)
word_idx = np.array([range(len(x_i)) for x_i in x])
return np.transpose(word_idx)
class WordDropout(WordNoising):
"""Randomly drop input words. If not passing blank_idx (default is None),
then dropped words will be removed. Otherwise, it will be replaced by the
blank_idx."""
def __init__(
self,
dictionary,
default_dropout_prob=0.1,
bpe_cont_marker="@@",
bpe_end_marker=None,
):
super().__init__(dictionary, bpe_cont_marker, bpe_end_marker)
self.default_dropout_prob = default_dropout_prob
def noising(self, x, lengths, dropout_prob=None, blank_idx=None):
if dropout_prob is None:
dropout_prob = self.default_dropout_prob
# x: (T x B), lengths: B
if dropout_prob == 0:
return x, lengths
assert 0 < dropout_prob < 1
# be sure to drop entire words
word_idx = self.get_word_idx(x)
sentences = []
modified_lengths = []
for i in range(lengths.size(0)):
# Since dropout probabilities need to apply over non-pad tokens,
# it is not trivial to generate the keep mask without consider
# input lengths; otherwise, this could be done outside the loop
# We want to drop whole words based on word_idx grouping
num_words = max(word_idx[:, i]) + 1
# ith example: [x0, x1, ..., eos, pad, ..., pad]
# We should only generate keep probs for non-EOS tokens. Thus if the
# input sentence ends in EOS, the last word idx is not included in
# the dropout mask generation and we append True to always keep EOS.
# Otherwise, just generate the dropout mask for all word idx
# positions.
has_eos = x[lengths[i] - 1, i] == self.dictionary.eos()
if has_eos: # has eos?
keep = np.random.rand(num_words - 1) >= dropout_prob
keep = np.append(keep, [True]) # keep EOS symbol
else:
keep = np.random.rand(num_words) >= dropout_prob
words = x[: lengths[i], i].tolist()
# TODO: speed up the following loop
# drop words from the input according to keep
new_s = [
w if keep[word_idx[j, i]] else blank_idx for j, w in enumerate(words)
]
new_s = [w for w in new_s if w is not None]
# we need to have at least one word in the sentence (more than the
# start / end sentence symbols)
if len(new_s) <= 1:
# insert at beginning in case the only token left is EOS
# EOS should be at end of list.
new_s.insert(0, words[np.random.randint(0, len(words))])
assert len(new_s) >= 1 and (
not has_eos # Either don't have EOS at end or last token is EOS
or (len(new_s) >= 2 and new_s[-1] == self.dictionary.eos())
), "New sentence is invalid."
sentences.append(new_s)
modified_lengths.append(len(new_s))
# re-construct input
modified_lengths = torch.LongTensor(modified_lengths)
modified_x = torch.LongTensor(
modified_lengths.max(), modified_lengths.size(0)
).fill_(self.dictionary.pad())
for i in range(modified_lengths.size(0)):
modified_x[: modified_lengths[i], i].copy_(torch.LongTensor(sentences[i]))
return modified_x, modified_lengths
class WordShuffle(WordNoising):
"""Shuffle words by no more than k positions."""
def __init__(
self,
dictionary,
default_max_shuffle_distance=3,
bpe_cont_marker="@@",
bpe_end_marker=None,
):
super().__init__(dictionary, bpe_cont_marker, bpe_end_marker)
self.default_max_shuffle_distance = 3
def noising(self, x, lengths, max_shuffle_distance=None):
if max_shuffle_distance is None:
max_shuffle_distance = self.default_max_shuffle_distance
# x: (T x B), lengths: B
if max_shuffle_distance == 0:
return x, lengths
# max_shuffle_distance < 1 will return the same sequence
assert max_shuffle_distance > 1
# define noise word scores
noise = np.random.uniform(
0,
max_shuffle_distance,
size=(x.size(0), x.size(1)),
)
noise[0] = -1 # do not move start sentence symbol
# be sure to shuffle entire words
word_idx = self.get_word_idx(x)
x2 = x.clone()
for i in range(lengths.size(0)):
length_no_eos = lengths[i]
if x[lengths[i] - 1, i] == self.dictionary.eos():
length_no_eos = lengths[i] - 1
# generate a random permutation
scores = word_idx[:length_no_eos, i] + noise[word_idx[:length_no_eos, i], i]
# ensure no reordering inside a word
scores += 1e-6 * np.arange(length_no_eos.item())
permutation = scores.argsort()
# shuffle words
x2[:length_no_eos, i].copy_(
x2[:length_no_eos, i][torch.from_numpy(permutation)]
)
return x2, lengths
class UnsupervisedMTNoising(WordNoising):
"""
Implements the default configuration for noising in UnsupervisedMT
(github.com/facebookresearch/UnsupervisedMT)
"""
def __init__(
self,
dictionary,
max_word_shuffle_distance,
word_dropout_prob,
word_blanking_prob,
bpe_cont_marker="@@",
bpe_end_marker=None,
):
super().__init__(dictionary)
self.max_word_shuffle_distance = max_word_shuffle_distance
self.word_dropout_prob = word_dropout_prob
self.word_blanking_prob = word_blanking_prob
self.word_dropout = WordDropout(
dictionary=dictionary,
bpe_cont_marker=bpe_cont_marker,
bpe_end_marker=bpe_end_marker,
)
self.word_shuffle = WordShuffle(
dictionary=dictionary,
bpe_cont_marker=bpe_cont_marker,
bpe_end_marker=bpe_end_marker,
)
def noising(self, x, lengths):
# 1. Word Shuffle
noisy_src_tokens, noisy_src_lengths = self.word_shuffle.noising(
x=x,
lengths=lengths,
max_shuffle_distance=self.max_word_shuffle_distance,
)
# 2. Word Dropout
noisy_src_tokens, noisy_src_lengths = self.word_dropout.noising(
x=noisy_src_tokens,
lengths=noisy_src_lengths,
dropout_prob=self.word_dropout_prob,
)
# 3. Word Blanking
noisy_src_tokens, noisy_src_lengths = self.word_dropout.noising(
x=noisy_src_tokens,
lengths=noisy_src_lengths,
dropout_prob=self.word_blanking_prob,
blank_idx=self.dictionary.unk(),
)
return noisy_src_tokens
class NoisingDataset(torch.utils.data.Dataset):
def __init__(
self,
src_dataset,
src_dict,
seed,
noiser=None,
noising_class=UnsupervisedMTNoising,
**kwargs
):
"""
Wrap a :class:`~torch.utils.data.Dataset` and apply noise to the
samples based on the supplied noising configuration.
Args:
src_dataset (~torch.utils.data.Dataset): dataset to wrap.
to build self.src_dataset --
a LanguagePairDataset with src dataset as the source dataset and
None as the target dataset. Should NOT have padding so that
src_lengths are accurately calculated by language_pair_dataset
collate function.
We use language_pair_dataset here to encapsulate the tgt_dataset
so we can re-use the LanguagePairDataset collater to format the
batches in the structure that SequenceGenerator expects.
src_dict (~fairseq.data.Dictionary): source dictionary
seed (int): seed to use when generating random noise
noiser (WordNoising): a pre-initialized :class:`WordNoising`
instance. If this is None, a new instance will be created using
*noising_class* and *kwargs*.
noising_class (class, optional): class to use to initialize a
default :class:`WordNoising` instance.
kwargs (dict, optional): arguments to initialize the default
:class:`WordNoising` instance given by *noiser*.
"""
self.src_dataset = src_dataset
self.src_dict = src_dict
self.seed = seed
self.noiser = (
noiser
if noiser is not None
else noising_class(
dictionary=src_dict,
**kwargs,
)
)
self.sizes = src_dataset.sizes
def __getitem__(self, index):
"""
Returns a single noisy sample. Multiple samples are fed to the collater
create a noising dataset batch.
"""
src_tokens = self.src_dataset[index]
src_lengths = torch.LongTensor([len(src_tokens)])
src_tokens = src_tokens.unsqueeze(0)
# Transpose src tokens to fit expected shape of x in noising function
# (batch size, sequence length) -> (sequence length, batch size)
src_tokens_t = torch.t(src_tokens)
with data_utils.numpy_seed(self.seed + index):
noisy_src_tokens = self.noiser.noising(src_tokens_t, src_lengths)
# Transpose back to expected src_tokens format
# (sequence length, 1) -> (1, sequence length)
noisy_src_tokens = torch.t(noisy_src_tokens)
return noisy_src_tokens[0]
def __len__(self):
"""
The length of the noising dataset is the length of src.
"""
return len(self.src_dataset)
@property
def supports_prefetch(self):
return self.src_dataset.supports_prefetch
def prefetch(self, indices):
if self.src_dataset.supports_prefetch:
self.src_dataset.prefetch(indices)

View File

@@ -0,0 +1,17 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import FairseqDataset
class NumSamplesDataset(FairseqDataset):
def __getitem__(self, index):
return 1
def __len__(self):
return 0
def collater(self, samples):
return sum(samples)

View File

@@ -0,0 +1,31 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from . import BaseWrapperDataset
class NumelDataset(BaseWrapperDataset):
def __init__(self, dataset, reduce=False):
super().__init__(dataset)
self.reduce = reduce
def __getitem__(self, index):
item = self.dataset[index]
if torch.is_tensor(item):
return torch.numel(item)
else:
return np.size(item)
def __len__(self):
return len(self.dataset)
def collater(self, samples):
if self.reduce:
return sum(samples)
else:
return torch.tensor(samples)

View File

@@ -0,0 +1,15 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import BaseWrapperDataset
class OffsetTokensDataset(BaseWrapperDataset):
def __init__(self, dataset, offset):
super().__init__(dataset)
self.offset = offset
def __getitem__(self, idx):
return self.dataset[idx] + self.offset

View File

@@ -0,0 +1,31 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.data import data_utils
from . import BaseWrapperDataset
class PadDataset(BaseWrapperDataset):
def __init__(self, dataset, pad_idx, left_pad, pad_length=None):
super().__init__(dataset)
self.pad_idx = pad_idx
self.left_pad = left_pad
self.pad_length = pad_length
def collater(self, samples):
return data_utils.collate_tokens(
samples, self.pad_idx, left_pad=self.left_pad, pad_to_length=self.pad_length
)
class LeftPadDataset(PadDataset):
def __init__(self, dataset, pad_idx):
super().__init__(dataset, pad_idx, left_pad=True)
class RightPadDataset(PadDataset):
def __init__(self, dataset, pad_idx):
super().__init__(dataset, pad_idx, left_pad=False)

View File

@@ -0,0 +1,197 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import hashlib
import json
import subprocess
import tempfile
from typing import Hashable
try:
import pyarrow.plasma as plasma
PYARROW_AVAILABLE = True
except ImportError:
plasma = None
PYARROW_AVAILABLE = False
class PlasmaArray:
"""
Wrapper around numpy arrays that automatically moves the data to shared
memory upon serialization. This is particularly helpful when passing numpy
arrays through multiprocessing, so that data is not unnecessarily
duplicated or pickled.
"""
def __init__(self, array):
super().__init__()
self.array = array
self.disable = array.nbytes < 134217728 # disable for arrays <128MB
self.object_id = None
self.path = None
# variables with underscores shouldn't be pickled
self._client = None
self._server = None
self._server_tmp = None
self._plasma = None
@property
def plasma(self):
if self._plasma is None and not self.disable:
self._plasma = plasma
return self._plasma
def start_server(self):
if self.plasma is None or self._server is not None:
return
assert self.object_id is None
assert self.path is None
self._server_tmp = tempfile.NamedTemporaryFile()
self.path = self._server_tmp.name
self._server = subprocess.Popen(
["plasma_store", "-m", str(int(1.05 * self.array.nbytes)), "-s", self.path]
)
@property
def client(self):
if self._client is None:
assert self.path is not None
self._client = self.plasma.connect(self.path, num_retries=200)
return self._client
def __getstate__(self):
"""Called on pickle load"""
if self.plasma is None:
return self.__dict__
if self.object_id is None:
self.start_server()
self.object_id = self.client.put(self.array)
state = self.__dict__.copy()
del state["array"]
state["_client"] = None
state["_server"] = None
state["_server_tmp"] = None
state["_plasma"] = None
return state
def __setstate__(self, state):
"""Called on pickle save"""
self.__dict__.update(state)
if self.plasma is None:
return
self.array = self.client.get(self.object_id)
def __del__(self):
if self._server is not None:
self._server.kill()
self._server = None
self._server_tmp.close()
self._server_tmp = None
DEFAULT_PLASMA_PATH = "/tmp/plasma"
class PlasmaView:
"""Interface to write and read from shared memory. Whereas PlasmaArray writes to plasma on serialization,
PlasmaView writes to shared memory on instantiation."""
def __init__(self, array, split_path: str, hash_data: Hashable, plasma_path=None):
"""
Args:
array: numpy array to store. This can be read with ``PlasmaView().array``
split_path: the path whence the data was read, used for hashing
hash_data: other metadata about the array that can be used to create a unique key.
as of writing, the 3 callers in ``TokenBlockDataset`` use::
hash_data = ((block_size, document_sep_len, str(break_mode), len(dataset)), 0|1|2)
"""
assert PYARROW_AVAILABLE
assert split_path is not None
if plasma_path is None:
plasma_path = DEFAULT_PLASMA_PATH
self.path = plasma_path
self.split_path = split_path
self._client = None # Initialize lazily for pickle. plasma clients should not be deep copied or serialized.
self._n = None
self.object_id = self.get_object_id(self.split_path, hash_data)
try:
self.client.put(array, object_id=self.object_id)
except plasma.PlasmaObjectExists:
pass
@property
def client(self):
if self._client is None:
self._client = plasma.connect(self.path, num_retries=200)
return self._client
@property
def array(self):
"""Fetch a read only view of an np.array, stored in plasma."""
ret = self.client.get(self.object_id)
return ret
@staticmethod
def get_object_id(split_path: str, hash_data: Hashable):
"""Returns plasma.ObjectID from hashing split_path and object_num."""
hash = hashlib.blake2b(bytes(split_path, "utf-8"), digest_size=20)
harg = json.dumps(hash_data).encode("utf-8")
hash.update(harg)
return plasma.ObjectID(hash.digest())
def __getstate__(self):
"""Called on pickle save"""
self.disconnect()
state = self.__dict__.copy()
assert state["_client"] is None
assert "object_id" in state
return state
def __setstate__(self, state):
"""Called on pickle load"""
self.__dict__.update(state)
def __del__(self):
self.disconnect()
def disconnect(self):
if self._client is not None:
self._client.disconnect()
self._client = None
def __len__(self):
"""Save reads by caching len"""
if self._n is None:
self._n = len(self.array)
return self._n
GB100 = (1024**3) * 100
class PlasmaStore:
def __init__(self, path=DEFAULT_PLASMA_PATH, nbytes: int = GB100):
self.server = self.start(path, nbytes)
def __del__(self):
self.server.kill()
@staticmethod
def start(path=DEFAULT_PLASMA_PATH, nbytes: int = GB100) -> subprocess.Popen:
if not PYARROW_AVAILABLE:
raise ImportError("please run pip install pyarrow to use --use_plasma_view")
# best practice is to allocate more space than we need. The limitation seems to be the size of /dev/shm
_server = subprocess.Popen(["plasma_store", "-m", str(nbytes), "-s", path])
plasma.connect(path, num_retries=200) # If we can't connect we fail immediately
return _server

View File

@@ -0,0 +1,28 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from . import BaseWrapperDataset
class PrependDataset(BaseWrapperDataset):
def __init__(self, dataset, prepend_getter, ensure_first_token_is=None):
super().__init__(dataset)
self.prepend_getter = prepend_getter
self.ensure_first_token = ensure_first_token_is
def __getitem__(self, idx):
item = self.dataset[idx]
is_tuple = isinstance(item, tuple)
src = item[0] if is_tuple else item
assert self.ensure_first_token is None or src[0] == self.ensure_first_token
prepend_idx = self.prepend_getter(self.dataset, idx)
assert isinstance(prepend_idx, int)
src[0] = prepend_idx
item = tuple((src,) + item[1:]) if is_tuple else src
return item

View File

@@ -0,0 +1,41 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from . import BaseWrapperDataset
class PrependTokenDataset(BaseWrapperDataset):
def __init__(self, dataset, token=None):
super().__init__(dataset)
self.token = token
if token is not None:
self._sizes = np.array(dataset.sizes) + 1
else:
self._sizes = dataset.sizes
def __getitem__(self, idx):
item = self.dataset[idx]
if self.token is not None:
item = torch.cat([item.new([self.token]), item])
return item
@property
def sizes(self):
return self._sizes
def num_tokens(self, index):
n = self.dataset.num_tokens(index)
if self.token is not None:
n += 1
return n
def size(self, index):
n = self.dataset.size(index)
if self.token is not None:
n += 1
return n

View File

@@ -0,0 +1,23 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import FairseqDataset
class RawLabelDataset(FairseqDataset):
def __init__(self, labels):
super().__init__()
self.labels = labels
def __getitem__(self, index):
return self.labels[index]
def __len__(self):
return len(self.labels)
def collater(self, samples):
return torch.tensor(samples)

View File

@@ -0,0 +1,36 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import BaseWrapperDataset
class ReplaceDataset(BaseWrapperDataset):
"""Replaces tokens found in the dataset by a specified replacement token
Args:
dataset (~torch.utils.data.Dataset): dataset to replace tokens in
replace_map(Dictionary[int,int]): map of token to replace -> replacement token
offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be
as many as the number of objects returned by the underlying dataset __getitem__ method.
"""
def __init__(self, dataset, replace_map, offsets):
super().__init__(dataset)
assert len(replace_map) > 0
self.replace_map = replace_map
self.offsets = offsets
def __getitem__(self, index):
item = self.dataset[index]
is_tuple = isinstance(item, tuple)
srcs = item if is_tuple else [item]
for offset, src in zip(self.offsets, srcs):
for k, v in self.replace_map.items():
src_off = src[offset:] if offset >= 0 else src[:offset]
src_off.masked_fill_(src_off == k, v)
item = srcs if is_tuple else srcs[0]
return item

View File

@@ -0,0 +1,139 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
from fairseq.data import BaseWrapperDataset, plasma_utils
logger = logging.getLogger(__name__)
class ResamplingDataset(BaseWrapperDataset):
"""Randomly samples from a given dataset at each epoch.
Sampling is done with or without replacement, depending on the "replace"
parameter.
Optionally, the epoch size can be rescaled. This is potentially desirable
to increase per-epoch coverage of the base dataset (since sampling with
replacement means that many items in the dataset will be left out). In the
case of sampling without replacement, size_ratio should be strictly less
than 1.
Args:
dataset (~torch.utils.data.Dataset): dataset on which to sample.
weights (List[float]): list of probability weights
(default: None, which corresponds to uniform sampling).
replace (bool): sampling mode; True for "with replacement", or False
for "without replacement" (default: True)
size_ratio (float): the ratio to subsample to; must be positive
(default: 1.0).
batch_by_size (bool): whether or not to batch by sequence length
(default: True).
seed (int): RNG seed to use (default: 0).
epoch (int): starting epoch number (default: 1).
"""
def __init__(
self,
dataset,
weights=None,
replace=True,
size_ratio=1.0,
batch_by_size=True,
seed=0,
epoch=1,
):
super().__init__(dataset)
if weights is None:
self.weights = None
else:
assert len(weights) == len(dataset)
weights_arr = np.array(weights, dtype=np.float64)
weights_arr /= weights_arr.sum()
self.weights = plasma_utils.PlasmaArray(weights_arr)
self.replace = replace
assert size_ratio > 0.0
if not self.replace:
assert size_ratio < 1.0
self.size_ratio = float(size_ratio)
self.actual_size = np.ceil(len(dataset) * self.size_ratio).astype(int)
self.batch_by_size = batch_by_size
self.seed = seed
self._cur_epoch = None
self._cur_indices = None
self.set_epoch(epoch)
def __getitem__(self, index):
return self.dataset[self._cur_indices.array[index]]
def __len__(self):
return self.actual_size
@property
def sizes(self):
if isinstance(self.dataset.sizes, list):
return [s[self._cur_indices.array] for s in self.dataset.sizes]
return self.dataset.sizes[self._cur_indices.array]
def num_tokens(self, index):
return self.dataset.num_tokens(self._cur_indices.array[index])
def size(self, index):
return self.dataset.size(self._cur_indices.array[index])
def ordered_indices(self):
if self.batch_by_size:
order = [
np.arange(len(self)),
self.sizes,
] # No need to handle `self.shuffle == True`
return np.lexsort(order)
else:
return np.arange(len(self))
def prefetch(self, indices):
self.dataset.prefetch(self._cur_indices.array[indices])
@property
def can_reuse_epoch_itr_across_epochs(self):
return False
def set_epoch(self, epoch):
logger.debug("ResamplingDataset.set_epoch: {}".format(epoch))
super().set_epoch(epoch)
if epoch == self._cur_epoch:
return
self._cur_epoch = epoch
# Generate a weighted sample of indices as a function of the
# random seed and the current epoch.
rng = np.random.RandomState(
[
42, # magic number
self.seed % (2**32), # global seed
self._cur_epoch, # epoch index
]
)
self._cur_indices = plasma_utils.PlasmaArray(
rng.choice(
len(self.dataset),
self.actual_size,
replace=self.replace,
p=(None if self.weights is None else self.weights.array),
)
)

View File

@@ -0,0 +1,18 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import BaseWrapperDataset
class RollDataset(BaseWrapperDataset):
def __init__(self, dataset, shifts):
super().__init__(dataset)
self.shifts = shifts
def __getitem__(self, index):
item = self.dataset[index]
return torch.roll(item, self.shifts)

View File

@@ -0,0 +1,160 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from collections import OrderedDict
from typing import Dict, Sequence
import numpy as np
from . import FairseqDataset, LanguagePairDataset
logger = logging.getLogger(__name__)
class RoundRobinZipDatasets(FairseqDataset):
"""Zip multiple :class:`~fairseq.data.FairseqDataset` instances together.
Shorter datasets are repeated in a round-robin fashion to match the length
of the longest one.
Args:
datasets (Dict[~fairseq.data.FairseqDataset]): a dictionary of
:class:`~fairseq.data.FairseqDataset` instances.
eval_key (str, optional): a key used at evaluation time that causes
this instance to pass-through batches from *datasets[eval_key]*.
"""
def __init__(self, datasets, eval_key=None):
super().__init__()
if isinstance(datasets, dict):
datasets = OrderedDict(datasets)
assert isinstance(datasets, OrderedDict)
assert datasets, "Can't make a RoundRobinZipDatasets out of nothing"
for dataset in datasets.values():
assert isinstance(dataset, FairseqDataset)
self.datasets = datasets
self.eval_key = eval_key
self.longest_dataset_key = max(datasets, key=lambda k: len(datasets[k]))
self.longest_dataset = datasets[self.longest_dataset_key]
self._ordered_indices: Dict[str, Sequence[int]] = None
def _map_index(self, key, index):
assert (
self._ordered_indices is not None
), "Must call RoundRobinZipDatasets.ordered_indices() first"
o = self._ordered_indices[key]
return o[index % len(o)]
def __getitem__(self, index):
if self.eval_key is None:
return OrderedDict(
[
(key, dataset[self._map_index(key, index)])
for key, dataset in self.datasets.items()
]
)
else:
# at evaluation time it's useful to pass-through batches from a single key
return self.datasets[self.eval_key][self._map_index(self.eval_key, index)]
def __len__(self):
if self._ordered_indices is not None:
return len(self._ordered_indices[self.longest_dataset_key])
return len(self.longest_dataset)
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
if len(samples) == 0:
return None
if self.eval_key is None:
return OrderedDict(
[
(key, dataset.collater([sample[key] for sample in samples]))
for key, dataset in self.datasets.items()
]
)
else:
# at evaluation time it's useful to pass-through batches from a single key
return self.datasets[self.eval_key].collater(samples)
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
# TODO make it configurable whether to use max() or sum() here
return max(
dataset.num_tokens(self._map_index(key, index))
for key, dataset in self.datasets.items()
)
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return {
key: dataset.size(self._map_index(key, index))
for key, dataset in self.datasets.items()
}
def ordered_indices(self):
"""Ordered indices for batching."""
if self._ordered_indices is None:
# Call the underlying dataset's ordered_indices() here, so that we
# get the same random ordering as we would have from using the
# underlying sub-datasets directly.
self._ordered_indices = OrderedDict(
[
(key, dataset.ordered_indices())
for key, dataset in self.datasets.items()
]
)
return np.arange(len(self))
def filter_indices_by_size(self, indices, max_positions=None):
"""
Filter each sub-dataset independently, then update the round robin to work
on the filtered sub-datasets.
"""
def _deep_until_language_pair(dataset):
if isinstance(dataset, LanguagePairDataset):
return dataset
if hasattr(dataset, "tgt_dataset"):
return _deep_until_language_pair(dataset.tgt_dataset)
if hasattr(dataset, "dataset"):
return _deep_until_language_pair(dataset.dataset)
raise Exception(f"Don't know how to unwrap this dataset: {dataset}")
if not isinstance(max_positions, dict):
max_positions = {k: max_positions for k in self.datasets.keys()}
ignored_some = False
for key, dataset in self.datasets.items():
dataset = _deep_until_language_pair(dataset)
self._ordered_indices[key], ignored = dataset.filter_indices_by_size(
self._ordered_indices[key], max_positions[key]
)
if len(ignored) > 0:
ignored_some = True
logger.warning(
f"{len(ignored)} samples from {key} have invalid sizes and will be skipped, "
f"max_positions={max_positions[key]}, first few sample ids={ignored[:10]}"
)
# Since we are modifying in place the _ordered_indices,
# it's not possible anymore to return valid ignored indices.
# Hopefully the extra debug information print above should be enough to debug.
# Ideally we would receive ignore_invalid_inputs so that we could have
# a proper error message.
return (np.arange(len(self)), [0] if ignored_some else [])
@property
def supports_prefetch(self):
return all(
getattr(dataset, "supports_prefetch", False)
for dataset in self.datasets.values()
)
def prefetch(self, indices):
for key, dataset in self.datasets.items():
dataset.prefetch([self._map_index(key, index) for index in indices])

View File

@@ -0,0 +1,78 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
from fairseq.data import data_utils
from . import BaseWrapperDataset
class TruncateDataset(BaseWrapperDataset):
"""Truncate a sequence by returning the first truncation_length tokens"""
def __init__(self, dataset, truncation_length):
super().__init__(dataset)
assert truncation_length is not None
self.truncation_length = truncation_length
self.dataset = dataset
def __getitem__(self, index):
item = self.dataset[index]
item_len = item.size(0)
if item_len > self.truncation_length:
item = item[: self.truncation_length]
return item
@property
def sizes(self):
return np.minimum(self.dataset.sizes, self.truncation_length)
def __len__(self):
return len(self.dataset)
class RandomCropDataset(TruncateDataset):
"""Truncate a sequence by returning a random crop of truncation_length tokens"""
def __init__(self, dataset, truncation_length, seed=1):
super().__init__(dataset, truncation_length)
self.seed = seed
self.epoch = 0
@property
def can_reuse_epoch_itr_across_epochs(self):
return True # only the crop changes, not item sizes
def set_epoch(self, epoch, **unused):
super().set_epoch(epoch)
self.epoch = epoch
def __getitem__(self, index):
with data_utils.numpy_seed(self.seed, self.epoch, index):
item = self.dataset[index]
item_len = item.size(0)
excess = item_len - self.truncation_length
if excess > 0:
start_idx = np.random.randint(0, excess)
item = item[start_idx : start_idx + self.truncation_length]
return item
def maybe_shorten_dataset(
dataset,
split,
shorten_data_split_list,
shorten_method,
tokens_per_sample,
seed,
):
truncate_split = (
split in shorten_data_split_list.split(",") or len(shorten_data_split_list) == 0
)
if shorten_method == "truncate" and truncate_split:
dataset = TruncateDataset(dataset, tokens_per_sample)
elif shorten_method == "random_crop" and truncate_split:
dataset = RandomCropDataset(dataset, tokens_per_sample, seed)
return dataset

View File

@@ -0,0 +1,21 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
from . import BaseWrapperDataset
class SortDataset(BaseWrapperDataset):
def __init__(self, dataset, sort_order):
super().__init__(dataset)
if not isinstance(sort_order, (list, tuple)):
sort_order = [sort_order]
self.sort_order = sort_order
assert all(len(so) == len(dataset) for so in sort_order)
def ordered_indices(self):
return np.lexsort(self.sort_order)

View File

@@ -0,0 +1,293 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from . import Dictionary, FairseqDataset, data_utils
def collate(
samples,
pad_idx,
eos_idx,
vocab,
left_pad_source=False,
left_pad_target=False,
input_feeding=True,
pad_to_length=None,
):
assert input_feeding
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
left_pad=left_pad,
move_eos_to_beginning=move_eos_to_beginning,
pad_to_length=pad_to_length,
)
id = torch.LongTensor([s["id"] for s in samples])
src_tokens = merge(
"source",
left_pad=left_pad_source,
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
)
# sort by descending source length
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge(
"target",
left_pad=left_pad_target,
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
target = target.index_select(0, sort_order)
ntokens = sum(len(s["target"]) for s in samples)
if input_feeding:
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
"target",
left_pad=left_pad_target,
move_eos_to_beginning=True,
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
else:
ntokens = sum(len(s["source"]) for s in samples)
batch = {
"id": id,
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
},
"target": target,
"target_lengths": torch.LongTensor([len(t) for t in target]),
"nsentences": samples[0]["source"].size(0),
"sort_order": sort_order,
}
if prev_output_tokens is not None:
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
return batch
class SpanMaskedTokensDataset(FairseqDataset):
"""
A wrapper around TokenBlockDataset for T5 dataset.
Args:
dataset (~torch.utils.data.Dataset): dataset to wrap
vocab (~fairseq.data.Dictionary): vocabulary
noise_density (float): fraction of the tokens to select as noise.
mean_noise_span_length (float): mean noise span length.
shuffle (bool, optional): shuffle the elements before batching.
Default: ``True``
seed: Seed for random number generator for reproducibility.
"""
def __init__(
self,
dataset: torch.utils.data.Dataset,
vocab: Dictionary,
noise_density: float,
mean_noise_span_length: float,
shuffle: bool,
seed: int = 1,
):
self.dataset = dataset
self.vocab = vocab
self.seed = seed
self.noise_density = noise_density
self.mean_noise_span_length = mean_noise_span_length
self.shuffle = shuffle
self.epoch = 0
@property
def can_reuse_epoch_itr_across_epochs(self):
return True # only the noise changes, not item sizes
def set_epoch(self, epoch, **unused):
self.epoch = epoch
def __getitem__(self, index):
with data_utils.numpy_seed(self.seed, self.epoch, index):
item = self.dataset[index]
assert item[-1] == self.vocab.eos()
noise_mask = self.random_spans_noise_mask(len(item))
source_sentinel_ids = self.create_sentinel_ids(noise_mask.astype(np.int8))
source = self.filter_input_ids(item, source_sentinel_ids)
target_sentinel_ids = self.create_sentinel_ids(
(~noise_mask).astype(np.int8)
)
target = self.filter_input_ids(item, target_sentinel_ids)
return {
"id": index,
"source": torch.from_numpy(source),
"target": torch.from_numpy(target),
}
def random_spans_noise_mask(self, length):
"""
This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens.
The number of noise tokens and the number of noise spans and non-noise spans
are determined deterministically as follows:
num_noise_tokens = round(length * noise_density)
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
Spans alternate between non-noise and noise, beginning with non-noise.
Subject to the above restrictions, all masks are equally likely.
Args:
length: an int32 scalar (length of the incoming token sequence)
Returns:
a boolean tensor with shape [length]
"""
orig_length = length
num_noise_tokens = int(np.round(length * self.noise_density))
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
# avoid degeneracy by ensuring positive number of noise spans
num_noise_spans = max(num_noise_spans, 1)
num_nonnoise_tokens = length - num_noise_tokens
# pick the lengths of the noise spans and the non-noise spans
def _random_segmentation(num_items, num_segments):
"""
Partition a sequence of items randomly into non-empty segments.
Args:
num_items: an integer scalar > 0
num_segments: an integer scalar in [1, num_items]
Returns:
a Tensor with shape [num_segments] containing positive integers that add up to num_items
"""
mask_indices = np.arange(num_items - 1) < (num_segments - 1)
np.random.shuffle(mask_indices)
first_in_segment = np.pad(mask_indices, [[1, 0]])
segment_id = np.cumsum(first_in_segment)
# count length of subsegments assuming that list is sorted
_, segment_length = np.unique(segment_id, return_counts=True)
return segment_length
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
nonnoise_span_lengths = _random_segmentation(
num_nonnoise_tokens, num_noise_spans
)
interleaved_span_lengths = np.reshape(
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
[num_noise_spans * 2],
)
span_starts = np.cumsum(interleaved_span_lengths)[:-1]
span_start_indicator = np.zeros((length,), dtype=np.int8)
span_start_indicator[span_starts] = True
span_num = np.cumsum(span_start_indicator)
is_noise = np.equal(span_num % 2, 1)
return is_noise[:orig_length]
def create_sentinel_ids(self, mask_indices):
"""
Sentinel ids creation given the indices that should be masked.
The start indices of each mask are replaced by the sentinel ids in increasing
order. Consecutive mask indices to be deleted are replaced with `-1`.
"""
start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
sentinel_ids = np.where(
start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices
)
# making sure all sentinel tokens are unique over the example
sentinel_ids = np.where(sentinel_ids != 0, len(self.vocab) - sentinel_ids, 0)
sentinel_ids -= mask_indices - start_indices
return sentinel_ids
@staticmethod
def filter_input_ids(input_ids, sentinel_ids):
"""
Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
"""
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
# masked tokens coming after sentinel tokens and should be removed
return input_ids_full[input_ids_full >= 0]
def __len__(self):
return len(self.dataset)
def collater(self, samples, pad_to_length=None):
"""
Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch of data
"""
return collate(
samples,
self.vocab.pad(),
self.vocab.eos(),
self.vocab,
pad_to_length=pad_to_length,
)
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return self.dataset.sizes[index]
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return self.dataset.sizes[index]
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
indices = np.random.permutation(len(self))
else:
indices = np.arange(len(self))
return indices[np.argsort(self.dataset.sizes[indices], kind="mergesort")]
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
@property
def supports_prefetch(self):
return (
hasattr(self.src, "supports_prefetch")
and self.src.supports_prefetch
and hasattr(self.tgt, "supports_prefetch")
and self.tgt.supports_prefetch
)

Some files were not shown because too many files have changed in this diff Show More