mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-30 19:31:20 +00:00
Add monkey patched fairseq package to run on python 3.11 (what is needed for our use of RVC at least)
This commit is contained in:
93
modules/voice_conversion/fairseq/data/audio/__init__.py
Normal file
93
modules/voice_conversion/fairseq/data/audio/__init__.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional
|
||||
import importlib
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AudioTransform(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config_dict(cls, config: Optional[Dict] = None):
|
||||
pass
|
||||
|
||||
|
||||
class CompositeAudioTransform(AudioTransform):
|
||||
def _from_config_dict(
|
||||
cls,
|
||||
transform_type,
|
||||
get_audio_transform,
|
||||
composite_cls,
|
||||
config=None,
|
||||
return_empty=False,
|
||||
):
|
||||
_config = {} if config is None else config
|
||||
_transforms = _config.get(f"{transform_type}_transforms")
|
||||
|
||||
if _transforms is None:
|
||||
if return_empty:
|
||||
_transforms = []
|
||||
else:
|
||||
return None
|
||||
|
||||
transforms = [
|
||||
get_audio_transform(_t).from_config_dict(_config.get(_t))
|
||||
for _t in _transforms
|
||||
]
|
||||
return composite_cls(transforms)
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = [t for t in transforms if t is not None]
|
||||
|
||||
def __call__(self, x):
|
||||
for t in self.transforms:
|
||||
x = t(x)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
format_string = (
|
||||
[self.__class__.__name__ + "("]
|
||||
+ [f" {t.__repr__()}" for t in self.transforms]
|
||||
+ [")"]
|
||||
)
|
||||
return "\n".join(format_string)
|
||||
|
||||
|
||||
def register_audio_transform(name, cls_type, registry, class_names):
|
||||
def register_audio_transform_cls(cls):
|
||||
if name in registry:
|
||||
raise ValueError(f"Cannot register duplicate transform ({name})")
|
||||
if not issubclass(cls, cls_type):
|
||||
raise ValueError(
|
||||
f"Transform ({name}: {cls.__name__}) must extend "
|
||||
f"{cls_type.__name__}"
|
||||
)
|
||||
if cls.__name__ in class_names:
|
||||
raise ValueError(
|
||||
f"Cannot register audio transform with duplicate "
|
||||
f"class name ({cls.__name__})"
|
||||
)
|
||||
registry[name] = cls
|
||||
class_names.add(cls.__name__)
|
||||
return cls
|
||||
|
||||
return register_audio_transform_cls
|
||||
|
||||
|
||||
def import_transforms(transforms_dir, transform_type):
|
||||
for file in os.listdir(transforms_dir):
|
||||
path = os.path.join(transforms_dir, file)
|
||||
if (
|
||||
not file.startswith("_")
|
||||
and not file.startswith(".")
|
||||
and (file.endswith(".py") or os.path.isdir(path))
|
||||
):
|
||||
name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||
importlib.import_module(
|
||||
f"fairseq.data.audio.{transform_type}_transforms." + name
|
||||
)
|
||||
|
||||
|
||||
# Utility fn for uniform numbers in transforms
|
||||
def rand_uniform(a, b):
|
||||
return np.random.uniform() * (b - a) + a
|
||||
389
modules/voice_conversion/fairseq/data/audio/audio_utils.py
Normal file
389
modules/voice_conversion/fairseq/data/audio/audio_utils.py
Normal file
@@ -0,0 +1,389 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import mmap
|
||||
from pathlib import Path
|
||||
import io
|
||||
from typing import BinaryIO, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
|
||||
|
||||
SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"}
|
||||
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
|
||||
|
||||
|
||||
def convert_waveform(
|
||||
waveform: Union[np.ndarray, torch.Tensor],
|
||||
sample_rate: int,
|
||||
normalize_volume: bool = False,
|
||||
to_mono: bool = False,
|
||||
to_sample_rate: Optional[int] = None,
|
||||
) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
|
||||
"""convert a waveform:
|
||||
- to a target sample rate
|
||||
- from multi-channel to mono channel
|
||||
- volume normalization
|
||||
|
||||
Args:
|
||||
waveform (numpy.ndarray or torch.Tensor): 2D original waveform
|
||||
(channels x length)
|
||||
sample_rate (int): original sample rate
|
||||
normalize_volume (bool): perform volume normalization
|
||||
to_mono (bool): convert to mono channel if having multiple channels
|
||||
to_sample_rate (Optional[int]): target sample rate
|
||||
Returns:
|
||||
waveform (numpy.ndarray): converted 2D waveform (channels x length)
|
||||
sample_rate (float): target sample rate
|
||||
"""
|
||||
try:
|
||||
import torchaudio.sox_effects as ta_sox
|
||||
except ImportError:
|
||||
raise ImportError("Please install torchaudio: pip install torchaudio")
|
||||
|
||||
effects = []
|
||||
if normalize_volume:
|
||||
effects.append(["gain", "-n"])
|
||||
if to_sample_rate is not None and to_sample_rate != sample_rate:
|
||||
effects.append(["rate", f"{to_sample_rate}"])
|
||||
if to_mono and waveform.shape[0] > 1:
|
||||
effects.append(["channels", "1"])
|
||||
if len(effects) > 0:
|
||||
is_np_input = isinstance(waveform, np.ndarray)
|
||||
_waveform = torch.from_numpy(waveform) if is_np_input else waveform
|
||||
converted, converted_sample_rate = ta_sox.apply_effects_tensor(
|
||||
_waveform, sample_rate, effects
|
||||
)
|
||||
if is_np_input:
|
||||
converted = converted.numpy()
|
||||
return converted, converted_sample_rate
|
||||
return waveform, sample_rate
|
||||
|
||||
|
||||
def get_waveform(
|
||||
path_or_fp: Union[str, BinaryIO],
|
||||
normalization: bool = True,
|
||||
mono: bool = True,
|
||||
frames: int = -1,
|
||||
start: int = 0,
|
||||
always_2d: bool = True,
|
||||
output_sample_rate: Optional[int] = None,
|
||||
normalize_volume: bool = False,
|
||||
waveform_transforms: Optional[CompositeAudioWaveformTransform] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio.
|
||||
|
||||
Args:
|
||||
path_or_fp (str or BinaryIO): the path or file-like object
|
||||
normalization (bool): normalize values to [-1, 1] (Default: True)
|
||||
mono (bool): convert multi-channel audio to mono-channel one
|
||||
frames (int): the number of frames to read. (-1 for reading all)
|
||||
start (int): Where to start reading. A negative value counts from the end.
|
||||
always_2d (bool): always return 2D array even for mono-channel audios
|
||||
output_sample_rate (Optional[int]): output sample rate
|
||||
normalize_volume (bool): normalize volume
|
||||
Returns:
|
||||
waveform (numpy.ndarray): 1D or 2D waveform (channels x length)
|
||||
sample_rate (float): sample rate
|
||||
"""
|
||||
if isinstance(path_or_fp, str):
|
||||
ext = Path(path_or_fp).suffix
|
||||
if ext not in SF_AUDIO_FILE_EXTENSIONS:
|
||||
raise ValueError(f"Unsupported audio format: {ext}")
|
||||
|
||||
try:
|
||||
import soundfile as sf
|
||||
except ImportError:
|
||||
raise ImportError("Please install soundfile: pip install soundfile")
|
||||
|
||||
waveform, sample_rate = sf.read(
|
||||
path_or_fp, dtype="float32", always_2d=True, frames=frames, start=start
|
||||
)
|
||||
waveform = waveform.T # T x C -> C x T
|
||||
waveform, sample_rate = convert_waveform(
|
||||
waveform,
|
||||
sample_rate,
|
||||
normalize_volume=normalize_volume,
|
||||
to_mono=mono,
|
||||
to_sample_rate=output_sample_rate,
|
||||
)
|
||||
|
||||
if not normalization:
|
||||
waveform *= 2**15 # denormalized to 16-bit signed integers
|
||||
|
||||
if waveform_transforms is not None:
|
||||
waveform, sample_rate = waveform_transforms(waveform, sample_rate)
|
||||
|
||||
if not always_2d:
|
||||
waveform = waveform.squeeze(axis=0)
|
||||
|
||||
return waveform, sample_rate
|
||||
|
||||
|
||||
def get_features_from_npy_or_audio(path, waveform_transforms=None):
|
||||
ext = Path(path).suffix
|
||||
if ext not in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
|
||||
raise ValueError(f'Unsupported file format for "{path}"')
|
||||
return (
|
||||
np.load(path)
|
||||
if ext == ".npy"
|
||||
else get_fbank(path, waveform_transforms=waveform_transforms)
|
||||
)
|
||||
|
||||
|
||||
def get_features_or_waveform_from_stored_zip(
|
||||
path,
|
||||
byte_offset,
|
||||
byte_size,
|
||||
need_waveform=False,
|
||||
use_sample_rate=None,
|
||||
waveform_transforms=None,
|
||||
):
|
||||
assert path.endswith(".zip")
|
||||
data = read_from_stored_zip(path, byte_offset, byte_size)
|
||||
f = io.BytesIO(data)
|
||||
if is_npy_data(data):
|
||||
features_or_waveform = np.load(f)
|
||||
elif is_sf_audio_data(data):
|
||||
features_or_waveform = (
|
||||
get_waveform(
|
||||
f,
|
||||
always_2d=False,
|
||||
output_sample_rate=use_sample_rate,
|
||||
waveform_transforms=waveform_transforms,
|
||||
)[0]
|
||||
if need_waveform
|
||||
else get_fbank(f, waveform_transforms=waveform_transforms)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown file format for "{path}"')
|
||||
return features_or_waveform
|
||||
|
||||
|
||||
def get_features_or_waveform(
|
||||
path: str, need_waveform=False, use_sample_rate=None, waveform_transforms=None
|
||||
):
|
||||
"""Get speech features from .npy file or waveform from .wav/.flac file.
|
||||
The file may be inside an uncompressed ZIP file and is accessed via byte
|
||||
offset and length.
|
||||
|
||||
Args:
|
||||
path (str): File path in the format of "<.npy/.wav/.flac path>" or
|
||||
"<zip path>:<byte offset>:<byte length>".
|
||||
need_waveform (bool): return waveform instead of features.
|
||||
use_sample_rate (int): change sample rate for the input wave file
|
||||
|
||||
Returns:
|
||||
features_or_waveform (numpy.ndarray): speech features or waveform.
|
||||
"""
|
||||
_path, slice_ptr = parse_path(path)
|
||||
if len(slice_ptr) == 0:
|
||||
if need_waveform:
|
||||
return get_waveform(
|
||||
_path,
|
||||
always_2d=False,
|
||||
output_sample_rate=use_sample_rate,
|
||||
waveform_transforms=waveform_transforms,
|
||||
)[0]
|
||||
return get_features_from_npy_or_audio(
|
||||
_path, waveform_transforms=waveform_transforms
|
||||
)
|
||||
elif len(slice_ptr) == 2:
|
||||
features_or_waveform = get_features_or_waveform_from_stored_zip(
|
||||
_path,
|
||||
slice_ptr[0],
|
||||
slice_ptr[1],
|
||||
need_waveform=need_waveform,
|
||||
use_sample_rate=use_sample_rate,
|
||||
waveform_transforms=waveform_transforms,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid path: {path}")
|
||||
|
||||
return features_or_waveform
|
||||
|
||||
|
||||
def _get_kaldi_fbank(
|
||||
waveform: np.ndarray, sample_rate: int, n_bins=80
|
||||
) -> Optional[np.ndarray]:
|
||||
"""Get mel-filter bank features via PyKaldi."""
|
||||
try:
|
||||
from kaldi.feat.fbank import Fbank, FbankOptions
|
||||
from kaldi.feat.mel import MelBanksOptions
|
||||
from kaldi.feat.window import FrameExtractionOptions
|
||||
from kaldi.matrix import Vector
|
||||
|
||||
mel_opts = MelBanksOptions()
|
||||
mel_opts.num_bins = n_bins
|
||||
frame_opts = FrameExtractionOptions()
|
||||
frame_opts.samp_freq = sample_rate
|
||||
opts = FbankOptions()
|
||||
opts.mel_opts = mel_opts
|
||||
opts.frame_opts = frame_opts
|
||||
fbank = Fbank(opts=opts)
|
||||
features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy()
|
||||
return features
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def _get_torchaudio_fbank(
|
||||
waveform: np.ndarray, sample_rate, n_bins=80
|
||||
) -> Optional[np.ndarray]:
|
||||
"""Get mel-filter bank features via TorchAudio."""
|
||||
try:
|
||||
import torchaudio.compliance.kaldi as ta_kaldi
|
||||
|
||||
waveform = torch.from_numpy(waveform)
|
||||
features = ta_kaldi.fbank(
|
||||
waveform, num_mel_bins=n_bins, sample_frequency=sample_rate
|
||||
)
|
||||
return features.numpy()
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def get_fbank(
|
||||
path_or_fp: Union[str, BinaryIO], n_bins=80, waveform_transforms=None
|
||||
) -> np.ndarray:
|
||||
"""Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
|
||||
(faster CPP implementation) to TorchAudio (Python implementation). Note that
|
||||
Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
|
||||
waveform should not be normalized."""
|
||||
waveform, sample_rate = get_waveform(
|
||||
path_or_fp, normalization=False, waveform_transforms=waveform_transforms
|
||||
)
|
||||
|
||||
features = _get_kaldi_fbank(waveform, sample_rate, n_bins)
|
||||
if features is None:
|
||||
features = _get_torchaudio_fbank(waveform, sample_rate, n_bins)
|
||||
if features is None:
|
||||
raise ImportError(
|
||||
"Please install pyKaldi or torchaudio to enable "
|
||||
"online filterbank feature extraction"
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def is_npy_data(data: bytes) -> bool:
|
||||
return data[0] == 147 and data[1] == 78
|
||||
|
||||
|
||||
def is_sf_audio_data(data: bytes) -> bool:
|
||||
is_wav = data[0] == 82 and data[1] == 73 and data[2] == 70
|
||||
is_flac = data[0] == 102 and data[1] == 76 and data[2] == 97
|
||||
is_ogg = data[0] == 79 and data[1] == 103 and data[2] == 103
|
||||
return is_wav or is_flac or is_ogg
|
||||
|
||||
|
||||
def mmap_read(path: str, offset: int, length: int) -> bytes:
|
||||
with open(path, "rb") as f:
|
||||
with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_o:
|
||||
data = mmap_o[offset : offset + length]
|
||||
return data
|
||||
|
||||
|
||||
def read_from_stored_zip(zip_path: str, offset: int, length: int) -> bytes:
|
||||
return mmap_read(zip_path, offset, length)
|
||||
|
||||
|
||||
def parse_path(path: str) -> Tuple[str, List[int]]:
|
||||
"""Parse data path which is either a path to
|
||||
1. a .npy/.wav/.flac/.ogg file
|
||||
2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
|
||||
|
||||
Args:
|
||||
path (str): the data path to parse
|
||||
|
||||
Returns:
|
||||
file_path (str): the file path
|
||||
slice_ptr (list of int): empty in case 1;
|
||||
byte offset and length for the slice in case 2
|
||||
"""
|
||||
|
||||
if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
|
||||
_path, slice_ptr = path, []
|
||||
else:
|
||||
_path, *slice_ptr = path.split(":")
|
||||
if not Path(_path).is_file():
|
||||
raise FileNotFoundError(f"File not found: {_path}")
|
||||
assert len(slice_ptr) in {0, 2}, f"Invalid path: {path}"
|
||||
slice_ptr = [int(i) for i in slice_ptr]
|
||||
return _path, slice_ptr
|
||||
|
||||
|
||||
def get_window(window_fn: callable, n_fft: int, win_length: int) -> torch.Tensor:
|
||||
padding = n_fft - win_length
|
||||
assert padding >= 0
|
||||
return F.pad(window_fn(win_length), (padding // 2, padding - padding // 2))
|
||||
|
||||
|
||||
def get_fourier_basis(n_fft: int) -> torch.Tensor:
|
||||
basis = np.fft.fft(np.eye(n_fft))
|
||||
basis = np.vstack(
|
||||
[np.real(basis[: n_fft // 2 + 1, :]), np.imag(basis[: n_fft // 2 + 1, :])]
|
||||
)
|
||||
return torch.from_numpy(basis).float()
|
||||
|
||||
|
||||
def get_mel_filters(
|
||||
sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float
|
||||
) -> torch.Tensor:
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
raise ImportError("Please install librosa: pip install librosa")
|
||||
basis = librosa.filters.mel(sample_rate, n_fft, n_mels, f_min, f_max)
|
||||
return torch.from_numpy(basis).float()
|
||||
|
||||
|
||||
class TTSSpectrogram(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_fft: int,
|
||||
win_length: int,
|
||||
hop_length: int,
|
||||
window_fn: callable = torch.hann_window,
|
||||
return_phase: bool = False,
|
||||
) -> None:
|
||||
super(TTSSpectrogram, self).__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.return_phase = return_phase
|
||||
|
||||
basis = get_fourier_basis(n_fft).unsqueeze(1)
|
||||
basis *= get_window(window_fn, n_fft, win_length)
|
||||
self.register_buffer("basis", basis)
|
||||
|
||||
def forward(
|
||||
self, waveform: torch.Tensor
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
padding = (self.n_fft // 2, self.n_fft // 2)
|
||||
x = F.pad(waveform.unsqueeze(1), padding, mode="reflect")
|
||||
x = F.conv1d(x, self.basis, stride=self.hop_length)
|
||||
real_part = x[:, : self.n_fft // 2 + 1, :]
|
||||
imag_part = x[:, self.n_fft // 2 + 1 :, :]
|
||||
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
||||
if self.return_phase:
|
||||
phase = torch.atan2(imag_part, real_part)
|
||||
return magnitude, phase
|
||||
return magnitude
|
||||
|
||||
|
||||
class TTSMelScale(torch.nn.Module):
|
||||
def __init__(
|
||||
self, n_mels: int, sample_rate: int, f_min: float, f_max: float, n_stft: int
|
||||
) -> None:
|
||||
super(TTSMelScale, self).__init__()
|
||||
basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max)
|
||||
self.register_buffer("basis", basis)
|
||||
|
||||
def forward(self, specgram: torch.Tensor) -> torch.Tensor:
|
||||
return torch.matmul(self.basis, specgram)
|
||||
387
modules/voice_conversion/fairseq/data/audio/data_cfg.py
Normal file
387
modules/voice_conversion/fairseq/data/audio/data_cfg.py
Normal file
@@ -0,0 +1,387 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from argparse import Namespace
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fairseq.data import Dictionary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_config_from_yaml(yaml_path: Path):
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
print("Please install PyYAML: pip install PyYAML")
|
||||
config = {}
|
||||
if yaml_path.is_file():
|
||||
try:
|
||||
with open(yaml_path) as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to load config from {yaml_path.as_posix()}: {e}")
|
||||
else:
|
||||
raise FileNotFoundError(f"{yaml_path.as_posix()} not found")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class S2TDataConfig(object):
|
||||
"""Wrapper class for data config YAML"""
|
||||
|
||||
def __init__(self, yaml_path: Path):
|
||||
self.config = get_config_from_yaml(yaml_path)
|
||||
self.root = yaml_path.parent
|
||||
|
||||
def _auto_convert_to_abs_path(self, x):
|
||||
if isinstance(x, str):
|
||||
if not Path(x).exists() and (self.root / x).exists():
|
||||
return (self.root / x).as_posix()
|
||||
elif isinstance(x, dict):
|
||||
return {k: self._auto_convert_to_abs_path(v) for k, v in x.items()}
|
||||
return x
|
||||
|
||||
@property
|
||||
def vocab_filename(self):
|
||||
"""fairseq vocabulary file under data root"""
|
||||
return self.config.get("vocab_filename", "dict.txt")
|
||||
|
||||
@property
|
||||
def speaker_set_filename(self):
|
||||
"""speaker set file under data root"""
|
||||
return self.config.get("speaker_set_filename", None)
|
||||
|
||||
@property
|
||||
def shuffle(self) -> bool:
|
||||
"""Shuffle dataset samples before batching"""
|
||||
return self.config.get("shuffle", False)
|
||||
|
||||
@property
|
||||
def pre_tokenizer(self) -> Dict:
|
||||
"""Pre-tokenizer to apply before subword tokenization. Returning
|
||||
a dictionary with `tokenizer` providing the tokenizer name and
|
||||
the other items providing the tokenizer-specific arguments.
|
||||
Tokenizers are defined in `fairseq.data.encoders.*`"""
|
||||
tokenizer = self.config.get("pre_tokenizer", {"tokenizer": None})
|
||||
return self._auto_convert_to_abs_path(tokenizer)
|
||||
|
||||
@property
|
||||
def bpe_tokenizer(self) -> Dict:
|
||||
"""Subword tokenizer to apply after pre-tokenization. Returning
|
||||
a dictionary with `bpe` providing the tokenizer name and
|
||||
the other items providing the tokenizer-specific arguments.
|
||||
Tokenizers are defined in `fairseq.data.encoders.*`"""
|
||||
tokenizer = self.config.get("bpe_tokenizer", {"bpe": None})
|
||||
return self._auto_convert_to_abs_path(tokenizer)
|
||||
|
||||
@property
|
||||
def prepend_tgt_lang_tag(self) -> bool:
|
||||
"""Prepend target lang ID token as the target BOS (e.g. for to-many
|
||||
multilingual setting). During inference, this requires `--prefix-size 1`
|
||||
to force BOS to be lang ID token."""
|
||||
return self.config.get("prepend_tgt_lang_tag", False)
|
||||
|
||||
@property
|
||||
def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
|
||||
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
|
||||
return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
|
||||
|
||||
@property
|
||||
def input_feat_per_channel(self):
|
||||
"""The dimension of input features (per audio channel)"""
|
||||
return self.config.get("input_feat_per_channel", 80)
|
||||
|
||||
@property
|
||||
def input_channels(self):
|
||||
"""The number of channels in the input audio"""
|
||||
return self.config.get("input_channels", 1)
|
||||
|
||||
@property
|
||||
def sample_rate(self):
|
||||
return self.config.get("sample_rate", 16_000)
|
||||
|
||||
@property
|
||||
def sampling_alpha(self):
|
||||
"""Hyper-parameter alpha = 1/T for temperature-based resampling.
|
||||
(alpha = 1 for no resampling)"""
|
||||
return self.config.get("sampling_alpha", 1.0)
|
||||
|
||||
@property
|
||||
def use_audio_input(self):
|
||||
"""Needed by the dataset loader to see if the model requires
|
||||
raw audio as inputs."""
|
||||
return self.config.get("use_audio_input", False)
|
||||
|
||||
def standardize_audio(self) -> bool:
|
||||
return self.use_audio_input and self.config.get("standardize_audio", False)
|
||||
|
||||
@property
|
||||
def use_sample_rate(self):
|
||||
"""Needed by the dataset loader to see if the model requires
|
||||
raw audio with specific sample rate as inputs."""
|
||||
return self.config.get("use_sample_rate", 16000)
|
||||
|
||||
@property
|
||||
def audio_root(self):
|
||||
"""Audio paths in the manifest TSV can be relative and this provides
|
||||
the root path. Set this to empty string when using absolute paths."""
|
||||
return self.config.get("audio_root", "")
|
||||
|
||||
def get_transforms(self, transform_type, split, is_train):
|
||||
"""Split-specific feature transforms. Allowing train set
|
||||
wildcard `_train`, evaluation set wildcard `_eval` and general
|
||||
wildcard `*` for matching."""
|
||||
from copy import deepcopy
|
||||
|
||||
cfg = deepcopy(self.config)
|
||||
_cur = cfg.get(f"{transform_type}transforms", {})
|
||||
cur = _cur.get(split)
|
||||
cur = _cur.get("_train") if cur is None and is_train else cur
|
||||
cur = _cur.get("_eval") if cur is None and not is_train else cur
|
||||
cur = _cur.get("*") if cur is None else cur
|
||||
return cur
|
||||
|
||||
def get_feature_transforms(self, split, is_train):
|
||||
cfg = deepcopy(self.config)
|
||||
# TODO: deprecate transforms
|
||||
cur = self.get_transforms("", split, is_train)
|
||||
if cur is not None:
|
||||
logger.warning(
|
||||
"Auto converting transforms into feature_transforms, "
|
||||
"but transforms will be deprecated in the future. Please "
|
||||
"update this in the config."
|
||||
)
|
||||
ft_transforms = self.get_transforms("feature_", split, is_train)
|
||||
if ft_transforms:
|
||||
cur.extend(ft_transforms)
|
||||
else:
|
||||
cur = self.get_transforms("feature_", split, is_train)
|
||||
cfg["feature_transforms"] = cur
|
||||
return cfg
|
||||
|
||||
def get_waveform_transforms(self, split, is_train):
|
||||
cfg = deepcopy(self.config)
|
||||
cfg["waveform_transforms"] = self.get_transforms("waveform_", split, is_train)
|
||||
return cfg
|
||||
|
||||
def get_dataset_transforms(self, split, is_train):
|
||||
cfg = deepcopy(self.config)
|
||||
cfg["dataset_transforms"] = self.get_transforms("dataset_", split, is_train)
|
||||
return cfg
|
||||
|
||||
@property
|
||||
def global_cmvn_stats_npz(self) -> Optional[str]:
|
||||
path = self.config.get("global_cmvn", {}).get("stats_npz_path", None)
|
||||
return self._auto_convert_to_abs_path(path)
|
||||
|
||||
@property
|
||||
def vocoder(self) -> Dict[str, str]:
|
||||
vocoder = self.config.get("vocoder", {"type": "griffin_lim"})
|
||||
return self._auto_convert_to_abs_path(vocoder)
|
||||
|
||||
@property
|
||||
def hub(self) -> Dict[str, str]:
|
||||
return self.config.get("hub", {})
|
||||
|
||||
|
||||
class S2SDataConfig(S2TDataConfig):
|
||||
"""Wrapper class for data config YAML"""
|
||||
|
||||
@property
|
||||
def vocab_filename(self):
|
||||
"""fairseq vocabulary file under data root"""
|
||||
return self.config.get("vocab_filename", None)
|
||||
|
||||
@property
|
||||
def pre_tokenizer(self) -> Dict:
|
||||
return None
|
||||
|
||||
@property
|
||||
def bpe_tokenizer(self) -> Dict:
|
||||
return None
|
||||
|
||||
@property
|
||||
def input_transformed_channels(self):
|
||||
"""The number of channels in the audio after feature transforms"""
|
||||
# TODO: move this into individual transforms
|
||||
# TODO: deprecate transforms
|
||||
_cur = self.config.get("transforms", {})
|
||||
ft_transforms = self.config.get("feature_transforms", {})
|
||||
if _cur and ft_transforms:
|
||||
_cur.update(ft_transforms)
|
||||
else:
|
||||
_cur = self.config.get("feature_transforms", {})
|
||||
cur = _cur.get("_train", [])
|
||||
|
||||
_channels = self.input_channels
|
||||
if "delta_deltas" in cur:
|
||||
_channels *= 3
|
||||
|
||||
return _channels
|
||||
|
||||
@property
|
||||
def output_sample_rate(self):
|
||||
"""The audio sample rate of output target speech"""
|
||||
return self.config.get("output_sample_rate", 22050)
|
||||
|
||||
@property
|
||||
def target_speaker_embed(self):
|
||||
"""Target speaker embedding file (one line per target audio sample)"""
|
||||
return self.config.get("target_speaker_embed", None)
|
||||
|
||||
@property
|
||||
def prepend_tgt_lang_tag_as_bos(self) -> bool:
|
||||
"""Prepend target lang ID token as the target BOS."""
|
||||
return self.config.get("prepend_tgt_lang_tag_as_bos", False)
|
||||
|
||||
|
||||
class MultitaskConfig(object):
|
||||
"""Wrapper class for data config YAML"""
|
||||
|
||||
def __init__(self, yaml_path: Path):
|
||||
config = get_config_from_yaml(yaml_path)
|
||||
self.config = {}
|
||||
for k, v in config.items():
|
||||
self.config[k] = SingleTaskConfig(k, v)
|
||||
|
||||
def get_all_tasks(self):
|
||||
return self.config
|
||||
|
||||
def get_single_task(self, name):
|
||||
assert name in self.config, f"multitask '{name}' does not exist!"
|
||||
return self.config[name]
|
||||
|
||||
@property
|
||||
def first_pass_decoder_task_index(self):
|
||||
"""Return the task index of the first-pass text decoder.
|
||||
If there are multiple 'is_first_pass_decoder: True' in the config file,
|
||||
the last task is used for the first-pass decoder.
|
||||
If there is no 'is_first_pass_decoder: True' in the config file,
|
||||
the last task whose task_name includes 'target' and decoder_type is not ctc.
|
||||
"""
|
||||
idx = -1
|
||||
for i, (k, v) in enumerate(self.config.items()):
|
||||
if v.is_first_pass_decoder:
|
||||
idx = i
|
||||
if idx < 0:
|
||||
for i, (k, v) in enumerate(self.config.items()):
|
||||
if k.startswith("target") and v.decoder_type == "transformer":
|
||||
idx = i
|
||||
return idx
|
||||
|
||||
|
||||
class SingleTaskConfig(object):
|
||||
def __init__(self, name, config):
|
||||
self.task_name = name
|
||||
self.config = config
|
||||
dict_path = config.get("dict", "")
|
||||
self.tgt_dict = Dictionary.load(dict_path) if Path(dict_path).exists() else None
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.config.get("data", "")
|
||||
|
||||
@property
|
||||
def decoder_type(self):
|
||||
return self.config.get("decoder_type", "transformer")
|
||||
|
||||
@property
|
||||
def decoder_args(self):
|
||||
"""Decoder arch related args"""
|
||||
args = self.config.get("decoder_args", {})
|
||||
return Namespace(**args)
|
||||
|
||||
@property
|
||||
def criterion_cfg(self):
|
||||
"""cfg for the multitask criterion"""
|
||||
if self.decoder_type == "ctc":
|
||||
from fairseq.criterions.ctc import CtcCriterionConfig
|
||||
|
||||
cfg = CtcCriterionConfig
|
||||
cfg.zero_infinity = self.config.get("zero_infinity", True)
|
||||
else:
|
||||
from fairseq.criterions.label_smoothed_cross_entropy import (
|
||||
LabelSmoothedCrossEntropyCriterionConfig,
|
||||
)
|
||||
|
||||
cfg = LabelSmoothedCrossEntropyCriterionConfig
|
||||
cfg.label_smoothing = self.config.get("label_smoothing", 0.2)
|
||||
return cfg
|
||||
|
||||
@property
|
||||
def input_from(self):
|
||||
"""Condition on encoder/decoder of the main model"""
|
||||
return "decoder" if "decoder_layer" in self.config else "encoder"
|
||||
|
||||
@property
|
||||
def input_layer(self):
|
||||
if self.input_from == "decoder":
|
||||
return self.config["decoder_layer"] - 1
|
||||
else:
|
||||
# default using the output from the last encoder layer (-1)
|
||||
return self.config.get("encoder_layer", 0) - 1
|
||||
|
||||
@property
|
||||
def loss_weight_schedule(self):
|
||||
return (
|
||||
"decay"
|
||||
if "loss_weight_max" in self.config
|
||||
and "loss_weight_decay_steps" in self.config
|
||||
else "fixed"
|
||||
)
|
||||
|
||||
def get_loss_weight(self, num_updates):
|
||||
if self.loss_weight_schedule == "fixed":
|
||||
weight = self.config.get("loss_weight", 1.0)
|
||||
else: # "decay"
|
||||
assert (
|
||||
self.config.get("loss_weight_decay_steps", 0) > 0
|
||||
), "loss_weight_decay_steps must be greater than 0 for a decay schedule"
|
||||
loss_weight_min = self.config.get("loss_weight_min", 0.0001)
|
||||
loss_weight_decay_stepsize = (
|
||||
self.config["loss_weight_max"] - loss_weight_min
|
||||
) / self.config["loss_weight_decay_steps"]
|
||||
weight = max(
|
||||
self.config["loss_weight_max"]
|
||||
- loss_weight_decay_stepsize * num_updates,
|
||||
loss_weight_min,
|
||||
)
|
||||
return weight
|
||||
|
||||
@property
|
||||
def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
|
||||
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
|
||||
return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
"""EOS token during generation"""
|
||||
return self.config.get("eos_token", "<eos>")
|
||||
|
||||
@property
|
||||
def rdrop_alpha(self):
|
||||
return self.config.get("rdrop_alpha", 0.0)
|
||||
|
||||
@property
|
||||
def is_first_pass_decoder(self):
|
||||
flag = self.config.get("is_first_pass_decoder", False)
|
||||
if flag:
|
||||
if self.decoder_type == "ctc":
|
||||
raise ValueError(
|
||||
"First-pass decoder in the multi-decoder model must not be CTC."
|
||||
)
|
||||
if "target" not in self.task_name:
|
||||
raise Warning(
|
||||
'The name of the first-pass decoder does not include "target".'
|
||||
)
|
||||
return flag
|
||||
|
||||
@property
|
||||
def get_lang_tag_mapping(self):
|
||||
return self.config.get("lang_tag_mapping", {})
|
||||
@@ -0,0 +1,53 @@
|
||||
import os
|
||||
from fairseq.data.audio import (
|
||||
AudioTransform,
|
||||
CompositeAudioTransform,
|
||||
import_transforms,
|
||||
register_audio_transform,
|
||||
)
|
||||
|
||||
|
||||
class AudioDatasetTransform(AudioTransform):
|
||||
pass
|
||||
|
||||
|
||||
AUDIO_DATASET_TRANSFORM_REGISTRY = {}
|
||||
AUDIO_DATASET_TRANSFORM_CLASS_NAMES = set()
|
||||
|
||||
|
||||
def get_audio_dataset_transform(name):
|
||||
return AUDIO_DATASET_TRANSFORM_REGISTRY[name]
|
||||
|
||||
|
||||
def register_audio_dataset_transform(name):
|
||||
return register_audio_transform(
|
||||
name,
|
||||
AudioDatasetTransform,
|
||||
AUDIO_DATASET_TRANSFORM_REGISTRY,
|
||||
AUDIO_DATASET_TRANSFORM_CLASS_NAMES,
|
||||
)
|
||||
|
||||
|
||||
import_transforms(os.path.dirname(__file__), "dataset")
|
||||
|
||||
|
||||
class CompositeAudioDatasetTransform(CompositeAudioTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
return super()._from_config_dict(
|
||||
cls,
|
||||
"dataset",
|
||||
get_audio_dataset_transform,
|
||||
CompositeAudioDatasetTransform,
|
||||
config,
|
||||
return_empty=True,
|
||||
)
|
||||
|
||||
def get_transform(self, cls):
|
||||
for t in self.transforms:
|
||||
if isinstance(t, cls):
|
||||
return t
|
||||
return None
|
||||
|
||||
def has_transform(self, cls):
|
||||
return self.get_transform(cls) is not None
|
||||
@@ -0,0 +1,61 @@
|
||||
from typing import List
|
||||
import numpy as np
|
||||
|
||||
from fairseq.data.audio.dataset_transforms import (
|
||||
AudioDatasetTransform,
|
||||
register_audio_dataset_transform,
|
||||
)
|
||||
|
||||
_DEFAULTS = {"rate": 0.25, "max_tokens": 3000, "attempts": 5}
|
||||
|
||||
|
||||
@register_audio_dataset_transform("concataugment")
|
||||
class ConcatAugment(AudioDatasetTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return ConcatAugment(
|
||||
_config.get("rate", _DEFAULTS["rate"]),
|
||||
_config.get("max_tokens", _DEFAULTS["max_tokens"]),
|
||||
_config.get("attempts", _DEFAULTS["attempts"]),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rate=_DEFAULTS["rate"],
|
||||
max_tokens=_DEFAULTS["max_tokens"],
|
||||
attempts=_DEFAULTS["attempts"],
|
||||
):
|
||||
self.rate, self.max_tokens, self.attempts = rate, max_tokens, attempts
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ "("
|
||||
+ ", ".join(
|
||||
[
|
||||
f"rate={self.rate}",
|
||||
f"max_tokens={self.max_tokens}",
|
||||
f"attempts={self.attempts}",
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
|
||||
def find_indices(self, index: int, n_frames: List[int], n_samples: int):
|
||||
# skip conditions: application rate, max_tokens limit exceeded
|
||||
if np.random.random() > self.rate:
|
||||
return [index]
|
||||
if self.max_tokens and n_frames[index] > self.max_tokens:
|
||||
return [index]
|
||||
|
||||
# pick second sample to concatenate
|
||||
for _ in range(self.attempts):
|
||||
index2 = np.random.randint(0, n_samples)
|
||||
if index2 != index and (
|
||||
not self.max_tokens
|
||||
or n_frames[index] + n_frames[index2] < self.max_tokens
|
||||
):
|
||||
return [index, index2]
|
||||
|
||||
return [index]
|
||||
@@ -0,0 +1,105 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq.data.audio import rand_uniform
|
||||
from fairseq.data.audio.dataset_transforms import (
|
||||
AudioDatasetTransform,
|
||||
register_audio_dataset_transform,
|
||||
)
|
||||
from fairseq.data.audio.waveform_transforms.noiseaugment import (
|
||||
NoiseAugmentTransform,
|
||||
)
|
||||
|
||||
_DEFAULTS = {
|
||||
"rate": 0.25,
|
||||
"mixing_noise_rate": 0.1,
|
||||
"noise_path": "",
|
||||
"noise_snr_min": -5,
|
||||
"noise_snr_max": 5,
|
||||
"utterance_snr_min": -5,
|
||||
"utterance_snr_max": 5,
|
||||
}
|
||||
|
||||
|
||||
@register_audio_dataset_transform("noisyoverlapaugment")
|
||||
class NoisyOverlapAugment(AudioDatasetTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return NoisyOverlapAugment(
|
||||
_config.get("rate", _DEFAULTS["rate"]),
|
||||
_config.get("mixing_noise_rate", _DEFAULTS["mixing_noise_rate"]),
|
||||
_config.get("noise_path", _DEFAULTS["noise_path"]),
|
||||
_config.get("noise_snr_min", _DEFAULTS["noise_snr_min"]),
|
||||
_config.get("noise_snr_max", _DEFAULTS["noise_snr_max"]),
|
||||
_config.get("utterance_snr_min", _DEFAULTS["utterance_snr_min"]),
|
||||
_config.get("utterance_snr_max", _DEFAULTS["utterance_snr_max"]),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rate=_DEFAULTS["rate"],
|
||||
mixing_noise_rate=_DEFAULTS["mixing_noise_rate"],
|
||||
noise_path=_DEFAULTS["noise_path"],
|
||||
noise_snr_min=_DEFAULTS["noise_snr_min"],
|
||||
noise_snr_max=_DEFAULTS["noise_snr_max"],
|
||||
utterance_snr_min=_DEFAULTS["utterance_snr_min"],
|
||||
utterance_snr_max=_DEFAULTS["utterance_snr_max"],
|
||||
):
|
||||
self.rate = rate
|
||||
self.mixing_noise_rate = mixing_noise_rate
|
||||
self.noise_shaper = NoiseAugmentTransform(noise_path)
|
||||
self.noise_snr_min = noise_snr_min
|
||||
self.noise_snr_max = noise_snr_max
|
||||
self.utterance_snr_min = utterance_snr_min
|
||||
self.utterance_snr_max = utterance_snr_max
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ "("
|
||||
+ ", ".join(
|
||||
[
|
||||
f"rate={self.rate}",
|
||||
f"mixing_noise_rate={self.mixing_noise_rate}",
|
||||
f"noise_snr_min={self.noise_snr_min}",
|
||||
f"noise_snr_max={self.noise_snr_max}",
|
||||
f"utterance_snr_min={self.utterance_snr_min}",
|
||||
f"utterance_snr_max={self.utterance_snr_max}",
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
|
||||
def __call__(self, sources):
|
||||
for i, source in enumerate(sources):
|
||||
if np.random.random() > self.rate:
|
||||
continue
|
||||
|
||||
pri = source.numpy()
|
||||
|
||||
if np.random.random() > self.mixing_noise_rate:
|
||||
sec = sources[np.random.randint(0, len(sources))].numpy()
|
||||
snr = rand_uniform(self.utterance_snr_min, self.utterance_snr_max)
|
||||
else:
|
||||
sec = self.noise_shaper.pick_sample(source.shape)
|
||||
snr = rand_uniform(self.noise_snr_min, self.noise_snr_max)
|
||||
|
||||
L1 = pri.shape[-1]
|
||||
L2 = sec.shape[-1]
|
||||
l = np.random.randint(0, min(round(L1 / 2), L2)) # mix len
|
||||
s_source = np.random.randint(0, L1 - l)
|
||||
s_sec = np.random.randint(0, L2 - l)
|
||||
|
||||
get_power = lambda x: np.mean(x**2)
|
||||
if get_power(sec) == 0:
|
||||
continue
|
||||
|
||||
scl = np.sqrt(get_power(pri) / (np.power(10, snr / 10) * get_power(sec)))
|
||||
|
||||
pri[s_source : s_source + l] = np.add(
|
||||
pri[s_source : s_source + l], np.multiply(scl, sec[s_sec : s_sec + l])
|
||||
)
|
||||
sources[i] = torch.from_numpy(pri).float()
|
||||
|
||||
return sources
|
||||
@@ -0,0 +1,43 @@
|
||||
import os
|
||||
from fairseq.data.audio import (
|
||||
AudioTransform,
|
||||
CompositeAudioTransform,
|
||||
import_transforms,
|
||||
register_audio_transform,
|
||||
)
|
||||
|
||||
|
||||
class AudioFeatureTransform(AudioTransform):
|
||||
pass
|
||||
|
||||
|
||||
AUDIO_FEATURE_TRANSFORM_REGISTRY = {}
|
||||
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set()
|
||||
|
||||
|
||||
def get_audio_feature_transform(name):
|
||||
return AUDIO_FEATURE_TRANSFORM_REGISTRY[name]
|
||||
|
||||
|
||||
def register_audio_feature_transform(name):
|
||||
return register_audio_transform(
|
||||
name,
|
||||
AudioFeatureTransform,
|
||||
AUDIO_FEATURE_TRANSFORM_REGISTRY,
|
||||
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES,
|
||||
)
|
||||
|
||||
|
||||
import_transforms(os.path.dirname(__file__), "feature")
|
||||
|
||||
|
||||
class CompositeAudioFeatureTransform(CompositeAudioTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
return super()._from_config_dict(
|
||||
cls,
|
||||
"feature",
|
||||
get_audio_feature_transform,
|
||||
CompositeAudioFeatureTransform,
|
||||
config,
|
||||
)
|
||||
@@ -0,0 +1,37 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data.audio.feature_transforms import (
|
||||
AudioFeatureTransform,
|
||||
register_audio_feature_transform,
|
||||
)
|
||||
|
||||
|
||||
@register_audio_feature_transform("delta_deltas")
|
||||
class DeltaDeltas(AudioFeatureTransform):
|
||||
"""Expand delta-deltas features from spectrum."""
|
||||
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return DeltaDeltas(_config.get("win_length", 5))
|
||||
|
||||
def __init__(self, win_length=5):
|
||||
self.win_length = win_length
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def __call__(self, spectrogram):
|
||||
from torchaudio.functional import compute_deltas
|
||||
|
||||
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
|
||||
# spectrogram is T x F, while compute_deltas takes (…, F, T)
|
||||
spectrogram = torch.from_numpy(spectrogram).transpose(0, 1)
|
||||
delta = compute_deltas(spectrogram)
|
||||
delta_delta = compute_deltas(delta)
|
||||
|
||||
out_feat = np.concatenate(
|
||||
[spectrogram, delta.numpy(), delta_delta.numpy()], axis=0
|
||||
)
|
||||
out_feat = np.transpose(out_feat)
|
||||
return out_feat
|
||||
@@ -0,0 +1,29 @@
|
||||
import numpy as np
|
||||
from fairseq.data.audio.feature_transforms import (
|
||||
AudioFeatureTransform,
|
||||
register_audio_feature_transform,
|
||||
)
|
||||
|
||||
|
||||
@register_audio_feature_transform("global_cmvn")
|
||||
class GlobalCMVN(AudioFeatureTransform):
|
||||
"""Global CMVN (cepstral mean and variance normalization). The global mean
|
||||
and variance need to be pre-computed and stored in NumPy format (.npz)."""
|
||||
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return GlobalCMVN(_config.get("stats_npz_path"))
|
||||
|
||||
def __init__(self, stats_npz_path):
|
||||
self.stats_npz_path = stats_npz_path
|
||||
stats = np.load(stats_npz_path)
|
||||
self.mean, self.std = stats["mean"], stats["std"]
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(stats_npz_path="{self.stats_npz_path}")'
|
||||
|
||||
def __call__(self, x):
|
||||
x = np.subtract(x, self.mean)
|
||||
x = np.divide(x, self.std)
|
||||
return x
|
||||
@@ -0,0 +1,131 @@
|
||||
import math
|
||||
import numbers
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from fairseq.data.audio.feature_transforms import (
|
||||
AudioFeatureTransform,
|
||||
register_audio_feature_transform,
|
||||
)
|
||||
|
||||
|
||||
@register_audio_feature_transform("specaugment")
|
||||
class SpecAugmentTransform(AudioFeatureTransform):
|
||||
"""SpecAugment (https://arxiv.org/abs/1904.08779)"""
|
||||
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return SpecAugmentTransform(
|
||||
_config.get("time_warp_W", 0),
|
||||
_config.get("freq_mask_N", 0),
|
||||
_config.get("freq_mask_F", 0),
|
||||
_config.get("time_mask_N", 0),
|
||||
_config.get("time_mask_T", 0),
|
||||
_config.get("time_mask_p", 0.0),
|
||||
_config.get("mask_value", None),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
time_warp_w: int = 0,
|
||||
freq_mask_n: int = 0,
|
||||
freq_mask_f: int = 0,
|
||||
time_mask_n: int = 0,
|
||||
time_mask_t: int = 0,
|
||||
time_mask_p: float = 0.0,
|
||||
mask_value: Optional[float] = 0.0,
|
||||
):
|
||||
# Sanity checks
|
||||
assert mask_value is None or isinstance(
|
||||
mask_value, numbers.Number
|
||||
), f"mask_value (type: {type(mask_value)}) must be None or a number"
|
||||
if freq_mask_n > 0:
|
||||
assert freq_mask_f > 0, (
|
||||
f"freq_mask_F ({freq_mask_f}) "
|
||||
f"must be larger than 0 when doing freq masking."
|
||||
)
|
||||
if time_mask_n > 0:
|
||||
assert time_mask_t > 0, (
|
||||
f"time_mask_T ({time_mask_t}) must be larger than 0 when "
|
||||
f"doing time masking."
|
||||
)
|
||||
|
||||
self.time_warp_w = time_warp_w
|
||||
self.freq_mask_n = freq_mask_n
|
||||
self.freq_mask_f = freq_mask_f
|
||||
self.time_mask_n = time_mask_n
|
||||
self.time_mask_t = time_mask_t
|
||||
self.time_mask_p = time_mask_p
|
||||
self.mask_value = mask_value
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ "("
|
||||
+ ", ".join(
|
||||
[
|
||||
f"time_warp_w={self.time_warp_w}",
|
||||
f"freq_mask_n={self.freq_mask_n}",
|
||||
f"freq_mask_f={self.freq_mask_f}",
|
||||
f"time_mask_n={self.time_mask_n}",
|
||||
f"time_mask_t={self.time_mask_t}",
|
||||
f"time_mask_p={self.time_mask_p}",
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
|
||||
def __call__(self, spectrogram):
|
||||
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
|
||||
|
||||
distorted = spectrogram.copy() # make a copy of input spectrogram.
|
||||
num_frames = spectrogram.shape[0] # or 'tau' in the paper.
|
||||
num_freqs = spectrogram.shape[1] # or 'miu' in the paper.
|
||||
mask_value = self.mask_value
|
||||
|
||||
if mask_value is None: # if no value was specified, use local mean.
|
||||
mask_value = spectrogram.mean()
|
||||
|
||||
if num_frames == 0:
|
||||
return spectrogram
|
||||
|
||||
if num_freqs < self.freq_mask_f:
|
||||
return spectrogram
|
||||
|
||||
if self.time_warp_w > 0:
|
||||
if 2 * self.time_warp_w < num_frames:
|
||||
import cv2
|
||||
|
||||
w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w)
|
||||
w = np.random.randint(-self.time_warp_w + 1, self.time_warp_w)
|
||||
upper, lower = distorted[:w0, :], distorted[w0:, :]
|
||||
upper = cv2.resize(
|
||||
upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR
|
||||
)
|
||||
lower = cv2.resize(
|
||||
lower,
|
||||
dsize=(num_freqs, num_frames - w0 - w),
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
distorted = np.concatenate((upper, lower), axis=0)
|
||||
|
||||
for _i in range(self.freq_mask_n):
|
||||
f = np.random.randint(0, self.freq_mask_f)
|
||||
f0 = np.random.randint(0, num_freqs - f)
|
||||
if f != 0:
|
||||
distorted[:, f0 : f0 + f] = mask_value
|
||||
|
||||
max_time_mask_t = min(
|
||||
self.time_mask_t, math.floor(num_frames * self.time_mask_p)
|
||||
)
|
||||
if max_time_mask_t < 1:
|
||||
return distorted
|
||||
|
||||
for _i in range(self.time_mask_n):
|
||||
t = np.random.randint(0, max_time_mask_t)
|
||||
t0 = np.random.randint(0, num_frames - t)
|
||||
if t != 0:
|
||||
distorted[t0 : t0 + t, :] = mask_value
|
||||
|
||||
return distorted
|
||||
@@ -0,0 +1,41 @@
|
||||
import numpy as np
|
||||
|
||||
from fairseq.data.audio.feature_transforms import (
|
||||
AudioFeatureTransform,
|
||||
register_audio_feature_transform,
|
||||
)
|
||||
|
||||
|
||||
@register_audio_feature_transform("utterance_cmvn")
|
||||
class UtteranceCMVN(AudioFeatureTransform):
|
||||
"""Utterance-level CMVN (cepstral mean and variance normalization)"""
|
||||
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return UtteranceCMVN(
|
||||
_config.get("norm_means", True),
|
||||
_config.get("norm_vars", True),
|
||||
)
|
||||
|
||||
def __init__(self, norm_means=True, norm_vars=True):
|
||||
self.norm_means, self.norm_vars = norm_means, norm_vars
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
mean = x.mean(axis=0)
|
||||
square_sums = (x**2).sum(axis=0)
|
||||
|
||||
if self.norm_means:
|
||||
x = np.subtract(x, mean)
|
||||
if self.norm_vars:
|
||||
var = square_sums / x.shape[0] - mean**2
|
||||
std = np.sqrt(np.maximum(var, 1e-10))
|
||||
x = np.divide(x, std)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,205 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.abs
|
||||
|
||||
import csv
|
||||
import logging
|
||||
import os.path as op
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
|
||||
from fairseq.data.audio.text_to_speech_dataset import (
|
||||
TextToSpeechDataset,
|
||||
TextToSpeechDatasetCreator,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FrmTextToSpeechDataset(TextToSpeechDataset):
|
||||
def __init__(
|
||||
self,
|
||||
split: str,
|
||||
is_train_split: bool,
|
||||
data_cfg: S2TDataConfig,
|
||||
audio_paths: List[str],
|
||||
n_frames: List[int],
|
||||
src_texts: Optional[List[str]] = None,
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
speakers: Optional[List[str]] = None,
|
||||
src_langs: Optional[List[str]] = None,
|
||||
tgt_langs: Optional[List[str]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
tgt_dict: Optional[Dictionary] = None,
|
||||
pre_tokenizer=None,
|
||||
bpe_tokenizer=None,
|
||||
n_frames_per_step=1,
|
||||
speaker_to_id=None,
|
||||
do_chunk=False,
|
||||
chunk_bound=-1,
|
||||
chunk_init=50,
|
||||
chunk_incr=5,
|
||||
add_eos=True,
|
||||
dedup=True,
|
||||
ref_fpu=-1,
|
||||
):
|
||||
# It assumes texts are encoded at a fixed frame-rate
|
||||
super().__init__(
|
||||
split=split,
|
||||
is_train_split=is_train_split,
|
||||
data_cfg=data_cfg,
|
||||
audio_paths=audio_paths,
|
||||
n_frames=n_frames,
|
||||
src_texts=src_texts,
|
||||
tgt_texts=tgt_texts,
|
||||
speakers=speakers,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
ids=ids,
|
||||
tgt_dict=tgt_dict,
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
speaker_to_id=speaker_to_id,
|
||||
)
|
||||
|
||||
self.do_chunk = do_chunk
|
||||
self.chunk_bound = chunk_bound
|
||||
self.chunk_init = chunk_init
|
||||
self.chunk_incr = chunk_incr
|
||||
self.add_eos = add_eos
|
||||
self.dedup = dedup
|
||||
self.ref_fpu = ref_fpu
|
||||
|
||||
self.chunk_size = -1
|
||||
|
||||
if do_chunk:
|
||||
assert self.chunk_incr >= 0
|
||||
assert self.pre_tokenizer is None
|
||||
|
||||
def __getitem__(self, index):
|
||||
index, source, target, speaker_id, _, _, _ = super().__getitem__(index)
|
||||
if target[-1].item() == self.tgt_dict.eos_index:
|
||||
target = target[:-1]
|
||||
|
||||
fpu = source.size(0) / target.size(0) # frame-per-unit
|
||||
fps = self.n_frames_per_step
|
||||
assert (
|
||||
self.ref_fpu == -1 or abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1
|
||||
), f"{fpu*fps} != {self.ref_fpu}"
|
||||
|
||||
# only chunk training split
|
||||
if self.is_train_split and self.do_chunk and self.chunk_size > 0:
|
||||
lang = target[: int(self.data_cfg.prepend_tgt_lang_tag)]
|
||||
text = target[int(self.data_cfg.prepend_tgt_lang_tag) :]
|
||||
size = len(text)
|
||||
chunk_size = min(self.chunk_size, size)
|
||||
chunk_start = np.random.randint(size - chunk_size + 1)
|
||||
text = text[chunk_start : chunk_start + chunk_size]
|
||||
target = torch.cat((lang, text), 0)
|
||||
|
||||
f_size = int(np.floor(chunk_size * fpu))
|
||||
f_start = int(np.floor(chunk_start * fpu))
|
||||
assert f_size > 0
|
||||
source = source[f_start : f_start + f_size, :]
|
||||
|
||||
if self.dedup:
|
||||
target = torch.unique_consecutive(target)
|
||||
|
||||
if self.add_eos:
|
||||
eos_idx = self.tgt_dict.eos_index
|
||||
target = torch.cat((target, torch.LongTensor([eos_idx])), 0)
|
||||
|
||||
return index, source, target, speaker_id
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
if self.is_train_split and self.do_chunk:
|
||||
old = self.chunk_size
|
||||
self.chunk_size = self.chunk_init + epoch * self.chunk_incr
|
||||
if self.chunk_bound > 0:
|
||||
self.chunk_size = min(self.chunk_size, self.chunk_bound)
|
||||
logger.info(
|
||||
(
|
||||
f"{self.split}: setting chunk size "
|
||||
f"from {old} to {self.chunk_size}"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
|
||||
# inherit for key names
|
||||
@classmethod
|
||||
def from_tsv(
|
||||
cls,
|
||||
root: str,
|
||||
data_cfg: S2TDataConfig,
|
||||
split: str,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
is_train_split: bool,
|
||||
n_frames_per_step: int,
|
||||
speaker_to_id,
|
||||
do_chunk: bool = False,
|
||||
chunk_bound: int = -1,
|
||||
chunk_init: int = 50,
|
||||
chunk_incr: int = 5,
|
||||
add_eos: bool = True,
|
||||
dedup: bool = True,
|
||||
ref_fpu: float = -1,
|
||||
) -> FrmTextToSpeechDataset:
|
||||
tsv_path = op.join(root, f"{split}.tsv")
|
||||
if not op.isfile(tsv_path):
|
||||
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
|
||||
with open(tsv_path) as f:
|
||||
reader = csv.DictReader(
|
||||
f,
|
||||
delimiter="\t",
|
||||
quotechar=None,
|
||||
doublequote=False,
|
||||
lineterminator="\n",
|
||||
quoting=csv.QUOTE_NONE,
|
||||
)
|
||||
s = [dict(e) for e in reader]
|
||||
assert len(s) > 0
|
||||
|
||||
ids = [ss[cls.KEY_ID] for ss in s]
|
||||
audio_paths = [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s]
|
||||
n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s]
|
||||
tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s]
|
||||
src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s]
|
||||
speakers = [ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s]
|
||||
src_langs = [ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s]
|
||||
tgt_langs = [ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s]
|
||||
|
||||
return FrmTextToSpeechDataset(
|
||||
split=split,
|
||||
is_train_split=is_train_split,
|
||||
data_cfg=data_cfg,
|
||||
audio_paths=audio_paths,
|
||||
n_frames=n_frames,
|
||||
src_texts=src_texts,
|
||||
tgt_texts=tgt_texts,
|
||||
speakers=speakers,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
ids=ids,
|
||||
tgt_dict=tgt_dict,
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
speaker_to_id=speaker_to_id,
|
||||
do_chunk=do_chunk,
|
||||
chunk_bound=chunk_bound,
|
||||
chunk_init=chunk_init,
|
||||
chunk_incr=chunk_incr,
|
||||
add_eos=add_eos,
|
||||
dedup=dedup,
|
||||
ref_fpu=ref_fpu,
|
||||
)
|
||||
356
modules/voice_conversion/fairseq/data/audio/hubert_dataset.py
Normal file
356
modules/voice_conversion/fairseq/data/audio/hubert_dataset.py
Normal file
@@ -0,0 +1,356 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairseq.data import data_utils
|
||||
from fairseq.data.fairseq_dataset import FairseqDataset
|
||||
from fairseq.data.audio.audio_utils import (
|
||||
parse_path,
|
||||
read_from_stored_zip,
|
||||
)
|
||||
import io
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_audio(manifest_path, max_keep, min_keep):
|
||||
n_long, n_short = 0, 0
|
||||
names, inds, sizes = [], [], []
|
||||
with open(manifest_path) as f:
|
||||
root = f.readline().strip()
|
||||
for ind, line in enumerate(f):
|
||||
items = line.strip().split("\t")
|
||||
assert len(items) == 2, line
|
||||
sz = int(items[1])
|
||||
if min_keep is not None and sz < min_keep:
|
||||
n_short += 1
|
||||
elif max_keep is not None and sz > max_keep:
|
||||
n_long += 1
|
||||
else:
|
||||
names.append(items[0])
|
||||
inds.append(ind)
|
||||
sizes.append(sz)
|
||||
tot = ind + 1
|
||||
logger.info(
|
||||
(
|
||||
f"max_keep={max_keep}, min_keep={min_keep}, "
|
||||
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
|
||||
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
||||
)
|
||||
)
|
||||
return root, names, inds, tot, sizes
|
||||
|
||||
|
||||
def load_label(label_path, inds, tot):
|
||||
with open(label_path) as f:
|
||||
labels = [line.rstrip() for line in f]
|
||||
assert (
|
||||
len(labels) == tot
|
||||
), f"number of labels does not match ({len(labels)} != {tot})"
|
||||
labels = [labels[i] for i in inds]
|
||||
return labels
|
||||
|
||||
|
||||
def load_label_offset(label_path, inds, tot):
|
||||
with open(label_path) as f:
|
||||
code_lengths = [len(line.encode("utf-8")) for line in f]
|
||||
assert (
|
||||
len(code_lengths) == tot
|
||||
), f"number of labels does not match ({len(code_lengths)} != {tot})"
|
||||
offsets = list(itertools.accumulate([0] + code_lengths))
|
||||
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
|
||||
return offsets
|
||||
|
||||
|
||||
def verify_label_lengths(
|
||||
audio_sizes,
|
||||
audio_rate,
|
||||
label_path,
|
||||
label_rate,
|
||||
inds,
|
||||
tot,
|
||||
tol=0.1, # tolerance in seconds
|
||||
):
|
||||
if label_rate < 0:
|
||||
logger.info(f"{label_path} is sequence label. skipped")
|
||||
return
|
||||
|
||||
with open(label_path) as f:
|
||||
lengths = [len(line.rstrip().split()) for line in f]
|
||||
assert len(lengths) == tot
|
||||
lengths = [lengths[i] for i in inds]
|
||||
num_invalid = 0
|
||||
for i, ind in enumerate(inds):
|
||||
dur_from_audio = audio_sizes[i] / audio_rate
|
||||
dur_from_label = lengths[i] / label_rate
|
||||
if abs(dur_from_audio - dur_from_label) > tol:
|
||||
logger.warning(
|
||||
(
|
||||
f"audio and label duration differ too much "
|
||||
f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
|
||||
f"in line {ind+1} of {label_path}. Check if `label_rate` "
|
||||
f"is correctly set (currently {label_rate}). "
|
||||
f"num. of samples = {audio_sizes[i]}; "
|
||||
f"label length = {lengths[i]}"
|
||||
)
|
||||
)
|
||||
num_invalid += 1
|
||||
if num_invalid > 0:
|
||||
logger.warning(
|
||||
f"total {num_invalid} (audio, label) pairs with mismatched lengths"
|
||||
)
|
||||
|
||||
|
||||
class HubertDataset(FairseqDataset):
|
||||
def __init__(
|
||||
self,
|
||||
manifest_path: str,
|
||||
sample_rate: float,
|
||||
label_paths: List[str],
|
||||
label_rates: Union[List[float], float], # -1 for sequence labels
|
||||
pad_list: List[str],
|
||||
eos_list: List[str],
|
||||
label_processors: Optional[List[Any]] = None,
|
||||
max_keep_sample_size: Optional[int] = None,
|
||||
min_keep_sample_size: Optional[int] = None,
|
||||
max_sample_size: Optional[int] = None,
|
||||
shuffle: bool = True,
|
||||
pad_audio: bool = False,
|
||||
normalize: bool = False,
|
||||
store_labels: bool = True,
|
||||
random_crop: bool = False,
|
||||
single_target: bool = False,
|
||||
):
|
||||
self.audio_root, self.audio_names, inds, tot, self.sizes = load_audio(
|
||||
manifest_path, max_keep_sample_size, min_keep_sample_size
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
self.shuffle = shuffle
|
||||
self.random_crop = random_crop
|
||||
|
||||
self.num_labels = len(label_paths)
|
||||
self.pad_list = pad_list
|
||||
self.eos_list = eos_list
|
||||
self.label_processors = label_processors
|
||||
self.single_target = single_target
|
||||
self.label_rates = (
|
||||
[label_rates for _ in range(len(label_paths))]
|
||||
if isinstance(label_rates, float)
|
||||
else label_rates
|
||||
)
|
||||
self.store_labels = store_labels
|
||||
if store_labels:
|
||||
self.label_list = [load_label(p, inds, tot) for p in label_paths]
|
||||
else:
|
||||
self.label_paths = label_paths
|
||||
self.label_offsets_list = [
|
||||
load_label_offset(p, inds, tot) for p in label_paths
|
||||
]
|
||||
assert label_processors is None or len(label_processors) == self.num_labels
|
||||
for label_path, label_rate in zip(label_paths, self.label_rates):
|
||||
verify_label_lengths(
|
||||
self.sizes, sample_rate, label_path, label_rate, inds, tot
|
||||
)
|
||||
|
||||
self.max_sample_size = (
|
||||
max_sample_size if max_sample_size is not None else sys.maxsize
|
||||
)
|
||||
self.pad_audio = pad_audio
|
||||
self.normalize = normalize
|
||||
logger.info(
|
||||
f"pad_audio={pad_audio}, random_crop={random_crop}, "
|
||||
f"normalize={normalize}, max_sample_size={self.max_sample_size}"
|
||||
)
|
||||
|
||||
def get_audio(self, index):
|
||||
import soundfile as sf
|
||||
|
||||
wav_path = os.path.join(self.audio_root, self.audio_names[index])
|
||||
_path, slice_ptr = parse_path(wav_path)
|
||||
if len(slice_ptr) == 0:
|
||||
wav, cur_sample_rate = sf.read(_path)
|
||||
else:
|
||||
assert _path.endswith(".zip")
|
||||
data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
|
||||
f = io.BytesIO(data)
|
||||
wav, cur_sample_rate = sf.read(f)
|
||||
wav = torch.from_numpy(wav).float()
|
||||
wav = self.postprocess(wav, cur_sample_rate)
|
||||
return wav
|
||||
|
||||
def get_label(self, index, label_idx):
|
||||
if self.store_labels:
|
||||
label = self.label_list[label_idx][index]
|
||||
else:
|
||||
with open(self.label_paths[label_idx]) as f:
|
||||
offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
||||
f.seek(offset_s)
|
||||
label = f.read(offset_e - offset_s)
|
||||
|
||||
if self.label_processors is not None:
|
||||
label = self.label_processors[label_idx](label)
|
||||
return label
|
||||
|
||||
def get_labels(self, index):
|
||||
return [self.get_label(index, i) for i in range(self.num_labels)]
|
||||
|
||||
def __getitem__(self, index):
|
||||
wav = self.get_audio(index)
|
||||
labels = self.get_labels(index)
|
||||
return {"id": index, "source": wav, "label_list": labels}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sizes)
|
||||
|
||||
def crop_to_max_size(self, wav, target_size):
|
||||
size = len(wav)
|
||||
diff = size - target_size
|
||||
if diff <= 0:
|
||||
return wav, 0
|
||||
|
||||
start, end = 0, target_size
|
||||
if self.random_crop:
|
||||
start = np.random.randint(0, diff + 1)
|
||||
end = size - diff + start
|
||||
return wav[start:end], start
|
||||
|
||||
def collater(self, samples):
|
||||
# target = max(sizes) -> random_crop not used
|
||||
# target = max_sample_size -> random_crop used for long
|
||||
samples = [s for s in samples if s["source"] is not None]
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
audios = [s["source"] for s in samples]
|
||||
audio_sizes = [len(s) for s in audios]
|
||||
if self.pad_audio:
|
||||
audio_size = min(max(audio_sizes), self.max_sample_size)
|
||||
else:
|
||||
audio_size = min(min(audio_sizes), self.max_sample_size)
|
||||
collated_audios, padding_mask, audio_starts = self.collater_audio(
|
||||
audios, audio_size
|
||||
)
|
||||
|
||||
targets_by_label = [
|
||||
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
|
||||
]
|
||||
targets_list, lengths_list, ntokens_list = self.collater_label(
|
||||
targets_by_label, audio_size, audio_starts
|
||||
)
|
||||
|
||||
net_input = {"source": collated_audios, "padding_mask": padding_mask}
|
||||
batch = {
|
||||
"id": torch.LongTensor([s["id"] for s in samples]),
|
||||
"net_input": net_input,
|
||||
}
|
||||
|
||||
if self.single_target:
|
||||
batch["target_lengths"] = lengths_list[0]
|
||||
batch["ntokens"] = ntokens_list[0]
|
||||
batch["target"] = targets_list[0]
|
||||
else:
|
||||
batch["target_lengths_list"] = lengths_list
|
||||
batch["ntokens_list"] = ntokens_list
|
||||
batch["target_list"] = targets_list
|
||||
return batch
|
||||
|
||||
def collater_audio(self, audios, audio_size):
|
||||
collated_audios = audios[0].new_zeros(len(audios), audio_size)
|
||||
padding_mask = (
|
||||
torch.BoolTensor(collated_audios.shape).fill_(False)
|
||||
# if self.pad_audio else None
|
||||
)
|
||||
audio_starts = [0 for _ in audios]
|
||||
for i, audio in enumerate(audios):
|
||||
diff = len(audio) - audio_size
|
||||
if diff == 0:
|
||||
collated_audios[i] = audio
|
||||
elif diff < 0:
|
||||
assert self.pad_audio
|
||||
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
|
||||
padding_mask[i, diff:] = True
|
||||
else:
|
||||
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
|
||||
audio, audio_size
|
||||
)
|
||||
return collated_audios, padding_mask, audio_starts
|
||||
|
||||
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
|
||||
assert label_rate > 0
|
||||
s2f = label_rate / self.sample_rate
|
||||
frm_starts = [int(round(s * s2f)) for s in audio_starts]
|
||||
frm_size = int(round(audio_size * s2f))
|
||||
if not self.pad_audio:
|
||||
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
|
||||
frm_size = min(frm_size, *rem_size)
|
||||
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
|
||||
logger.debug(f"audio_starts={audio_starts}")
|
||||
logger.debug(f"frame_starts={frm_starts}")
|
||||
logger.debug(f"frame_size={frm_size}")
|
||||
|
||||
lengths = torch.LongTensor([len(t) for t in targets])
|
||||
ntokens = lengths.sum().item()
|
||||
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
||||
return targets, lengths, ntokens
|
||||
|
||||
def collater_seq_label(self, targets, pad):
|
||||
lengths = torch.LongTensor([len(t) for t in targets])
|
||||
ntokens = lengths.sum().item()
|
||||
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
||||
return targets, lengths, ntokens
|
||||
|
||||
def collater_label(self, targets_by_label, audio_size, audio_starts):
|
||||
targets_list, lengths_list, ntokens_list = [], [], []
|
||||
itr = zip(targets_by_label, self.label_rates, self.pad_list)
|
||||
for targets, label_rate, pad in itr:
|
||||
if label_rate == -1.0:
|
||||
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
|
||||
else:
|
||||
targets, lengths, ntokens = self.collater_frm_label(
|
||||
targets, audio_size, audio_starts, label_rate, pad
|
||||
)
|
||||
targets_list.append(targets)
|
||||
lengths_list.append(lengths)
|
||||
ntokens_list.append(ntokens)
|
||||
return targets_list, lengths_list, ntokens_list
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.size(index)
|
||||
|
||||
def size(self, index):
|
||||
if self.pad_audio:
|
||||
return self.sizes[index]
|
||||
return min(self.sizes[index], self.max_sample_size)
|
||||
|
||||
def ordered_indices(self):
|
||||
if self.shuffle:
|
||||
order = [np.random.permutation(len(self))]
|
||||
else:
|
||||
order = [np.arange(len(self))]
|
||||
|
||||
order.append(self.sizes)
|
||||
return np.lexsort(order)[::-1]
|
||||
|
||||
def postprocess(self, wav, cur_sample_rate):
|
||||
if wav.dim() == 2:
|
||||
wav = wav.mean(-1)
|
||||
assert wav.dim() == 1, wav.dim()
|
||||
|
||||
if cur_sample_rate != self.sample_rate:
|
||||
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
||||
|
||||
if self.normalize:
|
||||
with torch.no_grad():
|
||||
wav = F.layer_norm(wav, wav.shape)
|
||||
return wav
|
||||
@@ -0,0 +1,284 @@
|
||||
# Copyright (c) 2021-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Optional, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
from fairseq.data.resampling_dataset import ResamplingDataset
|
||||
import torch
|
||||
from fairseq.data import (
|
||||
ConcatDataset,
|
||||
LanguagePairDataset,
|
||||
FileAudioDataset,
|
||||
data_utils,
|
||||
)
|
||||
from fairseq.data import FairseqDataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModalityDatasetItem(NamedTuple):
|
||||
datasetname: str
|
||||
dataset: any
|
||||
max_positions: List[int]
|
||||
max_tokens: Optional[int] = None
|
||||
max_sentences: Optional[int] = None
|
||||
|
||||
|
||||
def resampling_dataset_present(ds):
|
||||
if isinstance(ds, ResamplingDataset):
|
||||
return True
|
||||
if isinstance(ds, ConcatDataset):
|
||||
return any(resampling_dataset_present(d) for d in ds.datasets)
|
||||
if hasattr(ds, "dataset"):
|
||||
return resampling_dataset_present(ds.dataset)
|
||||
return False
|
||||
|
||||
|
||||
# MultiModalityDataset: it concate multiple datasets with different modalities.
|
||||
# Compared with ConcatDataset it can 1) sample data given the ratios for different datasets
|
||||
# 2) it adds mode to indicate what type of the data samples come from.
|
||||
# It will be used with GroupedEpochBatchIterator together to generate mini-batch with samples
|
||||
# from the same type of dataset
|
||||
# If only one dataset is used, it will perform like the original dataset with mode added
|
||||
class MultiModalityDataset(ConcatDataset):
|
||||
def __init__(self, datasets: List[ModalityDatasetItem]):
|
||||
id_to_mode = []
|
||||
dsets = []
|
||||
max_tokens = []
|
||||
max_sentences = []
|
||||
max_positions = []
|
||||
for dset in datasets:
|
||||
id_to_mode.append(dset.datasetname)
|
||||
dsets.append(dset.dataset)
|
||||
max_tokens.append(dset.max_tokens)
|
||||
max_positions.append(dset.max_positions)
|
||||
max_sentences.append(dset.max_sentences)
|
||||
weights = [1.0 for s in dsets]
|
||||
super().__init__(dsets, weights)
|
||||
self.max_tokens = max_tokens
|
||||
self.max_positions = max_positions
|
||||
self.max_sentences = max_sentences
|
||||
self.id_to_mode = id_to_mode
|
||||
self.raw_sub_batch_samplers = []
|
||||
self._cur_epoch = 0
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
super().set_epoch(epoch)
|
||||
self._cur_epoch = epoch
|
||||
|
||||
def __getitem__(self, idx):
|
||||
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
||||
sample = self.datasets[dataset_idx][sample_idx]
|
||||
return (dataset_idx, sample)
|
||||
|
||||
def collater(self, samples):
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
dataset_idx = samples[0][0]
|
||||
# make sure all samples in samples are from same dataset
|
||||
assert sum([0 if dataset_idx == s[0] else 1 for s in samples]) == 0
|
||||
samples = self.datasets[dataset_idx].collater([x[1] for x in samples])
|
||||
# add mode
|
||||
samples["net_input"]["mode"] = self.id_to_mode[dataset_idx]
|
||||
|
||||
return samples
|
||||
|
||||
def size(self, index: int):
|
||||
if len(self.datasets) == 1:
|
||||
return self.datasets[0].size(index)
|
||||
return super().size(index)
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
if len(self.datasets) == 1:
|
||||
return self.datasets[0].sizes
|
||||
return super().sizes
|
||||
|
||||
def ordered_indices(self):
|
||||
"""
|
||||
Returns indices sorted by length. So less padding is needed.
|
||||
"""
|
||||
if len(self.datasets) == 1:
|
||||
return self.datasets[0].ordered_indices()
|
||||
indices_group = []
|
||||
for d_idx, ds in enumerate(self.datasets):
|
||||
sample_num = self.cumulative_sizes[d_idx]
|
||||
if d_idx > 0:
|
||||
sample_num = sample_num - self.cumulative_sizes[d_idx - 1]
|
||||
assert sample_num == len(ds)
|
||||
indices_group.append(ds.ordered_indices())
|
||||
return indices_group
|
||||
|
||||
def get_raw_batch_samplers(self, required_batch_size_multiple, seed):
|
||||
with data_utils.numpy_seed(seed):
|
||||
indices = self.ordered_indices()
|
||||
for i, ds in enumerate(self.datasets):
|
||||
# If we have ResamplingDataset, the same id can correpond to a different
|
||||
# sample in the next epoch, so we need to rebuild this at every epoch
|
||||
if i < len(self.raw_sub_batch_samplers) and not resampling_dataset_present(
|
||||
ds
|
||||
):
|
||||
logger.info(f"dataset {i} is valid and it is not re-sampled")
|
||||
continue
|
||||
indices[i] = ds.filter_indices_by_size(
|
||||
indices[i],
|
||||
self.max_positions[i],
|
||||
)[0]
|
||||
sub_batch_sampler = ds.batch_by_size(
|
||||
indices[i],
|
||||
max_tokens=self.max_tokens[i],
|
||||
max_sentences=self.max_sentences[i],
|
||||
required_batch_size_multiple=required_batch_size_multiple,
|
||||
)
|
||||
if i < len(self.raw_sub_batch_samplers):
|
||||
self.raw_sub_batch_samplers[i] = sub_batch_sampler
|
||||
else:
|
||||
self.raw_sub_batch_samplers.append(sub_batch_sampler)
|
||||
|
||||
def get_batch_samplers(self, mult_ratios, required_batch_size_multiple, seed):
|
||||
self.get_raw_batch_samplers(required_batch_size_multiple, seed)
|
||||
batch_samplers = []
|
||||
for i, _ in enumerate(self.datasets):
|
||||
if i > 0:
|
||||
sub_batch_sampler = [
|
||||
[y + self.cumulative_sizes[i - 1] for y in x]
|
||||
for x in self.raw_sub_batch_samplers[i]
|
||||
]
|
||||
else:
|
||||
sub_batch_sampler = list(self.raw_sub_batch_samplers[i])
|
||||
smp_r = mult_ratios[i]
|
||||
if smp_r != 1:
|
||||
is_increase = "increased" if smp_r > 1 else "decreased"
|
||||
logger.info(
|
||||
"number of batch for the dataset {} is {} from {} to {}".format(
|
||||
self.id_to_mode[i],
|
||||
is_increase,
|
||||
len(sub_batch_sampler),
|
||||
int(len(sub_batch_sampler) * smp_r),
|
||||
)
|
||||
)
|
||||
mul_samplers = []
|
||||
for _ in range(math.floor(smp_r)):
|
||||
mul_samplers = mul_samplers + sub_batch_sampler
|
||||
if math.floor(smp_r) != smp_r:
|
||||
with data_utils.numpy_seed(seed + self._cur_epoch):
|
||||
np.random.shuffle(sub_batch_sampler)
|
||||
smp_num = int(
|
||||
(smp_r - math.floor(smp_r)) * len(sub_batch_sampler)
|
||||
)
|
||||
mul_samplers = mul_samplers + sub_batch_sampler[:smp_num]
|
||||
sub_batch_sampler = mul_samplers
|
||||
else:
|
||||
logger.info(
|
||||
"dataset {} batch number is {} ".format(
|
||||
self.id_to_mode[i], len(sub_batch_sampler)
|
||||
)
|
||||
)
|
||||
batch_samplers.append(sub_batch_sampler)
|
||||
|
||||
return batch_samplers
|
||||
|
||||
|
||||
class LangPairMaskDataset(FairseqDataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: LanguagePairDataset,
|
||||
src_eos: int,
|
||||
src_bos: Optional[int] = None,
|
||||
noise_id: Optional[int] = -1,
|
||||
mask_ratio: Optional[float] = 0,
|
||||
mask_type: Optional[str] = "random",
|
||||
):
|
||||
self.dataset = dataset
|
||||
self.src_eos = src_eos
|
||||
self.src_bos = src_bos
|
||||
self.noise_id = noise_id
|
||||
self.mask_ratio = mask_ratio
|
||||
self.mask_type = mask_type
|
||||
assert mask_type in ("random", "tail")
|
||||
|
||||
@property
|
||||
def src_sizes(self):
|
||||
return self.dataset.src_sizes
|
||||
|
||||
@property
|
||||
def tgt_sizes(self):
|
||||
return self.dataset.tgt_sizes
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
# dataset.sizes can be a dynamically computed sizes:
|
||||
return self.dataset.sizes
|
||||
|
||||
def get_batch_shapes(self):
|
||||
if hasattr(self.dataset, "get_batch_shapes"):
|
||||
return self.dataset.get_batch_shapes()
|
||||
return self.dataset.buckets
|
||||
|
||||
def num_tokens_vec(self, indices):
|
||||
return self.dataset.num_tokens_vec(indices)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.dataset.num_tokens(index)
|
||||
|
||||
def size(self, index):
|
||||
return self.dataset.size(index)
|
||||
|
||||
def ordered_indices(self):
|
||||
return self.dataset.ordered_indices()
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return getattr(self.dataset, "supports_prefetch", False)
|
||||
|
||||
def prefetch(self, indices):
|
||||
return self.dataset.prefetch(indices)
|
||||
|
||||
def mask_src_tokens(self, sample):
|
||||
src_item = sample["source"]
|
||||
mask = None
|
||||
if self.mask_type == "random":
|
||||
mask = torch.rand(len(src_item)).le(self.mask_ratio)
|
||||
else:
|
||||
mask = torch.ones(len(src_item))
|
||||
mask[: int(len(src_item) * (1 - self.mask_ratio))] = 0
|
||||
mask = mask.eq(1)
|
||||
if src_item[0] == self.src_bos:
|
||||
mask[0] = False
|
||||
if src_item[-1] == self.src_eos:
|
||||
mask[-1] = False
|
||||
mask_src_item = src_item.masked_fill(mask, self.noise_id)
|
||||
smp = {"id": sample["id"], "source": mask_src_item, "target": sample["target"]}
|
||||
return smp
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.dataset[index]
|
||||
if self.mask_ratio > 0:
|
||||
sample = self.mask_src_tokens(sample)
|
||||
return sample
|
||||
|
||||
def collater(self, samples, pad_to_length=None):
|
||||
return self.dataset.collater(samples, pad_to_length)
|
||||
|
||||
|
||||
class FileAudioDatasetWrapper(FileAudioDataset):
|
||||
def collater(self, samples):
|
||||
samples = super().collater(samples)
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
samples["net_input"]["src_tokens"] = samples["net_input"]["source"]
|
||||
samples["net_input"]["prev_output_tokens"] = None
|
||||
del samples["net_input"]["source"]
|
||||
samples["net_input"]["src_lengths"] = None
|
||||
samples["net_input"]["alignment"] = None
|
||||
return samples
|
||||
393
modules/voice_conversion/fairseq/data/audio/raw_audio_dataset.py
Normal file
393
modules/voice_conversion/fairseq/data/audio/raw_audio_dataset.py
Normal file
@@ -0,0 +1,393 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .. import FairseqDataset
|
||||
from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes
|
||||
from fairseq.data.audio.audio_utils import (
|
||||
parse_path,
|
||||
read_from_stored_zip,
|
||||
is_sf_audio_data,
|
||||
)
|
||||
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RawAudioDataset(FairseqDataset):
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate,
|
||||
max_sample_size=None,
|
||||
min_sample_size=0,
|
||||
shuffle=True,
|
||||
pad=False,
|
||||
normalize=False,
|
||||
compute_mask_indices=False,
|
||||
**mask_compute_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.sizes = []
|
||||
self.max_sample_size = (
|
||||
max_sample_size if max_sample_size is not None else sys.maxsize
|
||||
)
|
||||
self.min_sample_size = min_sample_size
|
||||
self.pad = pad
|
||||
self.shuffle = shuffle
|
||||
self.normalize = normalize
|
||||
self.compute_mask_indices = compute_mask_indices
|
||||
if self.compute_mask_indices:
|
||||
self.mask_compute_kwargs = mask_compute_kwargs
|
||||
self._features_size_map = {}
|
||||
self._C = mask_compute_kwargs["encoder_embed_dim"]
|
||||
self._conv_feature_layers = eval(mask_compute_kwargs["conv_feature_layers"])
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sizes)
|
||||
|
||||
def postprocess(self, feats, curr_sample_rate):
|
||||
if feats.dim() == 2:
|
||||
feats = feats.mean(-1)
|
||||
|
||||
if curr_sample_rate != self.sample_rate:
|
||||
raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}")
|
||||
|
||||
assert feats.dim() == 1, feats.dim()
|
||||
|
||||
if self.normalize:
|
||||
with torch.no_grad():
|
||||
feats = F.layer_norm(feats, feats.shape)
|
||||
return feats
|
||||
|
||||
def crop_to_max_size(self, wav, target_size):
|
||||
size = len(wav)
|
||||
diff = size - target_size
|
||||
if diff <= 0:
|
||||
return wav
|
||||
|
||||
start = np.random.randint(0, diff + 1)
|
||||
end = size - diff + start
|
||||
return wav[start:end]
|
||||
|
||||
def _compute_mask_indices(self, dims, padding_mask):
|
||||
B, T, C = dims
|
||||
mask_indices, mask_channel_indices = None, None
|
||||
if self.mask_compute_kwargs["mask_prob"] > 0:
|
||||
mask_indices = compute_mask_indices(
|
||||
(B, T),
|
||||
padding_mask,
|
||||
self.mask_compute_kwargs["mask_prob"],
|
||||
self.mask_compute_kwargs["mask_length"],
|
||||
self.mask_compute_kwargs["mask_selection"],
|
||||
self.mask_compute_kwargs["mask_other"],
|
||||
min_masks=2,
|
||||
no_overlap=self.mask_compute_kwargs["no_mask_overlap"],
|
||||
min_space=self.mask_compute_kwargs["mask_min_space"],
|
||||
)
|
||||
mask_indices = torch.from_numpy(mask_indices)
|
||||
if self.mask_compute_kwargs["mask_channel_prob"] > 0:
|
||||
mask_channel_indices = compute_mask_indices(
|
||||
(B, C),
|
||||
None,
|
||||
self.mask_compute_kwargs["mask_channel_prob"],
|
||||
self.mask_compute_kwargs["mask_channel_length"],
|
||||
self.mask_compute_kwargs["mask_channel_selection"],
|
||||
self.mask_compute_kwargs["mask_channel_other"],
|
||||
no_overlap=self.mask_compute_kwargs["no_mask_channel_overlap"],
|
||||
min_space=self.mask_compute_kwargs["mask_channel_min_space"],
|
||||
)
|
||||
mask_channel_indices = (
|
||||
torch.from_numpy(mask_channel_indices).unsqueeze(1).expand(-1, T, -1)
|
||||
)
|
||||
|
||||
return mask_indices, mask_channel_indices
|
||||
|
||||
@staticmethod
|
||||
def _bucket_tensor(tensor, num_pad, value):
|
||||
return F.pad(tensor, (0, num_pad), value=value)
|
||||
|
||||
def collater(self, samples):
|
||||
samples = [s for s in samples if s["source"] is not None]
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
sources = [s["source"] for s in samples]
|
||||
sizes = [len(s) for s in sources]
|
||||
|
||||
if self.pad:
|
||||
target_size = min(max(sizes), self.max_sample_size)
|
||||
else:
|
||||
target_size = min(min(sizes), self.max_sample_size)
|
||||
|
||||
collated_sources = sources[0].new_zeros(len(sources), target_size)
|
||||
padding_mask = (
|
||||
torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None
|
||||
)
|
||||
for i, (source, size) in enumerate(zip(sources, sizes)):
|
||||
diff = size - target_size
|
||||
if diff == 0:
|
||||
collated_sources[i] = source
|
||||
elif diff < 0:
|
||||
assert self.pad
|
||||
collated_sources[i] = torch.cat(
|
||||
[source, source.new_full((-diff,), 0.0)]
|
||||
)
|
||||
padding_mask[i, diff:] = True
|
||||
else:
|
||||
collated_sources[i] = self.crop_to_max_size(source, target_size)
|
||||
|
||||
input = {"source": collated_sources}
|
||||
out = {"id": torch.LongTensor([s["id"] for s in samples])}
|
||||
if self.pad:
|
||||
input["padding_mask"] = padding_mask
|
||||
|
||||
if hasattr(self, "num_buckets") and self.num_buckets > 0:
|
||||
assert self.pad, "Cannot bucket without padding first."
|
||||
bucket = max(self._bucketed_sizes[s["id"]] for s in samples)
|
||||
num_pad = bucket - collated_sources.size(-1)
|
||||
if num_pad:
|
||||
input["source"] = self._bucket_tensor(collated_sources, num_pad, 0)
|
||||
input["padding_mask"] = self._bucket_tensor(padding_mask, num_pad, True)
|
||||
|
||||
if self.compute_mask_indices:
|
||||
B = input["source"].size(0)
|
||||
T = self._get_mask_indices_dims(input["source"].size(-1))
|
||||
padding_mask_reshaped = input["padding_mask"].clone()
|
||||
extra = padding_mask_reshaped.size(1) % T
|
||||
if extra > 0:
|
||||
padding_mask_reshaped = padding_mask_reshaped[:, :-extra]
|
||||
padding_mask_reshaped = padding_mask_reshaped.view(
|
||||
padding_mask_reshaped.size(0), T, -1
|
||||
)
|
||||
padding_mask_reshaped = padding_mask_reshaped.all(-1)
|
||||
input["padding_count"] = padding_mask_reshaped.sum(-1).max().item()
|
||||
mask_indices, mask_channel_indices = self._compute_mask_indices(
|
||||
(B, T, self._C),
|
||||
padding_mask_reshaped,
|
||||
)
|
||||
input["mask_indices"] = mask_indices
|
||||
input["mask_channel_indices"] = mask_channel_indices
|
||||
out["sample_size"] = mask_indices.sum().item()
|
||||
|
||||
out["net_input"] = input
|
||||
return out
|
||||
|
||||
def _get_mask_indices_dims(self, size, padding=0, dilation=1):
|
||||
if size not in self._features_size_map:
|
||||
L_in = size
|
||||
for (_, kernel_size, stride) in self._conv_feature_layers:
|
||||
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
|
||||
L_out = 1 + L_out // stride
|
||||
L_in = L_out
|
||||
self._features_size_map[size] = L_out
|
||||
return self._features_size_map[size]
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.size(index)
|
||||
|
||||
def size(self, index):
|
||||
"""Return an example's size as a float or tuple. This value is used when
|
||||
filtering a dataset with ``--max-positions``."""
|
||||
if self.pad:
|
||||
return self.sizes[index]
|
||||
return min(self.sizes[index], self.max_sample_size)
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
|
||||
if self.shuffle:
|
||||
order = [np.random.permutation(len(self))]
|
||||
order.append(
|
||||
np.minimum(
|
||||
np.array(self.sizes),
|
||||
self.max_sample_size,
|
||||
)
|
||||
)
|
||||
return np.lexsort(order)[::-1]
|
||||
else:
|
||||
return np.arange(len(self))
|
||||
|
||||
def set_bucket_info(self, num_buckets):
|
||||
self.num_buckets = num_buckets
|
||||
if self.num_buckets > 0:
|
||||
self._collated_sizes = np.minimum(
|
||||
np.array(self.sizes),
|
||||
self.max_sample_size,
|
||||
)
|
||||
self.buckets = get_buckets(
|
||||
self._collated_sizes,
|
||||
self.num_buckets,
|
||||
)
|
||||
self._bucketed_sizes = get_bucketed_sizes(
|
||||
self._collated_sizes, self.buckets
|
||||
)
|
||||
logger.info(
|
||||
f"{len(self.buckets)} bucket(s) for the audio dataset: "
|
||||
f"{self.buckets}"
|
||||
)
|
||||
|
||||
|
||||
class FileAudioDataset(RawAudioDataset):
|
||||
def __init__(
|
||||
self,
|
||||
manifest_path,
|
||||
sample_rate,
|
||||
max_sample_size=None,
|
||||
min_sample_size=0,
|
||||
shuffle=True,
|
||||
pad=False,
|
||||
normalize=False,
|
||||
num_buckets=0,
|
||||
compute_mask_indices=False,
|
||||
text_compression_level=TextCompressionLevel.none,
|
||||
**mask_compute_kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
max_sample_size=max_sample_size,
|
||||
min_sample_size=min_sample_size,
|
||||
shuffle=shuffle,
|
||||
pad=pad,
|
||||
normalize=normalize,
|
||||
compute_mask_indices=compute_mask_indices,
|
||||
**mask_compute_kwargs,
|
||||
)
|
||||
|
||||
self.text_compressor = TextCompressor(level=text_compression_level)
|
||||
|
||||
skipped = 0
|
||||
self.fnames = []
|
||||
sizes = []
|
||||
self.skipped_indices = set()
|
||||
|
||||
with open(manifest_path, "r") as f:
|
||||
self.root_dir = f.readline().strip()
|
||||
for i, line in enumerate(f):
|
||||
items = line.strip().split("\t")
|
||||
assert len(items) == 2, line
|
||||
sz = int(items[1])
|
||||
if min_sample_size is not None and sz < min_sample_size:
|
||||
skipped += 1
|
||||
self.skipped_indices.add(i)
|
||||
continue
|
||||
self.fnames.append(self.text_compressor.compress(items[0]))
|
||||
sizes.append(sz)
|
||||
logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples")
|
||||
|
||||
self.sizes = np.array(sizes, dtype=np.int64)
|
||||
|
||||
try:
|
||||
import pyarrow
|
||||
|
||||
self.fnames = pyarrow.array(self.fnames)
|
||||
except:
|
||||
logger.debug(
|
||||
"Could not create a pyarrow array. Please install pyarrow for better performance"
|
||||
)
|
||||
pass
|
||||
|
||||
self.set_bucket_info(num_buckets)
|
||||
|
||||
def __getitem__(self, index):
|
||||
import soundfile as sf
|
||||
|
||||
fn = self.fnames[index]
|
||||
fn = fn if isinstance(self.fnames, list) else fn.as_py()
|
||||
fn = self.text_compressor.decompress(fn)
|
||||
path_or_fp = os.path.join(self.root_dir, fn)
|
||||
_path, slice_ptr = parse_path(path_or_fp)
|
||||
if len(slice_ptr) == 2:
|
||||
byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
|
||||
assert is_sf_audio_data(byte_data)
|
||||
path_or_fp = io.BytesIO(byte_data)
|
||||
|
||||
wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32")
|
||||
|
||||
feats = torch.from_numpy(wav).float()
|
||||
feats = self.postprocess(feats, curr_sample_rate)
|
||||
return {"id": index, "source": feats}
|
||||
|
||||
|
||||
class BinarizedAudioDataset(RawAudioDataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir,
|
||||
split,
|
||||
sample_rate,
|
||||
max_sample_size=None,
|
||||
min_sample_size=0,
|
||||
shuffle=True,
|
||||
pad=False,
|
||||
normalize=False,
|
||||
num_buckets=0,
|
||||
compute_mask_indices=False,
|
||||
**mask_compute_kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
max_sample_size=max_sample_size,
|
||||
min_sample_size=min_sample_size,
|
||||
shuffle=shuffle,
|
||||
pad=pad,
|
||||
normalize=normalize,
|
||||
compute_mask_indices=compute_mask_indices,
|
||||
**mask_compute_kwargs,
|
||||
)
|
||||
|
||||
from fairseq.data import data_utils, Dictionary
|
||||
|
||||
self.fnames_dict = Dictionary.load(os.path.join(data_dir, "dict.txt"))
|
||||
|
||||
root_path = os.path.join(data_dir, f"{split}.root")
|
||||
if os.path.exists(root_path):
|
||||
with open(root_path, "r") as f:
|
||||
self.root_dir = next(f).strip()
|
||||
else:
|
||||
self.root_dir = None
|
||||
|
||||
fnames_path = os.path.join(data_dir, split)
|
||||
self.fnames = data_utils.load_indexed_dataset(fnames_path, self.fnames_dict)
|
||||
lengths_path = os.path.join(data_dir, f"{split}.lengths")
|
||||
|
||||
with open(lengths_path, "r") as f:
|
||||
for line in f:
|
||||
sz = int(line.rstrip())
|
||||
assert (
|
||||
sz >= min_sample_size
|
||||
), f"Min sample size is not supported for binarized dataset, but found a sample with size {sz}"
|
||||
self.sizes.append(sz)
|
||||
|
||||
self.sizes = np.array(self.sizes, dtype=np.int64)
|
||||
|
||||
self.set_bucket_info(num_buckets)
|
||||
logger.info(f"loaded {len(self.fnames)} samples")
|
||||
|
||||
def __getitem__(self, index):
|
||||
import soundfile as sf
|
||||
|
||||
fname = self.fnames_dict.string(self.fnames[index], separator="")
|
||||
if self.root_dir:
|
||||
fname = os.path.join(self.root_dir, fname)
|
||||
|
||||
wav, curr_sample_rate = sf.read(fname)
|
||||
feats = torch.from_numpy(wav).float()
|
||||
feats = self.postprocess(feats, curr_sample_rate)
|
||||
return {"id": index, "source": feats}
|
||||
@@ -0,0 +1,379 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq.data import ConcatDataset, Dictionary
|
||||
from fairseq.data import data_utils as fairseq_data_utils
|
||||
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
||||
from fairseq.data.audio.data_cfg import S2SDataConfig
|
||||
from fairseq.data.audio.speech_to_text_dataset import (
|
||||
SpeechToTextDataset,
|
||||
SpeechToTextDatasetCreator,
|
||||
TextTargetMultitaskData,
|
||||
_collate_frames,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechToSpeechDatasetItem(object):
|
||||
index: int
|
||||
source: torch.Tensor
|
||||
target: Optional[torch.Tensor] = None
|
||||
target_speaker: Optional[torch.Tensor] = None
|
||||
tgt_lang_tag: Optional[int] = None
|
||||
|
||||
|
||||
class SpeechToSpeechDataset(SpeechToTextDataset):
|
||||
def __init__(
|
||||
self,
|
||||
split: str,
|
||||
is_train_split: bool,
|
||||
data_cfg: S2SDataConfig,
|
||||
src_audio_paths: List[str],
|
||||
src_n_frames: List[int],
|
||||
tgt_audio_paths: List[str],
|
||||
tgt_n_frames: List[int],
|
||||
src_langs: Optional[List[str]] = None,
|
||||
tgt_langs: Optional[List[str]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
target_is_code: bool = False,
|
||||
tgt_dict: Dictionary = None,
|
||||
n_frames_per_step: int = 1,
|
||||
):
|
||||
tgt_texts = tgt_audio_paths if target_is_code else None
|
||||
super().__init__(
|
||||
split=split,
|
||||
is_train_split=is_train_split,
|
||||
cfg=data_cfg,
|
||||
audio_paths=src_audio_paths,
|
||||
n_frames=src_n_frames,
|
||||
ids=ids,
|
||||
tgt_dict=tgt_dict,
|
||||
tgt_texts=tgt_texts,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
)
|
||||
|
||||
self.tgt_audio_paths = tgt_audio_paths
|
||||
self.tgt_lens = [t // self.n_frames_per_step for t in tgt_n_frames]
|
||||
|
||||
assert not target_is_code or tgt_dict is not None
|
||||
self.target_is_code = target_is_code
|
||||
|
||||
assert len(tgt_audio_paths) == self.n_samples
|
||||
assert len(tgt_n_frames) == self.n_samples
|
||||
|
||||
self.tgt_speakers = None
|
||||
if self.cfg.target_speaker_embed:
|
||||
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(
|
||||
self.cfg.target_speaker_embed, split
|
||||
)
|
||||
spk_emb_dict = {s["id"]: s["speaker_embed"] for s in samples}
|
||||
self.tgt_speakers = [spk_emb_dict[id] for id in self.ids]
|
||||
assert len(self.tgt_speakers) == self.n_samples
|
||||
|
||||
logger.info(self.__repr__())
|
||||
|
||||
def pack_units(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.n_frames_per_step <= 1:
|
||||
return input
|
||||
|
||||
offset = 4
|
||||
vocab_size = (
|
||||
len(self.tgt_dict) - offset
|
||||
) # remove offset from <bos>, <pad>, <eos>, <unk>, which is specific to fairseq dictionary
|
||||
|
||||
assert input.dim() == 1
|
||||
stacked_input = (
|
||||
input[:-1].view(-1, self.n_frames_per_step) - offset
|
||||
) # remove <eos>
|
||||
scale = [
|
||||
pow(vocab_size, self.n_frames_per_step - 1 - i)
|
||||
for i in range(self.n_frames_per_step)
|
||||
]
|
||||
scale = torch.LongTensor(scale).squeeze(0)
|
||||
res = input.new((len(input) - 1) // self.n_frames_per_step + 1).fill_(input[-1])
|
||||
res[:-1] = (stacked_input * scale).sum(dim=1) + offset
|
||||
|
||||
return res
|
||||
|
||||
def __getitem__(self, index: int) -> SpeechToSpeechDatasetItem:
|
||||
source = self._get_source_audio(index)
|
||||
|
||||
tgt_lang_tag = None
|
||||
if self.cfg.prepend_tgt_lang_tag_as_bos:
|
||||
# prepend_tgt_lang_tag_as_bos: put tgt_lang_tag as bos of target
|
||||
tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
|
||||
|
||||
if not self.target_is_code:
|
||||
target = get_features_or_waveform(self.tgt_audio_paths[index])
|
||||
target = torch.from_numpy(target).float()
|
||||
target = self.pack_frames(target)
|
||||
else:
|
||||
target = self.tgt_dict.encode_line(
|
||||
self.tgt_audio_paths[index],
|
||||
add_if_not_exist=False,
|
||||
append_eos=True,
|
||||
).long()
|
||||
if self.n_frames_per_step > 1:
|
||||
n_tgt_frame = target.size(0) - 1 # exclude <eos>
|
||||
keep_n_tgt_frame = n_tgt_frame - n_tgt_frame % self.n_frames_per_step
|
||||
target = torch.cat(
|
||||
(
|
||||
target[:keep_n_tgt_frame],
|
||||
target.new_full((1,), self.tgt_dict.eos()),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if self.tgt_speakers:
|
||||
tgt_spk = get_features_or_waveform(self.tgt_speakers[index])
|
||||
tgt_spk = torch.from_numpy(tgt_spk).float()
|
||||
else:
|
||||
tgt_spk = torch.FloatTensor([])
|
||||
|
||||
return SpeechToSpeechDatasetItem(
|
||||
index=index,
|
||||
source=source,
|
||||
target=target,
|
||||
target_speaker=tgt_spk,
|
||||
tgt_lang_tag=tgt_lang_tag,
|
||||
)
|
||||
|
||||
def _collate_target(self, samples: List[SpeechToSpeechDatasetItem]) -> torch.Tensor:
|
||||
if self.target_is_code:
|
||||
target = fairseq_data_utils.collate_tokens(
|
||||
[x.target for x in samples],
|
||||
self.tgt_dict.pad(),
|
||||
self.tgt_dict.eos(),
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
)
|
||||
# convert stacked units to a single id
|
||||
pack_targets = [self.pack_units(x.target) for x in samples]
|
||||
prev_output_tokens = fairseq_data_utils.collate_tokens(
|
||||
pack_targets,
|
||||
self.tgt_dict.pad(),
|
||||
self.tgt_dict.eos(),
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=True,
|
||||
)
|
||||
target_lengths = torch.tensor(
|
||||
[x.size(0) for x in pack_targets], dtype=torch.long
|
||||
)
|
||||
else:
|
||||
target = _collate_frames([x.target for x in samples], is_audio_input=False)
|
||||
bsz, _, d = target.size()
|
||||
prev_output_tokens = torch.cat(
|
||||
(target.new_full((bsz, 1, d), 0.0), target[:, :-1, :]), dim=1
|
||||
)
|
||||
target_lengths = torch.tensor(
|
||||
[x.target.size(0) for x in samples], dtype=torch.long
|
||||
)
|
||||
|
||||
return target, prev_output_tokens, target_lengths
|
||||
|
||||
def collater(
|
||||
self, samples: List[SpeechToSpeechDatasetItem], return_order: bool = False
|
||||
) -> Dict:
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
indices = torch.tensor([x.index for x in samples], dtype=torch.long)
|
||||
frames = _collate_frames([x.source for x in samples], self.cfg.use_audio_input)
|
||||
# sort samples by descending number of frames
|
||||
n_frames = torch.tensor([x.source.size(0) for x in samples], dtype=torch.long)
|
||||
n_frames, order = n_frames.sort(descending=True)
|
||||
indices = indices.index_select(0, order)
|
||||
frames = frames.index_select(0, order)
|
||||
|
||||
target, prev_output_tokens, target_lengths = self._collate_target(samples)
|
||||
target = target.index_select(0, order)
|
||||
target_lengths = target_lengths.index_select(0, order)
|
||||
prev_output_tokens = prev_output_tokens.index_select(0, order)
|
||||
ntokens = sum(x.target.size(0) for x in samples)
|
||||
|
||||
tgt_speakers = None
|
||||
if self.cfg.target_speaker_embed:
|
||||
tgt_speakers = _collate_frames(
|
||||
[x.target_speaker for x in samples], is_audio_input=True
|
||||
).index_select(0, order)
|
||||
|
||||
net_input = {
|
||||
"src_tokens": frames,
|
||||
"src_lengths": n_frames,
|
||||
"prev_output_tokens": prev_output_tokens,
|
||||
"tgt_speaker": tgt_speakers, # TODO: unify "speaker" and "tgt_speaker"
|
||||
}
|
||||
if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
|
||||
for i in range(len(samples)):
|
||||
net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
|
||||
out = {
|
||||
"id": indices,
|
||||
"net_input": net_input,
|
||||
"speaker": tgt_speakers, # to support Tacotron2 loss for speech-to-spectrogram model
|
||||
"target": target,
|
||||
"target_lengths": target_lengths,
|
||||
"ntokens": ntokens,
|
||||
"nsentences": len(samples),
|
||||
}
|
||||
if return_order:
|
||||
out["order"] = order
|
||||
return out
|
||||
|
||||
|
||||
class SpeechToSpeechMultitaskDataset(SpeechToSpeechDataset):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.multitask_data = {}
|
||||
|
||||
def add_multitask_dataset(self, task_name, task_data):
|
||||
self.multitask_data[task_name] = task_data
|
||||
|
||||
def __getitem__(
|
||||
self, index: int
|
||||
) -> Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]:
|
||||
s2s_data = super().__getitem__(index)
|
||||
|
||||
multitask_target = {}
|
||||
sample_id = self.ids[index]
|
||||
tgt_lang = self.tgt_langs[index]
|
||||
for task_name, task_dataset in self.multitask_data.items():
|
||||
multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
|
||||
|
||||
return s2s_data, multitask_target
|
||||
|
||||
def collater(
|
||||
self, samples: List[Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]]
|
||||
) -> Dict:
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
out = super().collater([s for s, _ in samples], return_order=True)
|
||||
order = out["order"]
|
||||
del out["order"]
|
||||
|
||||
for task_name, task_dataset in self.multitask_data.items():
|
||||
if "multitask" not in out:
|
||||
out["multitask"] = {}
|
||||
d = [s[task_name] for _, s in samples]
|
||||
task_target = task_dataset.collater(d)
|
||||
out["multitask"][task_name] = {
|
||||
"target": task_target["target"].index_select(0, order),
|
||||
"target_lengths": task_target["target_lengths"].index_select(0, order),
|
||||
"ntokens": task_target["ntokens"],
|
||||
}
|
||||
out["multitask"][task_name]["net_input"] = {
|
||||
"prev_output_tokens": task_target["prev_output_tokens"].index_select(
|
||||
0, order
|
||||
),
|
||||
}
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SpeechToSpeechDatasetCreator(object):
|
||||
# mandatory columns
|
||||
KEY_ID, KEY_SRC_AUDIO, KEY_SRC_N_FRAMES = "id", "src_audio", "src_n_frames"
|
||||
KEY_TGT_AUDIO, KEY_TGT_N_FRAMES = "tgt_audio", "tgt_n_frames"
|
||||
# optional columns
|
||||
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
|
||||
# default values
|
||||
DEFAULT_LANG = ""
|
||||
|
||||
@classmethod
|
||||
def _from_list(
|
||||
cls,
|
||||
split_name: str,
|
||||
is_train_split,
|
||||
samples: List[Dict],
|
||||
data_cfg: S2SDataConfig,
|
||||
target_is_code: bool = False,
|
||||
tgt_dict: Dictionary = None,
|
||||
n_frames_per_step: int = 1,
|
||||
multitask: Optional[Dict] = None,
|
||||
) -> SpeechToSpeechDataset:
|
||||
audio_root = Path(data_cfg.audio_root)
|
||||
ids = [s[cls.KEY_ID] for s in samples]
|
||||
src_audio_paths = [
|
||||
(audio_root / s[cls.KEY_SRC_AUDIO]).as_posix() for s in samples
|
||||
]
|
||||
tgt_audio_paths = [
|
||||
s[cls.KEY_TGT_AUDIO]
|
||||
if target_is_code
|
||||
else (audio_root / s[cls.KEY_TGT_AUDIO]).as_posix()
|
||||
for s in samples
|
||||
]
|
||||
src_n_frames = [int(s[cls.KEY_SRC_N_FRAMES]) for s in samples]
|
||||
tgt_n_frames = [int(s[cls.KEY_TGT_N_FRAMES]) for s in samples]
|
||||
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
|
||||
has_multitask = multitask is not None and len(multitask.keys()) > 0
|
||||
dataset_cls = (
|
||||
SpeechToSpeechMultitaskDataset if has_multitask else SpeechToSpeechDataset
|
||||
)
|
||||
|
||||
ds = dataset_cls(
|
||||
split=split_name,
|
||||
is_train_split=is_train_split,
|
||||
data_cfg=data_cfg,
|
||||
src_audio_paths=src_audio_paths,
|
||||
src_n_frames=src_n_frames,
|
||||
tgt_audio_paths=tgt_audio_paths,
|
||||
tgt_n_frames=tgt_n_frames,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
ids=ids,
|
||||
target_is_code=target_is_code,
|
||||
tgt_dict=tgt_dict,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
)
|
||||
|
||||
if has_multitask:
|
||||
for task_name, task_obj in multitask.items():
|
||||
task_data = TextTargetMultitaskData(
|
||||
task_obj.args, split_name, task_obj.target_dictionary
|
||||
)
|
||||
ds.add_multitask_dataset(task_name, task_data)
|
||||
return ds
|
||||
|
||||
@classmethod
|
||||
def from_tsv(
|
||||
cls,
|
||||
root: str,
|
||||
data_cfg: S2SDataConfig,
|
||||
splits: str,
|
||||
is_train_split: bool,
|
||||
epoch: int,
|
||||
seed: int,
|
||||
target_is_code: bool = False,
|
||||
tgt_dict: Dictionary = None,
|
||||
n_frames_per_step: int = 1,
|
||||
multitask: Optional[Dict] = None,
|
||||
) -> SpeechToSpeechDataset:
|
||||
datasets = []
|
||||
for split in splits.split(","):
|
||||
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(root, split)
|
||||
ds = cls._from_list(
|
||||
split_name=split,
|
||||
is_train_split=is_train_split,
|
||||
samples=samples,
|
||||
data_cfg=data_cfg,
|
||||
target_is_code=target_is_code,
|
||||
tgt_dict=tgt_dict,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
multitask=multitask,
|
||||
)
|
||||
datasets.append(ds)
|
||||
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
||||
@@ -0,0 +1,733 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import csv
|
||||
import logging
|
||||
import re
|
||||
from argparse import Namespace
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset
|
||||
from fairseq.data import data_utils as fairseq_data_utils
|
||||
from fairseq.data import encoders
|
||||
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
||||
from fairseq.data.audio.data_cfg import S2TDataConfig
|
||||
from fairseq.data.audio.dataset_transforms import CompositeAudioDatasetTransform
|
||||
from fairseq.data.audio.dataset_transforms.concataugment import ConcatAugment
|
||||
from fairseq.data.audio.dataset_transforms.noisyoverlapaugment import (
|
||||
NoisyOverlapAugment,
|
||||
)
|
||||
from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform
|
||||
from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _collate_frames(
|
||||
frames: List[torch.Tensor], is_audio_input: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert a list of 2D frames into a padded 3D tensor
|
||||
Args:
|
||||
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
|
||||
length of i-th frame and f_dim is static dimension of features
|
||||
Returns:
|
||||
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
|
||||
"""
|
||||
max_len = max(frame.size(0) for frame in frames)
|
||||
if is_audio_input:
|
||||
out = frames[0].new_zeros((len(frames), max_len))
|
||||
else:
|
||||
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
|
||||
for i, v in enumerate(frames):
|
||||
out[i, : v.size(0)] = v
|
||||
return out
|
||||
|
||||
|
||||
def _is_int_or_np_int(n):
|
||||
return isinstance(n, int) or (
|
||||
isinstance(n, np.generic) and isinstance(n.item(), int)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechToTextDatasetItem(object):
|
||||
index: int
|
||||
source: torch.Tensor
|
||||
target: Optional[torch.Tensor] = None
|
||||
speaker_id: Optional[int] = None
|
||||
|
||||
|
||||
class SpeechToTextDataset(FairseqDataset):
|
||||
LANG_TAG_TEMPLATE = "<lang:{}>"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split: str,
|
||||
is_train_split: bool,
|
||||
cfg: S2TDataConfig,
|
||||
audio_paths: List[str],
|
||||
n_frames: List[int],
|
||||
src_texts: Optional[List[str]] = None,
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
speakers: Optional[List[str]] = None,
|
||||
src_langs: Optional[List[str]] = None,
|
||||
tgt_langs: Optional[List[str]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
tgt_dict: Optional[Dictionary] = None,
|
||||
pre_tokenizer=None,
|
||||
bpe_tokenizer=None,
|
||||
n_frames_per_step=1,
|
||||
speaker_to_id=None,
|
||||
append_eos=True,
|
||||
):
|
||||
self.split, self.is_train_split = split, is_train_split
|
||||
self.cfg = cfg
|
||||
self.audio_paths, self.n_frames = audio_paths, n_frames
|
||||
self.n_samples = len(audio_paths)
|
||||
assert len(n_frames) == self.n_samples > 0
|
||||
assert src_texts is None or len(src_texts) == self.n_samples
|
||||
assert tgt_texts is None or len(tgt_texts) == self.n_samples
|
||||
assert speakers is None or len(speakers) == self.n_samples
|
||||
assert src_langs is None or len(src_langs) == self.n_samples
|
||||
assert tgt_langs is None or len(tgt_langs) == self.n_samples
|
||||
assert ids is None or len(ids) == self.n_samples
|
||||
assert (tgt_dict is None and tgt_texts is None) or (
|
||||
tgt_dict is not None and tgt_texts is not None
|
||||
)
|
||||
self.src_texts, self.tgt_texts = src_texts, tgt_texts
|
||||
self.src_langs, self.tgt_langs = src_langs, tgt_langs
|
||||
self.speakers = speakers
|
||||
self.tgt_dict = tgt_dict
|
||||
self.check_tgt_lang_tag()
|
||||
self.ids = ids
|
||||
self.shuffle = cfg.shuffle if is_train_split else False
|
||||
|
||||
self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
|
||||
self.cfg.get_feature_transforms(split, is_train_split)
|
||||
)
|
||||
self.waveform_transforms = CompositeAudioWaveformTransform.from_config_dict(
|
||||
self.cfg.get_waveform_transforms(split, is_train_split)
|
||||
)
|
||||
# TODO: add these to data_cfg.py
|
||||
self.dataset_transforms = CompositeAudioDatasetTransform.from_config_dict(
|
||||
self.cfg.get_dataset_transforms(split, is_train_split)
|
||||
)
|
||||
|
||||
# check proper usage of transforms
|
||||
if self.feature_transforms and self.cfg.use_audio_input:
|
||||
logger.warning(
|
||||
"Feature transforms will not be applied. To use feature transforms, "
|
||||
"set use_audio_input as False in config."
|
||||
)
|
||||
|
||||
self.pre_tokenizer = pre_tokenizer
|
||||
self.bpe_tokenizer = bpe_tokenizer
|
||||
self.n_frames_per_step = n_frames_per_step
|
||||
self.speaker_to_id = speaker_to_id
|
||||
|
||||
self.tgt_lens = self.get_tgt_lens_and_check_oov()
|
||||
self.append_eos = append_eos
|
||||
|
||||
logger.info(self.__repr__())
|
||||
|
||||
def get_tgt_lens_and_check_oov(self):
|
||||
if self.tgt_texts is None:
|
||||
return [0 for _ in range(self.n_samples)]
|
||||
tgt_lens = []
|
||||
n_tokens, n_oov_tokens = 0, 0
|
||||
for i in range(self.n_samples):
|
||||
tokenized = self.get_tokenized_tgt_text(i).split(" ")
|
||||
oov_tokens = [
|
||||
t
|
||||
for t in tokenized
|
||||
if self.tgt_dict.index(t) == self.tgt_dict.unk_index
|
||||
]
|
||||
n_tokens += len(tokenized)
|
||||
n_oov_tokens += len(oov_tokens)
|
||||
tgt_lens.append(len(tokenized))
|
||||
logger.info(f"'{self.split}' has {n_oov_tokens / n_tokens * 100:.2f}% OOV")
|
||||
return tgt_lens
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ f'(split="{self.split}", n_samples={self.n_samples:_}, '
|
||||
f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, "
|
||||
f"n_frames_per_step={self.n_frames_per_step}, "
|
||||
f"shuffle={self.shuffle}, "
|
||||
f"feature_transforms={self.feature_transforms}, "
|
||||
f"waveform_transforms={self.waveform_transforms}, "
|
||||
f"dataset_transforms={self.dataset_transforms})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lang_tag(cls, token):
|
||||
pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
|
||||
return re.match(pattern, token)
|
||||
|
||||
def check_tgt_lang_tag(self):
|
||||
if self.cfg.prepend_tgt_lang_tag:
|
||||
assert self.tgt_langs is not None and self.tgt_dict is not None
|
||||
tgt_lang_tags = [
|
||||
self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
|
||||
]
|
||||
assert all(t in self.tgt_dict for t in tgt_lang_tags)
|
||||
|
||||
@classmethod
|
||||
def tokenize(cls, tokenizer, text: str):
|
||||
return text if tokenizer is None else tokenizer.encode(text)
|
||||
|
||||
def get_tokenized_tgt_text(self, index: Union[int, List[int]]):
|
||||
if _is_int_or_np_int(index):
|
||||
text = self.tgt_texts[index]
|
||||
else:
|
||||
text = " ".join([self.tgt_texts[i] for i in index])
|
||||
|
||||
text = self.tokenize(self.pre_tokenizer, text)
|
||||
text = self.tokenize(self.bpe_tokenizer, text)
|
||||
return text
|
||||
|
||||
def pack_frames(self, feature: torch.Tensor):
|
||||
if self.n_frames_per_step == 1:
|
||||
return feature
|
||||
n_packed_frames = feature.shape[0] // self.n_frames_per_step
|
||||
feature = feature[: self.n_frames_per_step * n_packed_frames]
|
||||
return feature.reshape(n_packed_frames, -1)
|
||||
|
||||
@classmethod
|
||||
def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary):
|
||||
lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang))
|
||||
assert lang_tag_idx != dictionary.unk()
|
||||
return lang_tag_idx
|
||||
|
||||
def _get_source_audio(self, index: Union[int, List[int]]) -> torch.Tensor:
|
||||
"""
|
||||
Gives source audio for given index with any relevant transforms
|
||||
applied. For ConcatAug, source audios for given indices are
|
||||
concatenated in given order.
|
||||
Args:
|
||||
index (int or List[int]): index—or in the case of ConcatAug,
|
||||
indices—to pull the source audio for
|
||||
Returns:
|
||||
source audios concatenated for given indices with
|
||||
relevant transforms appplied
|
||||
"""
|
||||
if _is_int_or_np_int(index):
|
||||
source = get_features_or_waveform(
|
||||
self.audio_paths[index],
|
||||
need_waveform=self.cfg.use_audio_input,
|
||||
use_sample_rate=self.cfg.use_sample_rate,
|
||||
waveform_transforms=self.waveform_transforms,
|
||||
)
|
||||
else:
|
||||
source = np.concatenate(
|
||||
[
|
||||
get_features_or_waveform(
|
||||
self.audio_paths[i],
|
||||
need_waveform=self.cfg.use_audio_input,
|
||||
use_sample_rate=self.cfg.use_sample_rate,
|
||||
waveform_transforms=self.waveform_transforms,
|
||||
)
|
||||
for i in index
|
||||
]
|
||||
)
|
||||
if self.cfg.use_audio_input:
|
||||
source = torch.from_numpy(source).float()
|
||||
if self.cfg.standardize_audio:
|
||||
with torch.no_grad():
|
||||
source = F.layer_norm(source, source.shape)
|
||||
else:
|
||||
if self.feature_transforms is not None:
|
||||
source = self.feature_transforms(source)
|
||||
source = torch.from_numpy(source).float()
|
||||
return source
|
||||
|
||||
def __getitem__(self, index: int) -> SpeechToTextDatasetItem:
|
||||
has_concat = self.dataset_transforms.has_transform(ConcatAugment)
|
||||
if has_concat:
|
||||
concat = self.dataset_transforms.get_transform(ConcatAugment)
|
||||
indices = concat.find_indices(index, self.n_frames, self.n_samples)
|
||||
|
||||
source = self._get_source_audio(indices if has_concat else index)
|
||||
source = self.pack_frames(source)
|
||||
|
||||
target = None
|
||||
if self.tgt_texts is not None:
|
||||
tokenized = self.get_tokenized_tgt_text(indices if has_concat else index)
|
||||
target = self.tgt_dict.encode_line(
|
||||
tokenized, add_if_not_exist=False, append_eos=self.append_eos
|
||||
).long()
|
||||
if self.cfg.prepend_tgt_lang_tag:
|
||||
lang_tag_idx = self.get_lang_tag_idx(
|
||||
self.tgt_langs[index], self.tgt_dict
|
||||
)
|
||||
target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
|
||||
|
||||
if self.cfg.prepend_bos_and_append_tgt_lang_tag:
|
||||
bos = torch.LongTensor([self.tgt_dict.bos()])
|
||||
lang_tag_idx = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
|
||||
assert lang_tag_idx != self.tgt_dict.unk()
|
||||
lang_tag_idx = torch.LongTensor([lang_tag_idx])
|
||||
target = torch.cat((bos, target, lang_tag_idx), 0)
|
||||
|
||||
speaker_id = None
|
||||
if self.speaker_to_id is not None:
|
||||
speaker_id = self.speaker_to_id[self.speakers[index]]
|
||||
return SpeechToTextDatasetItem(
|
||||
index=index, source=source, target=target, speaker_id=speaker_id
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.n_samples
|
||||
|
||||
def collater(
|
||||
self, samples: List[SpeechToTextDatasetItem], return_order: bool = False
|
||||
) -> Dict:
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
indices = torch.tensor([x.index for x in samples], dtype=torch.long)
|
||||
|
||||
sources = [x.source for x in samples]
|
||||
has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment)
|
||||
if has_NOAug and self.cfg.use_audio_input:
|
||||
NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment)
|
||||
sources = NOAug(sources)
|
||||
|
||||
frames = _collate_frames(sources, self.cfg.use_audio_input)
|
||||
# sort samples by descending number of frames
|
||||
n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long)
|
||||
n_frames, order = n_frames.sort(descending=True)
|
||||
indices = indices.index_select(0, order)
|
||||
frames = frames.index_select(0, order)
|
||||
|
||||
target, target_lengths = None, None
|
||||
prev_output_tokens = None
|
||||
ntokens = None
|
||||
if self.tgt_texts is not None:
|
||||
target = fairseq_data_utils.collate_tokens(
|
||||
[x.target for x in samples],
|
||||
self.tgt_dict.pad(),
|
||||
self.tgt_dict.eos(),
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
)
|
||||
target = target.index_select(0, order)
|
||||
target_lengths = torch.tensor(
|
||||
[x.target.size(0) for x in samples], dtype=torch.long
|
||||
).index_select(0, order)
|
||||
prev_output_tokens = fairseq_data_utils.collate_tokens(
|
||||
[x.target for x in samples],
|
||||
self.tgt_dict.pad(),
|
||||
eos_idx=None,
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=True,
|
||||
)
|
||||
prev_output_tokens = prev_output_tokens.index_select(0, order)
|
||||
ntokens = sum(x.target.size(0) for x in samples)
|
||||
|
||||
speaker = None
|
||||
if self.speaker_to_id is not None:
|
||||
speaker = (
|
||||
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
|
||||
.index_select(0, order)
|
||||
.view(-1, 1)
|
||||
)
|
||||
|
||||
net_input = {
|
||||
"src_tokens": frames,
|
||||
"src_lengths": n_frames,
|
||||
"prev_output_tokens": prev_output_tokens,
|
||||
}
|
||||
out = {
|
||||
"id": indices,
|
||||
"net_input": net_input,
|
||||
"speaker": speaker,
|
||||
"target": target,
|
||||
"target_lengths": target_lengths,
|
||||
"ntokens": ntokens,
|
||||
"nsentences": len(samples),
|
||||
}
|
||||
if return_order:
|
||||
out["order"] = order
|
||||
return out
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.n_frames[index]
|
||||
|
||||
def size(self, index):
|
||||
return self.n_frames[index], self.tgt_lens[index]
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return np.array(self.n_frames)
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return True
|
||||
|
||||
def ordered_indices(self):
|
||||
if self.shuffle:
|
||||
order = [np.random.permutation(len(self))]
|
||||
else:
|
||||
order = [np.arange(len(self))]
|
||||
# first by descending order of # of frames then by original/random order
|
||||
order.append([-n for n in self.n_frames])
|
||||
return np.lexsort(order)
|
||||
|
||||
def prefetch(self, indices):
|
||||
raise False
|
||||
|
||||
|
||||
class TextTargetMultitaskData(object):
|
||||
# mandatory columns
|
||||
KEY_ID, KEY_TEXT = "id", "tgt_text"
|
||||
LANG_TAG_TEMPLATE = "<lang:{}>"
|
||||
|
||||
def __init__(self, args, split, tgt_dict):
|
||||
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split)
|
||||
self.data = {s[self.KEY_ID]: s[self.KEY_TEXT] for s in samples}
|
||||
self.dict = tgt_dict
|
||||
self.append_eos = args.decoder_type != "ctc"
|
||||
self.pre_tokenizer = self.build_tokenizer(args)
|
||||
self.bpe_tokenizer = self.build_bpe(args)
|
||||
self.prepend_bos_and_append_tgt_lang_tag = (
|
||||
args.prepend_bos_and_append_tgt_lang_tag
|
||||
)
|
||||
self.eos_token = args.eos_token
|
||||
self.lang_tag_mapping = args.get_lang_tag_mapping
|
||||
|
||||
@classmethod
|
||||
def is_lang_tag(cls, token):
|
||||
pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
|
||||
return re.match(pattern, token)
|
||||
|
||||
@classmethod
|
||||
def tokenize(cls, tokenizer, text: str):
|
||||
return text if tokenizer is None else tokenizer.encode(text)
|
||||
|
||||
def get_tokenized_tgt_text(self, index: int):
|
||||
text = self.tokenize(self.pre_tokenizer, self.data[index])
|
||||
text = self.tokenize(self.bpe_tokenizer, text)
|
||||
return text
|
||||
|
||||
def get_lang_tag_idx(self, lang: str, dictionary: Dictionary):
|
||||
lang_tag = self.LANG_TAG_TEMPLATE.format(lang)
|
||||
lang_tag = self.lang_tag_mapping.get(lang_tag, lang_tag)
|
||||
lang_tag_idx = dictionary.index(lang_tag)
|
||||
assert lang_tag_idx != dictionary.unk(), (lang, lang_tag)
|
||||
return lang_tag_idx
|
||||
|
||||
def build_tokenizer(self, args):
|
||||
pre_tokenizer = args.config.get("pre_tokenizer")
|
||||
if pre_tokenizer is not None:
|
||||
logger.info(f"pre-tokenizer: {pre_tokenizer}")
|
||||
return encoders.build_tokenizer(Namespace(**pre_tokenizer))
|
||||
else:
|
||||
return None
|
||||
|
||||
def build_bpe(self, args):
|
||||
bpe_tokenizer = args.config.get("bpe_tokenizer")
|
||||
if bpe_tokenizer is not None:
|
||||
logger.info(f"tokenizer: {bpe_tokenizer}")
|
||||
return encoders.build_bpe(Namespace(**bpe_tokenizer))
|
||||
else:
|
||||
return None
|
||||
|
||||
def get(self, sample_id, tgt_lang=None):
|
||||
if sample_id in self.data:
|
||||
tokenized = self.get_tokenized_tgt_text(sample_id)
|
||||
target = self.dict.encode_line(
|
||||
tokenized,
|
||||
add_if_not_exist=False,
|
||||
append_eos=self.append_eos,
|
||||
)
|
||||
if self.prepend_bos_and_append_tgt_lang_tag:
|
||||
bos = torch.LongTensor([self.dict.bos()])
|
||||
lang_tag_idx = self.get_lang_tag_idx(tgt_lang, self.dict)
|
||||
assert lang_tag_idx != self.dict.unk()
|
||||
lang_tag_idx = torch.LongTensor([lang_tag_idx])
|
||||
target = torch.cat((bos, target, lang_tag_idx), 0)
|
||||
return target
|
||||
else:
|
||||
logger.warning(f"no target for {sample_id}")
|
||||
return torch.IntTensor([])
|
||||
|
||||
def collater(self, samples: List[torch.Tensor]) -> torch.Tensor:
|
||||
out = fairseq_data_utils.collate_tokens(
|
||||
samples,
|
||||
self.dict.pad(),
|
||||
eos_idx=None,
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
).long()
|
||||
|
||||
prev_out = fairseq_data_utils.collate_tokens(
|
||||
samples,
|
||||
self.dict.pad(),
|
||||
eos_idx=None,
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=True,
|
||||
).long()
|
||||
|
||||
target_lengths = torch.tensor([t.size(0) for t in samples], dtype=torch.long)
|
||||
ntokens = sum(t.size(0) for t in samples)
|
||||
|
||||
output = {
|
||||
"prev_output_tokens": prev_out,
|
||||
"target": out,
|
||||
"target_lengths": target_lengths,
|
||||
"ntokens": ntokens,
|
||||
}
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class SpeechToTextMultitaskDataset(SpeechToTextDataset):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.multitask_data = {}
|
||||
|
||||
def add_multitask_dataset(self, task_name, task_data):
|
||||
self.multitask_data[task_name] = task_data
|
||||
|
||||
def __getitem__(
|
||||
self, index: int
|
||||
) -> Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]:
|
||||
s2t_data = super().__getitem__(index)
|
||||
|
||||
multitask_target = {}
|
||||
sample_id = self.ids[index]
|
||||
tgt_lang = self.tgt_langs[index]
|
||||
for task_name, task_dataset in self.multitask_data.items():
|
||||
multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
|
||||
|
||||
return s2t_data, multitask_target
|
||||
|
||||
def collater(
|
||||
self, samples: List[Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]]
|
||||
) -> Dict:
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
out = super().collater([s for s, _ in samples], return_order=True)
|
||||
order = out["order"]
|
||||
del out["order"]
|
||||
|
||||
for task_name, task_dataset in self.multitask_data.items():
|
||||
if "multitask" not in out:
|
||||
out["multitask"] = {}
|
||||
d = [s[task_name] for _, s in samples]
|
||||
task_target = task_dataset.collater(d)
|
||||
out["multitask"][task_name] = {
|
||||
"target": task_target["target"].index_select(0, order),
|
||||
"target_lengths": task_target["target_lengths"].index_select(0, order),
|
||||
"ntokens": task_target["ntokens"],
|
||||
}
|
||||
out["multitask"][task_name]["net_input"] = {
|
||||
"prev_output_tokens": task_target["prev_output_tokens"].index_select(
|
||||
0, order
|
||||
),
|
||||
}
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SpeechToTextDatasetCreator(object):
|
||||
# mandatory columns
|
||||
KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames"
|
||||
KEY_TGT_TEXT = "tgt_text"
|
||||
# optional columns
|
||||
KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text"
|
||||
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
|
||||
# default values
|
||||
DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = ""
|
||||
|
||||
@classmethod
|
||||
def _from_list(
|
||||
cls,
|
||||
split_name: str,
|
||||
is_train_split,
|
||||
samples: List[Dict],
|
||||
cfg: S2TDataConfig,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
n_frames_per_step,
|
||||
speaker_to_id,
|
||||
multitask: Optional[Dict] = None,
|
||||
) -> SpeechToTextDataset:
|
||||
audio_root = Path(cfg.audio_root)
|
||||
ids = [s[cls.KEY_ID] for s in samples]
|
||||
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
|
||||
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
|
||||
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
|
||||
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
|
||||
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
|
||||
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
|
||||
has_multitask = multitask is not None and len(multitask.keys()) > 0
|
||||
dataset_cls = (
|
||||
SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset
|
||||
)
|
||||
|
||||
ds = dataset_cls(
|
||||
split=split_name,
|
||||
is_train_split=is_train_split,
|
||||
cfg=cfg,
|
||||
audio_paths=audio_paths,
|
||||
n_frames=n_frames,
|
||||
src_texts=src_texts,
|
||||
tgt_texts=tgt_texts,
|
||||
speakers=speakers,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
ids=ids,
|
||||
tgt_dict=tgt_dict,
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
speaker_to_id=speaker_to_id,
|
||||
)
|
||||
|
||||
if has_multitask:
|
||||
for task_name, task_obj in multitask.items():
|
||||
task_data = TextTargetMultitaskData(
|
||||
task_obj.args, split_name, task_obj.target_dictionary
|
||||
)
|
||||
ds.add_multitask_dataset(task_name, task_data)
|
||||
return ds
|
||||
|
||||
@classmethod
|
||||
def get_size_ratios(
|
||||
cls, datasets: List[SpeechToTextDataset], alpha: float = 1.0
|
||||
) -> List[float]:
|
||||
"""Size ratios for temperature-based sampling
|
||||
(https://arxiv.org/abs/1907.05019)"""
|
||||
|
||||
id_to_lp, lp_to_sz = {}, defaultdict(int)
|
||||
for ds in datasets:
|
||||
lang_pairs = {f"{s}->{t}" for s, t in zip(ds.src_langs, ds.tgt_langs)}
|
||||
assert len(lang_pairs) == 1
|
||||
lang_pair = list(lang_pairs)[0]
|
||||
id_to_lp[ds.split] = lang_pair
|
||||
lp_to_sz[lang_pair] += sum(ds.n_frames)
|
||||
|
||||
sz_sum = sum(v for v in lp_to_sz.values())
|
||||
lp_to_prob = {k: v / sz_sum for k, v in lp_to_sz.items()}
|
||||
lp_to_tgt_prob = {k: v**alpha for k, v in lp_to_prob.items()}
|
||||
prob_sum = sum(v for v in lp_to_tgt_prob.values())
|
||||
lp_to_tgt_prob = {k: v / prob_sum for k, v in lp_to_tgt_prob.items()}
|
||||
lp_to_sz_ratio = {
|
||||
k: (lp_to_tgt_prob[k] * sz_sum) / v for k, v in lp_to_sz.items()
|
||||
}
|
||||
size_ratio = [lp_to_sz_ratio[id_to_lp[ds.split]] for ds in datasets]
|
||||
|
||||
p_formatted = {
|
||||
k: f"{lp_to_prob[k]:.3f}->{lp_to_tgt_prob[k]:.3f}" for k in lp_to_sz
|
||||
}
|
||||
logger.info(f"sampling probability balancing: {p_formatted}")
|
||||
sr_formatted = {ds.split: f"{r:.3f}" for ds, r in zip(datasets, size_ratio)}
|
||||
logger.info(f"balanced sampling size ratio: {sr_formatted}")
|
||||
return size_ratio
|
||||
|
||||
@classmethod
|
||||
def _load_samples_from_tsv(cls, root: str, split: str):
|
||||
tsv_path = Path(root) / f"{split}.tsv"
|
||||
if not tsv_path.is_file():
|
||||
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
|
||||
with open(tsv_path) as f:
|
||||
reader = csv.DictReader(
|
||||
f,
|
||||
delimiter="\t",
|
||||
quotechar=None,
|
||||
doublequote=False,
|
||||
lineterminator="\n",
|
||||
quoting=csv.QUOTE_NONE,
|
||||
)
|
||||
samples = [dict(e) for e in reader]
|
||||
if len(samples) == 0:
|
||||
raise ValueError(f"Empty manifest: {tsv_path}")
|
||||
return samples
|
||||
|
||||
@classmethod
|
||||
def _from_tsv(
|
||||
cls,
|
||||
root: str,
|
||||
cfg: S2TDataConfig,
|
||||
split: str,
|
||||
tgt_dict,
|
||||
is_train_split: bool,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
n_frames_per_step,
|
||||
speaker_to_id,
|
||||
multitask: Optional[Dict] = None,
|
||||
) -> SpeechToTextDataset:
|
||||
samples = cls._load_samples_from_tsv(root, split)
|
||||
return cls._from_list(
|
||||
split,
|
||||
is_train_split,
|
||||
samples,
|
||||
cfg,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
n_frames_per_step,
|
||||
speaker_to_id,
|
||||
multitask,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tsv(
|
||||
cls,
|
||||
root: str,
|
||||
cfg: S2TDataConfig,
|
||||
splits: str,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
is_train_split: bool,
|
||||
epoch: int,
|
||||
seed: int,
|
||||
n_frames_per_step: int = 1,
|
||||
speaker_to_id=None,
|
||||
multitask: Optional[Dict] = None,
|
||||
) -> SpeechToTextDataset:
|
||||
datasets = [
|
||||
cls._from_tsv(
|
||||
root=root,
|
||||
cfg=cfg,
|
||||
split=split,
|
||||
tgt_dict=tgt_dict,
|
||||
is_train_split=is_train_split,
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
speaker_to_id=speaker_to_id,
|
||||
multitask=multitask,
|
||||
)
|
||||
for split in splits.split(",")
|
||||
]
|
||||
|
||||
if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
|
||||
# temperature-based sampling
|
||||
size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
|
||||
datasets = [
|
||||
ResamplingDataset(
|
||||
d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
|
||||
)
|
||||
for r, d in zip(size_ratios, datasets)
|
||||
]
|
||||
|
||||
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
||||
@@ -0,0 +1,359 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq.data import ConcatDataset, Dictionary, ResamplingDataset
|
||||
from fairseq.data import data_utils as fairseq_data_utils
|
||||
from fairseq.data.audio.speech_to_text_dataset import (
|
||||
S2TDataConfig,
|
||||
SpeechToTextDataset,
|
||||
SpeechToTextDatasetCreator,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class S2TJointDataConfig(S2TDataConfig):
|
||||
"""Wrapper class for data config YAML"""
|
||||
|
||||
@property
|
||||
def src_vocab_filename(self):
|
||||
"""fairseq vocabulary file under data root"""
|
||||
return self.config.get("src_vocab_filename", "src_dict.txt")
|
||||
|
||||
@property
|
||||
def src_pre_tokenizer(self) -> Dict:
|
||||
"""Pre-tokenizer to apply before subword tokenization. Returning
|
||||
a dictionary with `tokenizer` providing the tokenizer name and
|
||||
the other items providing the tokenizer-specific arguments.
|
||||
Tokenizers are defined in `fairseq.data.encoders.*`"""
|
||||
return self.config.get("src_pre_tokenizer", {"tokenizer": None})
|
||||
|
||||
@property
|
||||
def src_bpe_tokenizer(self) -> Dict:
|
||||
"""Subword tokenizer to apply on source text after pre-tokenization.
|
||||
Returning a dictionary with `bpe` providing the tokenizer name and
|
||||
the other items providing the tokenizer-specific arguments.
|
||||
Tokenizers are defined in `fairseq.data.encoders.*`"""
|
||||
return self.config.get("src_bpe_tokenizer", {"bpe": None})
|
||||
|
||||
@property
|
||||
def prepend_tgt_lang_tag_no_change(self) -> bool:
|
||||
"""Prepend target lang ID token as the prev_output_tokens BOS (e.g. for
|
||||
to-many multilingual setting). No change needed during inference.
|
||||
This option is deprecated and replaced by prepend_tgt_lang_tag_as_bos.
|
||||
"""
|
||||
value = self.config.get("prepend_tgt_lang_tag_no_change", None)
|
||||
if value is None:
|
||||
return self.config.get("prepend_tgt_lang_tag_as_bos", False)
|
||||
return value
|
||||
|
||||
@property
|
||||
def sampling_text_alpha(self):
|
||||
"""Hyper-parameter alpha = 1/T for temperature-based resampling. (text
|
||||
input only) (alpha = 1 for no resampling)"""
|
||||
return self.config.get("sampling_text_alpha", 1.0)
|
||||
|
||||
|
||||
class SpeechToTextJointDatasetItem(NamedTuple):
|
||||
index: int
|
||||
source: torch.Tensor
|
||||
target: Optional[torch.Tensor] = None
|
||||
src_txt_tokens: Optional[torch.Tensor] = None
|
||||
tgt_lang_tag: Optional[int] = None
|
||||
src_lang_tag: Optional[int] = None
|
||||
tgt_alignment: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
# use_src_lang_id:
|
||||
# 0: don't use src_lang_id
|
||||
# 1: attach src_lang_id to the src_txt_tokens as eos
|
||||
class SpeechToTextJointDataset(SpeechToTextDataset):
|
||||
def __init__(
|
||||
self,
|
||||
split: str,
|
||||
is_train_split: bool,
|
||||
cfg: S2TJointDataConfig,
|
||||
audio_paths: List[str],
|
||||
n_frames: List[int],
|
||||
src_texts: Optional[List[str]] = None,
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
speakers: Optional[List[str]] = None,
|
||||
src_langs: Optional[List[str]] = None,
|
||||
tgt_langs: Optional[List[str]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
tgt_dict: Optional[Dictionary] = None,
|
||||
src_dict: Optional[Dictionary] = None,
|
||||
pre_tokenizer=None,
|
||||
bpe_tokenizer=None,
|
||||
src_pre_tokenizer=None,
|
||||
src_bpe_tokenizer=None,
|
||||
append_eos: Optional[bool] = True,
|
||||
alignment: Optional[List[str]] = None,
|
||||
use_src_lang_id: Optional[int] = 0,
|
||||
):
|
||||
super().__init__(
|
||||
split,
|
||||
is_train_split,
|
||||
cfg,
|
||||
audio_paths,
|
||||
n_frames,
|
||||
src_texts=src_texts,
|
||||
tgt_texts=tgt_texts,
|
||||
speakers=speakers,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
ids=ids,
|
||||
tgt_dict=tgt_dict,
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
append_eos=append_eos,
|
||||
)
|
||||
|
||||
self.src_dict = src_dict
|
||||
self.src_pre_tokenizer = src_pre_tokenizer
|
||||
self.src_bpe_tokenizer = src_bpe_tokenizer
|
||||
self.alignment = None
|
||||
self.use_src_lang_id = use_src_lang_id
|
||||
if alignment is not None:
|
||||
self.alignment = [
|
||||
[float(s) for s in sample.split()] for sample in alignment
|
||||
]
|
||||
|
||||
def get_tokenized_src_text(self, index: int):
|
||||
text = self.tokenize(self.src_pre_tokenizer, self.src_texts[index])
|
||||
text = self.tokenize(self.src_bpe_tokenizer, text)
|
||||
return text
|
||||
|
||||
def __getitem__(self, index: int) -> SpeechToTextJointDatasetItem:
|
||||
s2t_dataset_item = super().__getitem__(index)
|
||||
src_tokens = None
|
||||
src_lang_tag = None
|
||||
if self.src_texts is not None and self.src_dict is not None:
|
||||
src_tokens = self.get_tokenized_src_text(index)
|
||||
src_tokens = self.src_dict.encode_line(
|
||||
src_tokens, add_if_not_exist=False, append_eos=True
|
||||
).long()
|
||||
if self.use_src_lang_id > 0:
|
||||
src_lang_tag = self.get_lang_tag_idx(
|
||||
self.src_langs[index], self.src_dict
|
||||
)
|
||||
tgt_lang_tag = None
|
||||
if self.cfg.prepend_tgt_lang_tag_no_change:
|
||||
# prepend_tgt_lang_tag_no_change: modify prev_output_tokens instead
|
||||
tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
|
||||
ali = None
|
||||
if self.alignment is not None:
|
||||
ali = torch.Tensor(self.alignment[index]).float()
|
||||
|
||||
return SpeechToTextJointDatasetItem(
|
||||
index=index,
|
||||
source=s2t_dataset_item.source,
|
||||
target=s2t_dataset_item.target,
|
||||
src_txt_tokens=src_tokens,
|
||||
tgt_lang_tag=tgt_lang_tag,
|
||||
src_lang_tag=src_lang_tag,
|
||||
tgt_alignment=ali,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.n_samples
|
||||
|
||||
def collater(self, samples: List[SpeechToTextJointDatasetItem]) -> Dict:
|
||||
s2t_out = super().collater(samples, return_order=True)
|
||||
if s2t_out == {}:
|
||||
return s2t_out
|
||||
net_input, order = s2t_out["net_input"], s2t_out["order"]
|
||||
|
||||
if self.src_texts is not None and self.src_dict is not None:
|
||||
src_txt_tokens = fairseq_data_utils.collate_tokens(
|
||||
[x.src_txt_tokens for x in samples],
|
||||
self.src_dict.pad(),
|
||||
self.src_dict.eos(),
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
)
|
||||
src_txt_lengths = torch.tensor(
|
||||
[x.src_txt_tokens.size()[0] for x in samples], dtype=torch.long
|
||||
)
|
||||
if self.use_src_lang_id > 0:
|
||||
src_lang_idxs = torch.tensor(
|
||||
[s.src_lang_tag for s in samples], dtype=src_txt_tokens.dtype
|
||||
)
|
||||
if self.use_src_lang_id == 1: # replace eos with lang_id
|
||||
eos_idx = src_txt_lengths - 1
|
||||
src_txt_tokens.scatter_(
|
||||
1, eos_idx.view(-1, 1), src_lang_idxs.view(-1, 1)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Implementation is required")
|
||||
|
||||
src_txt_tokens = src_txt_tokens.index_select(0, order)
|
||||
src_txt_lengths = src_txt_lengths.index_select(0, order)
|
||||
net_input["src_txt_tokens"] = src_txt_tokens
|
||||
net_input["src_txt_lengths"] = src_txt_lengths
|
||||
|
||||
net_input["alignment"] = None
|
||||
if self.alignment is not None:
|
||||
max_len = max([s.tgt_alignment.size(0) for s in samples])
|
||||
alignment = torch.ones(len(samples), max_len).float()
|
||||
for i, s in enumerate(samples):
|
||||
cur_len = s.tgt_alignment.size(0)
|
||||
alignment[i][:cur_len].copy_(s.tgt_alignment)
|
||||
net_input["alignment"] = alignment.index_select(0, order)
|
||||
|
||||
if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
|
||||
for i in range(len(samples)):
|
||||
net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
|
||||
|
||||
out = {
|
||||
"id": s2t_out["id"],
|
||||
"net_input": net_input,
|
||||
"target": s2t_out["target"],
|
||||
"target_lengths": s2t_out["target_lengths"],
|
||||
"ntokens": s2t_out["ntokens"],
|
||||
"nsentences": len(samples),
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
class SpeechToTextJointDatasetCreator(SpeechToTextDatasetCreator):
|
||||
KEY_ALIGN = "align"
|
||||
|
||||
@classmethod
|
||||
def _from_list(
|
||||
cls,
|
||||
split_name: str,
|
||||
is_train_split,
|
||||
samples: List[Dict],
|
||||
cfg: S2TJointDataConfig,
|
||||
tgt_dict,
|
||||
src_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
src_pre_tokenizer,
|
||||
src_bpe_tokenizer,
|
||||
append_eos,
|
||||
use_src_lang_id,
|
||||
) -> SpeechToTextJointDataset:
|
||||
audio_root = Path(cfg.audio_root)
|
||||
ids = [s[cls.KEY_ID] for s in samples]
|
||||
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
|
||||
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
|
||||
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
|
||||
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
|
||||
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
|
||||
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
tgt_alignment = None
|
||||
if cls.KEY_ALIGN in samples[0].keys():
|
||||
tgt_alignment = [s[cls.KEY_ALIGN] for s in samples]
|
||||
return SpeechToTextJointDataset(
|
||||
split_name,
|
||||
is_train_split,
|
||||
cfg,
|
||||
audio_paths,
|
||||
n_frames,
|
||||
src_texts=src_texts,
|
||||
tgt_texts=tgt_texts,
|
||||
speakers=speakers,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
ids=ids,
|
||||
tgt_dict=tgt_dict,
|
||||
src_dict=src_dict,
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
src_pre_tokenizer=src_pre_tokenizer,
|
||||
src_bpe_tokenizer=src_bpe_tokenizer,
|
||||
append_eos=append_eos,
|
||||
alignment=tgt_alignment,
|
||||
use_src_lang_id=use_src_lang_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_tsv(
|
||||
cls,
|
||||
root: str,
|
||||
cfg: S2TJointDataConfig,
|
||||
split: str,
|
||||
tgt_dict,
|
||||
src_dict,
|
||||
is_train_split: bool,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
src_pre_tokenizer,
|
||||
src_bpe_tokenizer,
|
||||
append_eos: bool,
|
||||
use_src_lang_id: int,
|
||||
) -> SpeechToTextJointDataset:
|
||||
samples = cls._load_samples_from_tsv(root, split)
|
||||
return cls._from_list(
|
||||
split,
|
||||
is_train_split,
|
||||
samples,
|
||||
cfg,
|
||||
tgt_dict,
|
||||
src_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
src_pre_tokenizer,
|
||||
src_bpe_tokenizer,
|
||||
append_eos,
|
||||
use_src_lang_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tsv(
|
||||
cls,
|
||||
root: str,
|
||||
cfg: S2TJointDataConfig,
|
||||
splits: str,
|
||||
tgt_dict,
|
||||
src_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
src_pre_tokenizer,
|
||||
src_bpe_tokenizer,
|
||||
is_train_split: bool,
|
||||
epoch: int,
|
||||
seed: int,
|
||||
append_eos: Optional[bool] = True,
|
||||
use_src_lang_id: Optional[int] = 0,
|
||||
) -> SpeechToTextJointDataset:
|
||||
datasets = [
|
||||
cls._from_tsv(
|
||||
root,
|
||||
cfg,
|
||||
split,
|
||||
tgt_dict,
|
||||
src_dict,
|
||||
is_train_split,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
src_pre_tokenizer,
|
||||
src_bpe_tokenizer,
|
||||
append_eos=append_eos,
|
||||
use_src_lang_id=use_src_lang_id,
|
||||
)
|
||||
for split in splits.split(",")
|
||||
]
|
||||
|
||||
if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
|
||||
# temperature-based sampling
|
||||
size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
|
||||
datasets = [
|
||||
ResamplingDataset(
|
||||
d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
|
||||
)
|
||||
for r, d in zip(size_ratios, datasets)
|
||||
]
|
||||
|
||||
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
||||
@@ -0,0 +1,250 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.abs
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.data import data_utils as fairseq_data_utils
|
||||
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
||||
from fairseq.data.audio.speech_to_text_dataset import (
|
||||
S2TDataConfig,
|
||||
SpeechToTextDataset,
|
||||
SpeechToTextDatasetCreator,
|
||||
_collate_frames,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextToSpeechDatasetItem(object):
|
||||
index: int
|
||||
source: torch.Tensor
|
||||
target: Optional[torch.Tensor] = None
|
||||
speaker_id: Optional[int] = None
|
||||
duration: Optional[torch.Tensor] = None
|
||||
pitch: Optional[torch.Tensor] = None
|
||||
energy: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class TextToSpeechDataset(SpeechToTextDataset):
|
||||
def __init__(
|
||||
self,
|
||||
split: str,
|
||||
is_train_split: bool,
|
||||
cfg: S2TDataConfig,
|
||||
audio_paths: List[str],
|
||||
n_frames: List[int],
|
||||
src_texts: Optional[List[str]] = None,
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
speakers: Optional[List[str]] = None,
|
||||
src_langs: Optional[List[str]] = None,
|
||||
tgt_langs: Optional[List[str]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
tgt_dict: Optional[Dictionary] = None,
|
||||
pre_tokenizer=None,
|
||||
bpe_tokenizer=None,
|
||||
n_frames_per_step=1,
|
||||
speaker_to_id=None,
|
||||
durations: Optional[List[List[int]]] = None,
|
||||
pitches: Optional[List[str]] = None,
|
||||
energies: Optional[List[str]] = None,
|
||||
):
|
||||
super(TextToSpeechDataset, self).__init__(
|
||||
split,
|
||||
is_train_split,
|
||||
cfg,
|
||||
audio_paths,
|
||||
n_frames,
|
||||
src_texts=src_texts,
|
||||
tgt_texts=tgt_texts,
|
||||
speakers=speakers,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
ids=ids,
|
||||
tgt_dict=tgt_dict,
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
speaker_to_id=speaker_to_id,
|
||||
)
|
||||
self.durations = durations
|
||||
self.pitches = pitches
|
||||
self.energies = energies
|
||||
|
||||
def __getitem__(self, index: int) -> TextToSpeechDatasetItem:
|
||||
s2t_item = super().__getitem__(index)
|
||||
|
||||
duration, pitch, energy = None, None, None
|
||||
if self.durations is not None:
|
||||
duration = torch.tensor(
|
||||
self.durations[index] + [0], dtype=torch.long # pad 0 for EOS
|
||||
)
|
||||
if self.pitches is not None:
|
||||
pitch = get_features_or_waveform(self.pitches[index])
|
||||
pitch = torch.from_numpy(
|
||||
np.concatenate((pitch, [0])) # pad 0 for EOS
|
||||
).float()
|
||||
if self.energies is not None:
|
||||
energy = get_features_or_waveform(self.energies[index])
|
||||
energy = torch.from_numpy(
|
||||
np.concatenate((energy, [0])) # pad 0 for EOS
|
||||
).float()
|
||||
return TextToSpeechDatasetItem(
|
||||
index=index,
|
||||
source=s2t_item.source,
|
||||
target=s2t_item.target,
|
||||
speaker_id=s2t_item.speaker_id,
|
||||
duration=duration,
|
||||
pitch=pitch,
|
||||
energy=energy,
|
||||
)
|
||||
|
||||
def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]:
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
src_lengths, order = torch.tensor(
|
||||
[s.target.shape[0] for s in samples], dtype=torch.long
|
||||
).sort(descending=True)
|
||||
id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select(
|
||||
0, order
|
||||
)
|
||||
feat = _collate_frames(
|
||||
[s.source for s in samples], self.cfg.use_audio_input
|
||||
).index_select(0, order)
|
||||
target_lengths = torch.tensor(
|
||||
[s.source.shape[0] for s in samples], dtype=torch.long
|
||||
).index_select(0, order)
|
||||
|
||||
src_tokens = fairseq_data_utils.collate_tokens(
|
||||
[s.target for s in samples],
|
||||
self.tgt_dict.pad(),
|
||||
self.tgt_dict.eos(),
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
).index_select(0, order)
|
||||
|
||||
speaker = None
|
||||
if self.speaker_to_id is not None:
|
||||
speaker = (
|
||||
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
|
||||
.index_select(0, order)
|
||||
.view(-1, 1)
|
||||
)
|
||||
|
||||
bsz, _, d = feat.size()
|
||||
prev_output_tokens = torch.cat(
|
||||
(feat.new_zeros((bsz, 1, d)), feat[:, :-1, :]), dim=1
|
||||
)
|
||||
|
||||
durations, pitches, energies = None, None, None
|
||||
if self.durations is not None:
|
||||
durations = fairseq_data_utils.collate_tokens(
|
||||
[s.duration for s in samples], 0
|
||||
).index_select(0, order)
|
||||
assert src_tokens.shape[1] == durations.shape[1]
|
||||
if self.pitches is not None:
|
||||
pitches = _collate_frames([s.pitch for s in samples], True)
|
||||
pitches = pitches.index_select(0, order)
|
||||
assert src_tokens.shape[1] == pitches.shape[1]
|
||||
if self.energies is not None:
|
||||
energies = _collate_frames([s.energy for s in samples], True)
|
||||
energies = energies.index_select(0, order)
|
||||
assert src_tokens.shape[1] == energies.shape[1]
|
||||
src_texts = [self.tgt_dict.string(samples[i].target) for i in order]
|
||||
|
||||
return {
|
||||
"id": id_,
|
||||
"net_input": {
|
||||
"src_tokens": src_tokens,
|
||||
"src_lengths": src_lengths,
|
||||
"prev_output_tokens": prev_output_tokens,
|
||||
},
|
||||
"speaker": speaker,
|
||||
"target": feat,
|
||||
"durations": durations,
|
||||
"pitches": pitches,
|
||||
"energies": energies,
|
||||
"target_lengths": target_lengths,
|
||||
"ntokens": sum(target_lengths).item(),
|
||||
"nsentences": len(samples),
|
||||
"src_texts": src_texts,
|
||||
}
|
||||
|
||||
|
||||
class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator):
|
||||
KEY_DURATION = "duration"
|
||||
KEY_PITCH = "pitch"
|
||||
KEY_ENERGY = "energy"
|
||||
|
||||
@classmethod
|
||||
def _from_list(
|
||||
cls,
|
||||
split_name: str,
|
||||
is_train_split,
|
||||
samples: List[Dict],
|
||||
cfg: S2TDataConfig,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
n_frames_per_step,
|
||||
speaker_to_id,
|
||||
multitask=None,
|
||||
) -> TextToSpeechDataset:
|
||||
audio_root = Path(cfg.audio_root)
|
||||
ids = [s[cls.KEY_ID] for s in samples]
|
||||
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
|
||||
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
|
||||
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
|
||||
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
|
||||
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
|
||||
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
|
||||
durations = [s.get(cls.KEY_DURATION, None) for s in samples]
|
||||
durations = [
|
||||
None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations
|
||||
]
|
||||
durations = None if any(dd is None for dd in durations) else durations
|
||||
|
||||
pitches = [s.get(cls.KEY_PITCH, None) for s in samples]
|
||||
pitches = [
|
||||
None if pp is None else (audio_root / pp).as_posix() for pp in pitches
|
||||
]
|
||||
pitches = None if any(pp is None for pp in pitches) else pitches
|
||||
|
||||
energies = [s.get(cls.KEY_ENERGY, None) for s in samples]
|
||||
energies = [
|
||||
None if ee is None else (audio_root / ee).as_posix() for ee in energies
|
||||
]
|
||||
energies = None if any(ee is None for ee in energies) else energies
|
||||
|
||||
return TextToSpeechDataset(
|
||||
split_name,
|
||||
is_train_split,
|
||||
cfg,
|
||||
audio_paths,
|
||||
n_frames,
|
||||
src_texts,
|
||||
tgt_texts,
|
||||
speakers,
|
||||
src_langs,
|
||||
tgt_langs,
|
||||
ids,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
n_frames_per_step,
|
||||
speaker_to_id,
|
||||
durations,
|
||||
pitches,
|
||||
energies,
|
||||
)
|
||||
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
from fairseq.data.audio import (
|
||||
AudioTransform,
|
||||
CompositeAudioTransform,
|
||||
import_transforms,
|
||||
register_audio_transform,
|
||||
)
|
||||
|
||||
|
||||
class AudioWaveformTransform(AudioTransform):
|
||||
pass
|
||||
|
||||
|
||||
AUDIO_WAVEFORM_TRANSFORM_REGISTRY = {}
|
||||
AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES = set()
|
||||
|
||||
|
||||
def get_audio_waveform_transform(name):
|
||||
return AUDIO_WAVEFORM_TRANSFORM_REGISTRY[name]
|
||||
|
||||
|
||||
def register_audio_waveform_transform(name):
|
||||
return register_audio_transform(
|
||||
name,
|
||||
AudioWaveformTransform,
|
||||
AUDIO_WAVEFORM_TRANSFORM_REGISTRY,
|
||||
AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES,
|
||||
)
|
||||
|
||||
|
||||
import_transforms(os.path.dirname(__file__), "waveform")
|
||||
|
||||
|
||||
class CompositeAudioWaveformTransform(CompositeAudioTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
return super()._from_config_dict(
|
||||
cls,
|
||||
"waveform",
|
||||
get_audio_waveform_transform,
|
||||
CompositeAudioWaveformTransform,
|
||||
config,
|
||||
)
|
||||
|
||||
def __call__(self, x, sample_rate):
|
||||
for t in self.transforms:
|
||||
x, sample_rate = t(x, sample_rate)
|
||||
return x, sample_rate
|
||||
@@ -0,0 +1,201 @@
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from math import ceil
|
||||
|
||||
from fairseq.data.audio import rand_uniform
|
||||
from fairseq.data.audio.waveform_transforms import (
|
||||
AudioWaveformTransform,
|
||||
register_audio_waveform_transform,
|
||||
)
|
||||
|
||||
SNR_MIN = 5.0
|
||||
SNR_MAX = 15.0
|
||||
RATE = 0.25
|
||||
|
||||
NOISE_RATE = 1.0
|
||||
NOISE_LEN_MEAN = 0.2
|
||||
NOISE_LEN_STD = 0.05
|
||||
|
||||
|
||||
class NoiseAugmentTransform(AudioWaveformTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return cls(
|
||||
_config.get("samples_path", None),
|
||||
_config.get("snr_min", SNR_MIN),
|
||||
_config.get("snr_max", SNR_MAX),
|
||||
_config.get("rate", RATE),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samples_path: str,
|
||||
snr_min: float = SNR_MIN,
|
||||
snr_max: float = SNR_MAX,
|
||||
rate: float = RATE,
|
||||
):
|
||||
# Sanity checks
|
||||
assert (
|
||||
samples_path
|
||||
), "need to provide path to audio samples for noise augmentation"
|
||||
assert snr_max >= snr_min, f"empty signal-to-noise range ({snr_min}, {snr_max})"
|
||||
assert rate >= 0 and rate <= 1, "rate should be a float between 0 to 1"
|
||||
|
||||
self.paths = list(Path(samples_path).glob("**/*.wav")) # load music
|
||||
self.n_samples = len(self.paths)
|
||||
assert self.n_samples > 0, f"no audio files found in {samples_path}"
|
||||
|
||||
self.snr_min = snr_min
|
||||
self.snr_max = snr_max
|
||||
self.rate = rate
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ "("
|
||||
+ ", ".join(
|
||||
[
|
||||
f"n_samples={self.n_samples}",
|
||||
f"snr={self.snr_min}-{self.snr_max}dB",
|
||||
f"rate={self.rate}",
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
|
||||
def pick_sample(self, goal_shape, always_2d=False, use_sample_rate=None):
|
||||
from fairseq.data.audio.audio_utils import get_waveform
|
||||
|
||||
path = self.paths[np.random.randint(0, self.n_samples)]
|
||||
sample = get_waveform(
|
||||
path, always_2d=always_2d, output_sample_rate=use_sample_rate
|
||||
)[0]
|
||||
|
||||
# Check dimensions match, else silently skip adding noise to sample
|
||||
# NOTE: SHOULD THIS QUIT WITH AN ERROR?
|
||||
is_2d = len(goal_shape) == 2
|
||||
if len(goal_shape) != sample.ndim or (
|
||||
is_2d and goal_shape[0] != sample.shape[0]
|
||||
):
|
||||
return np.zeros(goal_shape)
|
||||
|
||||
# Cut/repeat sample to size
|
||||
len_dim = len(goal_shape) - 1
|
||||
n_repeat = ceil(goal_shape[len_dim] / sample.shape[len_dim])
|
||||
repeated = np.tile(sample, [1, n_repeat] if is_2d else n_repeat)
|
||||
start = np.random.randint(0, repeated.shape[len_dim] - goal_shape[len_dim] + 1)
|
||||
return (
|
||||
repeated[:, start : start + goal_shape[len_dim]]
|
||||
if is_2d
|
||||
else repeated[start : start + goal_shape[len_dim]]
|
||||
)
|
||||
|
||||
def _mix(self, source, noise, snr):
|
||||
get_power = lambda x: np.mean(x**2)
|
||||
if get_power(noise):
|
||||
scl = np.sqrt(
|
||||
get_power(source) / (np.power(10, snr / 10) * get_power(noise))
|
||||
)
|
||||
else:
|
||||
scl = 0
|
||||
return 1 * source + scl * noise
|
||||
|
||||
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
|
||||
return self.pick_sample(goal_shape, always_2d, use_sample_rate)
|
||||
|
||||
def __call__(self, source, sample_rate):
|
||||
if np.random.random() > self.rate:
|
||||
return source, sample_rate
|
||||
|
||||
noise = self._get_noise(
|
||||
source.shape, always_2d=True, use_sample_rate=sample_rate
|
||||
)
|
||||
|
||||
return (
|
||||
self._mix(source, noise, rand_uniform(self.snr_min, self.snr_max)),
|
||||
sample_rate,
|
||||
)
|
||||
|
||||
|
||||
@register_audio_waveform_transform("musicaugment")
|
||||
class MusicAugmentTransform(NoiseAugmentTransform):
|
||||
pass
|
||||
|
||||
|
||||
@register_audio_waveform_transform("backgroundnoiseaugment")
|
||||
class BackgroundNoiseAugmentTransform(NoiseAugmentTransform):
|
||||
pass
|
||||
|
||||
|
||||
@register_audio_waveform_transform("babbleaugment")
|
||||
class BabbleAugmentTransform(NoiseAugmentTransform):
|
||||
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
|
||||
for i in range(np.random.randint(3, 8)):
|
||||
speech = self.pick_sample(goal_shape, always_2d, use_sample_rate)
|
||||
if i == 0:
|
||||
agg_noise = speech
|
||||
else: # SNR scaled by i (how many noise signals already in agg_noise)
|
||||
agg_noise = self._mix(agg_noise, speech, i)
|
||||
return agg_noise
|
||||
|
||||
|
||||
@register_audio_waveform_transform("sporadicnoiseaugment")
|
||||
class SporadicNoiseAugmentTransform(NoiseAugmentTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return cls(
|
||||
_config.get("samples_path", None),
|
||||
_config.get("snr_min", SNR_MIN),
|
||||
_config.get("snr_max", SNR_MAX),
|
||||
_config.get("rate", RATE),
|
||||
_config.get("noise_rate", NOISE_RATE),
|
||||
_config.get("noise_len_mean", NOISE_LEN_MEAN),
|
||||
_config.get("noise_len_std", NOISE_LEN_STD),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samples_path: str,
|
||||
snr_min: float = SNR_MIN,
|
||||
snr_max: float = SNR_MAX,
|
||||
rate: float = RATE,
|
||||
noise_rate: float = NOISE_RATE, # noises per second
|
||||
noise_len_mean: float = NOISE_LEN_MEAN, # length of noises in seconds
|
||||
noise_len_std: float = NOISE_LEN_STD,
|
||||
):
|
||||
super().__init__(samples_path, snr_min, snr_max, rate)
|
||||
self.noise_rate = noise_rate
|
||||
self.noise_len_mean = noise_len_mean
|
||||
self.noise_len_std = noise_len_std
|
||||
|
||||
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
|
||||
agg_noise = np.zeros(goal_shape)
|
||||
len_dim = len(goal_shape) - 1
|
||||
is_2d = len(goal_shape) == 2
|
||||
|
||||
n_noises = round(self.noise_rate * goal_shape[len_dim] / use_sample_rate)
|
||||
start_pointers = [
|
||||
round(rand_uniform(0, goal_shape[len_dim])) for _ in range(n_noises)
|
||||
]
|
||||
|
||||
for start_pointer in start_pointers:
|
||||
noise_shape = list(goal_shape)
|
||||
len_seconds = np.random.normal(self.noise_len_mean, self.noise_len_std)
|
||||
noise_shape[len_dim] = round(max(0, len_seconds) * use_sample_rate)
|
||||
end_pointer = start_pointer + noise_shape[len_dim]
|
||||
if end_pointer >= goal_shape[len_dim]:
|
||||
continue
|
||||
|
||||
noise = self.pick_sample(noise_shape, always_2d, use_sample_rate)
|
||||
if is_2d:
|
||||
agg_noise[:, start_pointer:end_pointer] = (
|
||||
agg_noise[:, start_pointer:end_pointer] + noise
|
||||
)
|
||||
else:
|
||||
agg_noise[start_pointer:end_pointer] = (
|
||||
agg_noise[start_pointer:end_pointer] + noise
|
||||
)
|
||||
|
||||
return agg_noise
|
||||
Reference in New Issue
Block a user