Add monkey patched fairseq package to run on python 3.11 (what is needed for our use of RVC at least)

This commit is contained in:
Tony Ribeiro
2023-08-10 02:58:52 +02:00
parent 28024c5649
commit 60a8e5c9c6
465 changed files with 95671 additions and 0 deletions

View File

@@ -0,0 +1,93 @@
from abc import ABC, abstractmethod
from typing import Dict, Optional
import importlib
import os
import numpy as np
class AudioTransform(ABC):
@classmethod
@abstractmethod
def from_config_dict(cls, config: Optional[Dict] = None):
pass
class CompositeAudioTransform(AudioTransform):
def _from_config_dict(
cls,
transform_type,
get_audio_transform,
composite_cls,
config=None,
return_empty=False,
):
_config = {} if config is None else config
_transforms = _config.get(f"{transform_type}_transforms")
if _transforms is None:
if return_empty:
_transforms = []
else:
return None
transforms = [
get_audio_transform(_t).from_config_dict(_config.get(_t))
for _t in _transforms
]
return composite_cls(transforms)
def __init__(self, transforms):
self.transforms = [t for t in transforms if t is not None]
def __call__(self, x):
for t in self.transforms:
x = t(x)
return x
def __repr__(self):
format_string = (
[self.__class__.__name__ + "("]
+ [f" {t.__repr__()}" for t in self.transforms]
+ [")"]
)
return "\n".join(format_string)
def register_audio_transform(name, cls_type, registry, class_names):
def register_audio_transform_cls(cls):
if name in registry:
raise ValueError(f"Cannot register duplicate transform ({name})")
if not issubclass(cls, cls_type):
raise ValueError(
f"Transform ({name}: {cls.__name__}) must extend "
f"{cls_type.__name__}"
)
if cls.__name__ in class_names:
raise ValueError(
f"Cannot register audio transform with duplicate "
f"class name ({cls.__name__})"
)
registry[name] = cls
class_names.add(cls.__name__)
return cls
return register_audio_transform_cls
def import_transforms(transforms_dir, transform_type):
for file in os.listdir(transforms_dir):
path = os.path.join(transforms_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
name = file[: file.find(".py")] if file.endswith(".py") else file
importlib.import_module(
f"fairseq.data.audio.{transform_type}_transforms." + name
)
# Utility fn for uniform numbers in transforms
def rand_uniform(a, b):
return np.random.uniform() * (b - a) + a

View File

@@ -0,0 +1,389 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import mmap
from pathlib import Path
import io
from typing import BinaryIO, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"}
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
def convert_waveform(
waveform: Union[np.ndarray, torch.Tensor],
sample_rate: int,
normalize_volume: bool = False,
to_mono: bool = False,
to_sample_rate: Optional[int] = None,
) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
"""convert a waveform:
- to a target sample rate
- from multi-channel to mono channel
- volume normalization
Args:
waveform (numpy.ndarray or torch.Tensor): 2D original waveform
(channels x length)
sample_rate (int): original sample rate
normalize_volume (bool): perform volume normalization
to_mono (bool): convert to mono channel if having multiple channels
to_sample_rate (Optional[int]): target sample rate
Returns:
waveform (numpy.ndarray): converted 2D waveform (channels x length)
sample_rate (float): target sample rate
"""
try:
import torchaudio.sox_effects as ta_sox
except ImportError:
raise ImportError("Please install torchaudio: pip install torchaudio")
effects = []
if normalize_volume:
effects.append(["gain", "-n"])
if to_sample_rate is not None and to_sample_rate != sample_rate:
effects.append(["rate", f"{to_sample_rate}"])
if to_mono and waveform.shape[0] > 1:
effects.append(["channels", "1"])
if len(effects) > 0:
is_np_input = isinstance(waveform, np.ndarray)
_waveform = torch.from_numpy(waveform) if is_np_input else waveform
converted, converted_sample_rate = ta_sox.apply_effects_tensor(
_waveform, sample_rate, effects
)
if is_np_input:
converted = converted.numpy()
return converted, converted_sample_rate
return waveform, sample_rate
def get_waveform(
path_or_fp: Union[str, BinaryIO],
normalization: bool = True,
mono: bool = True,
frames: int = -1,
start: int = 0,
always_2d: bool = True,
output_sample_rate: Optional[int] = None,
normalize_volume: bool = False,
waveform_transforms: Optional[CompositeAudioWaveformTransform] = None,
) -> Tuple[np.ndarray, int]:
"""Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio.
Args:
path_or_fp (str or BinaryIO): the path or file-like object
normalization (bool): normalize values to [-1, 1] (Default: True)
mono (bool): convert multi-channel audio to mono-channel one
frames (int): the number of frames to read. (-1 for reading all)
start (int): Where to start reading. A negative value counts from the end.
always_2d (bool): always return 2D array even for mono-channel audios
output_sample_rate (Optional[int]): output sample rate
normalize_volume (bool): normalize volume
Returns:
waveform (numpy.ndarray): 1D or 2D waveform (channels x length)
sample_rate (float): sample rate
"""
if isinstance(path_or_fp, str):
ext = Path(path_or_fp).suffix
if ext not in SF_AUDIO_FILE_EXTENSIONS:
raise ValueError(f"Unsupported audio format: {ext}")
try:
import soundfile as sf
except ImportError:
raise ImportError("Please install soundfile: pip install soundfile")
waveform, sample_rate = sf.read(
path_or_fp, dtype="float32", always_2d=True, frames=frames, start=start
)
waveform = waveform.T # T x C -> C x T
waveform, sample_rate = convert_waveform(
waveform,
sample_rate,
normalize_volume=normalize_volume,
to_mono=mono,
to_sample_rate=output_sample_rate,
)
if not normalization:
waveform *= 2**15 # denormalized to 16-bit signed integers
if waveform_transforms is not None:
waveform, sample_rate = waveform_transforms(waveform, sample_rate)
if not always_2d:
waveform = waveform.squeeze(axis=0)
return waveform, sample_rate
def get_features_from_npy_or_audio(path, waveform_transforms=None):
ext = Path(path).suffix
if ext not in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
raise ValueError(f'Unsupported file format for "{path}"')
return (
np.load(path)
if ext == ".npy"
else get_fbank(path, waveform_transforms=waveform_transforms)
)
def get_features_or_waveform_from_stored_zip(
path,
byte_offset,
byte_size,
need_waveform=False,
use_sample_rate=None,
waveform_transforms=None,
):
assert path.endswith(".zip")
data = read_from_stored_zip(path, byte_offset, byte_size)
f = io.BytesIO(data)
if is_npy_data(data):
features_or_waveform = np.load(f)
elif is_sf_audio_data(data):
features_or_waveform = (
get_waveform(
f,
always_2d=False,
output_sample_rate=use_sample_rate,
waveform_transforms=waveform_transforms,
)[0]
if need_waveform
else get_fbank(f, waveform_transforms=waveform_transforms)
)
else:
raise ValueError(f'Unknown file format for "{path}"')
return features_or_waveform
def get_features_or_waveform(
path: str, need_waveform=False, use_sample_rate=None, waveform_transforms=None
):
"""Get speech features from .npy file or waveform from .wav/.flac file.
The file may be inside an uncompressed ZIP file and is accessed via byte
offset and length.
Args:
path (str): File path in the format of "<.npy/.wav/.flac path>" or
"<zip path>:<byte offset>:<byte length>".
need_waveform (bool): return waveform instead of features.
use_sample_rate (int): change sample rate for the input wave file
Returns:
features_or_waveform (numpy.ndarray): speech features or waveform.
"""
_path, slice_ptr = parse_path(path)
if len(slice_ptr) == 0:
if need_waveform:
return get_waveform(
_path,
always_2d=False,
output_sample_rate=use_sample_rate,
waveform_transforms=waveform_transforms,
)[0]
return get_features_from_npy_or_audio(
_path, waveform_transforms=waveform_transforms
)
elif len(slice_ptr) == 2:
features_or_waveform = get_features_or_waveform_from_stored_zip(
_path,
slice_ptr[0],
slice_ptr[1],
need_waveform=need_waveform,
use_sample_rate=use_sample_rate,
waveform_transforms=waveform_transforms,
)
else:
raise ValueError(f"Invalid path: {path}")
return features_or_waveform
def _get_kaldi_fbank(
waveform: np.ndarray, sample_rate: int, n_bins=80
) -> Optional[np.ndarray]:
"""Get mel-filter bank features via PyKaldi."""
try:
from kaldi.feat.fbank import Fbank, FbankOptions
from kaldi.feat.mel import MelBanksOptions
from kaldi.feat.window import FrameExtractionOptions
from kaldi.matrix import Vector
mel_opts = MelBanksOptions()
mel_opts.num_bins = n_bins
frame_opts = FrameExtractionOptions()
frame_opts.samp_freq = sample_rate
opts = FbankOptions()
opts.mel_opts = mel_opts
opts.frame_opts = frame_opts
fbank = Fbank(opts=opts)
features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy()
return features
except ImportError:
return None
def _get_torchaudio_fbank(
waveform: np.ndarray, sample_rate, n_bins=80
) -> Optional[np.ndarray]:
"""Get mel-filter bank features via TorchAudio."""
try:
import torchaudio.compliance.kaldi as ta_kaldi
waveform = torch.from_numpy(waveform)
features = ta_kaldi.fbank(
waveform, num_mel_bins=n_bins, sample_frequency=sample_rate
)
return features.numpy()
except ImportError:
return None
def get_fbank(
path_or_fp: Union[str, BinaryIO], n_bins=80, waveform_transforms=None
) -> np.ndarray:
"""Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
(faster CPP implementation) to TorchAudio (Python implementation). Note that
Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
waveform should not be normalized."""
waveform, sample_rate = get_waveform(
path_or_fp, normalization=False, waveform_transforms=waveform_transforms
)
features = _get_kaldi_fbank(waveform, sample_rate, n_bins)
if features is None:
features = _get_torchaudio_fbank(waveform, sample_rate, n_bins)
if features is None:
raise ImportError(
"Please install pyKaldi or torchaudio to enable "
"online filterbank feature extraction"
)
return features
def is_npy_data(data: bytes) -> bool:
return data[0] == 147 and data[1] == 78
def is_sf_audio_data(data: bytes) -> bool:
is_wav = data[0] == 82 and data[1] == 73 and data[2] == 70
is_flac = data[0] == 102 and data[1] == 76 and data[2] == 97
is_ogg = data[0] == 79 and data[1] == 103 and data[2] == 103
return is_wav or is_flac or is_ogg
def mmap_read(path: str, offset: int, length: int) -> bytes:
with open(path, "rb") as f:
with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_o:
data = mmap_o[offset : offset + length]
return data
def read_from_stored_zip(zip_path: str, offset: int, length: int) -> bytes:
return mmap_read(zip_path, offset, length)
def parse_path(path: str) -> Tuple[str, List[int]]:
"""Parse data path which is either a path to
1. a .npy/.wav/.flac/.ogg file
2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
Args:
path (str): the data path to parse
Returns:
file_path (str): the file path
slice_ptr (list of int): empty in case 1;
byte offset and length for the slice in case 2
"""
if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
_path, slice_ptr = path, []
else:
_path, *slice_ptr = path.split(":")
if not Path(_path).is_file():
raise FileNotFoundError(f"File not found: {_path}")
assert len(slice_ptr) in {0, 2}, f"Invalid path: {path}"
slice_ptr = [int(i) for i in slice_ptr]
return _path, slice_ptr
def get_window(window_fn: callable, n_fft: int, win_length: int) -> torch.Tensor:
padding = n_fft - win_length
assert padding >= 0
return F.pad(window_fn(win_length), (padding // 2, padding - padding // 2))
def get_fourier_basis(n_fft: int) -> torch.Tensor:
basis = np.fft.fft(np.eye(n_fft))
basis = np.vstack(
[np.real(basis[: n_fft // 2 + 1, :]), np.imag(basis[: n_fft // 2 + 1, :])]
)
return torch.from_numpy(basis).float()
def get_mel_filters(
sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float
) -> torch.Tensor:
try:
import librosa
except ImportError:
raise ImportError("Please install librosa: pip install librosa")
basis = librosa.filters.mel(sample_rate, n_fft, n_mels, f_min, f_max)
return torch.from_numpy(basis).float()
class TTSSpectrogram(torch.nn.Module):
def __init__(
self,
n_fft: int,
win_length: int,
hop_length: int,
window_fn: callable = torch.hann_window,
return_phase: bool = False,
) -> None:
super(TTSSpectrogram, self).__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.return_phase = return_phase
basis = get_fourier_basis(n_fft).unsqueeze(1)
basis *= get_window(window_fn, n_fft, win_length)
self.register_buffer("basis", basis)
def forward(
self, waveform: torch.Tensor
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
padding = (self.n_fft // 2, self.n_fft // 2)
x = F.pad(waveform.unsqueeze(1), padding, mode="reflect")
x = F.conv1d(x, self.basis, stride=self.hop_length)
real_part = x[:, : self.n_fft // 2 + 1, :]
imag_part = x[:, self.n_fft // 2 + 1 :, :]
magnitude = torch.sqrt(real_part**2 + imag_part**2)
if self.return_phase:
phase = torch.atan2(imag_part, real_part)
return magnitude, phase
return magnitude
class TTSMelScale(torch.nn.Module):
def __init__(
self, n_mels: int, sample_rate: int, f_min: float, f_max: float, n_stft: int
) -> None:
super(TTSMelScale, self).__init__()
basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max)
self.register_buffer("basis", basis)
def forward(self, specgram: torch.Tensor) -> torch.Tensor:
return torch.matmul(self.basis, specgram)

View File

@@ -0,0 +1,387 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from argparse import Namespace
from copy import deepcopy
from pathlib import Path
from typing import Dict, Optional
from fairseq.data import Dictionary
logger = logging.getLogger(__name__)
def get_config_from_yaml(yaml_path: Path):
try:
import yaml
except ImportError:
print("Please install PyYAML: pip install PyYAML")
config = {}
if yaml_path.is_file():
try:
with open(yaml_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
except Exception as e:
raise Exception(f"Failed to load config from {yaml_path.as_posix()}: {e}")
else:
raise FileNotFoundError(f"{yaml_path.as_posix()} not found")
return config
class S2TDataConfig(object):
"""Wrapper class for data config YAML"""
def __init__(self, yaml_path: Path):
self.config = get_config_from_yaml(yaml_path)
self.root = yaml_path.parent
def _auto_convert_to_abs_path(self, x):
if isinstance(x, str):
if not Path(x).exists() and (self.root / x).exists():
return (self.root / x).as_posix()
elif isinstance(x, dict):
return {k: self._auto_convert_to_abs_path(v) for k, v in x.items()}
return x
@property
def vocab_filename(self):
"""fairseq vocabulary file under data root"""
return self.config.get("vocab_filename", "dict.txt")
@property
def speaker_set_filename(self):
"""speaker set file under data root"""
return self.config.get("speaker_set_filename", None)
@property
def shuffle(self) -> bool:
"""Shuffle dataset samples before batching"""
return self.config.get("shuffle", False)
@property
def pre_tokenizer(self) -> Dict:
"""Pre-tokenizer to apply before subword tokenization. Returning
a dictionary with `tokenizer` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
tokenizer = self.config.get("pre_tokenizer", {"tokenizer": None})
return self._auto_convert_to_abs_path(tokenizer)
@property
def bpe_tokenizer(self) -> Dict:
"""Subword tokenizer to apply after pre-tokenization. Returning
a dictionary with `bpe` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
tokenizer = self.config.get("bpe_tokenizer", {"bpe": None})
return self._auto_convert_to_abs_path(tokenizer)
@property
def prepend_tgt_lang_tag(self) -> bool:
"""Prepend target lang ID token as the target BOS (e.g. for to-many
multilingual setting). During inference, this requires `--prefix-size 1`
to force BOS to be lang ID token."""
return self.config.get("prepend_tgt_lang_tag", False)
@property
def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
@property
def input_feat_per_channel(self):
"""The dimension of input features (per audio channel)"""
return self.config.get("input_feat_per_channel", 80)
@property
def input_channels(self):
"""The number of channels in the input audio"""
return self.config.get("input_channels", 1)
@property
def sample_rate(self):
return self.config.get("sample_rate", 16_000)
@property
def sampling_alpha(self):
"""Hyper-parameter alpha = 1/T for temperature-based resampling.
(alpha = 1 for no resampling)"""
return self.config.get("sampling_alpha", 1.0)
@property
def use_audio_input(self):
"""Needed by the dataset loader to see if the model requires
raw audio as inputs."""
return self.config.get("use_audio_input", False)
def standardize_audio(self) -> bool:
return self.use_audio_input and self.config.get("standardize_audio", False)
@property
def use_sample_rate(self):
"""Needed by the dataset loader to see if the model requires
raw audio with specific sample rate as inputs."""
return self.config.get("use_sample_rate", 16000)
@property
def audio_root(self):
"""Audio paths in the manifest TSV can be relative and this provides
the root path. Set this to empty string when using absolute paths."""
return self.config.get("audio_root", "")
def get_transforms(self, transform_type, split, is_train):
"""Split-specific feature transforms. Allowing train set
wildcard `_train`, evaluation set wildcard `_eval` and general
wildcard `*` for matching."""
from copy import deepcopy
cfg = deepcopy(self.config)
_cur = cfg.get(f"{transform_type}transforms", {})
cur = _cur.get(split)
cur = _cur.get("_train") if cur is None and is_train else cur
cur = _cur.get("_eval") if cur is None and not is_train else cur
cur = _cur.get("*") if cur is None else cur
return cur
def get_feature_transforms(self, split, is_train):
cfg = deepcopy(self.config)
# TODO: deprecate transforms
cur = self.get_transforms("", split, is_train)
if cur is not None:
logger.warning(
"Auto converting transforms into feature_transforms, "
"but transforms will be deprecated in the future. Please "
"update this in the config."
)
ft_transforms = self.get_transforms("feature_", split, is_train)
if ft_transforms:
cur.extend(ft_transforms)
else:
cur = self.get_transforms("feature_", split, is_train)
cfg["feature_transforms"] = cur
return cfg
def get_waveform_transforms(self, split, is_train):
cfg = deepcopy(self.config)
cfg["waveform_transforms"] = self.get_transforms("waveform_", split, is_train)
return cfg
def get_dataset_transforms(self, split, is_train):
cfg = deepcopy(self.config)
cfg["dataset_transforms"] = self.get_transforms("dataset_", split, is_train)
return cfg
@property
def global_cmvn_stats_npz(self) -> Optional[str]:
path = self.config.get("global_cmvn", {}).get("stats_npz_path", None)
return self._auto_convert_to_abs_path(path)
@property
def vocoder(self) -> Dict[str, str]:
vocoder = self.config.get("vocoder", {"type": "griffin_lim"})
return self._auto_convert_to_abs_path(vocoder)
@property
def hub(self) -> Dict[str, str]:
return self.config.get("hub", {})
class S2SDataConfig(S2TDataConfig):
"""Wrapper class for data config YAML"""
@property
def vocab_filename(self):
"""fairseq vocabulary file under data root"""
return self.config.get("vocab_filename", None)
@property
def pre_tokenizer(self) -> Dict:
return None
@property
def bpe_tokenizer(self) -> Dict:
return None
@property
def input_transformed_channels(self):
"""The number of channels in the audio after feature transforms"""
# TODO: move this into individual transforms
# TODO: deprecate transforms
_cur = self.config.get("transforms", {})
ft_transforms = self.config.get("feature_transforms", {})
if _cur and ft_transforms:
_cur.update(ft_transforms)
else:
_cur = self.config.get("feature_transforms", {})
cur = _cur.get("_train", [])
_channels = self.input_channels
if "delta_deltas" in cur:
_channels *= 3
return _channels
@property
def output_sample_rate(self):
"""The audio sample rate of output target speech"""
return self.config.get("output_sample_rate", 22050)
@property
def target_speaker_embed(self):
"""Target speaker embedding file (one line per target audio sample)"""
return self.config.get("target_speaker_embed", None)
@property
def prepend_tgt_lang_tag_as_bos(self) -> bool:
"""Prepend target lang ID token as the target BOS."""
return self.config.get("prepend_tgt_lang_tag_as_bos", False)
class MultitaskConfig(object):
"""Wrapper class for data config YAML"""
def __init__(self, yaml_path: Path):
config = get_config_from_yaml(yaml_path)
self.config = {}
for k, v in config.items():
self.config[k] = SingleTaskConfig(k, v)
def get_all_tasks(self):
return self.config
def get_single_task(self, name):
assert name in self.config, f"multitask '{name}' does not exist!"
return self.config[name]
@property
def first_pass_decoder_task_index(self):
"""Return the task index of the first-pass text decoder.
If there are multiple 'is_first_pass_decoder: True' in the config file,
the last task is used for the first-pass decoder.
If there is no 'is_first_pass_decoder: True' in the config file,
the last task whose task_name includes 'target' and decoder_type is not ctc.
"""
idx = -1
for i, (k, v) in enumerate(self.config.items()):
if v.is_first_pass_decoder:
idx = i
if idx < 0:
for i, (k, v) in enumerate(self.config.items()):
if k.startswith("target") and v.decoder_type == "transformer":
idx = i
return idx
class SingleTaskConfig(object):
def __init__(self, name, config):
self.task_name = name
self.config = config
dict_path = config.get("dict", "")
self.tgt_dict = Dictionary.load(dict_path) if Path(dict_path).exists() else None
@property
def data(self):
return self.config.get("data", "")
@property
def decoder_type(self):
return self.config.get("decoder_type", "transformer")
@property
def decoder_args(self):
"""Decoder arch related args"""
args = self.config.get("decoder_args", {})
return Namespace(**args)
@property
def criterion_cfg(self):
"""cfg for the multitask criterion"""
if self.decoder_type == "ctc":
from fairseq.criterions.ctc import CtcCriterionConfig
cfg = CtcCriterionConfig
cfg.zero_infinity = self.config.get("zero_infinity", True)
else:
from fairseq.criterions.label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterionConfig,
)
cfg = LabelSmoothedCrossEntropyCriterionConfig
cfg.label_smoothing = self.config.get("label_smoothing", 0.2)
return cfg
@property
def input_from(self):
"""Condition on encoder/decoder of the main model"""
return "decoder" if "decoder_layer" in self.config else "encoder"
@property
def input_layer(self):
if self.input_from == "decoder":
return self.config["decoder_layer"] - 1
else:
# default using the output from the last encoder layer (-1)
return self.config.get("encoder_layer", 0) - 1
@property
def loss_weight_schedule(self):
return (
"decay"
if "loss_weight_max" in self.config
and "loss_weight_decay_steps" in self.config
else "fixed"
)
def get_loss_weight(self, num_updates):
if self.loss_weight_schedule == "fixed":
weight = self.config.get("loss_weight", 1.0)
else: # "decay"
assert (
self.config.get("loss_weight_decay_steps", 0) > 0
), "loss_weight_decay_steps must be greater than 0 for a decay schedule"
loss_weight_min = self.config.get("loss_weight_min", 0.0001)
loss_weight_decay_stepsize = (
self.config["loss_weight_max"] - loss_weight_min
) / self.config["loss_weight_decay_steps"]
weight = max(
self.config["loss_weight_max"]
- loss_weight_decay_stepsize * num_updates,
loss_weight_min,
)
return weight
@property
def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
@property
def eos_token(self):
"""EOS token during generation"""
return self.config.get("eos_token", "<eos>")
@property
def rdrop_alpha(self):
return self.config.get("rdrop_alpha", 0.0)
@property
def is_first_pass_decoder(self):
flag = self.config.get("is_first_pass_decoder", False)
if flag:
if self.decoder_type == "ctc":
raise ValueError(
"First-pass decoder in the multi-decoder model must not be CTC."
)
if "target" not in self.task_name:
raise Warning(
'The name of the first-pass decoder does not include "target".'
)
return flag
@property
def get_lang_tag_mapping(self):
return self.config.get("lang_tag_mapping", {})

View File

@@ -0,0 +1,53 @@
import os
from fairseq.data.audio import (
AudioTransform,
CompositeAudioTransform,
import_transforms,
register_audio_transform,
)
class AudioDatasetTransform(AudioTransform):
pass
AUDIO_DATASET_TRANSFORM_REGISTRY = {}
AUDIO_DATASET_TRANSFORM_CLASS_NAMES = set()
def get_audio_dataset_transform(name):
return AUDIO_DATASET_TRANSFORM_REGISTRY[name]
def register_audio_dataset_transform(name):
return register_audio_transform(
name,
AudioDatasetTransform,
AUDIO_DATASET_TRANSFORM_REGISTRY,
AUDIO_DATASET_TRANSFORM_CLASS_NAMES,
)
import_transforms(os.path.dirname(__file__), "dataset")
class CompositeAudioDatasetTransform(CompositeAudioTransform):
@classmethod
def from_config_dict(cls, config=None):
return super()._from_config_dict(
cls,
"dataset",
get_audio_dataset_transform,
CompositeAudioDatasetTransform,
config,
return_empty=True,
)
def get_transform(self, cls):
for t in self.transforms:
if isinstance(t, cls):
return t
return None
def has_transform(self, cls):
return self.get_transform(cls) is not None

View File

@@ -0,0 +1,61 @@
from typing import List
import numpy as np
from fairseq.data.audio.dataset_transforms import (
AudioDatasetTransform,
register_audio_dataset_transform,
)
_DEFAULTS = {"rate": 0.25, "max_tokens": 3000, "attempts": 5}
@register_audio_dataset_transform("concataugment")
class ConcatAugment(AudioDatasetTransform):
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return ConcatAugment(
_config.get("rate", _DEFAULTS["rate"]),
_config.get("max_tokens", _DEFAULTS["max_tokens"]),
_config.get("attempts", _DEFAULTS["attempts"]),
)
def __init__(
self,
rate=_DEFAULTS["rate"],
max_tokens=_DEFAULTS["max_tokens"],
attempts=_DEFAULTS["attempts"],
):
self.rate, self.max_tokens, self.attempts = rate, max_tokens, attempts
def __repr__(self):
return (
self.__class__.__name__
+ "("
+ ", ".join(
[
f"rate={self.rate}",
f"max_tokens={self.max_tokens}",
f"attempts={self.attempts}",
]
)
+ ")"
)
def find_indices(self, index: int, n_frames: List[int], n_samples: int):
# skip conditions: application rate, max_tokens limit exceeded
if np.random.random() > self.rate:
return [index]
if self.max_tokens and n_frames[index] > self.max_tokens:
return [index]
# pick second sample to concatenate
for _ in range(self.attempts):
index2 = np.random.randint(0, n_samples)
if index2 != index and (
not self.max_tokens
or n_frames[index] + n_frames[index2] < self.max_tokens
):
return [index, index2]
return [index]

View File

@@ -0,0 +1,105 @@
import numpy as np
import torch
from fairseq.data.audio import rand_uniform
from fairseq.data.audio.dataset_transforms import (
AudioDatasetTransform,
register_audio_dataset_transform,
)
from fairseq.data.audio.waveform_transforms.noiseaugment import (
NoiseAugmentTransform,
)
_DEFAULTS = {
"rate": 0.25,
"mixing_noise_rate": 0.1,
"noise_path": "",
"noise_snr_min": -5,
"noise_snr_max": 5,
"utterance_snr_min": -5,
"utterance_snr_max": 5,
}
@register_audio_dataset_transform("noisyoverlapaugment")
class NoisyOverlapAugment(AudioDatasetTransform):
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return NoisyOverlapAugment(
_config.get("rate", _DEFAULTS["rate"]),
_config.get("mixing_noise_rate", _DEFAULTS["mixing_noise_rate"]),
_config.get("noise_path", _DEFAULTS["noise_path"]),
_config.get("noise_snr_min", _DEFAULTS["noise_snr_min"]),
_config.get("noise_snr_max", _DEFAULTS["noise_snr_max"]),
_config.get("utterance_snr_min", _DEFAULTS["utterance_snr_min"]),
_config.get("utterance_snr_max", _DEFAULTS["utterance_snr_max"]),
)
def __init__(
self,
rate=_DEFAULTS["rate"],
mixing_noise_rate=_DEFAULTS["mixing_noise_rate"],
noise_path=_DEFAULTS["noise_path"],
noise_snr_min=_DEFAULTS["noise_snr_min"],
noise_snr_max=_DEFAULTS["noise_snr_max"],
utterance_snr_min=_DEFAULTS["utterance_snr_min"],
utterance_snr_max=_DEFAULTS["utterance_snr_max"],
):
self.rate = rate
self.mixing_noise_rate = mixing_noise_rate
self.noise_shaper = NoiseAugmentTransform(noise_path)
self.noise_snr_min = noise_snr_min
self.noise_snr_max = noise_snr_max
self.utterance_snr_min = utterance_snr_min
self.utterance_snr_max = utterance_snr_max
def __repr__(self):
return (
self.__class__.__name__
+ "("
+ ", ".join(
[
f"rate={self.rate}",
f"mixing_noise_rate={self.mixing_noise_rate}",
f"noise_snr_min={self.noise_snr_min}",
f"noise_snr_max={self.noise_snr_max}",
f"utterance_snr_min={self.utterance_snr_min}",
f"utterance_snr_max={self.utterance_snr_max}",
]
)
+ ")"
)
def __call__(self, sources):
for i, source in enumerate(sources):
if np.random.random() > self.rate:
continue
pri = source.numpy()
if np.random.random() > self.mixing_noise_rate:
sec = sources[np.random.randint(0, len(sources))].numpy()
snr = rand_uniform(self.utterance_snr_min, self.utterance_snr_max)
else:
sec = self.noise_shaper.pick_sample(source.shape)
snr = rand_uniform(self.noise_snr_min, self.noise_snr_max)
L1 = pri.shape[-1]
L2 = sec.shape[-1]
l = np.random.randint(0, min(round(L1 / 2), L2)) # mix len
s_source = np.random.randint(0, L1 - l)
s_sec = np.random.randint(0, L2 - l)
get_power = lambda x: np.mean(x**2)
if get_power(sec) == 0:
continue
scl = np.sqrt(get_power(pri) / (np.power(10, snr / 10) * get_power(sec)))
pri[s_source : s_source + l] = np.add(
pri[s_source : s_source + l], np.multiply(scl, sec[s_sec : s_sec + l])
)
sources[i] = torch.from_numpy(pri).float()
return sources

View File

@@ -0,0 +1,43 @@
import os
from fairseq.data.audio import (
AudioTransform,
CompositeAudioTransform,
import_transforms,
register_audio_transform,
)
class AudioFeatureTransform(AudioTransform):
pass
AUDIO_FEATURE_TRANSFORM_REGISTRY = {}
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set()
def get_audio_feature_transform(name):
return AUDIO_FEATURE_TRANSFORM_REGISTRY[name]
def register_audio_feature_transform(name):
return register_audio_transform(
name,
AudioFeatureTransform,
AUDIO_FEATURE_TRANSFORM_REGISTRY,
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES,
)
import_transforms(os.path.dirname(__file__), "feature")
class CompositeAudioFeatureTransform(CompositeAudioTransform):
@classmethod
def from_config_dict(cls, config=None):
return super()._from_config_dict(
cls,
"feature",
get_audio_feature_transform,
CompositeAudioFeatureTransform,
config,
)

View File

@@ -0,0 +1,37 @@
import numpy as np
import torch
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform,
register_audio_feature_transform,
)
@register_audio_feature_transform("delta_deltas")
class DeltaDeltas(AudioFeatureTransform):
"""Expand delta-deltas features from spectrum."""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return DeltaDeltas(_config.get("win_length", 5))
def __init__(self, win_length=5):
self.win_length = win_length
def __repr__(self):
return self.__class__.__name__
def __call__(self, spectrogram):
from torchaudio.functional import compute_deltas
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
# spectrogram is T x F, while compute_deltas takes (…, F, T)
spectrogram = torch.from_numpy(spectrogram).transpose(0, 1)
delta = compute_deltas(spectrogram)
delta_delta = compute_deltas(delta)
out_feat = np.concatenate(
[spectrogram, delta.numpy(), delta_delta.numpy()], axis=0
)
out_feat = np.transpose(out_feat)
return out_feat

View File

@@ -0,0 +1,29 @@
import numpy as np
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform,
register_audio_feature_transform,
)
@register_audio_feature_transform("global_cmvn")
class GlobalCMVN(AudioFeatureTransform):
"""Global CMVN (cepstral mean and variance normalization). The global mean
and variance need to be pre-computed and stored in NumPy format (.npz)."""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return GlobalCMVN(_config.get("stats_npz_path"))
def __init__(self, stats_npz_path):
self.stats_npz_path = stats_npz_path
stats = np.load(stats_npz_path)
self.mean, self.std = stats["mean"], stats["std"]
def __repr__(self):
return self.__class__.__name__ + f'(stats_npz_path="{self.stats_npz_path}")'
def __call__(self, x):
x = np.subtract(x, self.mean)
x = np.divide(x, self.std)
return x

View File

@@ -0,0 +1,131 @@
import math
import numbers
from typing import Optional
import numpy as np
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform,
register_audio_feature_transform,
)
@register_audio_feature_transform("specaugment")
class SpecAugmentTransform(AudioFeatureTransform):
"""SpecAugment (https://arxiv.org/abs/1904.08779)"""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return SpecAugmentTransform(
_config.get("time_warp_W", 0),
_config.get("freq_mask_N", 0),
_config.get("freq_mask_F", 0),
_config.get("time_mask_N", 0),
_config.get("time_mask_T", 0),
_config.get("time_mask_p", 0.0),
_config.get("mask_value", None),
)
def __init__(
self,
time_warp_w: int = 0,
freq_mask_n: int = 0,
freq_mask_f: int = 0,
time_mask_n: int = 0,
time_mask_t: int = 0,
time_mask_p: float = 0.0,
mask_value: Optional[float] = 0.0,
):
# Sanity checks
assert mask_value is None or isinstance(
mask_value, numbers.Number
), f"mask_value (type: {type(mask_value)}) must be None or a number"
if freq_mask_n > 0:
assert freq_mask_f > 0, (
f"freq_mask_F ({freq_mask_f}) "
f"must be larger than 0 when doing freq masking."
)
if time_mask_n > 0:
assert time_mask_t > 0, (
f"time_mask_T ({time_mask_t}) must be larger than 0 when "
f"doing time masking."
)
self.time_warp_w = time_warp_w
self.freq_mask_n = freq_mask_n
self.freq_mask_f = freq_mask_f
self.time_mask_n = time_mask_n
self.time_mask_t = time_mask_t
self.time_mask_p = time_mask_p
self.mask_value = mask_value
def __repr__(self):
return (
self.__class__.__name__
+ "("
+ ", ".join(
[
f"time_warp_w={self.time_warp_w}",
f"freq_mask_n={self.freq_mask_n}",
f"freq_mask_f={self.freq_mask_f}",
f"time_mask_n={self.time_mask_n}",
f"time_mask_t={self.time_mask_t}",
f"time_mask_p={self.time_mask_p}",
]
)
+ ")"
)
def __call__(self, spectrogram):
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
distorted = spectrogram.copy() # make a copy of input spectrogram.
num_frames = spectrogram.shape[0] # or 'tau' in the paper.
num_freqs = spectrogram.shape[1] # or 'miu' in the paper.
mask_value = self.mask_value
if mask_value is None: # if no value was specified, use local mean.
mask_value = spectrogram.mean()
if num_frames == 0:
return spectrogram
if num_freqs < self.freq_mask_f:
return spectrogram
if self.time_warp_w > 0:
if 2 * self.time_warp_w < num_frames:
import cv2
w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w)
w = np.random.randint(-self.time_warp_w + 1, self.time_warp_w)
upper, lower = distorted[:w0, :], distorted[w0:, :]
upper = cv2.resize(
upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR
)
lower = cv2.resize(
lower,
dsize=(num_freqs, num_frames - w0 - w),
interpolation=cv2.INTER_LINEAR,
)
distorted = np.concatenate((upper, lower), axis=0)
for _i in range(self.freq_mask_n):
f = np.random.randint(0, self.freq_mask_f)
f0 = np.random.randint(0, num_freqs - f)
if f != 0:
distorted[:, f0 : f0 + f] = mask_value
max_time_mask_t = min(
self.time_mask_t, math.floor(num_frames * self.time_mask_p)
)
if max_time_mask_t < 1:
return distorted
for _i in range(self.time_mask_n):
t = np.random.randint(0, max_time_mask_t)
t0 = np.random.randint(0, num_frames - t)
if t != 0:
distorted[t0 : t0 + t, :] = mask_value
return distorted

View File

@@ -0,0 +1,41 @@
import numpy as np
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform,
register_audio_feature_transform,
)
@register_audio_feature_transform("utterance_cmvn")
class UtteranceCMVN(AudioFeatureTransform):
"""Utterance-level CMVN (cepstral mean and variance normalization)"""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return UtteranceCMVN(
_config.get("norm_means", True),
_config.get("norm_vars", True),
)
def __init__(self, norm_means=True, norm_vars=True):
self.norm_means, self.norm_vars = norm_means, norm_vars
def __repr__(self):
return (
self.__class__.__name__
+ f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})"
)
def __call__(self, x):
mean = x.mean(axis=0)
square_sums = (x**2).sum(axis=0)
if self.norm_means:
x = np.subtract(x, mean)
if self.norm_vars:
var = square_sums / x.shape[0] - mean**2
std = np.sqrt(np.maximum(var, 1e-10))
x = np.divide(x, std)
return x

View File

@@ -0,0 +1,205 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.abs
import csv
import logging
import os.path as op
from typing import List, Optional
import numpy as np
import torch
from fairseq.data import Dictionary
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
from fairseq.data.audio.text_to_speech_dataset import (
TextToSpeechDataset,
TextToSpeechDatasetCreator,
)
logger = logging.getLogger(__name__)
class FrmTextToSpeechDataset(TextToSpeechDataset):
def __init__(
self,
split: str,
is_train_split: bool,
data_cfg: S2TDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
n_frames_per_step=1,
speaker_to_id=None,
do_chunk=False,
chunk_bound=-1,
chunk_init=50,
chunk_incr=5,
add_eos=True,
dedup=True,
ref_fpu=-1,
):
# It assumes texts are encoded at a fixed frame-rate
super().__init__(
split=split,
is_train_split=is_train_split,
data_cfg=data_cfg,
audio_paths=audio_paths,
n_frames=n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
)
self.do_chunk = do_chunk
self.chunk_bound = chunk_bound
self.chunk_init = chunk_init
self.chunk_incr = chunk_incr
self.add_eos = add_eos
self.dedup = dedup
self.ref_fpu = ref_fpu
self.chunk_size = -1
if do_chunk:
assert self.chunk_incr >= 0
assert self.pre_tokenizer is None
def __getitem__(self, index):
index, source, target, speaker_id, _, _, _ = super().__getitem__(index)
if target[-1].item() == self.tgt_dict.eos_index:
target = target[:-1]
fpu = source.size(0) / target.size(0) # frame-per-unit
fps = self.n_frames_per_step
assert (
self.ref_fpu == -1 or abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1
), f"{fpu*fps} != {self.ref_fpu}"
# only chunk training split
if self.is_train_split and self.do_chunk and self.chunk_size > 0:
lang = target[: int(self.data_cfg.prepend_tgt_lang_tag)]
text = target[int(self.data_cfg.prepend_tgt_lang_tag) :]
size = len(text)
chunk_size = min(self.chunk_size, size)
chunk_start = np.random.randint(size - chunk_size + 1)
text = text[chunk_start : chunk_start + chunk_size]
target = torch.cat((lang, text), 0)
f_size = int(np.floor(chunk_size * fpu))
f_start = int(np.floor(chunk_start * fpu))
assert f_size > 0
source = source[f_start : f_start + f_size, :]
if self.dedup:
target = torch.unique_consecutive(target)
if self.add_eos:
eos_idx = self.tgt_dict.eos_index
target = torch.cat((target, torch.LongTensor([eos_idx])), 0)
return index, source, target, speaker_id
def set_epoch(self, epoch):
if self.is_train_split and self.do_chunk:
old = self.chunk_size
self.chunk_size = self.chunk_init + epoch * self.chunk_incr
if self.chunk_bound > 0:
self.chunk_size = min(self.chunk_size, self.chunk_bound)
logger.info(
(
f"{self.split}: setting chunk size "
f"from {old} to {self.chunk_size}"
)
)
class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
# inherit for key names
@classmethod
def from_tsv(
cls,
root: str,
data_cfg: S2TDataConfig,
split: str,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
is_train_split: bool,
n_frames_per_step: int,
speaker_to_id,
do_chunk: bool = False,
chunk_bound: int = -1,
chunk_init: int = 50,
chunk_incr: int = 5,
add_eos: bool = True,
dedup: bool = True,
ref_fpu: float = -1,
) -> FrmTextToSpeechDataset:
tsv_path = op.join(root, f"{split}.tsv")
if not op.isfile(tsv_path):
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
with open(tsv_path) as f:
reader = csv.DictReader(
f,
delimiter="\t",
quotechar=None,
doublequote=False,
lineterminator="\n",
quoting=csv.QUOTE_NONE,
)
s = [dict(e) for e in reader]
assert len(s) > 0
ids = [ss[cls.KEY_ID] for ss in s]
audio_paths = [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s]
n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s]
tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s]
src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s]
speakers = [ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s]
src_langs = [ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s]
tgt_langs = [ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s]
return FrmTextToSpeechDataset(
split=split,
is_train_split=is_train_split,
data_cfg=data_cfg,
audio_paths=audio_paths,
n_frames=n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
do_chunk=do_chunk,
chunk_bound=chunk_bound,
chunk_init=chunk_init,
chunk_incr=chunk_incr,
add_eos=add_eos,
dedup=dedup,
ref_fpu=ref_fpu,
)

View File

@@ -0,0 +1,356 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import logging
import os
import sys
from typing import Any, List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.data import data_utils
from fairseq.data.fairseq_dataset import FairseqDataset
from fairseq.data.audio.audio_utils import (
parse_path,
read_from_stored_zip,
)
import io
logger = logging.getLogger(__name__)
def load_audio(manifest_path, max_keep, min_keep):
n_long, n_short = 0, 0
names, inds, sizes = [], [], []
with open(manifest_path) as f:
root = f.readline().strip()
for ind, line in enumerate(f):
items = line.strip().split("\t")
assert len(items) == 2, line
sz = int(items[1])
if min_keep is not None and sz < min_keep:
n_short += 1
elif max_keep is not None and sz > max_keep:
n_long += 1
else:
names.append(items[0])
inds.append(ind)
sizes.append(sz)
tot = ind + 1
logger.info(
(
f"max_keep={max_keep}, min_keep={min_keep}, "
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
)
)
return root, names, inds, tot, sizes
def load_label(label_path, inds, tot):
with open(label_path) as f:
labels = [line.rstrip() for line in f]
assert (
len(labels) == tot
), f"number of labels does not match ({len(labels)} != {tot})"
labels = [labels[i] for i in inds]
return labels
def load_label_offset(label_path, inds, tot):
with open(label_path) as f:
code_lengths = [len(line.encode("utf-8")) for line in f]
assert (
len(code_lengths) == tot
), f"number of labels does not match ({len(code_lengths)} != {tot})"
offsets = list(itertools.accumulate([0] + code_lengths))
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
return offsets
def verify_label_lengths(
audio_sizes,
audio_rate,
label_path,
label_rate,
inds,
tot,
tol=0.1, # tolerance in seconds
):
if label_rate < 0:
logger.info(f"{label_path} is sequence label. skipped")
return
with open(label_path) as f:
lengths = [len(line.rstrip().split()) for line in f]
assert len(lengths) == tot
lengths = [lengths[i] for i in inds]
num_invalid = 0
for i, ind in enumerate(inds):
dur_from_audio = audio_sizes[i] / audio_rate
dur_from_label = lengths[i] / label_rate
if abs(dur_from_audio - dur_from_label) > tol:
logger.warning(
(
f"audio and label duration differ too much "
f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
f"in line {ind+1} of {label_path}. Check if `label_rate` "
f"is correctly set (currently {label_rate}). "
f"num. of samples = {audio_sizes[i]}; "
f"label length = {lengths[i]}"
)
)
num_invalid += 1
if num_invalid > 0:
logger.warning(
f"total {num_invalid} (audio, label) pairs with mismatched lengths"
)
class HubertDataset(FairseqDataset):
def __init__(
self,
manifest_path: str,
sample_rate: float,
label_paths: List[str],
label_rates: Union[List[float], float], # -1 for sequence labels
pad_list: List[str],
eos_list: List[str],
label_processors: Optional[List[Any]] = None,
max_keep_sample_size: Optional[int] = None,
min_keep_sample_size: Optional[int] = None,
max_sample_size: Optional[int] = None,
shuffle: bool = True,
pad_audio: bool = False,
normalize: bool = False,
store_labels: bool = True,
random_crop: bool = False,
single_target: bool = False,
):
self.audio_root, self.audio_names, inds, tot, self.sizes = load_audio(
manifest_path, max_keep_sample_size, min_keep_sample_size
)
self.sample_rate = sample_rate
self.shuffle = shuffle
self.random_crop = random_crop
self.num_labels = len(label_paths)
self.pad_list = pad_list
self.eos_list = eos_list
self.label_processors = label_processors
self.single_target = single_target
self.label_rates = (
[label_rates for _ in range(len(label_paths))]
if isinstance(label_rates, float)
else label_rates
)
self.store_labels = store_labels
if store_labels:
self.label_list = [load_label(p, inds, tot) for p in label_paths]
else:
self.label_paths = label_paths
self.label_offsets_list = [
load_label_offset(p, inds, tot) for p in label_paths
]
assert label_processors is None or len(label_processors) == self.num_labels
for label_path, label_rate in zip(label_paths, self.label_rates):
verify_label_lengths(
self.sizes, sample_rate, label_path, label_rate, inds, tot
)
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
self.pad_audio = pad_audio
self.normalize = normalize
logger.info(
f"pad_audio={pad_audio}, random_crop={random_crop}, "
f"normalize={normalize}, max_sample_size={self.max_sample_size}"
)
def get_audio(self, index):
import soundfile as sf
wav_path = os.path.join(self.audio_root, self.audio_names[index])
_path, slice_ptr = parse_path(wav_path)
if len(slice_ptr) == 0:
wav, cur_sample_rate = sf.read(_path)
else:
assert _path.endswith(".zip")
data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
f = io.BytesIO(data)
wav, cur_sample_rate = sf.read(f)
wav = torch.from_numpy(wav).float()
wav = self.postprocess(wav, cur_sample_rate)
return wav
def get_label(self, index, label_idx):
if self.store_labels:
label = self.label_list[label_idx][index]
else:
with open(self.label_paths[label_idx]) as f:
offset_s, offset_e = self.label_offsets_list[label_idx][index]
f.seek(offset_s)
label = f.read(offset_e - offset_s)
if self.label_processors is not None:
label = self.label_processors[label_idx](label)
return label
def get_labels(self, index):
return [self.get_label(index, i) for i in range(self.num_labels)]
def __getitem__(self, index):
wav = self.get_audio(index)
labels = self.get_labels(index)
return {"id": index, "source": wav, "label_list": labels}
def __len__(self):
return len(self.sizes)
def crop_to_max_size(self, wav, target_size):
size = len(wav)
diff = size - target_size
if diff <= 0:
return wav, 0
start, end = 0, target_size
if self.random_crop:
start = np.random.randint(0, diff + 1)
end = size - diff + start
return wav[start:end], start
def collater(self, samples):
# target = max(sizes) -> random_crop not used
# target = max_sample_size -> random_crop used for long
samples = [s for s in samples if s["source"] is not None]
if len(samples) == 0:
return {}
audios = [s["source"] for s in samples]
audio_sizes = [len(s) for s in audios]
if self.pad_audio:
audio_size = min(max(audio_sizes), self.max_sample_size)
else:
audio_size = min(min(audio_sizes), self.max_sample_size)
collated_audios, padding_mask, audio_starts = self.collater_audio(
audios, audio_size
)
targets_by_label = [
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
]
targets_list, lengths_list, ntokens_list = self.collater_label(
targets_by_label, audio_size, audio_starts
)
net_input = {"source": collated_audios, "padding_mask": padding_mask}
batch = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": net_input,
}
if self.single_target:
batch["target_lengths"] = lengths_list[0]
batch["ntokens"] = ntokens_list[0]
batch["target"] = targets_list[0]
else:
batch["target_lengths_list"] = lengths_list
batch["ntokens_list"] = ntokens_list
batch["target_list"] = targets_list
return batch
def collater_audio(self, audios, audio_size):
collated_audios = audios[0].new_zeros(len(audios), audio_size)
padding_mask = (
torch.BoolTensor(collated_audios.shape).fill_(False)
# if self.pad_audio else None
)
audio_starts = [0 for _ in audios]
for i, audio in enumerate(audios):
diff = len(audio) - audio_size
if diff == 0:
collated_audios[i] = audio
elif diff < 0:
assert self.pad_audio
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
padding_mask[i, diff:] = True
else:
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
audio, audio_size
)
return collated_audios, padding_mask, audio_starts
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
assert label_rate > 0
s2f = label_rate / self.sample_rate
frm_starts = [int(round(s * s2f)) for s in audio_starts]
frm_size = int(round(audio_size * s2f))
if not self.pad_audio:
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
frm_size = min(frm_size, *rem_size)
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
logger.debug(f"audio_starts={audio_starts}")
logger.debug(f"frame_starts={frm_starts}")
logger.debug(f"frame_size={frm_size}")
lengths = torch.LongTensor([len(t) for t in targets])
ntokens = lengths.sum().item()
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
return targets, lengths, ntokens
def collater_seq_label(self, targets, pad):
lengths = torch.LongTensor([len(t) for t in targets])
ntokens = lengths.sum().item()
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
return targets, lengths, ntokens
def collater_label(self, targets_by_label, audio_size, audio_starts):
targets_list, lengths_list, ntokens_list = [], [], []
itr = zip(targets_by_label, self.label_rates, self.pad_list)
for targets, label_rate, pad in itr:
if label_rate == -1.0:
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
else:
targets, lengths, ntokens = self.collater_frm_label(
targets, audio_size, audio_starts, label_rate, pad
)
targets_list.append(targets)
lengths_list.append(lengths)
ntokens_list.append(ntokens)
return targets_list, lengths_list, ntokens_list
def num_tokens(self, index):
return self.size(index)
def size(self, index):
if self.pad_audio:
return self.sizes[index]
return min(self.sizes[index], self.max_sample_size)
def ordered_indices(self):
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)[::-1]
def postprocess(self, wav, cur_sample_rate):
if wav.dim() == 2:
wav = wav.mean(-1)
assert wav.dim() == 1, wav.dim()
if cur_sample_rate != self.sample_rate:
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
if self.normalize:
with torch.no_grad():
wav = F.layer_norm(wav, wav.shape)
return wav

View File

@@ -0,0 +1,284 @@
# Copyright (c) 2021-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
import math
from typing import List, Optional, NamedTuple
import numpy as np
from fairseq.data.resampling_dataset import ResamplingDataset
import torch
from fairseq.data import (
ConcatDataset,
LanguagePairDataset,
FileAudioDataset,
data_utils,
)
from fairseq.data import FairseqDataset
logger = logging.getLogger(__name__)
class ModalityDatasetItem(NamedTuple):
datasetname: str
dataset: any
max_positions: List[int]
max_tokens: Optional[int] = None
max_sentences: Optional[int] = None
def resampling_dataset_present(ds):
if isinstance(ds, ResamplingDataset):
return True
if isinstance(ds, ConcatDataset):
return any(resampling_dataset_present(d) for d in ds.datasets)
if hasattr(ds, "dataset"):
return resampling_dataset_present(ds.dataset)
return False
# MultiModalityDataset: it concate multiple datasets with different modalities.
# Compared with ConcatDataset it can 1) sample data given the ratios for different datasets
# 2) it adds mode to indicate what type of the data samples come from.
# It will be used with GroupedEpochBatchIterator together to generate mini-batch with samples
# from the same type of dataset
# If only one dataset is used, it will perform like the original dataset with mode added
class MultiModalityDataset(ConcatDataset):
def __init__(self, datasets: List[ModalityDatasetItem]):
id_to_mode = []
dsets = []
max_tokens = []
max_sentences = []
max_positions = []
for dset in datasets:
id_to_mode.append(dset.datasetname)
dsets.append(dset.dataset)
max_tokens.append(dset.max_tokens)
max_positions.append(dset.max_positions)
max_sentences.append(dset.max_sentences)
weights = [1.0 for s in dsets]
super().__init__(dsets, weights)
self.max_tokens = max_tokens
self.max_positions = max_positions
self.max_sentences = max_sentences
self.id_to_mode = id_to_mode
self.raw_sub_batch_samplers = []
self._cur_epoch = 0
def set_epoch(self, epoch):
super().set_epoch(epoch)
self._cur_epoch = epoch
def __getitem__(self, idx):
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
sample = self.datasets[dataset_idx][sample_idx]
return (dataset_idx, sample)
def collater(self, samples):
if len(samples) == 0:
return {}
dataset_idx = samples[0][0]
# make sure all samples in samples are from same dataset
assert sum([0 if dataset_idx == s[0] else 1 for s in samples]) == 0
samples = self.datasets[dataset_idx].collater([x[1] for x in samples])
# add mode
samples["net_input"]["mode"] = self.id_to_mode[dataset_idx]
return samples
def size(self, index: int):
if len(self.datasets) == 1:
return self.datasets[0].size(index)
return super().size(index)
@property
def sizes(self):
if len(self.datasets) == 1:
return self.datasets[0].sizes
return super().sizes
def ordered_indices(self):
"""
Returns indices sorted by length. So less padding is needed.
"""
if len(self.datasets) == 1:
return self.datasets[0].ordered_indices()
indices_group = []
for d_idx, ds in enumerate(self.datasets):
sample_num = self.cumulative_sizes[d_idx]
if d_idx > 0:
sample_num = sample_num - self.cumulative_sizes[d_idx - 1]
assert sample_num == len(ds)
indices_group.append(ds.ordered_indices())
return indices_group
def get_raw_batch_samplers(self, required_batch_size_multiple, seed):
with data_utils.numpy_seed(seed):
indices = self.ordered_indices()
for i, ds in enumerate(self.datasets):
# If we have ResamplingDataset, the same id can correpond to a different
# sample in the next epoch, so we need to rebuild this at every epoch
if i < len(self.raw_sub_batch_samplers) and not resampling_dataset_present(
ds
):
logger.info(f"dataset {i} is valid and it is not re-sampled")
continue
indices[i] = ds.filter_indices_by_size(
indices[i],
self.max_positions[i],
)[0]
sub_batch_sampler = ds.batch_by_size(
indices[i],
max_tokens=self.max_tokens[i],
max_sentences=self.max_sentences[i],
required_batch_size_multiple=required_batch_size_multiple,
)
if i < len(self.raw_sub_batch_samplers):
self.raw_sub_batch_samplers[i] = sub_batch_sampler
else:
self.raw_sub_batch_samplers.append(sub_batch_sampler)
def get_batch_samplers(self, mult_ratios, required_batch_size_multiple, seed):
self.get_raw_batch_samplers(required_batch_size_multiple, seed)
batch_samplers = []
for i, _ in enumerate(self.datasets):
if i > 0:
sub_batch_sampler = [
[y + self.cumulative_sizes[i - 1] for y in x]
for x in self.raw_sub_batch_samplers[i]
]
else:
sub_batch_sampler = list(self.raw_sub_batch_samplers[i])
smp_r = mult_ratios[i]
if smp_r != 1:
is_increase = "increased" if smp_r > 1 else "decreased"
logger.info(
"number of batch for the dataset {} is {} from {} to {}".format(
self.id_to_mode[i],
is_increase,
len(sub_batch_sampler),
int(len(sub_batch_sampler) * smp_r),
)
)
mul_samplers = []
for _ in range(math.floor(smp_r)):
mul_samplers = mul_samplers + sub_batch_sampler
if math.floor(smp_r) != smp_r:
with data_utils.numpy_seed(seed + self._cur_epoch):
np.random.shuffle(sub_batch_sampler)
smp_num = int(
(smp_r - math.floor(smp_r)) * len(sub_batch_sampler)
)
mul_samplers = mul_samplers + sub_batch_sampler[:smp_num]
sub_batch_sampler = mul_samplers
else:
logger.info(
"dataset {} batch number is {} ".format(
self.id_to_mode[i], len(sub_batch_sampler)
)
)
batch_samplers.append(sub_batch_sampler)
return batch_samplers
class LangPairMaskDataset(FairseqDataset):
def __init__(
self,
dataset: LanguagePairDataset,
src_eos: int,
src_bos: Optional[int] = None,
noise_id: Optional[int] = -1,
mask_ratio: Optional[float] = 0,
mask_type: Optional[str] = "random",
):
self.dataset = dataset
self.src_eos = src_eos
self.src_bos = src_bos
self.noise_id = noise_id
self.mask_ratio = mask_ratio
self.mask_type = mask_type
assert mask_type in ("random", "tail")
@property
def src_sizes(self):
return self.dataset.src_sizes
@property
def tgt_sizes(self):
return self.dataset.tgt_sizes
@property
def sizes(self):
# dataset.sizes can be a dynamically computed sizes:
return self.dataset.sizes
def get_batch_shapes(self):
if hasattr(self.dataset, "get_batch_shapes"):
return self.dataset.get_batch_shapes()
return self.dataset.buckets
def num_tokens_vec(self, indices):
return self.dataset.num_tokens_vec(indices)
def __len__(self):
return len(self.dataset)
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
return self.dataset.size(index)
def ordered_indices(self):
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)
def mask_src_tokens(self, sample):
src_item = sample["source"]
mask = None
if self.mask_type == "random":
mask = torch.rand(len(src_item)).le(self.mask_ratio)
else:
mask = torch.ones(len(src_item))
mask[: int(len(src_item) * (1 - self.mask_ratio))] = 0
mask = mask.eq(1)
if src_item[0] == self.src_bos:
mask[0] = False
if src_item[-1] == self.src_eos:
mask[-1] = False
mask_src_item = src_item.masked_fill(mask, self.noise_id)
smp = {"id": sample["id"], "source": mask_src_item, "target": sample["target"]}
return smp
def __getitem__(self, index):
sample = self.dataset[index]
if self.mask_ratio > 0:
sample = self.mask_src_tokens(sample)
return sample
def collater(self, samples, pad_to_length=None):
return self.dataset.collater(samples, pad_to_length)
class FileAudioDatasetWrapper(FileAudioDataset):
def collater(self, samples):
samples = super().collater(samples)
if len(samples) == 0:
return {}
samples["net_input"]["src_tokens"] = samples["net_input"]["source"]
samples["net_input"]["prev_output_tokens"] = None
del samples["net_input"]["source"]
samples["net_input"]["src_lengths"] = None
samples["net_input"]["alignment"] = None
return samples

View File

@@ -0,0 +1,393 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import sys
import io
import numpy as np
import torch
import torch.nn.functional as F
from .. import FairseqDataset
from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes
from fairseq.data.audio.audio_utils import (
parse_path,
read_from_stored_zip,
is_sf_audio_data,
)
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
logger = logging.getLogger(__name__)
class RawAudioDataset(FairseqDataset):
def __init__(
self,
sample_rate,
max_sample_size=None,
min_sample_size=0,
shuffle=True,
pad=False,
normalize=False,
compute_mask_indices=False,
**mask_compute_kwargs,
):
super().__init__()
self.sample_rate = sample_rate
self.sizes = []
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
self.min_sample_size = min_sample_size
self.pad = pad
self.shuffle = shuffle
self.normalize = normalize
self.compute_mask_indices = compute_mask_indices
if self.compute_mask_indices:
self.mask_compute_kwargs = mask_compute_kwargs
self._features_size_map = {}
self._C = mask_compute_kwargs["encoder_embed_dim"]
self._conv_feature_layers = eval(mask_compute_kwargs["conv_feature_layers"])
def __getitem__(self, index):
raise NotImplementedError()
def __len__(self):
return len(self.sizes)
def postprocess(self, feats, curr_sample_rate):
if feats.dim() == 2:
feats = feats.mean(-1)
if curr_sample_rate != self.sample_rate:
raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}")
assert feats.dim() == 1, feats.dim()
if self.normalize:
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
return feats
def crop_to_max_size(self, wav, target_size):
size = len(wav)
diff = size - target_size
if diff <= 0:
return wav
start = np.random.randint(0, diff + 1)
end = size - diff + start
return wav[start:end]
def _compute_mask_indices(self, dims, padding_mask):
B, T, C = dims
mask_indices, mask_channel_indices = None, None
if self.mask_compute_kwargs["mask_prob"] > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_compute_kwargs["mask_prob"],
self.mask_compute_kwargs["mask_length"],
self.mask_compute_kwargs["mask_selection"],
self.mask_compute_kwargs["mask_other"],
min_masks=2,
no_overlap=self.mask_compute_kwargs["no_mask_overlap"],
min_space=self.mask_compute_kwargs["mask_min_space"],
)
mask_indices = torch.from_numpy(mask_indices)
if self.mask_compute_kwargs["mask_channel_prob"] > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_compute_kwargs["mask_channel_prob"],
self.mask_compute_kwargs["mask_channel_length"],
self.mask_compute_kwargs["mask_channel_selection"],
self.mask_compute_kwargs["mask_channel_other"],
no_overlap=self.mask_compute_kwargs["no_mask_channel_overlap"],
min_space=self.mask_compute_kwargs["mask_channel_min_space"],
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices).unsqueeze(1).expand(-1, T, -1)
)
return mask_indices, mask_channel_indices
@staticmethod
def _bucket_tensor(tensor, num_pad, value):
return F.pad(tensor, (0, num_pad), value=value)
def collater(self, samples):
samples = [s for s in samples if s["source"] is not None]
if len(samples) == 0:
return {}
sources = [s["source"] for s in samples]
sizes = [len(s) for s in sources]
if self.pad:
target_size = min(max(sizes), self.max_sample_size)
else:
target_size = min(min(sizes), self.max_sample_size)
collated_sources = sources[0].new_zeros(len(sources), target_size)
padding_mask = (
torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None
)
for i, (source, size) in enumerate(zip(sources, sizes)):
diff = size - target_size
if diff == 0:
collated_sources[i] = source
elif diff < 0:
assert self.pad
collated_sources[i] = torch.cat(
[source, source.new_full((-diff,), 0.0)]
)
padding_mask[i, diff:] = True
else:
collated_sources[i] = self.crop_to_max_size(source, target_size)
input = {"source": collated_sources}
out = {"id": torch.LongTensor([s["id"] for s in samples])}
if self.pad:
input["padding_mask"] = padding_mask
if hasattr(self, "num_buckets") and self.num_buckets > 0:
assert self.pad, "Cannot bucket without padding first."
bucket = max(self._bucketed_sizes[s["id"]] for s in samples)
num_pad = bucket - collated_sources.size(-1)
if num_pad:
input["source"] = self._bucket_tensor(collated_sources, num_pad, 0)
input["padding_mask"] = self._bucket_tensor(padding_mask, num_pad, True)
if self.compute_mask_indices:
B = input["source"].size(0)
T = self._get_mask_indices_dims(input["source"].size(-1))
padding_mask_reshaped = input["padding_mask"].clone()
extra = padding_mask_reshaped.size(1) % T
if extra > 0:
padding_mask_reshaped = padding_mask_reshaped[:, :-extra]
padding_mask_reshaped = padding_mask_reshaped.view(
padding_mask_reshaped.size(0), T, -1
)
padding_mask_reshaped = padding_mask_reshaped.all(-1)
input["padding_count"] = padding_mask_reshaped.sum(-1).max().item()
mask_indices, mask_channel_indices = self._compute_mask_indices(
(B, T, self._C),
padding_mask_reshaped,
)
input["mask_indices"] = mask_indices
input["mask_channel_indices"] = mask_channel_indices
out["sample_size"] = mask_indices.sum().item()
out["net_input"] = input
return out
def _get_mask_indices_dims(self, size, padding=0, dilation=1):
if size not in self._features_size_map:
L_in = size
for (_, kernel_size, stride) in self._conv_feature_layers:
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
L_out = 1 + L_out // stride
L_in = L_out
self._features_size_map[size] = L_out
return self._features_size_map[size]
def num_tokens(self, index):
return self.size(index)
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
if self.pad:
return self.sizes[index]
return min(self.sizes[index], self.max_sample_size)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
order.append(
np.minimum(
np.array(self.sizes),
self.max_sample_size,
)
)
return np.lexsort(order)[::-1]
else:
return np.arange(len(self))
def set_bucket_info(self, num_buckets):
self.num_buckets = num_buckets
if self.num_buckets > 0:
self._collated_sizes = np.minimum(
np.array(self.sizes),
self.max_sample_size,
)
self.buckets = get_buckets(
self._collated_sizes,
self.num_buckets,
)
self._bucketed_sizes = get_bucketed_sizes(
self._collated_sizes, self.buckets
)
logger.info(
f"{len(self.buckets)} bucket(s) for the audio dataset: "
f"{self.buckets}"
)
class FileAudioDataset(RawAudioDataset):
def __init__(
self,
manifest_path,
sample_rate,
max_sample_size=None,
min_sample_size=0,
shuffle=True,
pad=False,
normalize=False,
num_buckets=0,
compute_mask_indices=False,
text_compression_level=TextCompressionLevel.none,
**mask_compute_kwargs,
):
super().__init__(
sample_rate=sample_rate,
max_sample_size=max_sample_size,
min_sample_size=min_sample_size,
shuffle=shuffle,
pad=pad,
normalize=normalize,
compute_mask_indices=compute_mask_indices,
**mask_compute_kwargs,
)
self.text_compressor = TextCompressor(level=text_compression_level)
skipped = 0
self.fnames = []
sizes = []
self.skipped_indices = set()
with open(manifest_path, "r") as f:
self.root_dir = f.readline().strip()
for i, line in enumerate(f):
items = line.strip().split("\t")
assert len(items) == 2, line
sz = int(items[1])
if min_sample_size is not None and sz < min_sample_size:
skipped += 1
self.skipped_indices.add(i)
continue
self.fnames.append(self.text_compressor.compress(items[0]))
sizes.append(sz)
logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples")
self.sizes = np.array(sizes, dtype=np.int64)
try:
import pyarrow
self.fnames = pyarrow.array(self.fnames)
except:
logger.debug(
"Could not create a pyarrow array. Please install pyarrow for better performance"
)
pass
self.set_bucket_info(num_buckets)
def __getitem__(self, index):
import soundfile as sf
fn = self.fnames[index]
fn = fn if isinstance(self.fnames, list) else fn.as_py()
fn = self.text_compressor.decompress(fn)
path_or_fp = os.path.join(self.root_dir, fn)
_path, slice_ptr = parse_path(path_or_fp)
if len(slice_ptr) == 2:
byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
assert is_sf_audio_data(byte_data)
path_or_fp = io.BytesIO(byte_data)
wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32")
feats = torch.from_numpy(wav).float()
feats = self.postprocess(feats, curr_sample_rate)
return {"id": index, "source": feats}
class BinarizedAudioDataset(RawAudioDataset):
def __init__(
self,
data_dir,
split,
sample_rate,
max_sample_size=None,
min_sample_size=0,
shuffle=True,
pad=False,
normalize=False,
num_buckets=0,
compute_mask_indices=False,
**mask_compute_kwargs,
):
super().__init__(
sample_rate=sample_rate,
max_sample_size=max_sample_size,
min_sample_size=min_sample_size,
shuffle=shuffle,
pad=pad,
normalize=normalize,
compute_mask_indices=compute_mask_indices,
**mask_compute_kwargs,
)
from fairseq.data import data_utils, Dictionary
self.fnames_dict = Dictionary.load(os.path.join(data_dir, "dict.txt"))
root_path = os.path.join(data_dir, f"{split}.root")
if os.path.exists(root_path):
with open(root_path, "r") as f:
self.root_dir = next(f).strip()
else:
self.root_dir = None
fnames_path = os.path.join(data_dir, split)
self.fnames = data_utils.load_indexed_dataset(fnames_path, self.fnames_dict)
lengths_path = os.path.join(data_dir, f"{split}.lengths")
with open(lengths_path, "r") as f:
for line in f:
sz = int(line.rstrip())
assert (
sz >= min_sample_size
), f"Min sample size is not supported for binarized dataset, but found a sample with size {sz}"
self.sizes.append(sz)
self.sizes = np.array(self.sizes, dtype=np.int64)
self.set_bucket_info(num_buckets)
logger.info(f"loaded {len(self.fnames)} samples")
def __getitem__(self, index):
import soundfile as sf
fname = self.fnames_dict.string(self.fnames[index], separator="")
if self.root_dir:
fname = os.path.join(self.root_dir, fname)
wav, curr_sample_rate = sf.read(fname)
feats = torch.from_numpy(wav).float()
feats = self.postprocess(feats, curr_sample_rate)
return {"id": index, "source": feats}

View File

@@ -0,0 +1,379 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
from fairseq.data import ConcatDataset, Dictionary
from fairseq.data import data_utils as fairseq_data_utils
from fairseq.data.audio.audio_utils import get_features_or_waveform
from fairseq.data.audio.data_cfg import S2SDataConfig
from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDataset,
SpeechToTextDatasetCreator,
TextTargetMultitaskData,
_collate_frames,
)
logger = logging.getLogger(__name__)
@dataclass
class SpeechToSpeechDatasetItem(object):
index: int
source: torch.Tensor
target: Optional[torch.Tensor] = None
target_speaker: Optional[torch.Tensor] = None
tgt_lang_tag: Optional[int] = None
class SpeechToSpeechDataset(SpeechToTextDataset):
def __init__(
self,
split: str,
is_train_split: bool,
data_cfg: S2SDataConfig,
src_audio_paths: List[str],
src_n_frames: List[int],
tgt_audio_paths: List[str],
tgt_n_frames: List[int],
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
target_is_code: bool = False,
tgt_dict: Dictionary = None,
n_frames_per_step: int = 1,
):
tgt_texts = tgt_audio_paths if target_is_code else None
super().__init__(
split=split,
is_train_split=is_train_split,
cfg=data_cfg,
audio_paths=src_audio_paths,
n_frames=src_n_frames,
ids=ids,
tgt_dict=tgt_dict,
tgt_texts=tgt_texts,
src_langs=src_langs,
tgt_langs=tgt_langs,
n_frames_per_step=n_frames_per_step,
)
self.tgt_audio_paths = tgt_audio_paths
self.tgt_lens = [t // self.n_frames_per_step for t in tgt_n_frames]
assert not target_is_code or tgt_dict is not None
self.target_is_code = target_is_code
assert len(tgt_audio_paths) == self.n_samples
assert len(tgt_n_frames) == self.n_samples
self.tgt_speakers = None
if self.cfg.target_speaker_embed:
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(
self.cfg.target_speaker_embed, split
)
spk_emb_dict = {s["id"]: s["speaker_embed"] for s in samples}
self.tgt_speakers = [spk_emb_dict[id] for id in self.ids]
assert len(self.tgt_speakers) == self.n_samples
logger.info(self.__repr__())
def pack_units(self, input: torch.Tensor) -> torch.Tensor:
if self.n_frames_per_step <= 1:
return input
offset = 4
vocab_size = (
len(self.tgt_dict) - offset
) # remove offset from <bos>, <pad>, <eos>, <unk>, which is specific to fairseq dictionary
assert input.dim() == 1
stacked_input = (
input[:-1].view(-1, self.n_frames_per_step) - offset
) # remove <eos>
scale = [
pow(vocab_size, self.n_frames_per_step - 1 - i)
for i in range(self.n_frames_per_step)
]
scale = torch.LongTensor(scale).squeeze(0)
res = input.new((len(input) - 1) // self.n_frames_per_step + 1).fill_(input[-1])
res[:-1] = (stacked_input * scale).sum(dim=1) + offset
return res
def __getitem__(self, index: int) -> SpeechToSpeechDatasetItem:
source = self._get_source_audio(index)
tgt_lang_tag = None
if self.cfg.prepend_tgt_lang_tag_as_bos:
# prepend_tgt_lang_tag_as_bos: put tgt_lang_tag as bos of target
tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
if not self.target_is_code:
target = get_features_or_waveform(self.tgt_audio_paths[index])
target = torch.from_numpy(target).float()
target = self.pack_frames(target)
else:
target = self.tgt_dict.encode_line(
self.tgt_audio_paths[index],
add_if_not_exist=False,
append_eos=True,
).long()
if self.n_frames_per_step > 1:
n_tgt_frame = target.size(0) - 1 # exclude <eos>
keep_n_tgt_frame = n_tgt_frame - n_tgt_frame % self.n_frames_per_step
target = torch.cat(
(
target[:keep_n_tgt_frame],
target.new_full((1,), self.tgt_dict.eos()),
),
dim=0,
)
if self.tgt_speakers:
tgt_spk = get_features_or_waveform(self.tgt_speakers[index])
tgt_spk = torch.from_numpy(tgt_spk).float()
else:
tgt_spk = torch.FloatTensor([])
return SpeechToSpeechDatasetItem(
index=index,
source=source,
target=target,
target_speaker=tgt_spk,
tgt_lang_tag=tgt_lang_tag,
)
def _collate_target(self, samples: List[SpeechToSpeechDatasetItem]) -> torch.Tensor:
if self.target_is_code:
target = fairseq_data_utils.collate_tokens(
[x.target for x in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
)
# convert stacked units to a single id
pack_targets = [self.pack_units(x.target) for x in samples]
prev_output_tokens = fairseq_data_utils.collate_tokens(
pack_targets,
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=True,
)
target_lengths = torch.tensor(
[x.size(0) for x in pack_targets], dtype=torch.long
)
else:
target = _collate_frames([x.target for x in samples], is_audio_input=False)
bsz, _, d = target.size()
prev_output_tokens = torch.cat(
(target.new_full((bsz, 1, d), 0.0), target[:, :-1, :]), dim=1
)
target_lengths = torch.tensor(
[x.target.size(0) for x in samples], dtype=torch.long
)
return target, prev_output_tokens, target_lengths
def collater(
self, samples: List[SpeechToSpeechDatasetItem], return_order: bool = False
) -> Dict:
if len(samples) == 0:
return {}
indices = torch.tensor([x.index for x in samples], dtype=torch.long)
frames = _collate_frames([x.source for x in samples], self.cfg.use_audio_input)
# sort samples by descending number of frames
n_frames = torch.tensor([x.source.size(0) for x in samples], dtype=torch.long)
n_frames, order = n_frames.sort(descending=True)
indices = indices.index_select(0, order)
frames = frames.index_select(0, order)
target, prev_output_tokens, target_lengths = self._collate_target(samples)
target = target.index_select(0, order)
target_lengths = target_lengths.index_select(0, order)
prev_output_tokens = prev_output_tokens.index_select(0, order)
ntokens = sum(x.target.size(0) for x in samples)
tgt_speakers = None
if self.cfg.target_speaker_embed:
tgt_speakers = _collate_frames(
[x.target_speaker for x in samples], is_audio_input=True
).index_select(0, order)
net_input = {
"src_tokens": frames,
"src_lengths": n_frames,
"prev_output_tokens": prev_output_tokens,
"tgt_speaker": tgt_speakers, # TODO: unify "speaker" and "tgt_speaker"
}
if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
for i in range(len(samples)):
net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
out = {
"id": indices,
"net_input": net_input,
"speaker": tgt_speakers, # to support Tacotron2 loss for speech-to-spectrogram model
"target": target,
"target_lengths": target_lengths,
"ntokens": ntokens,
"nsentences": len(samples),
}
if return_order:
out["order"] = order
return out
class SpeechToSpeechMultitaskDataset(SpeechToSpeechDataset):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.multitask_data = {}
def add_multitask_dataset(self, task_name, task_data):
self.multitask_data[task_name] = task_data
def __getitem__(
self, index: int
) -> Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]:
s2s_data = super().__getitem__(index)
multitask_target = {}
sample_id = self.ids[index]
tgt_lang = self.tgt_langs[index]
for task_name, task_dataset in self.multitask_data.items():
multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
return s2s_data, multitask_target
def collater(
self, samples: List[Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]]
) -> Dict:
if len(samples) == 0:
return {}
out = super().collater([s for s, _ in samples], return_order=True)
order = out["order"]
del out["order"]
for task_name, task_dataset in self.multitask_data.items():
if "multitask" not in out:
out["multitask"] = {}
d = [s[task_name] for _, s in samples]
task_target = task_dataset.collater(d)
out["multitask"][task_name] = {
"target": task_target["target"].index_select(0, order),
"target_lengths": task_target["target_lengths"].index_select(0, order),
"ntokens": task_target["ntokens"],
}
out["multitask"][task_name]["net_input"] = {
"prev_output_tokens": task_target["prev_output_tokens"].index_select(
0, order
),
}
return out
class SpeechToSpeechDatasetCreator(object):
# mandatory columns
KEY_ID, KEY_SRC_AUDIO, KEY_SRC_N_FRAMES = "id", "src_audio", "src_n_frames"
KEY_TGT_AUDIO, KEY_TGT_N_FRAMES = "tgt_audio", "tgt_n_frames"
# optional columns
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
# default values
DEFAULT_LANG = ""
@classmethod
def _from_list(
cls,
split_name: str,
is_train_split,
samples: List[Dict],
data_cfg: S2SDataConfig,
target_is_code: bool = False,
tgt_dict: Dictionary = None,
n_frames_per_step: int = 1,
multitask: Optional[Dict] = None,
) -> SpeechToSpeechDataset:
audio_root = Path(data_cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples]
src_audio_paths = [
(audio_root / s[cls.KEY_SRC_AUDIO]).as_posix() for s in samples
]
tgt_audio_paths = [
s[cls.KEY_TGT_AUDIO]
if target_is_code
else (audio_root / s[cls.KEY_TGT_AUDIO]).as_posix()
for s in samples
]
src_n_frames = [int(s[cls.KEY_SRC_N_FRAMES]) for s in samples]
tgt_n_frames = [int(s[cls.KEY_TGT_N_FRAMES]) for s in samples]
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
has_multitask = multitask is not None and len(multitask.keys()) > 0
dataset_cls = (
SpeechToSpeechMultitaskDataset if has_multitask else SpeechToSpeechDataset
)
ds = dataset_cls(
split=split_name,
is_train_split=is_train_split,
data_cfg=data_cfg,
src_audio_paths=src_audio_paths,
src_n_frames=src_n_frames,
tgt_audio_paths=tgt_audio_paths,
tgt_n_frames=tgt_n_frames,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
target_is_code=target_is_code,
tgt_dict=tgt_dict,
n_frames_per_step=n_frames_per_step,
)
if has_multitask:
for task_name, task_obj in multitask.items():
task_data = TextTargetMultitaskData(
task_obj.args, split_name, task_obj.target_dictionary
)
ds.add_multitask_dataset(task_name, task_data)
return ds
@classmethod
def from_tsv(
cls,
root: str,
data_cfg: S2SDataConfig,
splits: str,
is_train_split: bool,
epoch: int,
seed: int,
target_is_code: bool = False,
tgt_dict: Dictionary = None,
n_frames_per_step: int = 1,
multitask: Optional[Dict] = None,
) -> SpeechToSpeechDataset:
datasets = []
for split in splits.split(","):
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(root, split)
ds = cls._from_list(
split_name=split,
is_train_split=is_train_split,
samples=samples,
data_cfg=data_cfg,
target_is_code=target_is_code,
tgt_dict=tgt_dict,
n_frames_per_step=n_frames_per_step,
multitask=multitask,
)
datasets.append(ds)
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]

View File

@@ -0,0 +1,733 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import csv
import logging
import re
from argparse import Namespace
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset
from fairseq.data import data_utils as fairseq_data_utils
from fairseq.data import encoders
from fairseq.data.audio.audio_utils import get_features_or_waveform
from fairseq.data.audio.data_cfg import S2TDataConfig
from fairseq.data.audio.dataset_transforms import CompositeAudioDatasetTransform
from fairseq.data.audio.dataset_transforms.concataugment import ConcatAugment
from fairseq.data.audio.dataset_transforms.noisyoverlapaugment import (
NoisyOverlapAugment,
)
from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform
from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
logger = logging.getLogger(__name__)
def _collate_frames(
frames: List[torch.Tensor], is_audio_input: bool = False
) -> torch.Tensor:
"""
Convert a list of 2D frames into a padded 3D tensor
Args:
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
length of i-th frame and f_dim is static dimension of features
Returns:
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
"""
max_len = max(frame.size(0) for frame in frames)
if is_audio_input:
out = frames[0].new_zeros((len(frames), max_len))
else:
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
for i, v in enumerate(frames):
out[i, : v.size(0)] = v
return out
def _is_int_or_np_int(n):
return isinstance(n, int) or (
isinstance(n, np.generic) and isinstance(n.item(), int)
)
@dataclass
class SpeechToTextDatasetItem(object):
index: int
source: torch.Tensor
target: Optional[torch.Tensor] = None
speaker_id: Optional[int] = None
class SpeechToTextDataset(FairseqDataset):
LANG_TAG_TEMPLATE = "<lang:{}>"
def __init__(
self,
split: str,
is_train_split: bool,
cfg: S2TDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
n_frames_per_step=1,
speaker_to_id=None,
append_eos=True,
):
self.split, self.is_train_split = split, is_train_split
self.cfg = cfg
self.audio_paths, self.n_frames = audio_paths, n_frames
self.n_samples = len(audio_paths)
assert len(n_frames) == self.n_samples > 0
assert src_texts is None or len(src_texts) == self.n_samples
assert tgt_texts is None or len(tgt_texts) == self.n_samples
assert speakers is None or len(speakers) == self.n_samples
assert src_langs is None or len(src_langs) == self.n_samples
assert tgt_langs is None or len(tgt_langs) == self.n_samples
assert ids is None or len(ids) == self.n_samples
assert (tgt_dict is None and tgt_texts is None) or (
tgt_dict is not None and tgt_texts is not None
)
self.src_texts, self.tgt_texts = src_texts, tgt_texts
self.src_langs, self.tgt_langs = src_langs, tgt_langs
self.speakers = speakers
self.tgt_dict = tgt_dict
self.check_tgt_lang_tag()
self.ids = ids
self.shuffle = cfg.shuffle if is_train_split else False
self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
self.cfg.get_feature_transforms(split, is_train_split)
)
self.waveform_transforms = CompositeAudioWaveformTransform.from_config_dict(
self.cfg.get_waveform_transforms(split, is_train_split)
)
# TODO: add these to data_cfg.py
self.dataset_transforms = CompositeAudioDatasetTransform.from_config_dict(
self.cfg.get_dataset_transforms(split, is_train_split)
)
# check proper usage of transforms
if self.feature_transforms and self.cfg.use_audio_input:
logger.warning(
"Feature transforms will not be applied. To use feature transforms, "
"set use_audio_input as False in config."
)
self.pre_tokenizer = pre_tokenizer
self.bpe_tokenizer = bpe_tokenizer
self.n_frames_per_step = n_frames_per_step
self.speaker_to_id = speaker_to_id
self.tgt_lens = self.get_tgt_lens_and_check_oov()
self.append_eos = append_eos
logger.info(self.__repr__())
def get_tgt_lens_and_check_oov(self):
if self.tgt_texts is None:
return [0 for _ in range(self.n_samples)]
tgt_lens = []
n_tokens, n_oov_tokens = 0, 0
for i in range(self.n_samples):
tokenized = self.get_tokenized_tgt_text(i).split(" ")
oov_tokens = [
t
for t in tokenized
if self.tgt_dict.index(t) == self.tgt_dict.unk_index
]
n_tokens += len(tokenized)
n_oov_tokens += len(oov_tokens)
tgt_lens.append(len(tokenized))
logger.info(f"'{self.split}' has {n_oov_tokens / n_tokens * 100:.2f}% OOV")
return tgt_lens
def __repr__(self):
return (
self.__class__.__name__
+ f'(split="{self.split}", n_samples={self.n_samples:_}, '
f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, "
f"n_frames_per_step={self.n_frames_per_step}, "
f"shuffle={self.shuffle}, "
f"feature_transforms={self.feature_transforms}, "
f"waveform_transforms={self.waveform_transforms}, "
f"dataset_transforms={self.dataset_transforms})"
)
@classmethod
def is_lang_tag(cls, token):
pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
return re.match(pattern, token)
def check_tgt_lang_tag(self):
if self.cfg.prepend_tgt_lang_tag:
assert self.tgt_langs is not None and self.tgt_dict is not None
tgt_lang_tags = [
self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
]
assert all(t in self.tgt_dict for t in tgt_lang_tags)
@classmethod
def tokenize(cls, tokenizer, text: str):
return text if tokenizer is None else tokenizer.encode(text)
def get_tokenized_tgt_text(self, index: Union[int, List[int]]):
if _is_int_or_np_int(index):
text = self.tgt_texts[index]
else:
text = " ".join([self.tgt_texts[i] for i in index])
text = self.tokenize(self.pre_tokenizer, text)
text = self.tokenize(self.bpe_tokenizer, text)
return text
def pack_frames(self, feature: torch.Tensor):
if self.n_frames_per_step == 1:
return feature
n_packed_frames = feature.shape[0] // self.n_frames_per_step
feature = feature[: self.n_frames_per_step * n_packed_frames]
return feature.reshape(n_packed_frames, -1)
@classmethod
def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary):
lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang))
assert lang_tag_idx != dictionary.unk()
return lang_tag_idx
def _get_source_audio(self, index: Union[int, List[int]]) -> torch.Tensor:
"""
Gives source audio for given index with any relevant transforms
applied. For ConcatAug, source audios for given indices are
concatenated in given order.
Args:
index (int or List[int]): index—or in the case of ConcatAug,
indices—to pull the source audio for
Returns:
source audios concatenated for given indices with
relevant transforms appplied
"""
if _is_int_or_np_int(index):
source = get_features_or_waveform(
self.audio_paths[index],
need_waveform=self.cfg.use_audio_input,
use_sample_rate=self.cfg.use_sample_rate,
waveform_transforms=self.waveform_transforms,
)
else:
source = np.concatenate(
[
get_features_or_waveform(
self.audio_paths[i],
need_waveform=self.cfg.use_audio_input,
use_sample_rate=self.cfg.use_sample_rate,
waveform_transforms=self.waveform_transforms,
)
for i in index
]
)
if self.cfg.use_audio_input:
source = torch.from_numpy(source).float()
if self.cfg.standardize_audio:
with torch.no_grad():
source = F.layer_norm(source, source.shape)
else:
if self.feature_transforms is not None:
source = self.feature_transforms(source)
source = torch.from_numpy(source).float()
return source
def __getitem__(self, index: int) -> SpeechToTextDatasetItem:
has_concat = self.dataset_transforms.has_transform(ConcatAugment)
if has_concat:
concat = self.dataset_transforms.get_transform(ConcatAugment)
indices = concat.find_indices(index, self.n_frames, self.n_samples)
source = self._get_source_audio(indices if has_concat else index)
source = self.pack_frames(source)
target = None
if self.tgt_texts is not None:
tokenized = self.get_tokenized_tgt_text(indices if has_concat else index)
target = self.tgt_dict.encode_line(
tokenized, add_if_not_exist=False, append_eos=self.append_eos
).long()
if self.cfg.prepend_tgt_lang_tag:
lang_tag_idx = self.get_lang_tag_idx(
self.tgt_langs[index], self.tgt_dict
)
target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
if self.cfg.prepend_bos_and_append_tgt_lang_tag:
bos = torch.LongTensor([self.tgt_dict.bos()])
lang_tag_idx = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
assert lang_tag_idx != self.tgt_dict.unk()
lang_tag_idx = torch.LongTensor([lang_tag_idx])
target = torch.cat((bos, target, lang_tag_idx), 0)
speaker_id = None
if self.speaker_to_id is not None:
speaker_id = self.speaker_to_id[self.speakers[index]]
return SpeechToTextDatasetItem(
index=index, source=source, target=target, speaker_id=speaker_id
)
def __len__(self):
return self.n_samples
def collater(
self, samples: List[SpeechToTextDatasetItem], return_order: bool = False
) -> Dict:
if len(samples) == 0:
return {}
indices = torch.tensor([x.index for x in samples], dtype=torch.long)
sources = [x.source for x in samples]
has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment)
if has_NOAug and self.cfg.use_audio_input:
NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment)
sources = NOAug(sources)
frames = _collate_frames(sources, self.cfg.use_audio_input)
# sort samples by descending number of frames
n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long)
n_frames, order = n_frames.sort(descending=True)
indices = indices.index_select(0, order)
frames = frames.index_select(0, order)
target, target_lengths = None, None
prev_output_tokens = None
ntokens = None
if self.tgt_texts is not None:
target = fairseq_data_utils.collate_tokens(
[x.target for x in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
)
target = target.index_select(0, order)
target_lengths = torch.tensor(
[x.target.size(0) for x in samples], dtype=torch.long
).index_select(0, order)
prev_output_tokens = fairseq_data_utils.collate_tokens(
[x.target for x in samples],
self.tgt_dict.pad(),
eos_idx=None,
left_pad=False,
move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, order)
ntokens = sum(x.target.size(0) for x in samples)
speaker = None
if self.speaker_to_id is not None:
speaker = (
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
.index_select(0, order)
.view(-1, 1)
)
net_input = {
"src_tokens": frames,
"src_lengths": n_frames,
"prev_output_tokens": prev_output_tokens,
}
out = {
"id": indices,
"net_input": net_input,
"speaker": speaker,
"target": target,
"target_lengths": target_lengths,
"ntokens": ntokens,
"nsentences": len(samples),
}
if return_order:
out["order"] = order
return out
def num_tokens(self, index):
return self.n_frames[index]
def size(self, index):
return self.n_frames[index], self.tgt_lens[index]
@property
def sizes(self):
return np.array(self.n_frames)
@property
def can_reuse_epoch_itr_across_epochs(self):
return True
def ordered_indices(self):
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
# first by descending order of # of frames then by original/random order
order.append([-n for n in self.n_frames])
return np.lexsort(order)
def prefetch(self, indices):
raise False
class TextTargetMultitaskData(object):
# mandatory columns
KEY_ID, KEY_TEXT = "id", "tgt_text"
LANG_TAG_TEMPLATE = "<lang:{}>"
def __init__(self, args, split, tgt_dict):
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split)
self.data = {s[self.KEY_ID]: s[self.KEY_TEXT] for s in samples}
self.dict = tgt_dict
self.append_eos = args.decoder_type != "ctc"
self.pre_tokenizer = self.build_tokenizer(args)
self.bpe_tokenizer = self.build_bpe(args)
self.prepend_bos_and_append_tgt_lang_tag = (
args.prepend_bos_and_append_tgt_lang_tag
)
self.eos_token = args.eos_token
self.lang_tag_mapping = args.get_lang_tag_mapping
@classmethod
def is_lang_tag(cls, token):
pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
return re.match(pattern, token)
@classmethod
def tokenize(cls, tokenizer, text: str):
return text if tokenizer is None else tokenizer.encode(text)
def get_tokenized_tgt_text(self, index: int):
text = self.tokenize(self.pre_tokenizer, self.data[index])
text = self.tokenize(self.bpe_tokenizer, text)
return text
def get_lang_tag_idx(self, lang: str, dictionary: Dictionary):
lang_tag = self.LANG_TAG_TEMPLATE.format(lang)
lang_tag = self.lang_tag_mapping.get(lang_tag, lang_tag)
lang_tag_idx = dictionary.index(lang_tag)
assert lang_tag_idx != dictionary.unk(), (lang, lang_tag)
return lang_tag_idx
def build_tokenizer(self, args):
pre_tokenizer = args.config.get("pre_tokenizer")
if pre_tokenizer is not None:
logger.info(f"pre-tokenizer: {pre_tokenizer}")
return encoders.build_tokenizer(Namespace(**pre_tokenizer))
else:
return None
def build_bpe(self, args):
bpe_tokenizer = args.config.get("bpe_tokenizer")
if bpe_tokenizer is not None:
logger.info(f"tokenizer: {bpe_tokenizer}")
return encoders.build_bpe(Namespace(**bpe_tokenizer))
else:
return None
def get(self, sample_id, tgt_lang=None):
if sample_id in self.data:
tokenized = self.get_tokenized_tgt_text(sample_id)
target = self.dict.encode_line(
tokenized,
add_if_not_exist=False,
append_eos=self.append_eos,
)
if self.prepend_bos_and_append_tgt_lang_tag:
bos = torch.LongTensor([self.dict.bos()])
lang_tag_idx = self.get_lang_tag_idx(tgt_lang, self.dict)
assert lang_tag_idx != self.dict.unk()
lang_tag_idx = torch.LongTensor([lang_tag_idx])
target = torch.cat((bos, target, lang_tag_idx), 0)
return target
else:
logger.warning(f"no target for {sample_id}")
return torch.IntTensor([])
def collater(self, samples: List[torch.Tensor]) -> torch.Tensor:
out = fairseq_data_utils.collate_tokens(
samples,
self.dict.pad(),
eos_idx=None,
left_pad=False,
move_eos_to_beginning=False,
).long()
prev_out = fairseq_data_utils.collate_tokens(
samples,
self.dict.pad(),
eos_idx=None,
left_pad=False,
move_eos_to_beginning=True,
).long()
target_lengths = torch.tensor([t.size(0) for t in samples], dtype=torch.long)
ntokens = sum(t.size(0) for t in samples)
output = {
"prev_output_tokens": prev_out,
"target": out,
"target_lengths": target_lengths,
"ntokens": ntokens,
}
return output
class SpeechToTextMultitaskDataset(SpeechToTextDataset):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.multitask_data = {}
def add_multitask_dataset(self, task_name, task_data):
self.multitask_data[task_name] = task_data
def __getitem__(
self, index: int
) -> Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]:
s2t_data = super().__getitem__(index)
multitask_target = {}
sample_id = self.ids[index]
tgt_lang = self.tgt_langs[index]
for task_name, task_dataset in self.multitask_data.items():
multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
return s2t_data, multitask_target
def collater(
self, samples: List[Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]]
) -> Dict:
if len(samples) == 0:
return {}
out = super().collater([s for s, _ in samples], return_order=True)
order = out["order"]
del out["order"]
for task_name, task_dataset in self.multitask_data.items():
if "multitask" not in out:
out["multitask"] = {}
d = [s[task_name] for _, s in samples]
task_target = task_dataset.collater(d)
out["multitask"][task_name] = {
"target": task_target["target"].index_select(0, order),
"target_lengths": task_target["target_lengths"].index_select(0, order),
"ntokens": task_target["ntokens"],
}
out["multitask"][task_name]["net_input"] = {
"prev_output_tokens": task_target["prev_output_tokens"].index_select(
0, order
),
}
return out
class SpeechToTextDatasetCreator(object):
# mandatory columns
KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames"
KEY_TGT_TEXT = "tgt_text"
# optional columns
KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text"
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
# default values
DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = ""
@classmethod
def _from_list(
cls,
split_name: str,
is_train_split,
samples: List[Dict],
cfg: S2TDataConfig,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
multitask: Optional[Dict] = None,
) -> SpeechToTextDataset:
audio_root = Path(cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples]
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
has_multitask = multitask is not None and len(multitask.keys()) > 0
dataset_cls = (
SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset
)
ds = dataset_cls(
split=split_name,
is_train_split=is_train_split,
cfg=cfg,
audio_paths=audio_paths,
n_frames=n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
)
if has_multitask:
for task_name, task_obj in multitask.items():
task_data = TextTargetMultitaskData(
task_obj.args, split_name, task_obj.target_dictionary
)
ds.add_multitask_dataset(task_name, task_data)
return ds
@classmethod
def get_size_ratios(
cls, datasets: List[SpeechToTextDataset], alpha: float = 1.0
) -> List[float]:
"""Size ratios for temperature-based sampling
(https://arxiv.org/abs/1907.05019)"""
id_to_lp, lp_to_sz = {}, defaultdict(int)
for ds in datasets:
lang_pairs = {f"{s}->{t}" for s, t in zip(ds.src_langs, ds.tgt_langs)}
assert len(lang_pairs) == 1
lang_pair = list(lang_pairs)[0]
id_to_lp[ds.split] = lang_pair
lp_to_sz[lang_pair] += sum(ds.n_frames)
sz_sum = sum(v for v in lp_to_sz.values())
lp_to_prob = {k: v / sz_sum for k, v in lp_to_sz.items()}
lp_to_tgt_prob = {k: v**alpha for k, v in lp_to_prob.items()}
prob_sum = sum(v for v in lp_to_tgt_prob.values())
lp_to_tgt_prob = {k: v / prob_sum for k, v in lp_to_tgt_prob.items()}
lp_to_sz_ratio = {
k: (lp_to_tgt_prob[k] * sz_sum) / v for k, v in lp_to_sz.items()
}
size_ratio = [lp_to_sz_ratio[id_to_lp[ds.split]] for ds in datasets]
p_formatted = {
k: f"{lp_to_prob[k]:.3f}->{lp_to_tgt_prob[k]:.3f}" for k in lp_to_sz
}
logger.info(f"sampling probability balancing: {p_formatted}")
sr_formatted = {ds.split: f"{r:.3f}" for ds, r in zip(datasets, size_ratio)}
logger.info(f"balanced sampling size ratio: {sr_formatted}")
return size_ratio
@classmethod
def _load_samples_from_tsv(cls, root: str, split: str):
tsv_path = Path(root) / f"{split}.tsv"
if not tsv_path.is_file():
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
with open(tsv_path) as f:
reader = csv.DictReader(
f,
delimiter="\t",
quotechar=None,
doublequote=False,
lineterminator="\n",
quoting=csv.QUOTE_NONE,
)
samples = [dict(e) for e in reader]
if len(samples) == 0:
raise ValueError(f"Empty manifest: {tsv_path}")
return samples
@classmethod
def _from_tsv(
cls,
root: str,
cfg: S2TDataConfig,
split: str,
tgt_dict,
is_train_split: bool,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
multitask: Optional[Dict] = None,
) -> SpeechToTextDataset:
samples = cls._load_samples_from_tsv(root, split)
return cls._from_list(
split,
is_train_split,
samples,
cfg,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
multitask,
)
@classmethod
def from_tsv(
cls,
root: str,
cfg: S2TDataConfig,
splits: str,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
is_train_split: bool,
epoch: int,
seed: int,
n_frames_per_step: int = 1,
speaker_to_id=None,
multitask: Optional[Dict] = None,
) -> SpeechToTextDataset:
datasets = [
cls._from_tsv(
root=root,
cfg=cfg,
split=split,
tgt_dict=tgt_dict,
is_train_split=is_train_split,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
multitask=multitask,
)
for split in splits.split(",")
]
if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
# temperature-based sampling
size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
datasets = [
ResamplingDataset(
d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
)
for r, d in zip(size_ratios, datasets)
]
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]

View File

@@ -0,0 +1,359 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from pathlib import Path
from typing import Dict, List, NamedTuple, Optional
import torch
from fairseq.data import ConcatDataset, Dictionary, ResamplingDataset
from fairseq.data import data_utils as fairseq_data_utils
from fairseq.data.audio.speech_to_text_dataset import (
S2TDataConfig,
SpeechToTextDataset,
SpeechToTextDatasetCreator,
)
logger = logging.getLogger(__name__)
class S2TJointDataConfig(S2TDataConfig):
"""Wrapper class for data config YAML"""
@property
def src_vocab_filename(self):
"""fairseq vocabulary file under data root"""
return self.config.get("src_vocab_filename", "src_dict.txt")
@property
def src_pre_tokenizer(self) -> Dict:
"""Pre-tokenizer to apply before subword tokenization. Returning
a dictionary with `tokenizer` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
return self.config.get("src_pre_tokenizer", {"tokenizer": None})
@property
def src_bpe_tokenizer(self) -> Dict:
"""Subword tokenizer to apply on source text after pre-tokenization.
Returning a dictionary with `bpe` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
return self.config.get("src_bpe_tokenizer", {"bpe": None})
@property
def prepend_tgt_lang_tag_no_change(self) -> bool:
"""Prepend target lang ID token as the prev_output_tokens BOS (e.g. for
to-many multilingual setting). No change needed during inference.
This option is deprecated and replaced by prepend_tgt_lang_tag_as_bos.
"""
value = self.config.get("prepend_tgt_lang_tag_no_change", None)
if value is None:
return self.config.get("prepend_tgt_lang_tag_as_bos", False)
return value
@property
def sampling_text_alpha(self):
"""Hyper-parameter alpha = 1/T for temperature-based resampling. (text
input only) (alpha = 1 for no resampling)"""
return self.config.get("sampling_text_alpha", 1.0)
class SpeechToTextJointDatasetItem(NamedTuple):
index: int
source: torch.Tensor
target: Optional[torch.Tensor] = None
src_txt_tokens: Optional[torch.Tensor] = None
tgt_lang_tag: Optional[int] = None
src_lang_tag: Optional[int] = None
tgt_alignment: Optional[torch.Tensor] = None
# use_src_lang_id:
# 0: don't use src_lang_id
# 1: attach src_lang_id to the src_txt_tokens as eos
class SpeechToTextJointDataset(SpeechToTextDataset):
def __init__(
self,
split: str,
is_train_split: bool,
cfg: S2TJointDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
src_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
src_pre_tokenizer=None,
src_bpe_tokenizer=None,
append_eos: Optional[bool] = True,
alignment: Optional[List[str]] = None,
use_src_lang_id: Optional[int] = 0,
):
super().__init__(
split,
is_train_split,
cfg,
audio_paths,
n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
append_eos=append_eos,
)
self.src_dict = src_dict
self.src_pre_tokenizer = src_pre_tokenizer
self.src_bpe_tokenizer = src_bpe_tokenizer
self.alignment = None
self.use_src_lang_id = use_src_lang_id
if alignment is not None:
self.alignment = [
[float(s) for s in sample.split()] for sample in alignment
]
def get_tokenized_src_text(self, index: int):
text = self.tokenize(self.src_pre_tokenizer, self.src_texts[index])
text = self.tokenize(self.src_bpe_tokenizer, text)
return text
def __getitem__(self, index: int) -> SpeechToTextJointDatasetItem:
s2t_dataset_item = super().__getitem__(index)
src_tokens = None
src_lang_tag = None
if self.src_texts is not None and self.src_dict is not None:
src_tokens = self.get_tokenized_src_text(index)
src_tokens = self.src_dict.encode_line(
src_tokens, add_if_not_exist=False, append_eos=True
).long()
if self.use_src_lang_id > 0:
src_lang_tag = self.get_lang_tag_idx(
self.src_langs[index], self.src_dict
)
tgt_lang_tag = None
if self.cfg.prepend_tgt_lang_tag_no_change:
# prepend_tgt_lang_tag_no_change: modify prev_output_tokens instead
tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
ali = None
if self.alignment is not None:
ali = torch.Tensor(self.alignment[index]).float()
return SpeechToTextJointDatasetItem(
index=index,
source=s2t_dataset_item.source,
target=s2t_dataset_item.target,
src_txt_tokens=src_tokens,
tgt_lang_tag=tgt_lang_tag,
src_lang_tag=src_lang_tag,
tgt_alignment=ali,
)
def __len__(self):
return self.n_samples
def collater(self, samples: List[SpeechToTextJointDatasetItem]) -> Dict:
s2t_out = super().collater(samples, return_order=True)
if s2t_out == {}:
return s2t_out
net_input, order = s2t_out["net_input"], s2t_out["order"]
if self.src_texts is not None and self.src_dict is not None:
src_txt_tokens = fairseq_data_utils.collate_tokens(
[x.src_txt_tokens for x in samples],
self.src_dict.pad(),
self.src_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
)
src_txt_lengths = torch.tensor(
[x.src_txt_tokens.size()[0] for x in samples], dtype=torch.long
)
if self.use_src_lang_id > 0:
src_lang_idxs = torch.tensor(
[s.src_lang_tag for s in samples], dtype=src_txt_tokens.dtype
)
if self.use_src_lang_id == 1: # replace eos with lang_id
eos_idx = src_txt_lengths - 1
src_txt_tokens.scatter_(
1, eos_idx.view(-1, 1), src_lang_idxs.view(-1, 1)
)
else:
raise NotImplementedError("Implementation is required")
src_txt_tokens = src_txt_tokens.index_select(0, order)
src_txt_lengths = src_txt_lengths.index_select(0, order)
net_input["src_txt_tokens"] = src_txt_tokens
net_input["src_txt_lengths"] = src_txt_lengths
net_input["alignment"] = None
if self.alignment is not None:
max_len = max([s.tgt_alignment.size(0) for s in samples])
alignment = torch.ones(len(samples), max_len).float()
for i, s in enumerate(samples):
cur_len = s.tgt_alignment.size(0)
alignment[i][:cur_len].copy_(s.tgt_alignment)
net_input["alignment"] = alignment.index_select(0, order)
if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
for i in range(len(samples)):
net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
out = {
"id": s2t_out["id"],
"net_input": net_input,
"target": s2t_out["target"],
"target_lengths": s2t_out["target_lengths"],
"ntokens": s2t_out["ntokens"],
"nsentences": len(samples),
}
return out
class SpeechToTextJointDatasetCreator(SpeechToTextDatasetCreator):
KEY_ALIGN = "align"
@classmethod
def _from_list(
cls,
split_name: str,
is_train_split,
samples: List[Dict],
cfg: S2TJointDataConfig,
tgt_dict,
src_dict,
pre_tokenizer,
bpe_tokenizer,
src_pre_tokenizer,
src_bpe_tokenizer,
append_eos,
use_src_lang_id,
) -> SpeechToTextJointDataset:
audio_root = Path(cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples]
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_alignment = None
if cls.KEY_ALIGN in samples[0].keys():
tgt_alignment = [s[cls.KEY_ALIGN] for s in samples]
return SpeechToTextJointDataset(
split_name,
is_train_split,
cfg,
audio_paths,
n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
src_dict=src_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
src_pre_tokenizer=src_pre_tokenizer,
src_bpe_tokenizer=src_bpe_tokenizer,
append_eos=append_eos,
alignment=tgt_alignment,
use_src_lang_id=use_src_lang_id,
)
@classmethod
def _from_tsv(
cls,
root: str,
cfg: S2TJointDataConfig,
split: str,
tgt_dict,
src_dict,
is_train_split: bool,
pre_tokenizer,
bpe_tokenizer,
src_pre_tokenizer,
src_bpe_tokenizer,
append_eos: bool,
use_src_lang_id: int,
) -> SpeechToTextJointDataset:
samples = cls._load_samples_from_tsv(root, split)
return cls._from_list(
split,
is_train_split,
samples,
cfg,
tgt_dict,
src_dict,
pre_tokenizer,
bpe_tokenizer,
src_pre_tokenizer,
src_bpe_tokenizer,
append_eos,
use_src_lang_id,
)
@classmethod
def from_tsv(
cls,
root: str,
cfg: S2TJointDataConfig,
splits: str,
tgt_dict,
src_dict,
pre_tokenizer,
bpe_tokenizer,
src_pre_tokenizer,
src_bpe_tokenizer,
is_train_split: bool,
epoch: int,
seed: int,
append_eos: Optional[bool] = True,
use_src_lang_id: Optional[int] = 0,
) -> SpeechToTextJointDataset:
datasets = [
cls._from_tsv(
root,
cfg,
split,
tgt_dict,
src_dict,
is_train_split,
pre_tokenizer,
bpe_tokenizer,
src_pre_tokenizer,
src_bpe_tokenizer,
append_eos=append_eos,
use_src_lang_id=use_src_lang_id,
)
for split in splits.split(",")
]
if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
# temperature-based sampling
size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
datasets = [
ResamplingDataset(
d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
)
for r, d in zip(size_ratios, datasets)
]
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]

View File

@@ -0,0 +1,250 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.abs
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional
import numpy as np
import torch
from fairseq.data import Dictionary
from fairseq.data import data_utils as fairseq_data_utils
from fairseq.data.audio.audio_utils import get_features_or_waveform
from fairseq.data.audio.speech_to_text_dataset import (
S2TDataConfig,
SpeechToTextDataset,
SpeechToTextDatasetCreator,
_collate_frames,
)
@dataclass
class TextToSpeechDatasetItem(object):
index: int
source: torch.Tensor
target: Optional[torch.Tensor] = None
speaker_id: Optional[int] = None
duration: Optional[torch.Tensor] = None
pitch: Optional[torch.Tensor] = None
energy: Optional[torch.Tensor] = None
class TextToSpeechDataset(SpeechToTextDataset):
def __init__(
self,
split: str,
is_train_split: bool,
cfg: S2TDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
n_frames_per_step=1,
speaker_to_id=None,
durations: Optional[List[List[int]]] = None,
pitches: Optional[List[str]] = None,
energies: Optional[List[str]] = None,
):
super(TextToSpeechDataset, self).__init__(
split,
is_train_split,
cfg,
audio_paths,
n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
)
self.durations = durations
self.pitches = pitches
self.energies = energies
def __getitem__(self, index: int) -> TextToSpeechDatasetItem:
s2t_item = super().__getitem__(index)
duration, pitch, energy = None, None, None
if self.durations is not None:
duration = torch.tensor(
self.durations[index] + [0], dtype=torch.long # pad 0 for EOS
)
if self.pitches is not None:
pitch = get_features_or_waveform(self.pitches[index])
pitch = torch.from_numpy(
np.concatenate((pitch, [0])) # pad 0 for EOS
).float()
if self.energies is not None:
energy = get_features_or_waveform(self.energies[index])
energy = torch.from_numpy(
np.concatenate((energy, [0])) # pad 0 for EOS
).float()
return TextToSpeechDatasetItem(
index=index,
source=s2t_item.source,
target=s2t_item.target,
speaker_id=s2t_item.speaker_id,
duration=duration,
pitch=pitch,
energy=energy,
)
def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]:
if len(samples) == 0:
return {}
src_lengths, order = torch.tensor(
[s.target.shape[0] for s in samples], dtype=torch.long
).sort(descending=True)
id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select(
0, order
)
feat = _collate_frames(
[s.source for s in samples], self.cfg.use_audio_input
).index_select(0, order)
target_lengths = torch.tensor(
[s.source.shape[0] for s in samples], dtype=torch.long
).index_select(0, order)
src_tokens = fairseq_data_utils.collate_tokens(
[s.target for s in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
).index_select(0, order)
speaker = None
if self.speaker_to_id is not None:
speaker = (
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
.index_select(0, order)
.view(-1, 1)
)
bsz, _, d = feat.size()
prev_output_tokens = torch.cat(
(feat.new_zeros((bsz, 1, d)), feat[:, :-1, :]), dim=1
)
durations, pitches, energies = None, None, None
if self.durations is not None:
durations = fairseq_data_utils.collate_tokens(
[s.duration for s in samples], 0
).index_select(0, order)
assert src_tokens.shape[1] == durations.shape[1]
if self.pitches is not None:
pitches = _collate_frames([s.pitch for s in samples], True)
pitches = pitches.index_select(0, order)
assert src_tokens.shape[1] == pitches.shape[1]
if self.energies is not None:
energies = _collate_frames([s.energy for s in samples], True)
energies = energies.index_select(0, order)
assert src_tokens.shape[1] == energies.shape[1]
src_texts = [self.tgt_dict.string(samples[i].target) for i in order]
return {
"id": id_,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"prev_output_tokens": prev_output_tokens,
},
"speaker": speaker,
"target": feat,
"durations": durations,
"pitches": pitches,
"energies": energies,
"target_lengths": target_lengths,
"ntokens": sum(target_lengths).item(),
"nsentences": len(samples),
"src_texts": src_texts,
}
class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator):
KEY_DURATION = "duration"
KEY_PITCH = "pitch"
KEY_ENERGY = "energy"
@classmethod
def _from_list(
cls,
split_name: str,
is_train_split,
samples: List[Dict],
cfg: S2TDataConfig,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
multitask=None,
) -> TextToSpeechDataset:
audio_root = Path(cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples]
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
durations = [s.get(cls.KEY_DURATION, None) for s in samples]
durations = [
None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations
]
durations = None if any(dd is None for dd in durations) else durations
pitches = [s.get(cls.KEY_PITCH, None) for s in samples]
pitches = [
None if pp is None else (audio_root / pp).as_posix() for pp in pitches
]
pitches = None if any(pp is None for pp in pitches) else pitches
energies = [s.get(cls.KEY_ENERGY, None) for s in samples]
energies = [
None if ee is None else (audio_root / ee).as_posix() for ee in energies
]
energies = None if any(ee is None for ee in energies) else energies
return TextToSpeechDataset(
split_name,
is_train_split,
cfg,
audio_paths,
n_frames,
src_texts,
tgt_texts,
speakers,
src_langs,
tgt_langs,
ids,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
durations,
pitches,
energies,
)

View File

@@ -0,0 +1,48 @@
import os
from fairseq.data.audio import (
AudioTransform,
CompositeAudioTransform,
import_transforms,
register_audio_transform,
)
class AudioWaveformTransform(AudioTransform):
pass
AUDIO_WAVEFORM_TRANSFORM_REGISTRY = {}
AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES = set()
def get_audio_waveform_transform(name):
return AUDIO_WAVEFORM_TRANSFORM_REGISTRY[name]
def register_audio_waveform_transform(name):
return register_audio_transform(
name,
AudioWaveformTransform,
AUDIO_WAVEFORM_TRANSFORM_REGISTRY,
AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES,
)
import_transforms(os.path.dirname(__file__), "waveform")
class CompositeAudioWaveformTransform(CompositeAudioTransform):
@classmethod
def from_config_dict(cls, config=None):
return super()._from_config_dict(
cls,
"waveform",
get_audio_waveform_transform,
CompositeAudioWaveformTransform,
config,
)
def __call__(self, x, sample_rate):
for t in self.transforms:
x, sample_rate = t(x, sample_rate)
return x, sample_rate

View File

@@ -0,0 +1,201 @@
from pathlib import Path
import numpy as np
from math import ceil
from fairseq.data.audio import rand_uniform
from fairseq.data.audio.waveform_transforms import (
AudioWaveformTransform,
register_audio_waveform_transform,
)
SNR_MIN = 5.0
SNR_MAX = 15.0
RATE = 0.25
NOISE_RATE = 1.0
NOISE_LEN_MEAN = 0.2
NOISE_LEN_STD = 0.05
class NoiseAugmentTransform(AudioWaveformTransform):
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return cls(
_config.get("samples_path", None),
_config.get("snr_min", SNR_MIN),
_config.get("snr_max", SNR_MAX),
_config.get("rate", RATE),
)
def __init__(
self,
samples_path: str,
snr_min: float = SNR_MIN,
snr_max: float = SNR_MAX,
rate: float = RATE,
):
# Sanity checks
assert (
samples_path
), "need to provide path to audio samples for noise augmentation"
assert snr_max >= snr_min, f"empty signal-to-noise range ({snr_min}, {snr_max})"
assert rate >= 0 and rate <= 1, "rate should be a float between 0 to 1"
self.paths = list(Path(samples_path).glob("**/*.wav")) # load music
self.n_samples = len(self.paths)
assert self.n_samples > 0, f"no audio files found in {samples_path}"
self.snr_min = snr_min
self.snr_max = snr_max
self.rate = rate
def __repr__(self):
return (
self.__class__.__name__
+ "("
+ ", ".join(
[
f"n_samples={self.n_samples}",
f"snr={self.snr_min}-{self.snr_max}dB",
f"rate={self.rate}",
]
)
+ ")"
)
def pick_sample(self, goal_shape, always_2d=False, use_sample_rate=None):
from fairseq.data.audio.audio_utils import get_waveform
path = self.paths[np.random.randint(0, self.n_samples)]
sample = get_waveform(
path, always_2d=always_2d, output_sample_rate=use_sample_rate
)[0]
# Check dimensions match, else silently skip adding noise to sample
# NOTE: SHOULD THIS QUIT WITH AN ERROR?
is_2d = len(goal_shape) == 2
if len(goal_shape) != sample.ndim or (
is_2d and goal_shape[0] != sample.shape[0]
):
return np.zeros(goal_shape)
# Cut/repeat sample to size
len_dim = len(goal_shape) - 1
n_repeat = ceil(goal_shape[len_dim] / sample.shape[len_dim])
repeated = np.tile(sample, [1, n_repeat] if is_2d else n_repeat)
start = np.random.randint(0, repeated.shape[len_dim] - goal_shape[len_dim] + 1)
return (
repeated[:, start : start + goal_shape[len_dim]]
if is_2d
else repeated[start : start + goal_shape[len_dim]]
)
def _mix(self, source, noise, snr):
get_power = lambda x: np.mean(x**2)
if get_power(noise):
scl = np.sqrt(
get_power(source) / (np.power(10, snr / 10) * get_power(noise))
)
else:
scl = 0
return 1 * source + scl * noise
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
return self.pick_sample(goal_shape, always_2d, use_sample_rate)
def __call__(self, source, sample_rate):
if np.random.random() > self.rate:
return source, sample_rate
noise = self._get_noise(
source.shape, always_2d=True, use_sample_rate=sample_rate
)
return (
self._mix(source, noise, rand_uniform(self.snr_min, self.snr_max)),
sample_rate,
)
@register_audio_waveform_transform("musicaugment")
class MusicAugmentTransform(NoiseAugmentTransform):
pass
@register_audio_waveform_transform("backgroundnoiseaugment")
class BackgroundNoiseAugmentTransform(NoiseAugmentTransform):
pass
@register_audio_waveform_transform("babbleaugment")
class BabbleAugmentTransform(NoiseAugmentTransform):
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
for i in range(np.random.randint(3, 8)):
speech = self.pick_sample(goal_shape, always_2d, use_sample_rate)
if i == 0:
agg_noise = speech
else: # SNR scaled by i (how many noise signals already in agg_noise)
agg_noise = self._mix(agg_noise, speech, i)
return agg_noise
@register_audio_waveform_transform("sporadicnoiseaugment")
class SporadicNoiseAugmentTransform(NoiseAugmentTransform):
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return cls(
_config.get("samples_path", None),
_config.get("snr_min", SNR_MIN),
_config.get("snr_max", SNR_MAX),
_config.get("rate", RATE),
_config.get("noise_rate", NOISE_RATE),
_config.get("noise_len_mean", NOISE_LEN_MEAN),
_config.get("noise_len_std", NOISE_LEN_STD),
)
def __init__(
self,
samples_path: str,
snr_min: float = SNR_MIN,
snr_max: float = SNR_MAX,
rate: float = RATE,
noise_rate: float = NOISE_RATE, # noises per second
noise_len_mean: float = NOISE_LEN_MEAN, # length of noises in seconds
noise_len_std: float = NOISE_LEN_STD,
):
super().__init__(samples_path, snr_min, snr_max, rate)
self.noise_rate = noise_rate
self.noise_len_mean = noise_len_mean
self.noise_len_std = noise_len_std
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
agg_noise = np.zeros(goal_shape)
len_dim = len(goal_shape) - 1
is_2d = len(goal_shape) == 2
n_noises = round(self.noise_rate * goal_shape[len_dim] / use_sample_rate)
start_pointers = [
round(rand_uniform(0, goal_shape[len_dim])) for _ in range(n_noises)
]
for start_pointer in start_pointers:
noise_shape = list(goal_shape)
len_seconds = np.random.normal(self.noise_len_mean, self.noise_len_std)
noise_shape[len_dim] = round(max(0, len_seconds) * use_sample_rate)
end_pointer = start_pointer + noise_shape[len_dim]
if end_pointer >= goal_shape[len_dim]:
continue
noise = self.pick_sample(noise_shape, always_2d, use_sample_rate)
if is_2d:
agg_noise[:, start_pointer:end_pointer] = (
agg_noise[:, start_pointer:end_pointer] + noise
)
else:
agg_noise[start_pointer:end_pointer] = (
agg_noise[start_pointer:end_pointer] + noise
)
return agg_noise