mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-11 14:30:03 +00:00
Merge pull request #116 from Tony-sama/neo
Feature: RVC module, applying voice conversion to audio sent by ST
This commit is contained in:
0
data/models/rvc/.placeholder
Normal file
0
data/models/rvc/.placeholder
Normal file
0
data/tmp/.placeholder
Normal file
0
data/tmp/.placeholder
Normal file
21
modules/voice_conversion/fairseq/LICENSE
Normal file
21
modules/voice_conversion/fairseq/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) Facebook, Inc. and its affiliates.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
45
modules/voice_conversion/fairseq/__init__.py
Normal file
45
modules/voice_conversion/fairseq/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""isort:skip_file"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
try:
|
||||
from .version import __version__ # noqa
|
||||
except ImportError:
|
||||
version_txt = os.path.join(os.path.dirname(__file__), "version.txt")
|
||||
with open(version_txt) as f:
|
||||
__version__ = f.read().strip()
|
||||
|
||||
__all__ = ["pdb"]
|
||||
|
||||
# backwards compatibility to support `from fairseq.X import Y`
|
||||
from fairseq.distributed import utils as distributed_utils
|
||||
from fairseq.logging import meters, metrics, progress_bar # noqa
|
||||
|
||||
sys.modules["fairseq.distributed_utils"] = distributed_utils
|
||||
sys.modules["fairseq.meters"] = meters
|
||||
sys.modules["fairseq.metrics"] = metrics
|
||||
sys.modules["fairseq.progress_bar"] = progress_bar
|
||||
|
||||
# initialize hydra
|
||||
#from fairseq.dataclass.initialize import hydra_init
|
||||
|
||||
#hydra_init()
|
||||
|
||||
#import fairseq.criterions # noqa
|
||||
#import fairseq.distributed # noqa
|
||||
#import fairseq.models # noqa
|
||||
#import fairseq.modules # noqa
|
||||
#import fairseq.optim # noqa
|
||||
#import fairseq.optim.lr_scheduler # noqa
|
||||
#import fairseq.pdb # noqa
|
||||
#import fairseq.scoring # noqa
|
||||
#import fairseq.tasks # noqa
|
||||
#import fairseq.token_generation_constraints # noqa
|
||||
|
||||
#import fairseq.benchmark # noqa
|
||||
#import fairseq.model_parallel # noqa
|
||||
381
modules/voice_conversion/fairseq/binarizer.py
Normal file
381
modules/voice_conversion/fairseq/binarizer.py
Normal file
@@ -0,0 +1,381 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import typing as tp
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Pool
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq.data import Dictionary, indexed_dataset
|
||||
from fairseq.file_chunker_utils import Chunker, find_offsets
|
||||
from fairseq.file_io import PathManager
|
||||
from fairseq.tokenizer import tokenize_line
|
||||
|
||||
logger = logging.getLogger("binarizer")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BinarizeSummary:
|
||||
"""
|
||||
Keep track of what's going on in the binarizer
|
||||
"""
|
||||
|
||||
num_seq: int = 0
|
||||
replaced: tp.Optional[Counter] = None
|
||||
num_tok: int = 0
|
||||
|
||||
@property
|
||||
def num_replaced(self) -> int:
|
||||
if self.replaced is None:
|
||||
return 0
|
||||
return sum(self.replaced.values())
|
||||
|
||||
@property
|
||||
def replaced_percent(self) -> float:
|
||||
return 100 * self.num_replaced / self.num_tok
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = f"{self.num_seq} sents, {self.num_tok} tokens"
|
||||
if self.replaced is None:
|
||||
return base
|
||||
|
||||
return f"{base}, {self.replaced_percent:.3}% replaced"
|
||||
|
||||
def merge(self, other: "BinarizeSummary"):
|
||||
replaced = None
|
||||
if self.replaced is not None:
|
||||
replaced = self.replaced
|
||||
if other.replaced is not None:
|
||||
if replaced is None:
|
||||
replaced = other.replaced
|
||||
else:
|
||||
replaced += other.replaced
|
||||
self.replaced = replaced
|
||||
self.num_seq += other.num_seq
|
||||
self.num_tok += other.num_tok
|
||||
|
||||
|
||||
class Binarizer(ABC):
|
||||
"""
|
||||
a binarizer describes how to take a string and build a tensor out of it
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def binarize_line(
|
||||
self,
|
||||
line: str,
|
||||
summary: BinarizeSummary,
|
||||
) -> torch.IntTensor:
|
||||
...
|
||||
|
||||
|
||||
def _worker_prefix(output_prefix: str, worker_id: int):
|
||||
return f"{output_prefix}.pt{worker_id}"
|
||||
|
||||
|
||||
class FileBinarizer:
|
||||
"""
|
||||
An file binarizer can take a file, tokenize it, and binarize each line to a tensor
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def multiprocess_dataset(
|
||||
cls,
|
||||
input_file: str,
|
||||
dataset_impl: str,
|
||||
binarizer: Binarizer,
|
||||
output_prefix: str,
|
||||
vocab_size=None,
|
||||
num_workers=1,
|
||||
) -> BinarizeSummary:
|
||||
final_summary = BinarizeSummary()
|
||||
|
||||
offsets = find_offsets(input_file, num_workers)
|
||||
# find_offsets returns a list of position [pos1, pos2, pos3, pos4] but we would want pairs:
|
||||
# [(pos1, pos2), (pos2, pos3), (pos3, pos4)] to process the chunks with start/end info
|
||||
# we zip the list with itself shifted by one to get all the pairs.
|
||||
(first_chunk, *more_chunks) = zip(offsets, offsets[1:])
|
||||
pool = None
|
||||
if num_workers > 1:
|
||||
pool = Pool(processes=num_workers - 1)
|
||||
worker_results = [
|
||||
pool.apply_async(
|
||||
cls._binarize_chunk_and_finalize,
|
||||
args=(
|
||||
binarizer,
|
||||
input_file,
|
||||
start_offset,
|
||||
end_offset,
|
||||
_worker_prefix(
|
||||
output_prefix,
|
||||
worker_id,
|
||||
),
|
||||
dataset_impl,
|
||||
),
|
||||
kwds={
|
||||
"vocab_size": vocab_size,
|
||||
}
|
||||
if vocab_size is not None
|
||||
else {},
|
||||
)
|
||||
for worker_id, (start_offset, end_offset) in enumerate(
|
||||
more_chunks, start=1
|
||||
)
|
||||
]
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
for r in worker_results:
|
||||
summ = r.get()
|
||||
final_summary.merge(summ)
|
||||
|
||||
# do not close the bin file as we need to merge the worker results in
|
||||
final_ds, summ = cls._binarize_file_chunk(
|
||||
binarizer,
|
||||
input_file,
|
||||
offset_start=first_chunk[0],
|
||||
offset_end=first_chunk[1],
|
||||
output_prefix=output_prefix,
|
||||
dataset_impl=dataset_impl,
|
||||
vocab_size=vocab_size if vocab_size is not None else None,
|
||||
)
|
||||
final_summary.merge(summ)
|
||||
|
||||
if num_workers > 1:
|
||||
for worker_id in range(1, num_workers):
|
||||
# merge the worker outputs
|
||||
worker_output_prefix = _worker_prefix(
|
||||
output_prefix,
|
||||
worker_id,
|
||||
)
|
||||
final_ds.merge_file_(worker_output_prefix)
|
||||
try:
|
||||
os.remove(indexed_dataset.data_file_path(worker_output_prefix))
|
||||
os.remove(indexed_dataset.index_file_path(worker_output_prefix))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"couldn't remove {worker_output_prefix}.*", exc_info=e
|
||||
)
|
||||
|
||||
# now we can close the file
|
||||
idx_file = indexed_dataset.index_file_path(output_prefix)
|
||||
final_ds.finalize(idx_file)
|
||||
return final_summary
|
||||
|
||||
@staticmethod
|
||||
def _binarize_file_chunk(
|
||||
binarizer: Binarizer,
|
||||
filename: str,
|
||||
offset_start: int,
|
||||
offset_end: int,
|
||||
output_prefix: str,
|
||||
dataset_impl: str,
|
||||
vocab_size=None,
|
||||
) -> tp.Tuple[tp.Any, BinarizeSummary]: # (dataset builder, BinarizeSummary)
|
||||
"""
|
||||
creates a dataset builder and append binarized items to it. This function does not
|
||||
finalize the builder, this is useful if you want to do other things with your bin file
|
||||
like appending/merging other files
|
||||
"""
|
||||
bin_file = indexed_dataset.data_file_path(output_prefix)
|
||||
ds = indexed_dataset.make_builder(
|
||||
bin_file,
|
||||
impl=dataset_impl,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
summary = BinarizeSummary()
|
||||
|
||||
with Chunker(
|
||||
PathManager.get_local_path(filename), offset_start, offset_end
|
||||
) as line_iterator:
|
||||
for line in line_iterator:
|
||||
ds.add_item(binarizer.binarize_line(line, summary))
|
||||
|
||||
return ds, summary
|
||||
|
||||
@classmethod
|
||||
def _binarize_chunk_and_finalize(
|
||||
cls,
|
||||
binarizer: Binarizer,
|
||||
filename: str,
|
||||
offset_start: int,
|
||||
offset_end: int,
|
||||
output_prefix: str,
|
||||
dataset_impl: str,
|
||||
vocab_size=None,
|
||||
):
|
||||
"""
|
||||
same as above, but also finalizes the builder
|
||||
"""
|
||||
ds, summ = cls._binarize_file_chunk(
|
||||
binarizer,
|
||||
filename,
|
||||
offset_start,
|
||||
offset_end,
|
||||
output_prefix,
|
||||
dataset_impl,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
|
||||
idx_file = indexed_dataset.index_file_path(output_prefix)
|
||||
ds.finalize(idx_file)
|
||||
|
||||
return summ
|
||||
|
||||
|
||||
class VocabularyDatasetBinarizer(Binarizer):
|
||||
"""
|
||||
Takes a Dictionary/Vocabulary, assign ids to each
|
||||
token using the dictionary encode_line function.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dict: Dictionary,
|
||||
tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
|
||||
append_eos: bool = True,
|
||||
reverse_order: bool = False,
|
||||
already_numberized: bool = False,
|
||||
) -> None:
|
||||
self.dict = dict
|
||||
self.tokenize = tokenize
|
||||
self.append_eos = append_eos
|
||||
self.reverse_order = reverse_order
|
||||
self.already_numberized = already_numberized
|
||||
super().__init__()
|
||||
|
||||
def binarize_line(
|
||||
self,
|
||||
line: str,
|
||||
summary: BinarizeSummary,
|
||||
):
|
||||
if summary.replaced is None:
|
||||
summary.replaced = Counter()
|
||||
|
||||
def replaced_consumer(word, idx):
|
||||
if idx == self.dict.unk_index and word != self.dict.unk_word:
|
||||
summary.replaced.update([word])
|
||||
|
||||
if self.already_numberized:
|
||||
id_strings = line.strip().split()
|
||||
id_list = [int(id_string) for id_string in id_strings]
|
||||
if self.reverse_order:
|
||||
id_list.reverse()
|
||||
if self.append_eos:
|
||||
id_list.append(self.dict.eos())
|
||||
ids = torch.IntTensor(id_list)
|
||||
else:
|
||||
ids = self.dict.encode_line(
|
||||
line=line,
|
||||
line_tokenizer=self.tokenize,
|
||||
add_if_not_exist=False,
|
||||
consumer=replaced_consumer,
|
||||
append_eos=self.append_eos,
|
||||
reverse_order=self.reverse_order,
|
||||
)
|
||||
|
||||
summary.num_seq += 1
|
||||
summary.num_tok += len(ids)
|
||||
return ids
|
||||
|
||||
|
||||
class AlignmentDatasetBinarizer(Binarizer):
|
||||
"""
|
||||
binarize by parsing a set of alignments and packing
|
||||
them in a tensor (see utils.parse_alignment)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alignment_parser: tp.Callable[[str], torch.IntTensor],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.alignment_parser = alignment_parser
|
||||
|
||||
def binarize_line(
|
||||
self,
|
||||
line: str,
|
||||
summary: BinarizeSummary,
|
||||
):
|
||||
ids = self.alignment_parser(line)
|
||||
summary.num_seq += 1
|
||||
summary.num_tok += len(ids)
|
||||
return ids
|
||||
|
||||
|
||||
class LegacyBinarizer:
|
||||
@classmethod
|
||||
def binarize(
|
||||
cls,
|
||||
filename: str,
|
||||
dico: Dictionary,
|
||||
consumer: tp.Callable[[torch.IntTensor], None],
|
||||
tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
|
||||
append_eos: bool = True,
|
||||
reverse_order: bool = False,
|
||||
offset: int = 0,
|
||||
end: int = -1,
|
||||
already_numberized: bool = False,
|
||||
) -> tp.Dict[str, int]:
|
||||
binarizer = VocabularyDatasetBinarizer(
|
||||
dict=dico,
|
||||
tokenize=tokenize,
|
||||
append_eos=append_eos,
|
||||
reverse_order=reverse_order,
|
||||
already_numberized=already_numberized,
|
||||
)
|
||||
return cls._consume_file(
|
||||
filename,
|
||||
binarizer,
|
||||
consumer,
|
||||
offset_start=offset,
|
||||
offset_end=end,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def binarize_alignments(
|
||||
cls,
|
||||
filename: str,
|
||||
alignment_parser: tp.Callable[[str], torch.IntTensor],
|
||||
consumer: tp.Callable[[torch.IntTensor], None],
|
||||
offset: int = 0,
|
||||
end: int = -1,
|
||||
) -> tp.Dict[str, int]:
|
||||
binarizer = AlignmentDatasetBinarizer(alignment_parser)
|
||||
return cls._consume_file(
|
||||
filename,
|
||||
binarizer,
|
||||
consumer,
|
||||
offset_start=offset,
|
||||
offset_end=end,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _consume_file(
|
||||
filename: str,
|
||||
binarizer: Binarizer,
|
||||
consumer: tp.Callable[[torch.IntTensor], None],
|
||||
offset_start: int,
|
||||
offset_end: int,
|
||||
) -> tp.Dict[str, int]:
|
||||
summary = BinarizeSummary()
|
||||
|
||||
with Chunker(
|
||||
PathManager.get_local_path(filename), offset_start, offset_end
|
||||
) as line_iterator:
|
||||
for line in line_iterator:
|
||||
consumer(binarizer.binarize_line(line, summary))
|
||||
|
||||
return {
|
||||
"nseq": summary.num_seq,
|
||||
"nunk": summary.num_replaced,
|
||||
"ntok": summary.num_tok,
|
||||
"replaced": summary.replaced,
|
||||
}
|
||||
905
modules/voice_conversion/fairseq/checkpoint_utils.py
Normal file
905
modules/voice_conversion/fairseq/checkpoint_utils.py
Normal file
@@ -0,0 +1,905 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import ast
|
||||
import collections
|
||||
import contextlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import data_utils
|
||||
from fairseq.dataclass.configs import CheckpointConfig
|
||||
from fairseq.dataclass.utils import (
|
||||
convert_namespace_to_omegaconf,
|
||||
overwrite_args_by_name,
|
||||
)
|
||||
from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
|
||||
from fairseq.file_io import PathManager
|
||||
from fairseq.models import FairseqDecoder, FairseqEncoder
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
|
||||
from fairseq import meters
|
||||
|
||||
# only one worker should attempt to create the required dir
|
||||
if trainer.data_parallel_rank == 0:
|
||||
os.makedirs(cfg.save_dir, exist_ok=True)
|
||||
|
||||
prev_best = getattr(save_checkpoint, "best", val_loss)
|
||||
if val_loss is not None:
|
||||
best_function = max if cfg.maximize_best_checkpoint_metric else min
|
||||
save_checkpoint.best = best_function(val_loss, prev_best)
|
||||
|
||||
if cfg.no_save:
|
||||
return
|
||||
|
||||
trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
|
||||
|
||||
if not trainer.should_save_checkpoint_on_current_rank:
|
||||
if trainer.always_call_state_dict_during_save_checkpoint:
|
||||
trainer.state_dict()
|
||||
return
|
||||
|
||||
write_timer = meters.StopwatchMeter()
|
||||
write_timer.start()
|
||||
|
||||
epoch = epoch_itr.epoch
|
||||
end_of_epoch = epoch_itr.end_of_epoch()
|
||||
updates = trainer.get_num_updates()
|
||||
|
||||
logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
|
||||
|
||||
def is_better(a, b):
|
||||
return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
|
||||
|
||||
suffix = trainer.checkpoint_suffix
|
||||
checkpoint_conds = collections.OrderedDict()
|
||||
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
|
||||
end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
|
||||
)
|
||||
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
|
||||
not end_of_epoch
|
||||
and cfg.save_interval_updates > 0
|
||||
and updates % cfg.save_interval_updates == 0
|
||||
)
|
||||
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
|
||||
not hasattr(save_checkpoint, "best")
|
||||
or is_better(val_loss, save_checkpoint.best)
|
||||
)
|
||||
if val_loss is not None and cfg.keep_best_checkpoints > 0:
|
||||
worst_best = getattr(save_checkpoint, "best", None)
|
||||
chkpts = checkpoint_paths(
|
||||
cfg.save_dir,
|
||||
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
|
||||
cfg.best_checkpoint_metric, suffix
|
||||
),
|
||||
)
|
||||
if len(chkpts) > 0:
|
||||
p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
|
||||
worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
|
||||
# add random digits to resolve ties
|
||||
with data_utils.numpy_seed(epoch, updates, val_loss):
|
||||
rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)
|
||||
|
||||
checkpoint_conds[
|
||||
"checkpoint.best_{}_{:.3f}{}{}.pt".format(
|
||||
cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
|
||||
)
|
||||
] = worst_best is None or is_better(val_loss, worst_best)
|
||||
checkpoint_conds[
|
||||
"checkpoint_last{}.pt".format(suffix)
|
||||
] = not cfg.no_last_checkpoints
|
||||
|
||||
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
|
||||
if hasattr(save_checkpoint, "best"):
|
||||
extra_state.update({"best": save_checkpoint.best})
|
||||
|
||||
checkpoints = [
|
||||
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
|
||||
]
|
||||
if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank:
|
||||
trainer.save_checkpoint(checkpoints[0], extra_state)
|
||||
for cp in checkpoints[1:]:
|
||||
if cfg.write_checkpoints_asynchronously:
|
||||
# TODO[ioPath]: Need to implement a delayed asynchronous
|
||||
# file copying/moving feature.
|
||||
logger.warning(
|
||||
f"ioPath is not copying {checkpoints[0]} to {cp} "
|
||||
"since async write mode is on."
|
||||
)
|
||||
else:
|
||||
assert PathManager.copy(
|
||||
checkpoints[0], cp, overwrite=True
|
||||
), f"Failed to copy {checkpoints[0]} to {cp}"
|
||||
|
||||
write_timer.stop()
|
||||
logger.info(
|
||||
"Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
|
||||
checkpoints[0], epoch, updates, val_loss, write_timer.sum
|
||||
)
|
||||
)
|
||||
|
||||
if not end_of_epoch and cfg.keep_interval_updates > 0:
|
||||
# remove old checkpoints; checkpoints are sorted in descending order
|
||||
if cfg.keep_interval_updates_pattern == -1:
|
||||
checkpoints = checkpoint_paths(
|
||||
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
|
||||
)
|
||||
else:
|
||||
checkpoints = checkpoint_paths(
|
||||
cfg.save_dir,
|
||||
pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
|
||||
keep_match=True,
|
||||
)
|
||||
checkpoints = [
|
||||
x[0]
|
||||
for x in checkpoints
|
||||
if x[1] % cfg.keep_interval_updates_pattern != 0
|
||||
]
|
||||
|
||||
for old_chk in checkpoints[cfg.keep_interval_updates :]:
|
||||
if os.path.lexists(old_chk):
|
||||
os.remove(old_chk)
|
||||
elif PathManager.exists(old_chk):
|
||||
PathManager.rm(old_chk)
|
||||
|
||||
if cfg.keep_last_epochs > 0:
|
||||
# remove old epoch checkpoints; checkpoints are sorted in descending order
|
||||
checkpoints = checkpoint_paths(
|
||||
cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
|
||||
)
|
||||
for old_chk in checkpoints[cfg.keep_last_epochs :]:
|
||||
if os.path.lexists(old_chk):
|
||||
os.remove(old_chk)
|
||||
elif PathManager.exists(old_chk):
|
||||
PathManager.rm(old_chk)
|
||||
|
||||
if cfg.keep_best_checkpoints > 0:
|
||||
# only keep the best N checkpoints according to validation metric
|
||||
checkpoints = checkpoint_paths(
|
||||
cfg.save_dir,
|
||||
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
|
||||
cfg.best_checkpoint_metric, suffix
|
||||
),
|
||||
)
|
||||
if not cfg.maximize_best_checkpoint_metric:
|
||||
checkpoints = checkpoints[::-1]
|
||||
for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
|
||||
if os.path.lexists(old_chk):
|
||||
os.remove(old_chk)
|
||||
elif PathManager.exists(old_chk):
|
||||
PathManager.rm(old_chk)
|
||||
|
||||
|
||||
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
|
||||
"""
|
||||
Load a checkpoint and restore the training iterator.
|
||||
|
||||
*passthrough_args* will be passed through to
|
||||
``trainer.get_train_iterator``.
|
||||
"""
|
||||
|
||||
reset_optimizer = cfg.reset_optimizer
|
||||
reset_lr_scheduler = cfg.reset_lr_scheduler
|
||||
optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
|
||||
reset_meters = cfg.reset_meters
|
||||
reset_dataloader = cfg.reset_dataloader
|
||||
|
||||
if cfg.finetune_from_model is not None and (
|
||||
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
|
||||
):
|
||||
raise ValueError(
|
||||
"--finetune-from-model can not be set together with either --reset-optimizer"
|
||||
" or reset_lr_scheduler or reset_meters or reset_dataloader"
|
||||
)
|
||||
|
||||
suffix = trainer.checkpoint_suffix
|
||||
if (
|
||||
cfg.restore_file == "checkpoint_last.pt"
|
||||
): # default value of restore_file is 'checkpoint_last.pt'
|
||||
checkpoint_path = os.path.join(
|
||||
cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
|
||||
)
|
||||
first_launch = not PathManager.exists(checkpoint_path)
|
||||
if first_launch and getattr(cfg, "continue_once", None) is not None:
|
||||
checkpoint_path = cfg.continue_once
|
||||
elif cfg.finetune_from_model is not None and first_launch:
|
||||
# if there is no last checkpoint to restore, start the finetune from pretrained model
|
||||
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
|
||||
if PathManager.exists(cfg.finetune_from_model):
|
||||
checkpoint_path = cfg.finetune_from_model
|
||||
reset_optimizer = True
|
||||
reset_lr_scheduler = True
|
||||
reset_meters = True
|
||||
reset_dataloader = True
|
||||
logger.info(
|
||||
f"loading pretrained model from {checkpoint_path}: "
|
||||
"optimizer, lr scheduler, meters, dataloader will be reset"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"--finetune-from-model {cfg.finetune_from_model} does not exist"
|
||||
)
|
||||
elif suffix is not None:
|
||||
checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
|
||||
else:
|
||||
checkpoint_path = cfg.restore_file
|
||||
|
||||
if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
|
||||
raise ValueError(
|
||||
"--finetune-from-model and --restore-file (non-default value) "
|
||||
"can not be specified together: " + str(cfg)
|
||||
)
|
||||
|
||||
extra_state = trainer.load_checkpoint(
|
||||
checkpoint_path,
|
||||
reset_optimizer,
|
||||
reset_lr_scheduler,
|
||||
optimizer_overrides,
|
||||
reset_meters=reset_meters,
|
||||
)
|
||||
|
||||
if (
|
||||
extra_state is not None
|
||||
and "best" in extra_state
|
||||
and not reset_optimizer
|
||||
and not reset_meters
|
||||
):
|
||||
save_checkpoint.best = extra_state["best"]
|
||||
|
||||
if extra_state is not None and not reset_dataloader:
|
||||
# restore iterator from checkpoint
|
||||
itr_state = extra_state["train_iterator"]
|
||||
epoch_itr = trainer.get_train_iterator(
|
||||
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
|
||||
)
|
||||
epoch_itr.load_state_dict(itr_state)
|
||||
else:
|
||||
epoch_itr = trainer.get_train_iterator(
|
||||
epoch=1, load_dataset=True, **passthrough_args
|
||||
)
|
||||
|
||||
trainer.lr_step(epoch_itr.epoch)
|
||||
|
||||
return extra_state, epoch_itr
|
||||
|
||||
|
||||
def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
|
||||
"""Loads a checkpoint to CPU (with upgrading for backward compatibility).
|
||||
|
||||
If doing single-GPU training or if the checkpoint is only being loaded by at
|
||||
most one process on each node (current default behavior is for only rank 0
|
||||
to read the checkpoint from disk), load_on_all_ranks should be False to
|
||||
avoid errors from torch.distributed not having been initialized or
|
||||
torch.distributed.barrier() hanging.
|
||||
|
||||
If all processes on each node may be loading the checkpoint
|
||||
simultaneously, load_on_all_ranks should be set to True to avoid I/O
|
||||
conflicts.
|
||||
|
||||
There's currently no support for > 1 but < all processes loading the
|
||||
checkpoint on each node.
|
||||
"""
|
||||
local_path = PathManager.get_local_path(path)
|
||||
# The locally cached file returned by get_local_path() may be stale for
|
||||
# remote files that are periodically updated/overwritten (ex:
|
||||
# checkpoint_last.pt) - so we remove the local copy, sync across processes
|
||||
# (if needed), and then download a fresh copy.
|
||||
if local_path != path and PathManager.path_requires_pathmanager(path):
|
||||
try:
|
||||
os.remove(local_path)
|
||||
except FileNotFoundError:
|
||||
# With potentially multiple processes removing the same file, the
|
||||
# file being missing is benign (missing_ok isn't available until
|
||||
# Python 3.8).
|
||||
pass
|
||||
if load_on_all_ranks:
|
||||
torch.distributed.barrier()
|
||||
local_path = PathManager.get_local_path(path)
|
||||
|
||||
with open(local_path, "rb") as f:
|
||||
state = torch.load(f, map_location=torch.device("cpu"))
|
||||
|
||||
if "args" in state and state["args"] is not None and arg_overrides is not None:
|
||||
args = state["args"]
|
||||
for arg_name, arg_val in arg_overrides.items():
|
||||
setattr(args, arg_name, arg_val)
|
||||
|
||||
if "cfg" in state and state["cfg"] is not None:
|
||||
|
||||
# hack to be able to set Namespace in dict config. this should be removed when we update to newer
|
||||
# omegaconf version that supports object flags, or when we migrate all existing models
|
||||
from omegaconf import __version__ as oc_version
|
||||
from omegaconf import _utils
|
||||
|
||||
if oc_version < "2.2":
|
||||
old_primitive = _utils.is_primitive_type
|
||||
_utils.is_primitive_type = lambda _: True
|
||||
|
||||
state["cfg"] = OmegaConf.create(state["cfg"])
|
||||
|
||||
_utils.is_primitive_type = old_primitive
|
||||
OmegaConf.set_struct(state["cfg"], True)
|
||||
else:
|
||||
state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True})
|
||||
|
||||
if arg_overrides is not None:
|
||||
overwrite_args_by_name(state["cfg"], arg_overrides)
|
||||
|
||||
state = _upgrade_state_dict(state)
|
||||
return state
|
||||
|
||||
|
||||
def load_model_ensemble(
|
||||
filenames,
|
||||
arg_overrides: Optional[Dict[str, Any]] = None,
|
||||
task=None,
|
||||
strict=True,
|
||||
suffix="",
|
||||
num_shards=1,
|
||||
state=None,
|
||||
):
|
||||
"""Loads an ensemble of models.
|
||||
|
||||
Args:
|
||||
filenames (List[str]): checkpoint files to load
|
||||
arg_overrides (Dict[str,Any], optional): override model args that
|
||||
were used during model training
|
||||
task (fairseq.tasks.FairseqTask, optional): task to use for loading
|
||||
"""
|
||||
assert not (
|
||||
strict and num_shards > 1
|
||||
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
||||
ensemble, args, _task = load_model_ensemble_and_task(
|
||||
filenames,
|
||||
arg_overrides,
|
||||
task,
|
||||
strict,
|
||||
suffix,
|
||||
num_shards,
|
||||
state,
|
||||
)
|
||||
return ensemble, args
|
||||
|
||||
|
||||
def get_maybe_sharded_checkpoint_filename(
|
||||
filename: str, suffix: str, shard_idx: int, num_shards: int
|
||||
) -> str:
|
||||
orig_filename = filename
|
||||
filename = filename.replace(".pt", suffix + ".pt")
|
||||
fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt"
|
||||
model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
|
||||
if PathManager.exists(fsdp_filename):
|
||||
return fsdp_filename
|
||||
elif num_shards > 1:
|
||||
return model_parallel_filename
|
||||
else:
|
||||
return filename
|
||||
|
||||
|
||||
def load_model_ensemble_and_task(
|
||||
filenames,
|
||||
arg_overrides: Optional[Dict[str, Any]] = None,
|
||||
task=None,
|
||||
strict=True,
|
||||
suffix="",
|
||||
num_shards=1,
|
||||
state=None,
|
||||
):
|
||||
assert state is None or len(filenames) == 1
|
||||
|
||||
from fairseq import tasks
|
||||
|
||||
assert not (
|
||||
strict and num_shards > 1
|
||||
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
||||
ensemble = []
|
||||
cfg = None
|
||||
for filename in filenames:
|
||||
orig_filename = filename
|
||||
model_shard_state = {"shard_weights": [], "shard_metadata": []}
|
||||
assert num_shards > 0
|
||||
st = time.time()
|
||||
for shard_idx in range(num_shards):
|
||||
filename = get_maybe_sharded_checkpoint_filename(
|
||||
orig_filename, suffix, shard_idx, num_shards
|
||||
)
|
||||
|
||||
if not PathManager.exists(filename):
|
||||
raise IOError("Model file not found: {}".format(filename))
|
||||
if state is None:
|
||||
state = load_checkpoint_to_cpu(filename, arg_overrides)
|
||||
if "args" in state and state["args"] is not None:
|
||||
cfg = convert_namespace_to_omegaconf(state["args"])
|
||||
elif "cfg" in state and state["cfg"] is not None:
|
||||
cfg = state["cfg"]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Neither args nor cfg exist in state keys = {state.keys()}"
|
||||
)
|
||||
|
||||
if task is None:
|
||||
task = tasks.setup_task(cfg.task)
|
||||
|
||||
if "task_state" in state:
|
||||
task.load_state_dict(state["task_state"])
|
||||
|
||||
if "fsdp_metadata" in state and num_shards > 1:
|
||||
model_shard_state["shard_weights"].append(state["model"])
|
||||
model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
|
||||
# check FSDP import before the code goes too far
|
||||
if not has_FSDP:
|
||||
raise ImportError(
|
||||
"Cannot find FullyShardedDataParallel. "
|
||||
"Please install fairscale with: pip install fairscale"
|
||||
)
|
||||
if shard_idx == num_shards - 1:
|
||||
consolidated_model_state = FSDP.consolidate_shard_weights(
|
||||
shard_weights=model_shard_state["shard_weights"],
|
||||
shard_metadata=model_shard_state["shard_metadata"],
|
||||
)
|
||||
model = task.build_model(cfg.model)
|
||||
if (
|
||||
"optimizer_history" in state
|
||||
and len(state["optimizer_history"]) > 0
|
||||
and "num_updates" in state["optimizer_history"][-1]
|
||||
):
|
||||
model.set_num_updates(
|
||||
state["optimizer_history"][-1]["num_updates"]
|
||||
)
|
||||
model.load_state_dict(
|
||||
consolidated_model_state, strict=strict, model_cfg=cfg.model
|
||||
)
|
||||
else:
|
||||
# model parallel checkpoint or unsharded checkpoint
|
||||
# support old external tasks
|
||||
|
||||
argspec = inspect.getfullargspec(task.build_model)
|
||||
if "from_checkpoint" in argspec.args:
|
||||
model = task.build_model(cfg.model, from_checkpoint=True)
|
||||
else:
|
||||
model = task.build_model(cfg.model)
|
||||
if (
|
||||
"optimizer_history" in state
|
||||
and len(state["optimizer_history"]) > 0
|
||||
and "num_updates" in state["optimizer_history"][-1]
|
||||
):
|
||||
model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
|
||||
model.load_state_dict(
|
||||
state["model"], strict=strict, model_cfg=cfg.model
|
||||
)
|
||||
|
||||
# reset state so it gets loaded for the next model in ensemble
|
||||
state = None
|
||||
if shard_idx % 10 == 0 and shard_idx > 0:
|
||||
elapsed = time.time() - st
|
||||
logger.info(
|
||||
f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
|
||||
)
|
||||
|
||||
# build model for ensemble
|
||||
ensemble.append(model)
|
||||
return ensemble, cfg, task
|
||||
|
||||
|
||||
def load_model_ensemble_and_task_from_hf_hub(
|
||||
model_id,
|
||||
cache_dir: Optional[str] = None,
|
||||
arg_overrides: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"You need to install huggingface_hub to use `load_from_hf_hub`. "
|
||||
"See https://pypi.org/project/huggingface-hub/ for installation."
|
||||
)
|
||||
|
||||
library_name = "fairseq"
|
||||
cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix()
|
||||
cache_dir = snapshot_download(
|
||||
model_id, cache_dir=cache_dir, library_name=library_name, **kwargs
|
||||
)
|
||||
|
||||
_arg_overrides = arg_overrides or {}
|
||||
_arg_overrides["data"] = cache_dir
|
||||
return load_model_ensemble_and_task(
|
||||
[p.as_posix() for p in Path(cache_dir).glob("*.pt")],
|
||||
arg_overrides=_arg_overrides,
|
||||
)
|
||||
|
||||
|
||||
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
|
||||
"""Retrieves all checkpoints found in `path` directory.
|
||||
|
||||
Checkpoints are identified by matching filename to the specified pattern. If
|
||||
the pattern contains groups, the result will be sorted by the first group in
|
||||
descending order.
|
||||
"""
|
||||
pt_regexp = re.compile(pattern)
|
||||
files = PathManager.ls(path)
|
||||
|
||||
entries = []
|
||||
for i, f in enumerate(files):
|
||||
m = pt_regexp.fullmatch(f)
|
||||
if m is not None:
|
||||
idx = float(m.group(1)) if len(m.groups()) > 0 else i
|
||||
entries.append((idx, m.group(0)))
|
||||
if keep_match:
|
||||
return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
|
||||
else:
|
||||
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
|
||||
|
||||
|
||||
def torch_persistent_save(obj, filename, async_write: bool = False):
|
||||
if async_write:
|
||||
with PathManager.opena(filename, "wb") as f:
|
||||
_torch_persistent_save(obj, f)
|
||||
else:
|
||||
if PathManager.supports_rename(filename):
|
||||
# do atomic save
|
||||
with PathManager.open(filename + ".tmp", "wb") as f:
|
||||
_torch_persistent_save(obj, f)
|
||||
PathManager.rename(filename + ".tmp", filename)
|
||||
else:
|
||||
# fallback to non-atomic save
|
||||
with PathManager.open(filename, "wb") as f:
|
||||
_torch_persistent_save(obj, f)
|
||||
|
||||
|
||||
def _torch_persistent_save(obj, f):
|
||||
if isinstance(f, str):
|
||||
with PathManager.open(f, "wb") as h:
|
||||
torch_persistent_save(obj, h)
|
||||
return
|
||||
for i in range(3):
|
||||
try:
|
||||
return torch.save(obj, f)
|
||||
except Exception:
|
||||
if i == 2:
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
|
||||
def _upgrade_state_dict(state):
|
||||
"""Helper for upgrading old model checkpoints."""
|
||||
|
||||
# add optimizer_history
|
||||
if "optimizer_history" not in state:
|
||||
state["optimizer_history"] = [
|
||||
{"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
|
||||
]
|
||||
state["last_optimizer_state"] = state["optimizer"]
|
||||
del state["optimizer"]
|
||||
del state["best_loss"]
|
||||
# move extra_state into sub-dictionary
|
||||
if "epoch" in state and "extra_state" not in state:
|
||||
state["extra_state"] = {
|
||||
"epoch": state["epoch"],
|
||||
"batch_offset": state["batch_offset"],
|
||||
"val_loss": state["val_loss"],
|
||||
}
|
||||
del state["epoch"]
|
||||
del state["batch_offset"]
|
||||
del state["val_loss"]
|
||||
# reduce optimizer history's memory usage (only keep the last state)
|
||||
if "optimizer" in state["optimizer_history"][-1]:
|
||||
state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
|
||||
for optim_hist in state["optimizer_history"]:
|
||||
del optim_hist["optimizer"]
|
||||
# record the optimizer class name
|
||||
if "optimizer_name" not in state["optimizer_history"][-1]:
|
||||
state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
|
||||
# move best_loss into lr_scheduler_state
|
||||
if "lr_scheduler_state" not in state["optimizer_history"][-1]:
|
||||
state["optimizer_history"][-1]["lr_scheduler_state"] = {
|
||||
"best": state["optimizer_history"][-1]["best_loss"]
|
||||
}
|
||||
del state["optimizer_history"][-1]["best_loss"]
|
||||
# keep track of number of updates
|
||||
if "num_updates" not in state["optimizer_history"][-1]:
|
||||
state["optimizer_history"][-1]["num_updates"] = 0
|
||||
# use stateful training data iterator
|
||||
if "train_iterator" not in state["extra_state"]:
|
||||
state["extra_state"]["train_iterator"] = {
|
||||
"epoch": state["extra_state"].get("epoch", 0),
|
||||
"iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
|
||||
}
|
||||
|
||||
# backward compatibility, cfg updates
|
||||
if "args" in state and state["args"] is not None:
|
||||
# old model checkpoints may not have separate source/target positions
|
||||
if hasattr(state["args"], "max_positions") and not hasattr(
|
||||
state["args"], "max_source_positions"
|
||||
):
|
||||
state["args"].max_source_positions = state["args"].max_positions
|
||||
state["args"].max_target_positions = state["args"].max_positions
|
||||
# default to translation task
|
||||
if not hasattr(state["args"], "task"):
|
||||
state["args"].task = "translation"
|
||||
# --raw-text and --lazy-load are deprecated
|
||||
if getattr(state["args"], "raw_text", False):
|
||||
state["args"].dataset_impl = "raw"
|
||||
elif getattr(state["args"], "lazy_load", False):
|
||||
state["args"].dataset_impl = "lazy"
|
||||
# epochs start at 1
|
||||
if state["extra_state"]["train_iterator"] is not None:
|
||||
state["extra_state"]["train_iterator"]["epoch"] = max(
|
||||
state["extra_state"]["train_iterator"].get("epoch", 1), 1
|
||||
)
|
||||
# --remove-bpe ==> --postprocess
|
||||
if hasattr(state["args"], "remove_bpe"):
|
||||
state["args"].post_process = state["args"].remove_bpe
|
||||
# --min-lr ==> --stop-min-lr
|
||||
if hasattr(state["args"], "min_lr"):
|
||||
state["args"].stop_min_lr = state["args"].min_lr
|
||||
del state["args"].min_lr
|
||||
# binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
|
||||
if hasattr(state["args"], "criterion") and state["args"].criterion in [
|
||||
"binary_cross_entropy",
|
||||
"kd_binary_cross_entropy",
|
||||
]:
|
||||
state["args"].criterion = "wav2vec"
|
||||
# remove log_keys if it's None (criteria will supply a default value of [])
|
||||
if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
|
||||
delattr(state["args"], "log_keys")
|
||||
# speech_pretraining => audio pretraining
|
||||
if (
|
||||
hasattr(state["args"], "task")
|
||||
and state["args"].task == "speech_pretraining"
|
||||
):
|
||||
state["args"].task = "audio_pretraining"
|
||||
# audio_cpc => wav2vec
|
||||
if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc":
|
||||
state["args"].arch = "wav2vec"
|
||||
# convert legacy float learning rate to List[float]
|
||||
if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float):
|
||||
state["args"].lr = [state["args"].lr]
|
||||
# convert task data arg to a string instead of List[string]
|
||||
if (
|
||||
hasattr(state["args"], "data")
|
||||
and isinstance(state["args"].data, list)
|
||||
and len(state["args"].data) > 0
|
||||
):
|
||||
state["args"].data = state["args"].data[0]
|
||||
|
||||
state["cfg"] = convert_namespace_to_omegaconf(state["args"])
|
||||
|
||||
if "cfg" in state and state["cfg"] is not None:
|
||||
cfg = state["cfg"]
|
||||
with open_dict(cfg):
|
||||
# any upgrades for Hydra-based configs
|
||||
if (
|
||||
"task" in cfg
|
||||
and "eval_wer_config" in cfg.task
|
||||
and isinstance(cfg.task.eval_wer_config.print_alignment, bool)
|
||||
):
|
||||
cfg.task.eval_wer_config.print_alignment = "hard"
|
||||
if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
|
||||
cfg.generation.print_alignment = (
|
||||
"hard" if cfg.generation.print_alignment else None
|
||||
)
|
||||
if (
|
||||
"model" in cfg
|
||||
and "w2v_args" in cfg.model
|
||||
and cfg.model.w2v_args is not None
|
||||
and (
|
||||
hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args
|
||||
)
|
||||
and hasattr(cfg.model.w2v_args.task, "eval_wer_config")
|
||||
and cfg.model.w2v_args.task.eval_wer_config is not None
|
||||
and isinstance(
|
||||
cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool
|
||||
)
|
||||
):
|
||||
cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard"
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
|
||||
"""Prune the given state_dict if desired for LayerDrop
|
||||
(https://arxiv.org/abs/1909.11556).
|
||||
|
||||
Training with LayerDrop allows models to be robust to pruning at inference
|
||||
time. This function prunes state_dict to allow smaller models to be loaded
|
||||
from a larger model and re-maps the existing state_dict for this to occur.
|
||||
|
||||
It's called by functions that load models from checkpoints and does not
|
||||
need to be called directly.
|
||||
"""
|
||||
arch = None
|
||||
if model_cfg is not None:
|
||||
arch = (
|
||||
model_cfg._name
|
||||
if isinstance(model_cfg, DictConfig)
|
||||
else getattr(model_cfg, "arch", None)
|
||||
)
|
||||
|
||||
if not model_cfg or arch is None or arch == "ptt_transformer":
|
||||
# args should not be none, but don't crash if it is.
|
||||
return state_dict
|
||||
|
||||
encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
|
||||
decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
|
||||
|
||||
if not encoder_layers_to_keep and not decoder_layers_to_keep:
|
||||
return state_dict
|
||||
|
||||
# apply pruning
|
||||
logger.info(
|
||||
"Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
|
||||
)
|
||||
|
||||
def create_pruning_pass(layers_to_keep, layer_name):
|
||||
keep_layers = sorted(
|
||||
int(layer_string) for layer_string in layers_to_keep.split(",")
|
||||
)
|
||||
mapping_dict = {}
|
||||
for i in range(len(keep_layers)):
|
||||
mapping_dict[str(keep_layers[i])] = str(i)
|
||||
|
||||
regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
|
||||
return {"substitution_regex": regex, "mapping_dict": mapping_dict}
|
||||
|
||||
pruning_passes = []
|
||||
if encoder_layers_to_keep:
|
||||
pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
|
||||
if decoder_layers_to_keep:
|
||||
pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
|
||||
|
||||
new_state_dict = {}
|
||||
for layer_name in state_dict.keys():
|
||||
match = re.search(r"\.layers\.(\d+)\.", layer_name)
|
||||
# if layer has no number in it, it is a supporting layer, such as an
|
||||
# embedding
|
||||
if not match:
|
||||
new_state_dict[layer_name] = state_dict[layer_name]
|
||||
continue
|
||||
|
||||
# otherwise, layer should be pruned.
|
||||
original_layer_number = match.group(1)
|
||||
# figure out which mapping dict to replace from
|
||||
for pruning_pass in pruning_passes:
|
||||
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
|
||||
"substitution_regex"
|
||||
].search(layer_name):
|
||||
new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
|
||||
substitution_match = pruning_pass["substitution_regex"].search(
|
||||
layer_name
|
||||
)
|
||||
new_state_key = (
|
||||
layer_name[: substitution_match.start(1)]
|
||||
+ new_layer_number
|
||||
+ layer_name[substitution_match.end(1) :]
|
||||
)
|
||||
new_state_dict[new_state_key] = state_dict[layer_name]
|
||||
|
||||
# Since layers are now pruned, *_layers_to_keep are no longer needed.
|
||||
# This is more of "It would make it work fix" rather than a proper fix.
|
||||
if isinstance(model_cfg, DictConfig):
|
||||
context = open_dict(model_cfg)
|
||||
else:
|
||||
context = contextlib.ExitStack()
|
||||
with context:
|
||||
if hasattr(model_cfg, "encoder_layers_to_keep"):
|
||||
model_cfg.encoder_layers_to_keep = None
|
||||
if hasattr(model_cfg, "decoder_layers_to_keep"):
|
||||
model_cfg.decoder_layers_to_keep = None
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def load_pretrained_component_from_model(
|
||||
component: Union[FairseqEncoder, FairseqDecoder],
|
||||
checkpoint: str,
|
||||
strict: bool = True,
|
||||
):
|
||||
"""
|
||||
Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
|
||||
provided `component` object. If state_dict fails to load, there may be a
|
||||
mismatch in the architecture of the corresponding `component` found in the
|
||||
`checkpoint` file.
|
||||
"""
|
||||
if not PathManager.exists(checkpoint):
|
||||
raise IOError("Model file not found: {}".format(checkpoint))
|
||||
state = load_checkpoint_to_cpu(checkpoint)
|
||||
if isinstance(component, FairseqEncoder):
|
||||
component_type = "encoder"
|
||||
elif isinstance(component, FairseqDecoder):
|
||||
component_type = "decoder"
|
||||
else:
|
||||
raise ValueError(
|
||||
"component to load must be either a FairseqEncoder or "
|
||||
"FairseqDecoder. Loading other component types are not supported."
|
||||
)
|
||||
component_state_dict = OrderedDict()
|
||||
for key in state["model"].keys():
|
||||
if key.startswith(component_type):
|
||||
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight
|
||||
component_subkey = key[len(component_type) + 1 :]
|
||||
component_state_dict[component_subkey] = state["model"][key]
|
||||
component.load_state_dict(component_state_dict, strict=strict)
|
||||
return component
|
||||
|
||||
|
||||
def verify_checkpoint_directory(save_dir: str) -> None:
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
temp_file_path = os.path.join(save_dir, "dummy")
|
||||
try:
|
||||
with open(temp_file_path, "w"):
|
||||
pass
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
"Unable to access checkpoint save directory: {}".format(save_dir)
|
||||
)
|
||||
raise e
|
||||
else:
|
||||
os.remove(temp_file_path)
|
||||
|
||||
|
||||
def save_ema_as_checkpoint(src_path, dst_path):
|
||||
state = load_ema_from_checkpoint(src_path)
|
||||
torch_persistent_save(state, dst_path)
|
||||
|
||||
|
||||
def load_ema_from_checkpoint(fpath):
|
||||
"""Loads exponential moving averaged (EMA) checkpoint from input and
|
||||
returns a model with ema weights.
|
||||
|
||||
Args:
|
||||
fpath: A string path of checkpoint to load from.
|
||||
|
||||
Returns:
|
||||
A dict of string keys mapping to various values. The 'model' key
|
||||
from the returned dict should correspond to an OrderedDict mapping
|
||||
string parameter names to torch Tensors.
|
||||
"""
|
||||
params_dict = collections.OrderedDict()
|
||||
new_state = None
|
||||
|
||||
with PathManager.open(fpath, "rb") as f:
|
||||
new_state = torch.load(
|
||||
f,
|
||||
map_location=(
|
||||
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
|
||||
),
|
||||
)
|
||||
|
||||
# EMA model is stored in a separate "extra state"
|
||||
model_params = new_state["extra_state"]["ema"]
|
||||
|
||||
for key in list(model_params.keys()):
|
||||
p = model_params[key]
|
||||
if isinstance(p, torch.HalfTensor):
|
||||
p = p.float()
|
||||
if key not in params_dict:
|
||||
params_dict[key] = p.clone()
|
||||
# NOTE: clone() is needed in case of p is a shared parameter
|
||||
else:
|
||||
raise ValueError("Key {} is repeated in EMA model params.".format(key))
|
||||
|
||||
if len(params_dict) == 0:
|
||||
raise ValueError(
|
||||
f"Input checkpoint path '{fpath}' does not contain "
|
||||
"ema model weights, is this model trained with EMA?"
|
||||
)
|
||||
|
||||
new_state["model"] = params_dict
|
||||
return new_state
|
||||
130
modules/voice_conversion/fairseq/data/__init__.py
Normal file
130
modules/voice_conversion/fairseq/data/__init__.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""isort:skip_file"""
|
||||
|
||||
from .dictionary import Dictionary, TruncatedDictionary
|
||||
|
||||
from .fairseq_dataset import FairseqDataset, FairseqIterableDataset
|
||||
|
||||
from .base_wrapper_dataset import BaseWrapperDataset
|
||||
|
||||
from .add_target_dataset import AddTargetDataset
|
||||
from .append_token_dataset import AppendTokenDataset
|
||||
from .audio.raw_audio_dataset import BinarizedAudioDataset, FileAudioDataset
|
||||
from .audio.hubert_dataset import HubertDataset
|
||||
from .backtranslation_dataset import BacktranslationDataset
|
||||
from .bucket_pad_length_dataset import BucketPadLengthDataset
|
||||
from .colorize_dataset import ColorizeDataset
|
||||
from .concat_dataset import ConcatDataset
|
||||
from .concat_sentences_dataset import ConcatSentencesDataset
|
||||
from .denoising_dataset import DenoisingDataset
|
||||
from .id_dataset import IdDataset
|
||||
from .indexed_dataset import (
|
||||
IndexedCachedDataset,
|
||||
IndexedDataset,
|
||||
IndexedRawTextDataset,
|
||||
MMapIndexedDataset,
|
||||
)
|
||||
from .language_pair_dataset import LanguagePairDataset
|
||||
from .list_dataset import ListDataset
|
||||
from .lm_context_window_dataset import LMContextWindowDataset
|
||||
from .lru_cache_dataset import LRUCacheDataset
|
||||
from .mask_tokens_dataset import MaskTokensDataset
|
||||
from .monolingual_dataset import MonolingualDataset
|
||||
from .multi_corpus_sampled_dataset import MultiCorpusSampledDataset
|
||||
from .nested_dictionary_dataset import NestedDictionaryDataset
|
||||
from .noising import NoisingDataset
|
||||
from .numel_dataset import NumelDataset
|
||||
from .num_samples_dataset import NumSamplesDataset
|
||||
from .offset_tokens_dataset import OffsetTokensDataset
|
||||
from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset
|
||||
from .prepend_dataset import PrependDataset
|
||||
from .prepend_token_dataset import PrependTokenDataset
|
||||
from .raw_label_dataset import RawLabelDataset
|
||||
from .replace_dataset import ReplaceDataset
|
||||
from .resampling_dataset import ResamplingDataset
|
||||
from .roll_dataset import RollDataset
|
||||
from .round_robin_zip_datasets import RoundRobinZipDatasets
|
||||
from .sort_dataset import SortDataset
|
||||
from .strip_token_dataset import StripTokenDataset
|
||||
from .subsample_dataset import SubsampleDataset
|
||||
from .token_block_dataset import TokenBlockDataset
|
||||
from .transform_eos_dataset import TransformEosDataset
|
||||
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
|
||||
from .shorten_dataset import TruncateDataset, RandomCropDataset
|
||||
from .multilingual.sampled_multi_dataset import SampledMultiDataset
|
||||
from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset
|
||||
from .fasta_dataset import FastaDataset, EncodedFastaDataset
|
||||
from .transform_eos_concat_langpair_dataset import TransformEosConcatLangPairDataset
|
||||
|
||||
from .iterators import (
|
||||
CountingIterator,
|
||||
EpochBatchIterator,
|
||||
GroupedIterator,
|
||||
ShardedIterator,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AddTargetDataset",
|
||||
"AppendTokenDataset",
|
||||
"BacktranslationDataset",
|
||||
"BaseWrapperDataset",
|
||||
"BinarizedAudioDataset",
|
||||
"BucketPadLengthDataset",
|
||||
"ColorizeDataset",
|
||||
"ConcatDataset",
|
||||
"ConcatSentencesDataset",
|
||||
"CountingIterator",
|
||||
"DenoisingDataset",
|
||||
"Dictionary",
|
||||
"EncodedFastaDataset",
|
||||
"EpochBatchIterator",
|
||||
"FairseqDataset",
|
||||
"FairseqIterableDataset",
|
||||
"FastaDataset",
|
||||
"FileAudioDataset",
|
||||
"GroupedIterator",
|
||||
"HubertDataset",
|
||||
"IdDataset",
|
||||
"IndexedCachedDataset",
|
||||
"IndexedDataset",
|
||||
"IndexedRawTextDataset",
|
||||
"LanguagePairDataset",
|
||||
"LeftPadDataset",
|
||||
"ListDataset",
|
||||
"LMContextWindowDataset",
|
||||
"LRUCacheDataset",
|
||||
"MaskTokensDataset",
|
||||
"MMapIndexedDataset",
|
||||
"MonolingualDataset",
|
||||
"MultiCorpusSampledDataset",
|
||||
"NestedDictionaryDataset",
|
||||
"NoisingDataset",
|
||||
"NumelDataset",
|
||||
"NumSamplesDataset",
|
||||
"OffsetTokensDataset",
|
||||
"PadDataset",
|
||||
"PrependDataset",
|
||||
"PrependTokenDataset",
|
||||
"RandomCropDataset",
|
||||
"RawLabelDataset",
|
||||
"ResamplingDataset",
|
||||
"ReplaceDataset",
|
||||
"RightPadDataset",
|
||||
"RollDataset",
|
||||
"RoundRobinZipDatasets",
|
||||
"SampledMultiDataset",
|
||||
"SampledMultiEpochDataset",
|
||||
"ShardedIterator",
|
||||
"SortDataset",
|
||||
"StripTokenDataset",
|
||||
"SubsampleDataset",
|
||||
"TokenBlockDataset",
|
||||
"TransformEosDataset",
|
||||
"TransformEosLangPairDataset",
|
||||
"TransformEosConcatLangPairDataset",
|
||||
"TruncateDataset",
|
||||
"TruncatedDictionary",
|
||||
]
|
||||
83
modules/voice_conversion/fairseq/data/add_target_dataset.py
Normal file
83
modules/voice_conversion/fairseq/data/add_target_dataset.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from . import BaseWrapperDataset, data_utils
|
||||
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
|
||||
|
||||
|
||||
class AddTargetDataset(BaseWrapperDataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
labels,
|
||||
pad,
|
||||
eos,
|
||||
batch_targets,
|
||||
process_label=None,
|
||||
label_len_fn=None,
|
||||
add_to_input=False,
|
||||
text_compression_level=TextCompressionLevel.none,
|
||||
):
|
||||
super().__init__(dataset)
|
||||
self.labels = labels
|
||||
self.batch_targets = batch_targets
|
||||
self.pad = pad
|
||||
self.eos = eos
|
||||
self.process_label = process_label
|
||||
self.label_len_fn = label_len_fn
|
||||
self.add_to_input = add_to_input
|
||||
self.text_compressor = TextCompressor(level=text_compression_level)
|
||||
|
||||
def get_label(self, index, process_fn=None):
|
||||
lbl = self.labels[index]
|
||||
lbl = self.text_compressor.decompress(lbl)
|
||||
return lbl if process_fn is None else process_fn(lbl)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.dataset[index]
|
||||
item["label"] = self.get_label(index, process_fn=self.process_label)
|
||||
return item
|
||||
|
||||
def size(self, index):
|
||||
sz = self.dataset.size(index)
|
||||
own_sz = self.label_len_fn(self.get_label(index))
|
||||
return sz, own_sz
|
||||
|
||||
def collater(self, samples):
|
||||
collated = self.dataset.collater(samples)
|
||||
if len(collated) == 0:
|
||||
return collated
|
||||
indices = set(collated["id"].tolist())
|
||||
target = [s["label"] for s in samples if s["id"] in indices]
|
||||
|
||||
if self.add_to_input:
|
||||
eos = torch.LongTensor([self.eos])
|
||||
prev_output_tokens = [torch.cat([eos, t], axis=-1) for t in target]
|
||||
target = [torch.cat([t, eos], axis=-1) for t in target]
|
||||
collated["net_input"]["prev_output_tokens"] = prev_output_tokens
|
||||
|
||||
if self.batch_targets:
|
||||
collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
|
||||
target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
|
||||
collated["ntokens"] = collated["target_lengths"].sum().item()
|
||||
if getattr(collated["net_input"], "prev_output_tokens", None):
|
||||
collated["net_input"]["prev_output_tokens"] = data_utils.collate_tokens(
|
||||
collated["net_input"]["prev_output_tokens"],
|
||||
pad_idx=self.pad,
|
||||
left_pad=False,
|
||||
)
|
||||
else:
|
||||
collated["ntokens"] = sum([len(t) for t in target])
|
||||
|
||||
collated["target"] = target
|
||||
return collated
|
||||
|
||||
def filter_indices_by_size(self, indices, max_sizes):
|
||||
indices, ignored = data_utils._filter_by_size_dynamic(
|
||||
indices, self.size, max_sizes
|
||||
)
|
||||
return indices, ignored
|
||||
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import BaseWrapperDataset
|
||||
|
||||
|
||||
class AppendTokenDataset(BaseWrapperDataset):
|
||||
def __init__(self, dataset, token=None):
|
||||
super().__init__(dataset)
|
||||
self.token = token
|
||||
if token is not None:
|
||||
self._sizes = np.array(dataset.sizes) + 1
|
||||
else:
|
||||
self._sizes = dataset.sizes
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.dataset[idx]
|
||||
if self.token is not None:
|
||||
item = torch.cat([item, item.new([self.token])])
|
||||
return item
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
def num_tokens(self, index):
|
||||
n = self.dataset.num_tokens(index)
|
||||
if self.token is not None:
|
||||
n += 1
|
||||
return n
|
||||
|
||||
def size(self, index):
|
||||
n = self.dataset.size(index)
|
||||
if self.token is not None:
|
||||
n += 1
|
||||
return n
|
||||
93
modules/voice_conversion/fairseq/data/audio/__init__.py
Normal file
93
modules/voice_conversion/fairseq/data/audio/__init__.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional
|
||||
import importlib
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AudioTransform(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config_dict(cls, config: Optional[Dict] = None):
|
||||
pass
|
||||
|
||||
|
||||
class CompositeAudioTransform(AudioTransform):
|
||||
def _from_config_dict(
|
||||
cls,
|
||||
transform_type,
|
||||
get_audio_transform,
|
||||
composite_cls,
|
||||
config=None,
|
||||
return_empty=False,
|
||||
):
|
||||
_config = {} if config is None else config
|
||||
_transforms = _config.get(f"{transform_type}_transforms")
|
||||
|
||||
if _transforms is None:
|
||||
if return_empty:
|
||||
_transforms = []
|
||||
else:
|
||||
return None
|
||||
|
||||
transforms = [
|
||||
get_audio_transform(_t).from_config_dict(_config.get(_t))
|
||||
for _t in _transforms
|
||||
]
|
||||
return composite_cls(transforms)
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = [t for t in transforms if t is not None]
|
||||
|
||||
def __call__(self, x):
|
||||
for t in self.transforms:
|
||||
x = t(x)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
format_string = (
|
||||
[self.__class__.__name__ + "("]
|
||||
+ [f" {t.__repr__()}" for t in self.transforms]
|
||||
+ [")"]
|
||||
)
|
||||
return "\n".join(format_string)
|
||||
|
||||
|
||||
def register_audio_transform(name, cls_type, registry, class_names):
|
||||
def register_audio_transform_cls(cls):
|
||||
if name in registry:
|
||||
raise ValueError(f"Cannot register duplicate transform ({name})")
|
||||
if not issubclass(cls, cls_type):
|
||||
raise ValueError(
|
||||
f"Transform ({name}: {cls.__name__}) must extend "
|
||||
f"{cls_type.__name__}"
|
||||
)
|
||||
if cls.__name__ in class_names:
|
||||
raise ValueError(
|
||||
f"Cannot register audio transform with duplicate "
|
||||
f"class name ({cls.__name__})"
|
||||
)
|
||||
registry[name] = cls
|
||||
class_names.add(cls.__name__)
|
||||
return cls
|
||||
|
||||
return register_audio_transform_cls
|
||||
|
||||
|
||||
def import_transforms(transforms_dir, transform_type):
|
||||
for file in os.listdir(transforms_dir):
|
||||
path = os.path.join(transforms_dir, file)
|
||||
if (
|
||||
not file.startswith("_")
|
||||
and not file.startswith(".")
|
||||
and (file.endswith(".py") or os.path.isdir(path))
|
||||
):
|
||||
name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||
importlib.import_module(
|
||||
f"fairseq.data.audio.{transform_type}_transforms." + name
|
||||
)
|
||||
|
||||
|
||||
# Utility fn for uniform numbers in transforms
|
||||
def rand_uniform(a, b):
|
||||
return np.random.uniform() * (b - a) + a
|
||||
389
modules/voice_conversion/fairseq/data/audio/audio_utils.py
Normal file
389
modules/voice_conversion/fairseq/data/audio/audio_utils.py
Normal file
@@ -0,0 +1,389 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import mmap
|
||||
from pathlib import Path
|
||||
import io
|
||||
from typing import BinaryIO, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
|
||||
|
||||
SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"}
|
||||
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
|
||||
|
||||
|
||||
def convert_waveform(
|
||||
waveform: Union[np.ndarray, torch.Tensor],
|
||||
sample_rate: int,
|
||||
normalize_volume: bool = False,
|
||||
to_mono: bool = False,
|
||||
to_sample_rate: Optional[int] = None,
|
||||
) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
|
||||
"""convert a waveform:
|
||||
- to a target sample rate
|
||||
- from multi-channel to mono channel
|
||||
- volume normalization
|
||||
|
||||
Args:
|
||||
waveform (numpy.ndarray or torch.Tensor): 2D original waveform
|
||||
(channels x length)
|
||||
sample_rate (int): original sample rate
|
||||
normalize_volume (bool): perform volume normalization
|
||||
to_mono (bool): convert to mono channel if having multiple channels
|
||||
to_sample_rate (Optional[int]): target sample rate
|
||||
Returns:
|
||||
waveform (numpy.ndarray): converted 2D waveform (channels x length)
|
||||
sample_rate (float): target sample rate
|
||||
"""
|
||||
try:
|
||||
import torchaudio.sox_effects as ta_sox
|
||||
except ImportError:
|
||||
raise ImportError("Please install torchaudio: pip install torchaudio")
|
||||
|
||||
effects = []
|
||||
if normalize_volume:
|
||||
effects.append(["gain", "-n"])
|
||||
if to_sample_rate is not None and to_sample_rate != sample_rate:
|
||||
effects.append(["rate", f"{to_sample_rate}"])
|
||||
if to_mono and waveform.shape[0] > 1:
|
||||
effects.append(["channels", "1"])
|
||||
if len(effects) > 0:
|
||||
is_np_input = isinstance(waveform, np.ndarray)
|
||||
_waveform = torch.from_numpy(waveform) if is_np_input else waveform
|
||||
converted, converted_sample_rate = ta_sox.apply_effects_tensor(
|
||||
_waveform, sample_rate, effects
|
||||
)
|
||||
if is_np_input:
|
||||
converted = converted.numpy()
|
||||
return converted, converted_sample_rate
|
||||
return waveform, sample_rate
|
||||
|
||||
|
||||
def get_waveform(
|
||||
path_or_fp: Union[str, BinaryIO],
|
||||
normalization: bool = True,
|
||||
mono: bool = True,
|
||||
frames: int = -1,
|
||||
start: int = 0,
|
||||
always_2d: bool = True,
|
||||
output_sample_rate: Optional[int] = None,
|
||||
normalize_volume: bool = False,
|
||||
waveform_transforms: Optional[CompositeAudioWaveformTransform] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio.
|
||||
|
||||
Args:
|
||||
path_or_fp (str or BinaryIO): the path or file-like object
|
||||
normalization (bool): normalize values to [-1, 1] (Default: True)
|
||||
mono (bool): convert multi-channel audio to mono-channel one
|
||||
frames (int): the number of frames to read. (-1 for reading all)
|
||||
start (int): Where to start reading. A negative value counts from the end.
|
||||
always_2d (bool): always return 2D array even for mono-channel audios
|
||||
output_sample_rate (Optional[int]): output sample rate
|
||||
normalize_volume (bool): normalize volume
|
||||
Returns:
|
||||
waveform (numpy.ndarray): 1D or 2D waveform (channels x length)
|
||||
sample_rate (float): sample rate
|
||||
"""
|
||||
if isinstance(path_or_fp, str):
|
||||
ext = Path(path_or_fp).suffix
|
||||
if ext not in SF_AUDIO_FILE_EXTENSIONS:
|
||||
raise ValueError(f"Unsupported audio format: {ext}")
|
||||
|
||||
try:
|
||||
import soundfile as sf
|
||||
except ImportError:
|
||||
raise ImportError("Please install soundfile: pip install soundfile")
|
||||
|
||||
waveform, sample_rate = sf.read(
|
||||
path_or_fp, dtype="float32", always_2d=True, frames=frames, start=start
|
||||
)
|
||||
waveform = waveform.T # T x C -> C x T
|
||||
waveform, sample_rate = convert_waveform(
|
||||
waveform,
|
||||
sample_rate,
|
||||
normalize_volume=normalize_volume,
|
||||
to_mono=mono,
|
||||
to_sample_rate=output_sample_rate,
|
||||
)
|
||||
|
||||
if not normalization:
|
||||
waveform *= 2**15 # denormalized to 16-bit signed integers
|
||||
|
||||
if waveform_transforms is not None:
|
||||
waveform, sample_rate = waveform_transforms(waveform, sample_rate)
|
||||
|
||||
if not always_2d:
|
||||
waveform = waveform.squeeze(axis=0)
|
||||
|
||||
return waveform, sample_rate
|
||||
|
||||
|
||||
def get_features_from_npy_or_audio(path, waveform_transforms=None):
|
||||
ext = Path(path).suffix
|
||||
if ext not in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
|
||||
raise ValueError(f'Unsupported file format for "{path}"')
|
||||
return (
|
||||
np.load(path)
|
||||
if ext == ".npy"
|
||||
else get_fbank(path, waveform_transforms=waveform_transforms)
|
||||
)
|
||||
|
||||
|
||||
def get_features_or_waveform_from_stored_zip(
|
||||
path,
|
||||
byte_offset,
|
||||
byte_size,
|
||||
need_waveform=False,
|
||||
use_sample_rate=None,
|
||||
waveform_transforms=None,
|
||||
):
|
||||
assert path.endswith(".zip")
|
||||
data = read_from_stored_zip(path, byte_offset, byte_size)
|
||||
f = io.BytesIO(data)
|
||||
if is_npy_data(data):
|
||||
features_or_waveform = np.load(f)
|
||||
elif is_sf_audio_data(data):
|
||||
features_or_waveform = (
|
||||
get_waveform(
|
||||
f,
|
||||
always_2d=False,
|
||||
output_sample_rate=use_sample_rate,
|
||||
waveform_transforms=waveform_transforms,
|
||||
)[0]
|
||||
if need_waveform
|
||||
else get_fbank(f, waveform_transforms=waveform_transforms)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown file format for "{path}"')
|
||||
return features_or_waveform
|
||||
|
||||
|
||||
def get_features_or_waveform(
|
||||
path: str, need_waveform=False, use_sample_rate=None, waveform_transforms=None
|
||||
):
|
||||
"""Get speech features from .npy file or waveform from .wav/.flac file.
|
||||
The file may be inside an uncompressed ZIP file and is accessed via byte
|
||||
offset and length.
|
||||
|
||||
Args:
|
||||
path (str): File path in the format of "<.npy/.wav/.flac path>" or
|
||||
"<zip path>:<byte offset>:<byte length>".
|
||||
need_waveform (bool): return waveform instead of features.
|
||||
use_sample_rate (int): change sample rate for the input wave file
|
||||
|
||||
Returns:
|
||||
features_or_waveform (numpy.ndarray): speech features or waveform.
|
||||
"""
|
||||
_path, slice_ptr = parse_path(path)
|
||||
if len(slice_ptr) == 0:
|
||||
if need_waveform:
|
||||
return get_waveform(
|
||||
_path,
|
||||
always_2d=False,
|
||||
output_sample_rate=use_sample_rate,
|
||||
waveform_transforms=waveform_transforms,
|
||||
)[0]
|
||||
return get_features_from_npy_or_audio(
|
||||
_path, waveform_transforms=waveform_transforms
|
||||
)
|
||||
elif len(slice_ptr) == 2:
|
||||
features_or_waveform = get_features_or_waveform_from_stored_zip(
|
||||
_path,
|
||||
slice_ptr[0],
|
||||
slice_ptr[1],
|
||||
need_waveform=need_waveform,
|
||||
use_sample_rate=use_sample_rate,
|
||||
waveform_transforms=waveform_transforms,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid path: {path}")
|
||||
|
||||
return features_or_waveform
|
||||
|
||||
|
||||
def _get_kaldi_fbank(
|
||||
waveform: np.ndarray, sample_rate: int, n_bins=80
|
||||
) -> Optional[np.ndarray]:
|
||||
"""Get mel-filter bank features via PyKaldi."""
|
||||
try:
|
||||
from kaldi.feat.fbank import Fbank, FbankOptions
|
||||
from kaldi.feat.mel import MelBanksOptions
|
||||
from kaldi.feat.window import FrameExtractionOptions
|
||||
from kaldi.matrix import Vector
|
||||
|
||||
mel_opts = MelBanksOptions()
|
||||
mel_opts.num_bins = n_bins
|
||||
frame_opts = FrameExtractionOptions()
|
||||
frame_opts.samp_freq = sample_rate
|
||||
opts = FbankOptions()
|
||||
opts.mel_opts = mel_opts
|
||||
opts.frame_opts = frame_opts
|
||||
fbank = Fbank(opts=opts)
|
||||
features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy()
|
||||
return features
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def _get_torchaudio_fbank(
|
||||
waveform: np.ndarray, sample_rate, n_bins=80
|
||||
) -> Optional[np.ndarray]:
|
||||
"""Get mel-filter bank features via TorchAudio."""
|
||||
try:
|
||||
import torchaudio.compliance.kaldi as ta_kaldi
|
||||
|
||||
waveform = torch.from_numpy(waveform)
|
||||
features = ta_kaldi.fbank(
|
||||
waveform, num_mel_bins=n_bins, sample_frequency=sample_rate
|
||||
)
|
||||
return features.numpy()
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def get_fbank(
|
||||
path_or_fp: Union[str, BinaryIO], n_bins=80, waveform_transforms=None
|
||||
) -> np.ndarray:
|
||||
"""Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
|
||||
(faster CPP implementation) to TorchAudio (Python implementation). Note that
|
||||
Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
|
||||
waveform should not be normalized."""
|
||||
waveform, sample_rate = get_waveform(
|
||||
path_or_fp, normalization=False, waveform_transforms=waveform_transforms
|
||||
)
|
||||
|
||||
features = _get_kaldi_fbank(waveform, sample_rate, n_bins)
|
||||
if features is None:
|
||||
features = _get_torchaudio_fbank(waveform, sample_rate, n_bins)
|
||||
if features is None:
|
||||
raise ImportError(
|
||||
"Please install pyKaldi or torchaudio to enable "
|
||||
"online filterbank feature extraction"
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def is_npy_data(data: bytes) -> bool:
|
||||
return data[0] == 147 and data[1] == 78
|
||||
|
||||
|
||||
def is_sf_audio_data(data: bytes) -> bool:
|
||||
is_wav = data[0] == 82 and data[1] == 73 and data[2] == 70
|
||||
is_flac = data[0] == 102 and data[1] == 76 and data[2] == 97
|
||||
is_ogg = data[0] == 79 and data[1] == 103 and data[2] == 103
|
||||
return is_wav or is_flac or is_ogg
|
||||
|
||||
|
||||
def mmap_read(path: str, offset: int, length: int) -> bytes:
|
||||
with open(path, "rb") as f:
|
||||
with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_o:
|
||||
data = mmap_o[offset : offset + length]
|
||||
return data
|
||||
|
||||
|
||||
def read_from_stored_zip(zip_path: str, offset: int, length: int) -> bytes:
|
||||
return mmap_read(zip_path, offset, length)
|
||||
|
||||
|
||||
def parse_path(path: str) -> Tuple[str, List[int]]:
|
||||
"""Parse data path which is either a path to
|
||||
1. a .npy/.wav/.flac/.ogg file
|
||||
2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
|
||||
|
||||
Args:
|
||||
path (str): the data path to parse
|
||||
|
||||
Returns:
|
||||
file_path (str): the file path
|
||||
slice_ptr (list of int): empty in case 1;
|
||||
byte offset and length for the slice in case 2
|
||||
"""
|
||||
|
||||
if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
|
||||
_path, slice_ptr = path, []
|
||||
else:
|
||||
_path, *slice_ptr = path.split(":")
|
||||
if not Path(_path).is_file():
|
||||
raise FileNotFoundError(f"File not found: {_path}")
|
||||
assert len(slice_ptr) in {0, 2}, f"Invalid path: {path}"
|
||||
slice_ptr = [int(i) for i in slice_ptr]
|
||||
return _path, slice_ptr
|
||||
|
||||
|
||||
def get_window(window_fn: callable, n_fft: int, win_length: int) -> torch.Tensor:
|
||||
padding = n_fft - win_length
|
||||
assert padding >= 0
|
||||
return F.pad(window_fn(win_length), (padding // 2, padding - padding // 2))
|
||||
|
||||
|
||||
def get_fourier_basis(n_fft: int) -> torch.Tensor:
|
||||
basis = np.fft.fft(np.eye(n_fft))
|
||||
basis = np.vstack(
|
||||
[np.real(basis[: n_fft // 2 + 1, :]), np.imag(basis[: n_fft // 2 + 1, :])]
|
||||
)
|
||||
return torch.from_numpy(basis).float()
|
||||
|
||||
|
||||
def get_mel_filters(
|
||||
sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float
|
||||
) -> torch.Tensor:
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
raise ImportError("Please install librosa: pip install librosa")
|
||||
basis = librosa.filters.mel(sample_rate, n_fft, n_mels, f_min, f_max)
|
||||
return torch.from_numpy(basis).float()
|
||||
|
||||
|
||||
class TTSSpectrogram(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_fft: int,
|
||||
win_length: int,
|
||||
hop_length: int,
|
||||
window_fn: callable = torch.hann_window,
|
||||
return_phase: bool = False,
|
||||
) -> None:
|
||||
super(TTSSpectrogram, self).__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.return_phase = return_phase
|
||||
|
||||
basis = get_fourier_basis(n_fft).unsqueeze(1)
|
||||
basis *= get_window(window_fn, n_fft, win_length)
|
||||
self.register_buffer("basis", basis)
|
||||
|
||||
def forward(
|
||||
self, waveform: torch.Tensor
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
padding = (self.n_fft // 2, self.n_fft // 2)
|
||||
x = F.pad(waveform.unsqueeze(1), padding, mode="reflect")
|
||||
x = F.conv1d(x, self.basis, stride=self.hop_length)
|
||||
real_part = x[:, : self.n_fft // 2 + 1, :]
|
||||
imag_part = x[:, self.n_fft // 2 + 1 :, :]
|
||||
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
||||
if self.return_phase:
|
||||
phase = torch.atan2(imag_part, real_part)
|
||||
return magnitude, phase
|
||||
return magnitude
|
||||
|
||||
|
||||
class TTSMelScale(torch.nn.Module):
|
||||
def __init__(
|
||||
self, n_mels: int, sample_rate: int, f_min: float, f_max: float, n_stft: int
|
||||
) -> None:
|
||||
super(TTSMelScale, self).__init__()
|
||||
basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max)
|
||||
self.register_buffer("basis", basis)
|
||||
|
||||
def forward(self, specgram: torch.Tensor) -> torch.Tensor:
|
||||
return torch.matmul(self.basis, specgram)
|
||||
387
modules/voice_conversion/fairseq/data/audio/data_cfg.py
Normal file
387
modules/voice_conversion/fairseq/data/audio/data_cfg.py
Normal file
@@ -0,0 +1,387 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from argparse import Namespace
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fairseq.data import Dictionary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_config_from_yaml(yaml_path: Path):
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
print("Please install PyYAML: pip install PyYAML")
|
||||
config = {}
|
||||
if yaml_path.is_file():
|
||||
try:
|
||||
with open(yaml_path) as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to load config from {yaml_path.as_posix()}: {e}")
|
||||
else:
|
||||
raise FileNotFoundError(f"{yaml_path.as_posix()} not found")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class S2TDataConfig(object):
|
||||
"""Wrapper class for data config YAML"""
|
||||
|
||||
def __init__(self, yaml_path: Path):
|
||||
self.config = get_config_from_yaml(yaml_path)
|
||||
self.root = yaml_path.parent
|
||||
|
||||
def _auto_convert_to_abs_path(self, x):
|
||||
if isinstance(x, str):
|
||||
if not Path(x).exists() and (self.root / x).exists():
|
||||
return (self.root / x).as_posix()
|
||||
elif isinstance(x, dict):
|
||||
return {k: self._auto_convert_to_abs_path(v) for k, v in x.items()}
|
||||
return x
|
||||
|
||||
@property
|
||||
def vocab_filename(self):
|
||||
"""fairseq vocabulary file under data root"""
|
||||
return self.config.get("vocab_filename", "dict.txt")
|
||||
|
||||
@property
|
||||
def speaker_set_filename(self):
|
||||
"""speaker set file under data root"""
|
||||
return self.config.get("speaker_set_filename", None)
|
||||
|
||||
@property
|
||||
def shuffle(self) -> bool:
|
||||
"""Shuffle dataset samples before batching"""
|
||||
return self.config.get("shuffle", False)
|
||||
|
||||
@property
|
||||
def pre_tokenizer(self) -> Dict:
|
||||
"""Pre-tokenizer to apply before subword tokenization. Returning
|
||||
a dictionary with `tokenizer` providing the tokenizer name and
|
||||
the other items providing the tokenizer-specific arguments.
|
||||
Tokenizers are defined in `fairseq.data.encoders.*`"""
|
||||
tokenizer = self.config.get("pre_tokenizer", {"tokenizer": None})
|
||||
return self._auto_convert_to_abs_path(tokenizer)
|
||||
|
||||
@property
|
||||
def bpe_tokenizer(self) -> Dict:
|
||||
"""Subword tokenizer to apply after pre-tokenization. Returning
|
||||
a dictionary with `bpe` providing the tokenizer name and
|
||||
the other items providing the tokenizer-specific arguments.
|
||||
Tokenizers are defined in `fairseq.data.encoders.*`"""
|
||||
tokenizer = self.config.get("bpe_tokenizer", {"bpe": None})
|
||||
return self._auto_convert_to_abs_path(tokenizer)
|
||||
|
||||
@property
|
||||
def prepend_tgt_lang_tag(self) -> bool:
|
||||
"""Prepend target lang ID token as the target BOS (e.g. for to-many
|
||||
multilingual setting). During inference, this requires `--prefix-size 1`
|
||||
to force BOS to be lang ID token."""
|
||||
return self.config.get("prepend_tgt_lang_tag", False)
|
||||
|
||||
@property
|
||||
def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
|
||||
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
|
||||
return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
|
||||
|
||||
@property
|
||||
def input_feat_per_channel(self):
|
||||
"""The dimension of input features (per audio channel)"""
|
||||
return self.config.get("input_feat_per_channel", 80)
|
||||
|
||||
@property
|
||||
def input_channels(self):
|
||||
"""The number of channels in the input audio"""
|
||||
return self.config.get("input_channels", 1)
|
||||
|
||||
@property
|
||||
def sample_rate(self):
|
||||
return self.config.get("sample_rate", 16_000)
|
||||
|
||||
@property
|
||||
def sampling_alpha(self):
|
||||
"""Hyper-parameter alpha = 1/T for temperature-based resampling.
|
||||
(alpha = 1 for no resampling)"""
|
||||
return self.config.get("sampling_alpha", 1.0)
|
||||
|
||||
@property
|
||||
def use_audio_input(self):
|
||||
"""Needed by the dataset loader to see if the model requires
|
||||
raw audio as inputs."""
|
||||
return self.config.get("use_audio_input", False)
|
||||
|
||||
def standardize_audio(self) -> bool:
|
||||
return self.use_audio_input and self.config.get("standardize_audio", False)
|
||||
|
||||
@property
|
||||
def use_sample_rate(self):
|
||||
"""Needed by the dataset loader to see if the model requires
|
||||
raw audio with specific sample rate as inputs."""
|
||||
return self.config.get("use_sample_rate", 16000)
|
||||
|
||||
@property
|
||||
def audio_root(self):
|
||||
"""Audio paths in the manifest TSV can be relative and this provides
|
||||
the root path. Set this to empty string when using absolute paths."""
|
||||
return self.config.get("audio_root", "")
|
||||
|
||||
def get_transforms(self, transform_type, split, is_train):
|
||||
"""Split-specific feature transforms. Allowing train set
|
||||
wildcard `_train`, evaluation set wildcard `_eval` and general
|
||||
wildcard `*` for matching."""
|
||||
from copy import deepcopy
|
||||
|
||||
cfg = deepcopy(self.config)
|
||||
_cur = cfg.get(f"{transform_type}transforms", {})
|
||||
cur = _cur.get(split)
|
||||
cur = _cur.get("_train") if cur is None and is_train else cur
|
||||
cur = _cur.get("_eval") if cur is None and not is_train else cur
|
||||
cur = _cur.get("*") if cur is None else cur
|
||||
return cur
|
||||
|
||||
def get_feature_transforms(self, split, is_train):
|
||||
cfg = deepcopy(self.config)
|
||||
# TODO: deprecate transforms
|
||||
cur = self.get_transforms("", split, is_train)
|
||||
if cur is not None:
|
||||
logger.warning(
|
||||
"Auto converting transforms into feature_transforms, "
|
||||
"but transforms will be deprecated in the future. Please "
|
||||
"update this in the config."
|
||||
)
|
||||
ft_transforms = self.get_transforms("feature_", split, is_train)
|
||||
if ft_transforms:
|
||||
cur.extend(ft_transforms)
|
||||
else:
|
||||
cur = self.get_transforms("feature_", split, is_train)
|
||||
cfg["feature_transforms"] = cur
|
||||
return cfg
|
||||
|
||||
def get_waveform_transforms(self, split, is_train):
|
||||
cfg = deepcopy(self.config)
|
||||
cfg["waveform_transforms"] = self.get_transforms("waveform_", split, is_train)
|
||||
return cfg
|
||||
|
||||
def get_dataset_transforms(self, split, is_train):
|
||||
cfg = deepcopy(self.config)
|
||||
cfg["dataset_transforms"] = self.get_transforms("dataset_", split, is_train)
|
||||
return cfg
|
||||
|
||||
@property
|
||||
def global_cmvn_stats_npz(self) -> Optional[str]:
|
||||
path = self.config.get("global_cmvn", {}).get("stats_npz_path", None)
|
||||
return self._auto_convert_to_abs_path(path)
|
||||
|
||||
@property
|
||||
def vocoder(self) -> Dict[str, str]:
|
||||
vocoder = self.config.get("vocoder", {"type": "griffin_lim"})
|
||||
return self._auto_convert_to_abs_path(vocoder)
|
||||
|
||||
@property
|
||||
def hub(self) -> Dict[str, str]:
|
||||
return self.config.get("hub", {})
|
||||
|
||||
|
||||
class S2SDataConfig(S2TDataConfig):
|
||||
"""Wrapper class for data config YAML"""
|
||||
|
||||
@property
|
||||
def vocab_filename(self):
|
||||
"""fairseq vocabulary file under data root"""
|
||||
return self.config.get("vocab_filename", None)
|
||||
|
||||
@property
|
||||
def pre_tokenizer(self) -> Dict:
|
||||
return None
|
||||
|
||||
@property
|
||||
def bpe_tokenizer(self) -> Dict:
|
||||
return None
|
||||
|
||||
@property
|
||||
def input_transformed_channels(self):
|
||||
"""The number of channels in the audio after feature transforms"""
|
||||
# TODO: move this into individual transforms
|
||||
# TODO: deprecate transforms
|
||||
_cur = self.config.get("transforms", {})
|
||||
ft_transforms = self.config.get("feature_transforms", {})
|
||||
if _cur and ft_transforms:
|
||||
_cur.update(ft_transforms)
|
||||
else:
|
||||
_cur = self.config.get("feature_transforms", {})
|
||||
cur = _cur.get("_train", [])
|
||||
|
||||
_channels = self.input_channels
|
||||
if "delta_deltas" in cur:
|
||||
_channels *= 3
|
||||
|
||||
return _channels
|
||||
|
||||
@property
|
||||
def output_sample_rate(self):
|
||||
"""The audio sample rate of output target speech"""
|
||||
return self.config.get("output_sample_rate", 22050)
|
||||
|
||||
@property
|
||||
def target_speaker_embed(self):
|
||||
"""Target speaker embedding file (one line per target audio sample)"""
|
||||
return self.config.get("target_speaker_embed", None)
|
||||
|
||||
@property
|
||||
def prepend_tgt_lang_tag_as_bos(self) -> bool:
|
||||
"""Prepend target lang ID token as the target BOS."""
|
||||
return self.config.get("prepend_tgt_lang_tag_as_bos", False)
|
||||
|
||||
|
||||
class MultitaskConfig(object):
|
||||
"""Wrapper class for data config YAML"""
|
||||
|
||||
def __init__(self, yaml_path: Path):
|
||||
config = get_config_from_yaml(yaml_path)
|
||||
self.config = {}
|
||||
for k, v in config.items():
|
||||
self.config[k] = SingleTaskConfig(k, v)
|
||||
|
||||
def get_all_tasks(self):
|
||||
return self.config
|
||||
|
||||
def get_single_task(self, name):
|
||||
assert name in self.config, f"multitask '{name}' does not exist!"
|
||||
return self.config[name]
|
||||
|
||||
@property
|
||||
def first_pass_decoder_task_index(self):
|
||||
"""Return the task index of the first-pass text decoder.
|
||||
If there are multiple 'is_first_pass_decoder: True' in the config file,
|
||||
the last task is used for the first-pass decoder.
|
||||
If there is no 'is_first_pass_decoder: True' in the config file,
|
||||
the last task whose task_name includes 'target' and decoder_type is not ctc.
|
||||
"""
|
||||
idx = -1
|
||||
for i, (k, v) in enumerate(self.config.items()):
|
||||
if v.is_first_pass_decoder:
|
||||
idx = i
|
||||
if idx < 0:
|
||||
for i, (k, v) in enumerate(self.config.items()):
|
||||
if k.startswith("target") and v.decoder_type == "transformer":
|
||||
idx = i
|
||||
return idx
|
||||
|
||||
|
||||
class SingleTaskConfig(object):
|
||||
def __init__(self, name, config):
|
||||
self.task_name = name
|
||||
self.config = config
|
||||
dict_path = config.get("dict", "")
|
||||
self.tgt_dict = Dictionary.load(dict_path) if Path(dict_path).exists() else None
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.config.get("data", "")
|
||||
|
||||
@property
|
||||
def decoder_type(self):
|
||||
return self.config.get("decoder_type", "transformer")
|
||||
|
||||
@property
|
||||
def decoder_args(self):
|
||||
"""Decoder arch related args"""
|
||||
args = self.config.get("decoder_args", {})
|
||||
return Namespace(**args)
|
||||
|
||||
@property
|
||||
def criterion_cfg(self):
|
||||
"""cfg for the multitask criterion"""
|
||||
if self.decoder_type == "ctc":
|
||||
from fairseq.criterions.ctc import CtcCriterionConfig
|
||||
|
||||
cfg = CtcCriterionConfig
|
||||
cfg.zero_infinity = self.config.get("zero_infinity", True)
|
||||
else:
|
||||
from fairseq.criterions.label_smoothed_cross_entropy import (
|
||||
LabelSmoothedCrossEntropyCriterionConfig,
|
||||
)
|
||||
|
||||
cfg = LabelSmoothedCrossEntropyCriterionConfig
|
||||
cfg.label_smoothing = self.config.get("label_smoothing", 0.2)
|
||||
return cfg
|
||||
|
||||
@property
|
||||
def input_from(self):
|
||||
"""Condition on encoder/decoder of the main model"""
|
||||
return "decoder" if "decoder_layer" in self.config else "encoder"
|
||||
|
||||
@property
|
||||
def input_layer(self):
|
||||
if self.input_from == "decoder":
|
||||
return self.config["decoder_layer"] - 1
|
||||
else:
|
||||
# default using the output from the last encoder layer (-1)
|
||||
return self.config.get("encoder_layer", 0) - 1
|
||||
|
||||
@property
|
||||
def loss_weight_schedule(self):
|
||||
return (
|
||||
"decay"
|
||||
if "loss_weight_max" in self.config
|
||||
and "loss_weight_decay_steps" in self.config
|
||||
else "fixed"
|
||||
)
|
||||
|
||||
def get_loss_weight(self, num_updates):
|
||||
if self.loss_weight_schedule == "fixed":
|
||||
weight = self.config.get("loss_weight", 1.0)
|
||||
else: # "decay"
|
||||
assert (
|
||||
self.config.get("loss_weight_decay_steps", 0) > 0
|
||||
), "loss_weight_decay_steps must be greater than 0 for a decay schedule"
|
||||
loss_weight_min = self.config.get("loss_weight_min", 0.0001)
|
||||
loss_weight_decay_stepsize = (
|
||||
self.config["loss_weight_max"] - loss_weight_min
|
||||
) / self.config["loss_weight_decay_steps"]
|
||||
weight = max(
|
||||
self.config["loss_weight_max"]
|
||||
- loss_weight_decay_stepsize * num_updates,
|
||||
loss_weight_min,
|
||||
)
|
||||
return weight
|
||||
|
||||
@property
|
||||
def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
|
||||
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
|
||||
return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
"""EOS token during generation"""
|
||||
return self.config.get("eos_token", "<eos>")
|
||||
|
||||
@property
|
||||
def rdrop_alpha(self):
|
||||
return self.config.get("rdrop_alpha", 0.0)
|
||||
|
||||
@property
|
||||
def is_first_pass_decoder(self):
|
||||
flag = self.config.get("is_first_pass_decoder", False)
|
||||
if flag:
|
||||
if self.decoder_type == "ctc":
|
||||
raise ValueError(
|
||||
"First-pass decoder in the multi-decoder model must not be CTC."
|
||||
)
|
||||
if "target" not in self.task_name:
|
||||
raise Warning(
|
||||
'The name of the first-pass decoder does not include "target".'
|
||||
)
|
||||
return flag
|
||||
|
||||
@property
|
||||
def get_lang_tag_mapping(self):
|
||||
return self.config.get("lang_tag_mapping", {})
|
||||
@@ -0,0 +1,53 @@
|
||||
import os
|
||||
from fairseq.data.audio import (
|
||||
AudioTransform,
|
||||
CompositeAudioTransform,
|
||||
import_transforms,
|
||||
register_audio_transform,
|
||||
)
|
||||
|
||||
|
||||
class AudioDatasetTransform(AudioTransform):
|
||||
pass
|
||||
|
||||
|
||||
AUDIO_DATASET_TRANSFORM_REGISTRY = {}
|
||||
AUDIO_DATASET_TRANSFORM_CLASS_NAMES = set()
|
||||
|
||||
|
||||
def get_audio_dataset_transform(name):
|
||||
return AUDIO_DATASET_TRANSFORM_REGISTRY[name]
|
||||
|
||||
|
||||
def register_audio_dataset_transform(name):
|
||||
return register_audio_transform(
|
||||
name,
|
||||
AudioDatasetTransform,
|
||||
AUDIO_DATASET_TRANSFORM_REGISTRY,
|
||||
AUDIO_DATASET_TRANSFORM_CLASS_NAMES,
|
||||
)
|
||||
|
||||
|
||||
import_transforms(os.path.dirname(__file__), "dataset")
|
||||
|
||||
|
||||
class CompositeAudioDatasetTransform(CompositeAudioTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
return super()._from_config_dict(
|
||||
cls,
|
||||
"dataset",
|
||||
get_audio_dataset_transform,
|
||||
CompositeAudioDatasetTransform,
|
||||
config,
|
||||
return_empty=True,
|
||||
)
|
||||
|
||||
def get_transform(self, cls):
|
||||
for t in self.transforms:
|
||||
if isinstance(t, cls):
|
||||
return t
|
||||
return None
|
||||
|
||||
def has_transform(self, cls):
|
||||
return self.get_transform(cls) is not None
|
||||
@@ -0,0 +1,61 @@
|
||||
from typing import List
|
||||
import numpy as np
|
||||
|
||||
from fairseq.data.audio.dataset_transforms import (
|
||||
AudioDatasetTransform,
|
||||
register_audio_dataset_transform,
|
||||
)
|
||||
|
||||
_DEFAULTS = {"rate": 0.25, "max_tokens": 3000, "attempts": 5}
|
||||
|
||||
|
||||
@register_audio_dataset_transform("concataugment")
|
||||
class ConcatAugment(AudioDatasetTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return ConcatAugment(
|
||||
_config.get("rate", _DEFAULTS["rate"]),
|
||||
_config.get("max_tokens", _DEFAULTS["max_tokens"]),
|
||||
_config.get("attempts", _DEFAULTS["attempts"]),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rate=_DEFAULTS["rate"],
|
||||
max_tokens=_DEFAULTS["max_tokens"],
|
||||
attempts=_DEFAULTS["attempts"],
|
||||
):
|
||||
self.rate, self.max_tokens, self.attempts = rate, max_tokens, attempts
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ "("
|
||||
+ ", ".join(
|
||||
[
|
||||
f"rate={self.rate}",
|
||||
f"max_tokens={self.max_tokens}",
|
||||
f"attempts={self.attempts}",
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
|
||||
def find_indices(self, index: int, n_frames: List[int], n_samples: int):
|
||||
# skip conditions: application rate, max_tokens limit exceeded
|
||||
if np.random.random() > self.rate:
|
||||
return [index]
|
||||
if self.max_tokens and n_frames[index] > self.max_tokens:
|
||||
return [index]
|
||||
|
||||
# pick second sample to concatenate
|
||||
for _ in range(self.attempts):
|
||||
index2 = np.random.randint(0, n_samples)
|
||||
if index2 != index and (
|
||||
not self.max_tokens
|
||||
or n_frames[index] + n_frames[index2] < self.max_tokens
|
||||
):
|
||||
return [index, index2]
|
||||
|
||||
return [index]
|
||||
@@ -0,0 +1,105 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq.data.audio import rand_uniform
|
||||
from fairseq.data.audio.dataset_transforms import (
|
||||
AudioDatasetTransform,
|
||||
register_audio_dataset_transform,
|
||||
)
|
||||
from fairseq.data.audio.waveform_transforms.noiseaugment import (
|
||||
NoiseAugmentTransform,
|
||||
)
|
||||
|
||||
_DEFAULTS = {
|
||||
"rate": 0.25,
|
||||
"mixing_noise_rate": 0.1,
|
||||
"noise_path": "",
|
||||
"noise_snr_min": -5,
|
||||
"noise_snr_max": 5,
|
||||
"utterance_snr_min": -5,
|
||||
"utterance_snr_max": 5,
|
||||
}
|
||||
|
||||
|
||||
@register_audio_dataset_transform("noisyoverlapaugment")
|
||||
class NoisyOverlapAugment(AudioDatasetTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return NoisyOverlapAugment(
|
||||
_config.get("rate", _DEFAULTS["rate"]),
|
||||
_config.get("mixing_noise_rate", _DEFAULTS["mixing_noise_rate"]),
|
||||
_config.get("noise_path", _DEFAULTS["noise_path"]),
|
||||
_config.get("noise_snr_min", _DEFAULTS["noise_snr_min"]),
|
||||
_config.get("noise_snr_max", _DEFAULTS["noise_snr_max"]),
|
||||
_config.get("utterance_snr_min", _DEFAULTS["utterance_snr_min"]),
|
||||
_config.get("utterance_snr_max", _DEFAULTS["utterance_snr_max"]),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rate=_DEFAULTS["rate"],
|
||||
mixing_noise_rate=_DEFAULTS["mixing_noise_rate"],
|
||||
noise_path=_DEFAULTS["noise_path"],
|
||||
noise_snr_min=_DEFAULTS["noise_snr_min"],
|
||||
noise_snr_max=_DEFAULTS["noise_snr_max"],
|
||||
utterance_snr_min=_DEFAULTS["utterance_snr_min"],
|
||||
utterance_snr_max=_DEFAULTS["utterance_snr_max"],
|
||||
):
|
||||
self.rate = rate
|
||||
self.mixing_noise_rate = mixing_noise_rate
|
||||
self.noise_shaper = NoiseAugmentTransform(noise_path)
|
||||
self.noise_snr_min = noise_snr_min
|
||||
self.noise_snr_max = noise_snr_max
|
||||
self.utterance_snr_min = utterance_snr_min
|
||||
self.utterance_snr_max = utterance_snr_max
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ "("
|
||||
+ ", ".join(
|
||||
[
|
||||
f"rate={self.rate}",
|
||||
f"mixing_noise_rate={self.mixing_noise_rate}",
|
||||
f"noise_snr_min={self.noise_snr_min}",
|
||||
f"noise_snr_max={self.noise_snr_max}",
|
||||
f"utterance_snr_min={self.utterance_snr_min}",
|
||||
f"utterance_snr_max={self.utterance_snr_max}",
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
|
||||
def __call__(self, sources):
|
||||
for i, source in enumerate(sources):
|
||||
if np.random.random() > self.rate:
|
||||
continue
|
||||
|
||||
pri = source.numpy()
|
||||
|
||||
if np.random.random() > self.mixing_noise_rate:
|
||||
sec = sources[np.random.randint(0, len(sources))].numpy()
|
||||
snr = rand_uniform(self.utterance_snr_min, self.utterance_snr_max)
|
||||
else:
|
||||
sec = self.noise_shaper.pick_sample(source.shape)
|
||||
snr = rand_uniform(self.noise_snr_min, self.noise_snr_max)
|
||||
|
||||
L1 = pri.shape[-1]
|
||||
L2 = sec.shape[-1]
|
||||
l = np.random.randint(0, min(round(L1 / 2), L2)) # mix len
|
||||
s_source = np.random.randint(0, L1 - l)
|
||||
s_sec = np.random.randint(0, L2 - l)
|
||||
|
||||
get_power = lambda x: np.mean(x**2)
|
||||
if get_power(sec) == 0:
|
||||
continue
|
||||
|
||||
scl = np.sqrt(get_power(pri) / (np.power(10, snr / 10) * get_power(sec)))
|
||||
|
||||
pri[s_source : s_source + l] = np.add(
|
||||
pri[s_source : s_source + l], np.multiply(scl, sec[s_sec : s_sec + l])
|
||||
)
|
||||
sources[i] = torch.from_numpy(pri).float()
|
||||
|
||||
return sources
|
||||
@@ -0,0 +1,43 @@
|
||||
import os
|
||||
from fairseq.data.audio import (
|
||||
AudioTransform,
|
||||
CompositeAudioTransform,
|
||||
import_transforms,
|
||||
register_audio_transform,
|
||||
)
|
||||
|
||||
|
||||
class AudioFeatureTransform(AudioTransform):
|
||||
pass
|
||||
|
||||
|
||||
AUDIO_FEATURE_TRANSFORM_REGISTRY = {}
|
||||
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set()
|
||||
|
||||
|
||||
def get_audio_feature_transform(name):
|
||||
return AUDIO_FEATURE_TRANSFORM_REGISTRY[name]
|
||||
|
||||
|
||||
def register_audio_feature_transform(name):
|
||||
return register_audio_transform(
|
||||
name,
|
||||
AudioFeatureTransform,
|
||||
AUDIO_FEATURE_TRANSFORM_REGISTRY,
|
||||
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES,
|
||||
)
|
||||
|
||||
|
||||
import_transforms(os.path.dirname(__file__), "feature")
|
||||
|
||||
|
||||
class CompositeAudioFeatureTransform(CompositeAudioTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
return super()._from_config_dict(
|
||||
cls,
|
||||
"feature",
|
||||
get_audio_feature_transform,
|
||||
CompositeAudioFeatureTransform,
|
||||
config,
|
||||
)
|
||||
@@ -0,0 +1,37 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data.audio.feature_transforms import (
|
||||
AudioFeatureTransform,
|
||||
register_audio_feature_transform,
|
||||
)
|
||||
|
||||
|
||||
@register_audio_feature_transform("delta_deltas")
|
||||
class DeltaDeltas(AudioFeatureTransform):
|
||||
"""Expand delta-deltas features from spectrum."""
|
||||
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return DeltaDeltas(_config.get("win_length", 5))
|
||||
|
||||
def __init__(self, win_length=5):
|
||||
self.win_length = win_length
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def __call__(self, spectrogram):
|
||||
from torchaudio.functional import compute_deltas
|
||||
|
||||
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
|
||||
# spectrogram is T x F, while compute_deltas takes (…, F, T)
|
||||
spectrogram = torch.from_numpy(spectrogram).transpose(0, 1)
|
||||
delta = compute_deltas(spectrogram)
|
||||
delta_delta = compute_deltas(delta)
|
||||
|
||||
out_feat = np.concatenate(
|
||||
[spectrogram, delta.numpy(), delta_delta.numpy()], axis=0
|
||||
)
|
||||
out_feat = np.transpose(out_feat)
|
||||
return out_feat
|
||||
@@ -0,0 +1,29 @@
|
||||
import numpy as np
|
||||
from fairseq.data.audio.feature_transforms import (
|
||||
AudioFeatureTransform,
|
||||
register_audio_feature_transform,
|
||||
)
|
||||
|
||||
|
||||
@register_audio_feature_transform("global_cmvn")
|
||||
class GlobalCMVN(AudioFeatureTransform):
|
||||
"""Global CMVN (cepstral mean and variance normalization). The global mean
|
||||
and variance need to be pre-computed and stored in NumPy format (.npz)."""
|
||||
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return GlobalCMVN(_config.get("stats_npz_path"))
|
||||
|
||||
def __init__(self, stats_npz_path):
|
||||
self.stats_npz_path = stats_npz_path
|
||||
stats = np.load(stats_npz_path)
|
||||
self.mean, self.std = stats["mean"], stats["std"]
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(stats_npz_path="{self.stats_npz_path}")'
|
||||
|
||||
def __call__(self, x):
|
||||
x = np.subtract(x, self.mean)
|
||||
x = np.divide(x, self.std)
|
||||
return x
|
||||
@@ -0,0 +1,131 @@
|
||||
import math
|
||||
import numbers
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from fairseq.data.audio.feature_transforms import (
|
||||
AudioFeatureTransform,
|
||||
register_audio_feature_transform,
|
||||
)
|
||||
|
||||
|
||||
@register_audio_feature_transform("specaugment")
|
||||
class SpecAugmentTransform(AudioFeatureTransform):
|
||||
"""SpecAugment (https://arxiv.org/abs/1904.08779)"""
|
||||
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return SpecAugmentTransform(
|
||||
_config.get("time_warp_W", 0),
|
||||
_config.get("freq_mask_N", 0),
|
||||
_config.get("freq_mask_F", 0),
|
||||
_config.get("time_mask_N", 0),
|
||||
_config.get("time_mask_T", 0),
|
||||
_config.get("time_mask_p", 0.0),
|
||||
_config.get("mask_value", None),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
time_warp_w: int = 0,
|
||||
freq_mask_n: int = 0,
|
||||
freq_mask_f: int = 0,
|
||||
time_mask_n: int = 0,
|
||||
time_mask_t: int = 0,
|
||||
time_mask_p: float = 0.0,
|
||||
mask_value: Optional[float] = 0.0,
|
||||
):
|
||||
# Sanity checks
|
||||
assert mask_value is None or isinstance(
|
||||
mask_value, numbers.Number
|
||||
), f"mask_value (type: {type(mask_value)}) must be None or a number"
|
||||
if freq_mask_n > 0:
|
||||
assert freq_mask_f > 0, (
|
||||
f"freq_mask_F ({freq_mask_f}) "
|
||||
f"must be larger than 0 when doing freq masking."
|
||||
)
|
||||
if time_mask_n > 0:
|
||||
assert time_mask_t > 0, (
|
||||
f"time_mask_T ({time_mask_t}) must be larger than 0 when "
|
||||
f"doing time masking."
|
||||
)
|
||||
|
||||
self.time_warp_w = time_warp_w
|
||||
self.freq_mask_n = freq_mask_n
|
||||
self.freq_mask_f = freq_mask_f
|
||||
self.time_mask_n = time_mask_n
|
||||
self.time_mask_t = time_mask_t
|
||||
self.time_mask_p = time_mask_p
|
||||
self.mask_value = mask_value
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ "("
|
||||
+ ", ".join(
|
||||
[
|
||||
f"time_warp_w={self.time_warp_w}",
|
||||
f"freq_mask_n={self.freq_mask_n}",
|
||||
f"freq_mask_f={self.freq_mask_f}",
|
||||
f"time_mask_n={self.time_mask_n}",
|
||||
f"time_mask_t={self.time_mask_t}",
|
||||
f"time_mask_p={self.time_mask_p}",
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
|
||||
def __call__(self, spectrogram):
|
||||
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
|
||||
|
||||
distorted = spectrogram.copy() # make a copy of input spectrogram.
|
||||
num_frames = spectrogram.shape[0] # or 'tau' in the paper.
|
||||
num_freqs = spectrogram.shape[1] # or 'miu' in the paper.
|
||||
mask_value = self.mask_value
|
||||
|
||||
if mask_value is None: # if no value was specified, use local mean.
|
||||
mask_value = spectrogram.mean()
|
||||
|
||||
if num_frames == 0:
|
||||
return spectrogram
|
||||
|
||||
if num_freqs < self.freq_mask_f:
|
||||
return spectrogram
|
||||
|
||||
if self.time_warp_w > 0:
|
||||
if 2 * self.time_warp_w < num_frames:
|
||||
import cv2
|
||||
|
||||
w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w)
|
||||
w = np.random.randint(-self.time_warp_w + 1, self.time_warp_w)
|
||||
upper, lower = distorted[:w0, :], distorted[w0:, :]
|
||||
upper = cv2.resize(
|
||||
upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR
|
||||
)
|
||||
lower = cv2.resize(
|
||||
lower,
|
||||
dsize=(num_freqs, num_frames - w0 - w),
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
distorted = np.concatenate((upper, lower), axis=0)
|
||||
|
||||
for _i in range(self.freq_mask_n):
|
||||
f = np.random.randint(0, self.freq_mask_f)
|
||||
f0 = np.random.randint(0, num_freqs - f)
|
||||
if f != 0:
|
||||
distorted[:, f0 : f0 + f] = mask_value
|
||||
|
||||
max_time_mask_t = min(
|
||||
self.time_mask_t, math.floor(num_frames * self.time_mask_p)
|
||||
)
|
||||
if max_time_mask_t < 1:
|
||||
return distorted
|
||||
|
||||
for _i in range(self.time_mask_n):
|
||||
t = np.random.randint(0, max_time_mask_t)
|
||||
t0 = np.random.randint(0, num_frames - t)
|
||||
if t != 0:
|
||||
distorted[t0 : t0 + t, :] = mask_value
|
||||
|
||||
return distorted
|
||||
@@ -0,0 +1,41 @@
|
||||
import numpy as np
|
||||
|
||||
from fairseq.data.audio.feature_transforms import (
|
||||
AudioFeatureTransform,
|
||||
register_audio_feature_transform,
|
||||
)
|
||||
|
||||
|
||||
@register_audio_feature_transform("utterance_cmvn")
|
||||
class UtteranceCMVN(AudioFeatureTransform):
|
||||
"""Utterance-level CMVN (cepstral mean and variance normalization)"""
|
||||
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return UtteranceCMVN(
|
||||
_config.get("norm_means", True),
|
||||
_config.get("norm_vars", True),
|
||||
)
|
||||
|
||||
def __init__(self, norm_means=True, norm_vars=True):
|
||||
self.norm_means, self.norm_vars = norm_means, norm_vars
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
mean = x.mean(axis=0)
|
||||
square_sums = (x**2).sum(axis=0)
|
||||
|
||||
if self.norm_means:
|
||||
x = np.subtract(x, mean)
|
||||
if self.norm_vars:
|
||||
var = square_sums / x.shape[0] - mean**2
|
||||
std = np.sqrt(np.maximum(var, 1e-10))
|
||||
x = np.divide(x, std)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,205 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.abs
|
||||
|
||||
import csv
|
||||
import logging
|
||||
import os.path as op
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
|
||||
from fairseq.data.audio.text_to_speech_dataset import (
|
||||
TextToSpeechDataset,
|
||||
TextToSpeechDatasetCreator,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FrmTextToSpeechDataset(TextToSpeechDataset):
|
||||
def __init__(
|
||||
self,
|
||||
split: str,
|
||||
is_train_split: bool,
|
||||
data_cfg: S2TDataConfig,
|
||||
audio_paths: List[str],
|
||||
n_frames: List[int],
|
||||
src_texts: Optional[List[str]] = None,
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
speakers: Optional[List[str]] = None,
|
||||
src_langs: Optional[List[str]] = None,
|
||||
tgt_langs: Optional[List[str]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
tgt_dict: Optional[Dictionary] = None,
|
||||
pre_tokenizer=None,
|
||||
bpe_tokenizer=None,
|
||||
n_frames_per_step=1,
|
||||
speaker_to_id=None,
|
||||
do_chunk=False,
|
||||
chunk_bound=-1,
|
||||
chunk_init=50,
|
||||
chunk_incr=5,
|
||||
add_eos=True,
|
||||
dedup=True,
|
||||
ref_fpu=-1,
|
||||
):
|
||||
# It assumes texts are encoded at a fixed frame-rate
|
||||
super().__init__(
|
||||
split=split,
|
||||
is_train_split=is_train_split,
|
||||
data_cfg=data_cfg,
|
||||
audio_paths=audio_paths,
|
||||
n_frames=n_frames,
|
||||
src_texts=src_texts,
|
||||
tgt_texts=tgt_texts,
|
||||
speakers=speakers,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
ids=ids,
|
||||
tgt_dict=tgt_dict,
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
speaker_to_id=speaker_to_id,
|
||||
)
|
||||
|
||||
self.do_chunk = do_chunk
|
||||
self.chunk_bound = chunk_bound
|
||||
self.chunk_init = chunk_init
|
||||
self.chunk_incr = chunk_incr
|
||||
self.add_eos = add_eos
|
||||
self.dedup = dedup
|
||||
self.ref_fpu = ref_fpu
|
||||
|
||||
self.chunk_size = -1
|
||||
|
||||
if do_chunk:
|
||||
assert self.chunk_incr >= 0
|
||||
assert self.pre_tokenizer is None
|
||||
|
||||
def __getitem__(self, index):
|
||||
index, source, target, speaker_id, _, _, _ = super().__getitem__(index)
|
||||
if target[-1].item() == self.tgt_dict.eos_index:
|
||||
target = target[:-1]
|
||||
|
||||
fpu = source.size(0) / target.size(0) # frame-per-unit
|
||||
fps = self.n_frames_per_step
|
||||
assert (
|
||||
self.ref_fpu == -1 or abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1
|
||||
), f"{fpu*fps} != {self.ref_fpu}"
|
||||
|
||||
# only chunk training split
|
||||
if self.is_train_split and self.do_chunk and self.chunk_size > 0:
|
||||
lang = target[: int(self.data_cfg.prepend_tgt_lang_tag)]
|
||||
text = target[int(self.data_cfg.prepend_tgt_lang_tag) :]
|
||||
size = len(text)
|
||||
chunk_size = min(self.chunk_size, size)
|
||||
chunk_start = np.random.randint(size - chunk_size + 1)
|
||||
text = text[chunk_start : chunk_start + chunk_size]
|
||||
target = torch.cat((lang, text), 0)
|
||||
|
||||
f_size = int(np.floor(chunk_size * fpu))
|
||||
f_start = int(np.floor(chunk_start * fpu))
|
||||
assert f_size > 0
|
||||
source = source[f_start : f_start + f_size, :]
|
||||
|
||||
if self.dedup:
|
||||
target = torch.unique_consecutive(target)
|
||||
|
||||
if self.add_eos:
|
||||
eos_idx = self.tgt_dict.eos_index
|
||||
target = torch.cat((target, torch.LongTensor([eos_idx])), 0)
|
||||
|
||||
return index, source, target, speaker_id
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
if self.is_train_split and self.do_chunk:
|
||||
old = self.chunk_size
|
||||
self.chunk_size = self.chunk_init + epoch * self.chunk_incr
|
||||
if self.chunk_bound > 0:
|
||||
self.chunk_size = min(self.chunk_size, self.chunk_bound)
|
||||
logger.info(
|
||||
(
|
||||
f"{self.split}: setting chunk size "
|
||||
f"from {old} to {self.chunk_size}"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
|
||||
# inherit for key names
|
||||
@classmethod
|
||||
def from_tsv(
|
||||
cls,
|
||||
root: str,
|
||||
data_cfg: S2TDataConfig,
|
||||
split: str,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
is_train_split: bool,
|
||||
n_frames_per_step: int,
|
||||
speaker_to_id,
|
||||
do_chunk: bool = False,
|
||||
chunk_bound: int = -1,
|
||||
chunk_init: int = 50,
|
||||
chunk_incr: int = 5,
|
||||
add_eos: bool = True,
|
||||
dedup: bool = True,
|
||||
ref_fpu: float = -1,
|
||||
) -> FrmTextToSpeechDataset:
|
||||
tsv_path = op.join(root, f"{split}.tsv")
|
||||
if not op.isfile(tsv_path):
|
||||
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
|
||||
with open(tsv_path) as f:
|
||||
reader = csv.DictReader(
|
||||
f,
|
||||
delimiter="\t",
|
||||
quotechar=None,
|
||||
doublequote=False,
|
||||
lineterminator="\n",
|
||||
quoting=csv.QUOTE_NONE,
|
||||
)
|
||||
s = [dict(e) for e in reader]
|
||||
assert len(s) > 0
|
||||
|
||||
ids = [ss[cls.KEY_ID] for ss in s]
|
||||
audio_paths = [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s]
|
||||
n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s]
|
||||
tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s]
|
||||
src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s]
|
||||
speakers = [ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s]
|
||||
src_langs = [ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s]
|
||||
tgt_langs = [ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s]
|
||||
|
||||
return FrmTextToSpeechDataset(
|
||||
split=split,
|
||||
is_train_split=is_train_split,
|
||||
data_cfg=data_cfg,
|
||||
audio_paths=audio_paths,
|
||||
n_frames=n_frames,
|
||||
src_texts=src_texts,
|
||||
tgt_texts=tgt_texts,
|
||||
speakers=speakers,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
ids=ids,
|
||||
tgt_dict=tgt_dict,
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
speaker_to_id=speaker_to_id,
|
||||
do_chunk=do_chunk,
|
||||
chunk_bound=chunk_bound,
|
||||
chunk_init=chunk_init,
|
||||
chunk_incr=chunk_incr,
|
||||
add_eos=add_eos,
|
||||
dedup=dedup,
|
||||
ref_fpu=ref_fpu,
|
||||
)
|
||||
356
modules/voice_conversion/fairseq/data/audio/hubert_dataset.py
Normal file
356
modules/voice_conversion/fairseq/data/audio/hubert_dataset.py
Normal file
@@ -0,0 +1,356 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairseq.data import data_utils
|
||||
from fairseq.data.fairseq_dataset import FairseqDataset
|
||||
from fairseq.data.audio.audio_utils import (
|
||||
parse_path,
|
||||
read_from_stored_zip,
|
||||
)
|
||||
import io
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_audio(manifest_path, max_keep, min_keep):
|
||||
n_long, n_short = 0, 0
|
||||
names, inds, sizes = [], [], []
|
||||
with open(manifest_path) as f:
|
||||
root = f.readline().strip()
|
||||
for ind, line in enumerate(f):
|
||||
items = line.strip().split("\t")
|
||||
assert len(items) == 2, line
|
||||
sz = int(items[1])
|
||||
if min_keep is not None and sz < min_keep:
|
||||
n_short += 1
|
||||
elif max_keep is not None and sz > max_keep:
|
||||
n_long += 1
|
||||
else:
|
||||
names.append(items[0])
|
||||
inds.append(ind)
|
||||
sizes.append(sz)
|
||||
tot = ind + 1
|
||||
logger.info(
|
||||
(
|
||||
f"max_keep={max_keep}, min_keep={min_keep}, "
|
||||
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
|
||||
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
||||
)
|
||||
)
|
||||
return root, names, inds, tot, sizes
|
||||
|
||||
|
||||
def load_label(label_path, inds, tot):
|
||||
with open(label_path) as f:
|
||||
labels = [line.rstrip() for line in f]
|
||||
assert (
|
||||
len(labels) == tot
|
||||
), f"number of labels does not match ({len(labels)} != {tot})"
|
||||
labels = [labels[i] for i in inds]
|
||||
return labels
|
||||
|
||||
|
||||
def load_label_offset(label_path, inds, tot):
|
||||
with open(label_path) as f:
|
||||
code_lengths = [len(line.encode("utf-8")) for line in f]
|
||||
assert (
|
||||
len(code_lengths) == tot
|
||||
), f"number of labels does not match ({len(code_lengths)} != {tot})"
|
||||
offsets = list(itertools.accumulate([0] + code_lengths))
|
||||
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
|
||||
return offsets
|
||||
|
||||
|
||||
def verify_label_lengths(
|
||||
audio_sizes,
|
||||
audio_rate,
|
||||
label_path,
|
||||
label_rate,
|
||||
inds,
|
||||
tot,
|
||||
tol=0.1, # tolerance in seconds
|
||||
):
|
||||
if label_rate < 0:
|
||||
logger.info(f"{label_path} is sequence label. skipped")
|
||||
return
|
||||
|
||||
with open(label_path) as f:
|
||||
lengths = [len(line.rstrip().split()) for line in f]
|
||||
assert len(lengths) == tot
|
||||
lengths = [lengths[i] for i in inds]
|
||||
num_invalid = 0
|
||||
for i, ind in enumerate(inds):
|
||||
dur_from_audio = audio_sizes[i] / audio_rate
|
||||
dur_from_label = lengths[i] / label_rate
|
||||
if abs(dur_from_audio - dur_from_label) > tol:
|
||||
logger.warning(
|
||||
(
|
||||
f"audio and label duration differ too much "
|
||||
f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
|
||||
f"in line {ind+1} of {label_path}. Check if `label_rate` "
|
||||
f"is correctly set (currently {label_rate}). "
|
||||
f"num. of samples = {audio_sizes[i]}; "
|
||||
f"label length = {lengths[i]}"
|
||||
)
|
||||
)
|
||||
num_invalid += 1
|
||||
if num_invalid > 0:
|
||||
logger.warning(
|
||||
f"total {num_invalid} (audio, label) pairs with mismatched lengths"
|
||||
)
|
||||
|
||||
|
||||
class HubertDataset(FairseqDataset):
|
||||
def __init__(
|
||||
self,
|
||||
manifest_path: str,
|
||||
sample_rate: float,
|
||||
label_paths: List[str],
|
||||
label_rates: Union[List[float], float], # -1 for sequence labels
|
||||
pad_list: List[str],
|
||||
eos_list: List[str],
|
||||
label_processors: Optional[List[Any]] = None,
|
||||
max_keep_sample_size: Optional[int] = None,
|
||||
min_keep_sample_size: Optional[int] = None,
|
||||
max_sample_size: Optional[int] = None,
|
||||
shuffle: bool = True,
|
||||
pad_audio: bool = False,
|
||||
normalize: bool = False,
|
||||
store_labels: bool = True,
|
||||
random_crop: bool = False,
|
||||
single_target: bool = False,
|
||||
):
|
||||
self.audio_root, self.audio_names, inds, tot, self.sizes = load_audio(
|
||||
manifest_path, max_keep_sample_size, min_keep_sample_size
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
self.shuffle = shuffle
|
||||
self.random_crop = random_crop
|
||||
|
||||
self.num_labels = len(label_paths)
|
||||
self.pad_list = pad_list
|
||||
self.eos_list = eos_list
|
||||
self.label_processors = label_processors
|
||||
self.single_target = single_target
|
||||
self.label_rates = (
|
||||
[label_rates for _ in range(len(label_paths))]
|
||||
if isinstance(label_rates, float)
|
||||
else label_rates
|
||||
)
|
||||
self.store_labels = store_labels
|
||||
if store_labels:
|
||||
self.label_list = [load_label(p, inds, tot) for p in label_paths]
|
||||
else:
|
||||
self.label_paths = label_paths
|
||||
self.label_offsets_list = [
|
||||
load_label_offset(p, inds, tot) for p in label_paths
|
||||
]
|
||||
assert label_processors is None or len(label_processors) == self.num_labels
|
||||
for label_path, label_rate in zip(label_paths, self.label_rates):
|
||||
verify_label_lengths(
|
||||
self.sizes, sample_rate, label_path, label_rate, inds, tot
|
||||
)
|
||||
|
||||
self.max_sample_size = (
|
||||
max_sample_size if max_sample_size is not None else sys.maxsize
|
||||
)
|
||||
self.pad_audio = pad_audio
|
||||
self.normalize = normalize
|
||||
logger.info(
|
||||
f"pad_audio={pad_audio}, random_crop={random_crop}, "
|
||||
f"normalize={normalize}, max_sample_size={self.max_sample_size}"
|
||||
)
|
||||
|
||||
def get_audio(self, index):
|
||||
import soundfile as sf
|
||||
|
||||
wav_path = os.path.join(self.audio_root, self.audio_names[index])
|
||||
_path, slice_ptr = parse_path(wav_path)
|
||||
if len(slice_ptr) == 0:
|
||||
wav, cur_sample_rate = sf.read(_path)
|
||||
else:
|
||||
assert _path.endswith(".zip")
|
||||
data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
|
||||
f = io.BytesIO(data)
|
||||
wav, cur_sample_rate = sf.read(f)
|
||||
wav = torch.from_numpy(wav).float()
|
||||
wav = self.postprocess(wav, cur_sample_rate)
|
||||
return wav
|
||||
|
||||
def get_label(self, index, label_idx):
|
||||
if self.store_labels:
|
||||
label = self.label_list[label_idx][index]
|
||||
else:
|
||||
with open(self.label_paths[label_idx]) as f:
|
||||
offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
||||
f.seek(offset_s)
|
||||
label = f.read(offset_e - offset_s)
|
||||
|
||||
if self.label_processors is not None:
|
||||
label = self.label_processors[label_idx](label)
|
||||
return label
|
||||
|
||||
def get_labels(self, index):
|
||||
return [self.get_label(index, i) for i in range(self.num_labels)]
|
||||
|
||||
def __getitem__(self, index):
|
||||
wav = self.get_audio(index)
|
||||
labels = self.get_labels(index)
|
||||
return {"id": index, "source": wav, "label_list": labels}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sizes)
|
||||
|
||||
def crop_to_max_size(self, wav, target_size):
|
||||
size = len(wav)
|
||||
diff = size - target_size
|
||||
if diff <= 0:
|
||||
return wav, 0
|
||||
|
||||
start, end = 0, target_size
|
||||
if self.random_crop:
|
||||
start = np.random.randint(0, diff + 1)
|
||||
end = size - diff + start
|
||||
return wav[start:end], start
|
||||
|
||||
def collater(self, samples):
|
||||
# target = max(sizes) -> random_crop not used
|
||||
# target = max_sample_size -> random_crop used for long
|
||||
samples = [s for s in samples if s["source"] is not None]
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
audios = [s["source"] for s in samples]
|
||||
audio_sizes = [len(s) for s in audios]
|
||||
if self.pad_audio:
|
||||
audio_size = min(max(audio_sizes), self.max_sample_size)
|
||||
else:
|
||||
audio_size = min(min(audio_sizes), self.max_sample_size)
|
||||
collated_audios, padding_mask, audio_starts = self.collater_audio(
|
||||
audios, audio_size
|
||||
)
|
||||
|
||||
targets_by_label = [
|
||||
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
|
||||
]
|
||||
targets_list, lengths_list, ntokens_list = self.collater_label(
|
||||
targets_by_label, audio_size, audio_starts
|
||||
)
|
||||
|
||||
net_input = {"source": collated_audios, "padding_mask": padding_mask}
|
||||
batch = {
|
||||
"id": torch.LongTensor([s["id"] for s in samples]),
|
||||
"net_input": net_input,
|
||||
}
|
||||
|
||||
if self.single_target:
|
||||
batch["target_lengths"] = lengths_list[0]
|
||||
batch["ntokens"] = ntokens_list[0]
|
||||
batch["target"] = targets_list[0]
|
||||
else:
|
||||
batch["target_lengths_list"] = lengths_list
|
||||
batch["ntokens_list"] = ntokens_list
|
||||
batch["target_list"] = targets_list
|
||||
return batch
|
||||
|
||||
def collater_audio(self, audios, audio_size):
|
||||
collated_audios = audios[0].new_zeros(len(audios), audio_size)
|
||||
padding_mask = (
|
||||
torch.BoolTensor(collated_audios.shape).fill_(False)
|
||||
# if self.pad_audio else None
|
||||
)
|
||||
audio_starts = [0 for _ in audios]
|
||||
for i, audio in enumerate(audios):
|
||||
diff = len(audio) - audio_size
|
||||
if diff == 0:
|
||||
collated_audios[i] = audio
|
||||
elif diff < 0:
|
||||
assert self.pad_audio
|
||||
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
|
||||
padding_mask[i, diff:] = True
|
||||
else:
|
||||
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
|
||||
audio, audio_size
|
||||
)
|
||||
return collated_audios, padding_mask, audio_starts
|
||||
|
||||
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
|
||||
assert label_rate > 0
|
||||
s2f = label_rate / self.sample_rate
|
||||
frm_starts = [int(round(s * s2f)) for s in audio_starts]
|
||||
frm_size = int(round(audio_size * s2f))
|
||||
if not self.pad_audio:
|
||||
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
|
||||
frm_size = min(frm_size, *rem_size)
|
||||
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
|
||||
logger.debug(f"audio_starts={audio_starts}")
|
||||
logger.debug(f"frame_starts={frm_starts}")
|
||||
logger.debug(f"frame_size={frm_size}")
|
||||
|
||||
lengths = torch.LongTensor([len(t) for t in targets])
|
||||
ntokens = lengths.sum().item()
|
||||
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
||||
return targets, lengths, ntokens
|
||||
|
||||
def collater_seq_label(self, targets, pad):
|
||||
lengths = torch.LongTensor([len(t) for t in targets])
|
||||
ntokens = lengths.sum().item()
|
||||
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
||||
return targets, lengths, ntokens
|
||||
|
||||
def collater_label(self, targets_by_label, audio_size, audio_starts):
|
||||
targets_list, lengths_list, ntokens_list = [], [], []
|
||||
itr = zip(targets_by_label, self.label_rates, self.pad_list)
|
||||
for targets, label_rate, pad in itr:
|
||||
if label_rate == -1.0:
|
||||
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
|
||||
else:
|
||||
targets, lengths, ntokens = self.collater_frm_label(
|
||||
targets, audio_size, audio_starts, label_rate, pad
|
||||
)
|
||||
targets_list.append(targets)
|
||||
lengths_list.append(lengths)
|
||||
ntokens_list.append(ntokens)
|
||||
return targets_list, lengths_list, ntokens_list
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.size(index)
|
||||
|
||||
def size(self, index):
|
||||
if self.pad_audio:
|
||||
return self.sizes[index]
|
||||
return min(self.sizes[index], self.max_sample_size)
|
||||
|
||||
def ordered_indices(self):
|
||||
if self.shuffle:
|
||||
order = [np.random.permutation(len(self))]
|
||||
else:
|
||||
order = [np.arange(len(self))]
|
||||
|
||||
order.append(self.sizes)
|
||||
return np.lexsort(order)[::-1]
|
||||
|
||||
def postprocess(self, wav, cur_sample_rate):
|
||||
if wav.dim() == 2:
|
||||
wav = wav.mean(-1)
|
||||
assert wav.dim() == 1, wav.dim()
|
||||
|
||||
if cur_sample_rate != self.sample_rate:
|
||||
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
||||
|
||||
if self.normalize:
|
||||
with torch.no_grad():
|
||||
wav = F.layer_norm(wav, wav.shape)
|
||||
return wav
|
||||
@@ -0,0 +1,284 @@
|
||||
# Copyright (c) 2021-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Optional, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
from fairseq.data.resampling_dataset import ResamplingDataset
|
||||
import torch
|
||||
from fairseq.data import (
|
||||
ConcatDataset,
|
||||
LanguagePairDataset,
|
||||
FileAudioDataset,
|
||||
data_utils,
|
||||
)
|
||||
from fairseq.data import FairseqDataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModalityDatasetItem(NamedTuple):
|
||||
datasetname: str
|
||||
dataset: any
|
||||
max_positions: List[int]
|
||||
max_tokens: Optional[int] = None
|
||||
max_sentences: Optional[int] = None
|
||||
|
||||
|
||||
def resampling_dataset_present(ds):
|
||||
if isinstance(ds, ResamplingDataset):
|
||||
return True
|
||||
if isinstance(ds, ConcatDataset):
|
||||
return any(resampling_dataset_present(d) for d in ds.datasets)
|
||||
if hasattr(ds, "dataset"):
|
||||
return resampling_dataset_present(ds.dataset)
|
||||
return False
|
||||
|
||||
|
||||
# MultiModalityDataset: it concate multiple datasets with different modalities.
|
||||
# Compared with ConcatDataset it can 1) sample data given the ratios for different datasets
|
||||
# 2) it adds mode to indicate what type of the data samples come from.
|
||||
# It will be used with GroupedEpochBatchIterator together to generate mini-batch with samples
|
||||
# from the same type of dataset
|
||||
# If only one dataset is used, it will perform like the original dataset with mode added
|
||||
class MultiModalityDataset(ConcatDataset):
|
||||
def __init__(self, datasets: List[ModalityDatasetItem]):
|
||||
id_to_mode = []
|
||||
dsets = []
|
||||
max_tokens = []
|
||||
max_sentences = []
|
||||
max_positions = []
|
||||
for dset in datasets:
|
||||
id_to_mode.append(dset.datasetname)
|
||||
dsets.append(dset.dataset)
|
||||
max_tokens.append(dset.max_tokens)
|
||||
max_positions.append(dset.max_positions)
|
||||
max_sentences.append(dset.max_sentences)
|
||||
weights = [1.0 for s in dsets]
|
||||
super().__init__(dsets, weights)
|
||||
self.max_tokens = max_tokens
|
||||
self.max_positions = max_positions
|
||||
self.max_sentences = max_sentences
|
||||
self.id_to_mode = id_to_mode
|
||||
self.raw_sub_batch_samplers = []
|
||||
self._cur_epoch = 0
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
super().set_epoch(epoch)
|
||||
self._cur_epoch = epoch
|
||||
|
||||
def __getitem__(self, idx):
|
||||
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
||||
sample = self.datasets[dataset_idx][sample_idx]
|
||||
return (dataset_idx, sample)
|
||||
|
||||
def collater(self, samples):
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
dataset_idx = samples[0][0]
|
||||
# make sure all samples in samples are from same dataset
|
||||
assert sum([0 if dataset_idx == s[0] else 1 for s in samples]) == 0
|
||||
samples = self.datasets[dataset_idx].collater([x[1] for x in samples])
|
||||
# add mode
|
||||
samples["net_input"]["mode"] = self.id_to_mode[dataset_idx]
|
||||
|
||||
return samples
|
||||
|
||||
def size(self, index: int):
|
||||
if len(self.datasets) == 1:
|
||||
return self.datasets[0].size(index)
|
||||
return super().size(index)
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
if len(self.datasets) == 1:
|
||||
return self.datasets[0].sizes
|
||||
return super().sizes
|
||||
|
||||
def ordered_indices(self):
|
||||
"""
|
||||
Returns indices sorted by length. So less padding is needed.
|
||||
"""
|
||||
if len(self.datasets) == 1:
|
||||
return self.datasets[0].ordered_indices()
|
||||
indices_group = []
|
||||
for d_idx, ds in enumerate(self.datasets):
|
||||
sample_num = self.cumulative_sizes[d_idx]
|
||||
if d_idx > 0:
|
||||
sample_num = sample_num - self.cumulative_sizes[d_idx - 1]
|
||||
assert sample_num == len(ds)
|
||||
indices_group.append(ds.ordered_indices())
|
||||
return indices_group
|
||||
|
||||
def get_raw_batch_samplers(self, required_batch_size_multiple, seed):
|
||||
with data_utils.numpy_seed(seed):
|
||||
indices = self.ordered_indices()
|
||||
for i, ds in enumerate(self.datasets):
|
||||
# If we have ResamplingDataset, the same id can correpond to a different
|
||||
# sample in the next epoch, so we need to rebuild this at every epoch
|
||||
if i < len(self.raw_sub_batch_samplers) and not resampling_dataset_present(
|
||||
ds
|
||||
):
|
||||
logger.info(f"dataset {i} is valid and it is not re-sampled")
|
||||
continue
|
||||
indices[i] = ds.filter_indices_by_size(
|
||||
indices[i],
|
||||
self.max_positions[i],
|
||||
)[0]
|
||||
sub_batch_sampler = ds.batch_by_size(
|
||||
indices[i],
|
||||
max_tokens=self.max_tokens[i],
|
||||
max_sentences=self.max_sentences[i],
|
||||
required_batch_size_multiple=required_batch_size_multiple,
|
||||
)
|
||||
if i < len(self.raw_sub_batch_samplers):
|
||||
self.raw_sub_batch_samplers[i] = sub_batch_sampler
|
||||
else:
|
||||
self.raw_sub_batch_samplers.append(sub_batch_sampler)
|
||||
|
||||
def get_batch_samplers(self, mult_ratios, required_batch_size_multiple, seed):
|
||||
self.get_raw_batch_samplers(required_batch_size_multiple, seed)
|
||||
batch_samplers = []
|
||||
for i, _ in enumerate(self.datasets):
|
||||
if i > 0:
|
||||
sub_batch_sampler = [
|
||||
[y + self.cumulative_sizes[i - 1] for y in x]
|
||||
for x in self.raw_sub_batch_samplers[i]
|
||||
]
|
||||
else:
|
||||
sub_batch_sampler = list(self.raw_sub_batch_samplers[i])
|
||||
smp_r = mult_ratios[i]
|
||||
if smp_r != 1:
|
||||
is_increase = "increased" if smp_r > 1 else "decreased"
|
||||
logger.info(
|
||||
"number of batch for the dataset {} is {} from {} to {}".format(
|
||||
self.id_to_mode[i],
|
||||
is_increase,
|
||||
len(sub_batch_sampler),
|
||||
int(len(sub_batch_sampler) * smp_r),
|
||||
)
|
||||
)
|
||||
mul_samplers = []
|
||||
for _ in range(math.floor(smp_r)):
|
||||
mul_samplers = mul_samplers + sub_batch_sampler
|
||||
if math.floor(smp_r) != smp_r:
|
||||
with data_utils.numpy_seed(seed + self._cur_epoch):
|
||||
np.random.shuffle(sub_batch_sampler)
|
||||
smp_num = int(
|
||||
(smp_r - math.floor(smp_r)) * len(sub_batch_sampler)
|
||||
)
|
||||
mul_samplers = mul_samplers + sub_batch_sampler[:smp_num]
|
||||
sub_batch_sampler = mul_samplers
|
||||
else:
|
||||
logger.info(
|
||||
"dataset {} batch number is {} ".format(
|
||||
self.id_to_mode[i], len(sub_batch_sampler)
|
||||
)
|
||||
)
|
||||
batch_samplers.append(sub_batch_sampler)
|
||||
|
||||
return batch_samplers
|
||||
|
||||
|
||||
class LangPairMaskDataset(FairseqDataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: LanguagePairDataset,
|
||||
src_eos: int,
|
||||
src_bos: Optional[int] = None,
|
||||
noise_id: Optional[int] = -1,
|
||||
mask_ratio: Optional[float] = 0,
|
||||
mask_type: Optional[str] = "random",
|
||||
):
|
||||
self.dataset = dataset
|
||||
self.src_eos = src_eos
|
||||
self.src_bos = src_bos
|
||||
self.noise_id = noise_id
|
||||
self.mask_ratio = mask_ratio
|
||||
self.mask_type = mask_type
|
||||
assert mask_type in ("random", "tail")
|
||||
|
||||
@property
|
||||
def src_sizes(self):
|
||||
return self.dataset.src_sizes
|
||||
|
||||
@property
|
||||
def tgt_sizes(self):
|
||||
return self.dataset.tgt_sizes
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
# dataset.sizes can be a dynamically computed sizes:
|
||||
return self.dataset.sizes
|
||||
|
||||
def get_batch_shapes(self):
|
||||
if hasattr(self.dataset, "get_batch_shapes"):
|
||||
return self.dataset.get_batch_shapes()
|
||||
return self.dataset.buckets
|
||||
|
||||
def num_tokens_vec(self, indices):
|
||||
return self.dataset.num_tokens_vec(indices)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.dataset.num_tokens(index)
|
||||
|
||||
def size(self, index):
|
||||
return self.dataset.size(index)
|
||||
|
||||
def ordered_indices(self):
|
||||
return self.dataset.ordered_indices()
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return getattr(self.dataset, "supports_prefetch", False)
|
||||
|
||||
def prefetch(self, indices):
|
||||
return self.dataset.prefetch(indices)
|
||||
|
||||
def mask_src_tokens(self, sample):
|
||||
src_item = sample["source"]
|
||||
mask = None
|
||||
if self.mask_type == "random":
|
||||
mask = torch.rand(len(src_item)).le(self.mask_ratio)
|
||||
else:
|
||||
mask = torch.ones(len(src_item))
|
||||
mask[: int(len(src_item) * (1 - self.mask_ratio))] = 0
|
||||
mask = mask.eq(1)
|
||||
if src_item[0] == self.src_bos:
|
||||
mask[0] = False
|
||||
if src_item[-1] == self.src_eos:
|
||||
mask[-1] = False
|
||||
mask_src_item = src_item.masked_fill(mask, self.noise_id)
|
||||
smp = {"id": sample["id"], "source": mask_src_item, "target": sample["target"]}
|
||||
return smp
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.dataset[index]
|
||||
if self.mask_ratio > 0:
|
||||
sample = self.mask_src_tokens(sample)
|
||||
return sample
|
||||
|
||||
def collater(self, samples, pad_to_length=None):
|
||||
return self.dataset.collater(samples, pad_to_length)
|
||||
|
||||
|
||||
class FileAudioDatasetWrapper(FileAudioDataset):
|
||||
def collater(self, samples):
|
||||
samples = super().collater(samples)
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
samples["net_input"]["src_tokens"] = samples["net_input"]["source"]
|
||||
samples["net_input"]["prev_output_tokens"] = None
|
||||
del samples["net_input"]["source"]
|
||||
samples["net_input"]["src_lengths"] = None
|
||||
samples["net_input"]["alignment"] = None
|
||||
return samples
|
||||
393
modules/voice_conversion/fairseq/data/audio/raw_audio_dataset.py
Normal file
393
modules/voice_conversion/fairseq/data/audio/raw_audio_dataset.py
Normal file
@@ -0,0 +1,393 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .. import FairseqDataset
|
||||
from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes
|
||||
from fairseq.data.audio.audio_utils import (
|
||||
parse_path,
|
||||
read_from_stored_zip,
|
||||
is_sf_audio_data,
|
||||
)
|
||||
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RawAudioDataset(FairseqDataset):
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate,
|
||||
max_sample_size=None,
|
||||
min_sample_size=0,
|
||||
shuffle=True,
|
||||
pad=False,
|
||||
normalize=False,
|
||||
compute_mask_indices=False,
|
||||
**mask_compute_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.sizes = []
|
||||
self.max_sample_size = (
|
||||
max_sample_size if max_sample_size is not None else sys.maxsize
|
||||
)
|
||||
self.min_sample_size = min_sample_size
|
||||
self.pad = pad
|
||||
self.shuffle = shuffle
|
||||
self.normalize = normalize
|
||||
self.compute_mask_indices = compute_mask_indices
|
||||
if self.compute_mask_indices:
|
||||
self.mask_compute_kwargs = mask_compute_kwargs
|
||||
self._features_size_map = {}
|
||||
self._C = mask_compute_kwargs["encoder_embed_dim"]
|
||||
self._conv_feature_layers = eval(mask_compute_kwargs["conv_feature_layers"])
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sizes)
|
||||
|
||||
def postprocess(self, feats, curr_sample_rate):
|
||||
if feats.dim() == 2:
|
||||
feats = feats.mean(-1)
|
||||
|
||||
if curr_sample_rate != self.sample_rate:
|
||||
raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}")
|
||||
|
||||
assert feats.dim() == 1, feats.dim()
|
||||
|
||||
if self.normalize:
|
||||
with torch.no_grad():
|
||||
feats = F.layer_norm(feats, feats.shape)
|
||||
return feats
|
||||
|
||||
def crop_to_max_size(self, wav, target_size):
|
||||
size = len(wav)
|
||||
diff = size - target_size
|
||||
if diff <= 0:
|
||||
return wav
|
||||
|
||||
start = np.random.randint(0, diff + 1)
|
||||
end = size - diff + start
|
||||
return wav[start:end]
|
||||
|
||||
def _compute_mask_indices(self, dims, padding_mask):
|
||||
B, T, C = dims
|
||||
mask_indices, mask_channel_indices = None, None
|
||||
if self.mask_compute_kwargs["mask_prob"] > 0:
|
||||
mask_indices = compute_mask_indices(
|
||||
(B, T),
|
||||
padding_mask,
|
||||
self.mask_compute_kwargs["mask_prob"],
|
||||
self.mask_compute_kwargs["mask_length"],
|
||||
self.mask_compute_kwargs["mask_selection"],
|
||||
self.mask_compute_kwargs["mask_other"],
|
||||
min_masks=2,
|
||||
no_overlap=self.mask_compute_kwargs["no_mask_overlap"],
|
||||
min_space=self.mask_compute_kwargs["mask_min_space"],
|
||||
)
|
||||
mask_indices = torch.from_numpy(mask_indices)
|
||||
if self.mask_compute_kwargs["mask_channel_prob"] > 0:
|
||||
mask_channel_indices = compute_mask_indices(
|
||||
(B, C),
|
||||
None,
|
||||
self.mask_compute_kwargs["mask_channel_prob"],
|
||||
self.mask_compute_kwargs["mask_channel_length"],
|
||||
self.mask_compute_kwargs["mask_channel_selection"],
|
||||
self.mask_compute_kwargs["mask_channel_other"],
|
||||
no_overlap=self.mask_compute_kwargs["no_mask_channel_overlap"],
|
||||
min_space=self.mask_compute_kwargs["mask_channel_min_space"],
|
||||
)
|
||||
mask_channel_indices = (
|
||||
torch.from_numpy(mask_channel_indices).unsqueeze(1).expand(-1, T, -1)
|
||||
)
|
||||
|
||||
return mask_indices, mask_channel_indices
|
||||
|
||||
@staticmethod
|
||||
def _bucket_tensor(tensor, num_pad, value):
|
||||
return F.pad(tensor, (0, num_pad), value=value)
|
||||
|
||||
def collater(self, samples):
|
||||
samples = [s for s in samples if s["source"] is not None]
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
sources = [s["source"] for s in samples]
|
||||
sizes = [len(s) for s in sources]
|
||||
|
||||
if self.pad:
|
||||
target_size = min(max(sizes), self.max_sample_size)
|
||||
else:
|
||||
target_size = min(min(sizes), self.max_sample_size)
|
||||
|
||||
collated_sources = sources[0].new_zeros(len(sources), target_size)
|
||||
padding_mask = (
|
||||
torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None
|
||||
)
|
||||
for i, (source, size) in enumerate(zip(sources, sizes)):
|
||||
diff = size - target_size
|
||||
if diff == 0:
|
||||
collated_sources[i] = source
|
||||
elif diff < 0:
|
||||
assert self.pad
|
||||
collated_sources[i] = torch.cat(
|
||||
[source, source.new_full((-diff,), 0.0)]
|
||||
)
|
||||
padding_mask[i, diff:] = True
|
||||
else:
|
||||
collated_sources[i] = self.crop_to_max_size(source, target_size)
|
||||
|
||||
input = {"source": collated_sources}
|
||||
out = {"id": torch.LongTensor([s["id"] for s in samples])}
|
||||
if self.pad:
|
||||
input["padding_mask"] = padding_mask
|
||||
|
||||
if hasattr(self, "num_buckets") and self.num_buckets > 0:
|
||||
assert self.pad, "Cannot bucket without padding first."
|
||||
bucket = max(self._bucketed_sizes[s["id"]] for s in samples)
|
||||
num_pad = bucket - collated_sources.size(-1)
|
||||
if num_pad:
|
||||
input["source"] = self._bucket_tensor(collated_sources, num_pad, 0)
|
||||
input["padding_mask"] = self._bucket_tensor(padding_mask, num_pad, True)
|
||||
|
||||
if self.compute_mask_indices:
|
||||
B = input["source"].size(0)
|
||||
T = self._get_mask_indices_dims(input["source"].size(-1))
|
||||
padding_mask_reshaped = input["padding_mask"].clone()
|
||||
extra = padding_mask_reshaped.size(1) % T
|
||||
if extra > 0:
|
||||
padding_mask_reshaped = padding_mask_reshaped[:, :-extra]
|
||||
padding_mask_reshaped = padding_mask_reshaped.view(
|
||||
padding_mask_reshaped.size(0), T, -1
|
||||
)
|
||||
padding_mask_reshaped = padding_mask_reshaped.all(-1)
|
||||
input["padding_count"] = padding_mask_reshaped.sum(-1).max().item()
|
||||
mask_indices, mask_channel_indices = self._compute_mask_indices(
|
||||
(B, T, self._C),
|
||||
padding_mask_reshaped,
|
||||
)
|
||||
input["mask_indices"] = mask_indices
|
||||
input["mask_channel_indices"] = mask_channel_indices
|
||||
out["sample_size"] = mask_indices.sum().item()
|
||||
|
||||
out["net_input"] = input
|
||||
return out
|
||||
|
||||
def _get_mask_indices_dims(self, size, padding=0, dilation=1):
|
||||
if size not in self._features_size_map:
|
||||
L_in = size
|
||||
for (_, kernel_size, stride) in self._conv_feature_layers:
|
||||
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
|
||||
L_out = 1 + L_out // stride
|
||||
L_in = L_out
|
||||
self._features_size_map[size] = L_out
|
||||
return self._features_size_map[size]
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.size(index)
|
||||
|
||||
def size(self, index):
|
||||
"""Return an example's size as a float or tuple. This value is used when
|
||||
filtering a dataset with ``--max-positions``."""
|
||||
if self.pad:
|
||||
return self.sizes[index]
|
||||
return min(self.sizes[index], self.max_sample_size)
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
|
||||
if self.shuffle:
|
||||
order = [np.random.permutation(len(self))]
|
||||
order.append(
|
||||
np.minimum(
|
||||
np.array(self.sizes),
|
||||
self.max_sample_size,
|
||||
)
|
||||
)
|
||||
return np.lexsort(order)[::-1]
|
||||
else:
|
||||
return np.arange(len(self))
|
||||
|
||||
def set_bucket_info(self, num_buckets):
|
||||
self.num_buckets = num_buckets
|
||||
if self.num_buckets > 0:
|
||||
self._collated_sizes = np.minimum(
|
||||
np.array(self.sizes),
|
||||
self.max_sample_size,
|
||||
)
|
||||
self.buckets = get_buckets(
|
||||
self._collated_sizes,
|
||||
self.num_buckets,
|
||||
)
|
||||
self._bucketed_sizes = get_bucketed_sizes(
|
||||
self._collated_sizes, self.buckets
|
||||
)
|
||||
logger.info(
|
||||
f"{len(self.buckets)} bucket(s) for the audio dataset: "
|
||||
f"{self.buckets}"
|
||||
)
|
||||
|
||||
|
||||
class FileAudioDataset(RawAudioDataset):
|
||||
def __init__(
|
||||
self,
|
||||
manifest_path,
|
||||
sample_rate,
|
||||
max_sample_size=None,
|
||||
min_sample_size=0,
|
||||
shuffle=True,
|
||||
pad=False,
|
||||
normalize=False,
|
||||
num_buckets=0,
|
||||
compute_mask_indices=False,
|
||||
text_compression_level=TextCompressionLevel.none,
|
||||
**mask_compute_kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
max_sample_size=max_sample_size,
|
||||
min_sample_size=min_sample_size,
|
||||
shuffle=shuffle,
|
||||
pad=pad,
|
||||
normalize=normalize,
|
||||
compute_mask_indices=compute_mask_indices,
|
||||
**mask_compute_kwargs,
|
||||
)
|
||||
|
||||
self.text_compressor = TextCompressor(level=text_compression_level)
|
||||
|
||||
skipped = 0
|
||||
self.fnames = []
|
||||
sizes = []
|
||||
self.skipped_indices = set()
|
||||
|
||||
with open(manifest_path, "r") as f:
|
||||
self.root_dir = f.readline().strip()
|
||||
for i, line in enumerate(f):
|
||||
items = line.strip().split("\t")
|
||||
assert len(items) == 2, line
|
||||
sz = int(items[1])
|
||||
if min_sample_size is not None and sz < min_sample_size:
|
||||
skipped += 1
|
||||
self.skipped_indices.add(i)
|
||||
continue
|
||||
self.fnames.append(self.text_compressor.compress(items[0]))
|
||||
sizes.append(sz)
|
||||
logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples")
|
||||
|
||||
self.sizes = np.array(sizes, dtype=np.int64)
|
||||
|
||||
try:
|
||||
import pyarrow
|
||||
|
||||
self.fnames = pyarrow.array(self.fnames)
|
||||
except:
|
||||
logger.debug(
|
||||
"Could not create a pyarrow array. Please install pyarrow for better performance"
|
||||
)
|
||||
pass
|
||||
|
||||
self.set_bucket_info(num_buckets)
|
||||
|
||||
def __getitem__(self, index):
|
||||
import soundfile as sf
|
||||
|
||||
fn = self.fnames[index]
|
||||
fn = fn if isinstance(self.fnames, list) else fn.as_py()
|
||||
fn = self.text_compressor.decompress(fn)
|
||||
path_or_fp = os.path.join(self.root_dir, fn)
|
||||
_path, slice_ptr = parse_path(path_or_fp)
|
||||
if len(slice_ptr) == 2:
|
||||
byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
|
||||
assert is_sf_audio_data(byte_data)
|
||||
path_or_fp = io.BytesIO(byte_data)
|
||||
|
||||
wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32")
|
||||
|
||||
feats = torch.from_numpy(wav).float()
|
||||
feats = self.postprocess(feats, curr_sample_rate)
|
||||
return {"id": index, "source": feats}
|
||||
|
||||
|
||||
class BinarizedAudioDataset(RawAudioDataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir,
|
||||
split,
|
||||
sample_rate,
|
||||
max_sample_size=None,
|
||||
min_sample_size=0,
|
||||
shuffle=True,
|
||||
pad=False,
|
||||
normalize=False,
|
||||
num_buckets=0,
|
||||
compute_mask_indices=False,
|
||||
**mask_compute_kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
max_sample_size=max_sample_size,
|
||||
min_sample_size=min_sample_size,
|
||||
shuffle=shuffle,
|
||||
pad=pad,
|
||||
normalize=normalize,
|
||||
compute_mask_indices=compute_mask_indices,
|
||||
**mask_compute_kwargs,
|
||||
)
|
||||
|
||||
from fairseq.data import data_utils, Dictionary
|
||||
|
||||
self.fnames_dict = Dictionary.load(os.path.join(data_dir, "dict.txt"))
|
||||
|
||||
root_path = os.path.join(data_dir, f"{split}.root")
|
||||
if os.path.exists(root_path):
|
||||
with open(root_path, "r") as f:
|
||||
self.root_dir = next(f).strip()
|
||||
else:
|
||||
self.root_dir = None
|
||||
|
||||
fnames_path = os.path.join(data_dir, split)
|
||||
self.fnames = data_utils.load_indexed_dataset(fnames_path, self.fnames_dict)
|
||||
lengths_path = os.path.join(data_dir, f"{split}.lengths")
|
||||
|
||||
with open(lengths_path, "r") as f:
|
||||
for line in f:
|
||||
sz = int(line.rstrip())
|
||||
assert (
|
||||
sz >= min_sample_size
|
||||
), f"Min sample size is not supported for binarized dataset, but found a sample with size {sz}"
|
||||
self.sizes.append(sz)
|
||||
|
||||
self.sizes = np.array(self.sizes, dtype=np.int64)
|
||||
|
||||
self.set_bucket_info(num_buckets)
|
||||
logger.info(f"loaded {len(self.fnames)} samples")
|
||||
|
||||
def __getitem__(self, index):
|
||||
import soundfile as sf
|
||||
|
||||
fname = self.fnames_dict.string(self.fnames[index], separator="")
|
||||
if self.root_dir:
|
||||
fname = os.path.join(self.root_dir, fname)
|
||||
|
||||
wav, curr_sample_rate = sf.read(fname)
|
||||
feats = torch.from_numpy(wav).float()
|
||||
feats = self.postprocess(feats, curr_sample_rate)
|
||||
return {"id": index, "source": feats}
|
||||
@@ -0,0 +1,379 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq.data import ConcatDataset, Dictionary
|
||||
from fairseq.data import data_utils as fairseq_data_utils
|
||||
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
||||
from fairseq.data.audio.data_cfg import S2SDataConfig
|
||||
from fairseq.data.audio.speech_to_text_dataset import (
|
||||
SpeechToTextDataset,
|
||||
SpeechToTextDatasetCreator,
|
||||
TextTargetMultitaskData,
|
||||
_collate_frames,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechToSpeechDatasetItem(object):
|
||||
index: int
|
||||
source: torch.Tensor
|
||||
target: Optional[torch.Tensor] = None
|
||||
target_speaker: Optional[torch.Tensor] = None
|
||||
tgt_lang_tag: Optional[int] = None
|
||||
|
||||
|
||||
class SpeechToSpeechDataset(SpeechToTextDataset):
|
||||
def __init__(
|
||||
self,
|
||||
split: str,
|
||||
is_train_split: bool,
|
||||
data_cfg: S2SDataConfig,
|
||||
src_audio_paths: List[str],
|
||||
src_n_frames: List[int],
|
||||
tgt_audio_paths: List[str],
|
||||
tgt_n_frames: List[int],
|
||||
src_langs: Optional[List[str]] = None,
|
||||
tgt_langs: Optional[List[str]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
target_is_code: bool = False,
|
||||
tgt_dict: Dictionary = None,
|
||||
n_frames_per_step: int = 1,
|
||||
):
|
||||
tgt_texts = tgt_audio_paths if target_is_code else None
|
||||
super().__init__(
|
||||
split=split,
|
||||
is_train_split=is_train_split,
|
||||
cfg=data_cfg,
|
||||
audio_paths=src_audio_paths,
|
||||
n_frames=src_n_frames,
|
||||
ids=ids,
|
||||
tgt_dict=tgt_dict,
|
||||
tgt_texts=tgt_texts,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
)
|
||||
|
||||
self.tgt_audio_paths = tgt_audio_paths
|
||||
self.tgt_lens = [t // self.n_frames_per_step for t in tgt_n_frames]
|
||||
|
||||
assert not target_is_code or tgt_dict is not None
|
||||
self.target_is_code = target_is_code
|
||||
|
||||
assert len(tgt_audio_paths) == self.n_samples
|
||||
assert len(tgt_n_frames) == self.n_samples
|
||||
|
||||
self.tgt_speakers = None
|
||||
if self.cfg.target_speaker_embed:
|
||||
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(
|
||||
self.cfg.target_speaker_embed, split
|
||||
)
|
||||
spk_emb_dict = {s["id"]: s["speaker_embed"] for s in samples}
|
||||
self.tgt_speakers = [spk_emb_dict[id] for id in self.ids]
|
||||
assert len(self.tgt_speakers) == self.n_samples
|
||||
|
||||
logger.info(self.__repr__())
|
||||
|
||||
def pack_units(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.n_frames_per_step <= 1:
|
||||
return input
|
||||
|
||||
offset = 4
|
||||
vocab_size = (
|
||||
len(self.tgt_dict) - offset
|
||||
) # remove offset from <bos>, <pad>, <eos>, <unk>, which is specific to fairseq dictionary
|
||||
|
||||
assert input.dim() == 1
|
||||
stacked_input = (
|
||||
input[:-1].view(-1, self.n_frames_per_step) - offset
|
||||
) # remove <eos>
|
||||
scale = [
|
||||
pow(vocab_size, self.n_frames_per_step - 1 - i)
|
||||
for i in range(self.n_frames_per_step)
|
||||
]
|
||||
scale = torch.LongTensor(scale).squeeze(0)
|
||||
res = input.new((len(input) - 1) // self.n_frames_per_step + 1).fill_(input[-1])
|
||||
res[:-1] = (stacked_input * scale).sum(dim=1) + offset
|
||||
|
||||
return res
|
||||
|
||||
def __getitem__(self, index: int) -> SpeechToSpeechDatasetItem:
|
||||
source = self._get_source_audio(index)
|
||||
|
||||
tgt_lang_tag = None
|
||||
if self.cfg.prepend_tgt_lang_tag_as_bos:
|
||||
# prepend_tgt_lang_tag_as_bos: put tgt_lang_tag as bos of target
|
||||
tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
|
||||
|
||||
if not self.target_is_code:
|
||||
target = get_features_or_waveform(self.tgt_audio_paths[index])
|
||||
target = torch.from_numpy(target).float()
|
||||
target = self.pack_frames(target)
|
||||
else:
|
||||
target = self.tgt_dict.encode_line(
|
||||
self.tgt_audio_paths[index],
|
||||
add_if_not_exist=False,
|
||||
append_eos=True,
|
||||
).long()
|
||||
if self.n_frames_per_step > 1:
|
||||
n_tgt_frame = target.size(0) - 1 # exclude <eos>
|
||||
keep_n_tgt_frame = n_tgt_frame - n_tgt_frame % self.n_frames_per_step
|
||||
target = torch.cat(
|
||||
(
|
||||
target[:keep_n_tgt_frame],
|
||||
target.new_full((1,), self.tgt_dict.eos()),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if self.tgt_speakers:
|
||||
tgt_spk = get_features_or_waveform(self.tgt_speakers[index])
|
||||
tgt_spk = torch.from_numpy(tgt_spk).float()
|
||||
else:
|
||||
tgt_spk = torch.FloatTensor([])
|
||||
|
||||
return SpeechToSpeechDatasetItem(
|
||||
index=index,
|
||||
source=source,
|
||||
target=target,
|
||||
target_speaker=tgt_spk,
|
||||
tgt_lang_tag=tgt_lang_tag,
|
||||
)
|
||||
|
||||
def _collate_target(self, samples: List[SpeechToSpeechDatasetItem]) -> torch.Tensor:
|
||||
if self.target_is_code:
|
||||
target = fairseq_data_utils.collate_tokens(
|
||||
[x.target for x in samples],
|
||||
self.tgt_dict.pad(),
|
||||
self.tgt_dict.eos(),
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
)
|
||||
# convert stacked units to a single id
|
||||
pack_targets = [self.pack_units(x.target) for x in samples]
|
||||
prev_output_tokens = fairseq_data_utils.collate_tokens(
|
||||
pack_targets,
|
||||
self.tgt_dict.pad(),
|
||||
self.tgt_dict.eos(),
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=True,
|
||||
)
|
||||
target_lengths = torch.tensor(
|
||||
[x.size(0) for x in pack_targets], dtype=torch.long
|
||||
)
|
||||
else:
|
||||
target = _collate_frames([x.target for x in samples], is_audio_input=False)
|
||||
bsz, _, d = target.size()
|
||||
prev_output_tokens = torch.cat(
|
||||
(target.new_full((bsz, 1, d), 0.0), target[:, :-1, :]), dim=1
|
||||
)
|
||||
target_lengths = torch.tensor(
|
||||
[x.target.size(0) for x in samples], dtype=torch.long
|
||||
)
|
||||
|
||||
return target, prev_output_tokens, target_lengths
|
||||
|
||||
def collater(
|
||||
self, samples: List[SpeechToSpeechDatasetItem], return_order: bool = False
|
||||
) -> Dict:
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
indices = torch.tensor([x.index for x in samples], dtype=torch.long)
|
||||
frames = _collate_frames([x.source for x in samples], self.cfg.use_audio_input)
|
||||
# sort samples by descending number of frames
|
||||
n_frames = torch.tensor([x.source.size(0) for x in samples], dtype=torch.long)
|
||||
n_frames, order = n_frames.sort(descending=True)
|
||||
indices = indices.index_select(0, order)
|
||||
frames = frames.index_select(0, order)
|
||||
|
||||
target, prev_output_tokens, target_lengths = self._collate_target(samples)
|
||||
target = target.index_select(0, order)
|
||||
target_lengths = target_lengths.index_select(0, order)
|
||||
prev_output_tokens = prev_output_tokens.index_select(0, order)
|
||||
ntokens = sum(x.target.size(0) for x in samples)
|
||||
|
||||
tgt_speakers = None
|
||||
if self.cfg.target_speaker_embed:
|
||||
tgt_speakers = _collate_frames(
|
||||
[x.target_speaker for x in samples], is_audio_input=True
|
||||
).index_select(0, order)
|
||||
|
||||
net_input = {
|
||||
"src_tokens": frames,
|
||||
"src_lengths": n_frames,
|
||||
"prev_output_tokens": prev_output_tokens,
|
||||
"tgt_speaker": tgt_speakers, # TODO: unify "speaker" and "tgt_speaker"
|
||||
}
|
||||
if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
|
||||
for i in range(len(samples)):
|
||||
net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
|
||||
out = {
|
||||
"id": indices,
|
||||
"net_input": net_input,
|
||||
"speaker": tgt_speakers, # to support Tacotron2 loss for speech-to-spectrogram model
|
||||
"target": target,
|
||||
"target_lengths": target_lengths,
|
||||
"ntokens": ntokens,
|
||||
"nsentences": len(samples),
|
||||
}
|
||||
if return_order:
|
||||
out["order"] = order
|
||||
return out
|
||||
|
||||
|
||||
class SpeechToSpeechMultitaskDataset(SpeechToSpeechDataset):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.multitask_data = {}
|
||||
|
||||
def add_multitask_dataset(self, task_name, task_data):
|
||||
self.multitask_data[task_name] = task_data
|
||||
|
||||
def __getitem__(
|
||||
self, index: int
|
||||
) -> Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]:
|
||||
s2s_data = super().__getitem__(index)
|
||||
|
||||
multitask_target = {}
|
||||
sample_id = self.ids[index]
|
||||
tgt_lang = self.tgt_langs[index]
|
||||
for task_name, task_dataset in self.multitask_data.items():
|
||||
multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
|
||||
|
||||
return s2s_data, multitask_target
|
||||
|
||||
def collater(
|
||||
self, samples: List[Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]]
|
||||
) -> Dict:
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
out = super().collater([s for s, _ in samples], return_order=True)
|
||||
order = out["order"]
|
||||
del out["order"]
|
||||
|
||||
for task_name, task_dataset in self.multitask_data.items():
|
||||
if "multitask" not in out:
|
||||
out["multitask"] = {}
|
||||
d = [s[task_name] for _, s in samples]
|
||||
task_target = task_dataset.collater(d)
|
||||
out["multitask"][task_name] = {
|
||||
"target": task_target["target"].index_select(0, order),
|
||||
"target_lengths": task_target["target_lengths"].index_select(0, order),
|
||||
"ntokens": task_target["ntokens"],
|
||||
}
|
||||
out["multitask"][task_name]["net_input"] = {
|
||||
"prev_output_tokens": task_target["prev_output_tokens"].index_select(
|
||||
0, order
|
||||
),
|
||||
}
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SpeechToSpeechDatasetCreator(object):
|
||||
# mandatory columns
|
||||
KEY_ID, KEY_SRC_AUDIO, KEY_SRC_N_FRAMES = "id", "src_audio", "src_n_frames"
|
||||
KEY_TGT_AUDIO, KEY_TGT_N_FRAMES = "tgt_audio", "tgt_n_frames"
|
||||
# optional columns
|
||||
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
|
||||
# default values
|
||||
DEFAULT_LANG = ""
|
||||
|
||||
@classmethod
|
||||
def _from_list(
|
||||
cls,
|
||||
split_name: str,
|
||||
is_train_split,
|
||||
samples: List[Dict],
|
||||
data_cfg: S2SDataConfig,
|
||||
target_is_code: bool = False,
|
||||
tgt_dict: Dictionary = None,
|
||||
n_frames_per_step: int = 1,
|
||||
multitask: Optional[Dict] = None,
|
||||
) -> SpeechToSpeechDataset:
|
||||
audio_root = Path(data_cfg.audio_root)
|
||||
ids = [s[cls.KEY_ID] for s in samples]
|
||||
src_audio_paths = [
|
||||
(audio_root / s[cls.KEY_SRC_AUDIO]).as_posix() for s in samples
|
||||
]
|
||||
tgt_audio_paths = [
|
||||
s[cls.KEY_TGT_AUDIO]
|
||||
if target_is_code
|
||||
else (audio_root / s[cls.KEY_TGT_AUDIO]).as_posix()
|
||||
for s in samples
|
||||
]
|
||||
src_n_frames = [int(s[cls.KEY_SRC_N_FRAMES]) for s in samples]
|
||||
tgt_n_frames = [int(s[cls.KEY_TGT_N_FRAMES]) for s in samples]
|
||||
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
|
||||
has_multitask = multitask is not None and len(multitask.keys()) > 0
|
||||
dataset_cls = (
|
||||
SpeechToSpeechMultitaskDataset if has_multitask else SpeechToSpeechDataset
|
||||
)
|
||||
|
||||
ds = dataset_cls(
|
||||
split=split_name,
|
||||
is_train_split=is_train_split,
|
||||
data_cfg=data_cfg,
|
||||
src_audio_paths=src_audio_paths,
|
||||
src_n_frames=src_n_frames,
|
||||
tgt_audio_paths=tgt_audio_paths,
|
||||
tgt_n_frames=tgt_n_frames,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
ids=ids,
|
||||
target_is_code=target_is_code,
|
||||
tgt_dict=tgt_dict,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
)
|
||||
|
||||
if has_multitask:
|
||||
for task_name, task_obj in multitask.items():
|
||||
task_data = TextTargetMultitaskData(
|
||||
task_obj.args, split_name, task_obj.target_dictionary
|
||||
)
|
||||
ds.add_multitask_dataset(task_name, task_data)
|
||||
return ds
|
||||
|
||||
@classmethod
|
||||
def from_tsv(
|
||||
cls,
|
||||
root: str,
|
||||
data_cfg: S2SDataConfig,
|
||||
splits: str,
|
||||
is_train_split: bool,
|
||||
epoch: int,
|
||||
seed: int,
|
||||
target_is_code: bool = False,
|
||||
tgt_dict: Dictionary = None,
|
||||
n_frames_per_step: int = 1,
|
||||
multitask: Optional[Dict] = None,
|
||||
) -> SpeechToSpeechDataset:
|
||||
datasets = []
|
||||
for split in splits.split(","):
|
||||
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(root, split)
|
||||
ds = cls._from_list(
|
||||
split_name=split,
|
||||
is_train_split=is_train_split,
|
||||
samples=samples,
|
||||
data_cfg=data_cfg,
|
||||
target_is_code=target_is_code,
|
||||
tgt_dict=tgt_dict,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
multitask=multitask,
|
||||
)
|
||||
datasets.append(ds)
|
||||
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
||||
@@ -0,0 +1,733 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import csv
|
||||
import logging
|
||||
import re
|
||||
from argparse import Namespace
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset
|
||||
from fairseq.data import data_utils as fairseq_data_utils
|
||||
from fairseq.data import encoders
|
||||
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
||||
from fairseq.data.audio.data_cfg import S2TDataConfig
|
||||
from fairseq.data.audio.dataset_transforms import CompositeAudioDatasetTransform
|
||||
from fairseq.data.audio.dataset_transforms.concataugment import ConcatAugment
|
||||
from fairseq.data.audio.dataset_transforms.noisyoverlapaugment import (
|
||||
NoisyOverlapAugment,
|
||||
)
|
||||
from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform
|
||||
from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _collate_frames(
|
||||
frames: List[torch.Tensor], is_audio_input: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert a list of 2D frames into a padded 3D tensor
|
||||
Args:
|
||||
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
|
||||
length of i-th frame and f_dim is static dimension of features
|
||||
Returns:
|
||||
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
|
||||
"""
|
||||
max_len = max(frame.size(0) for frame in frames)
|
||||
if is_audio_input:
|
||||
out = frames[0].new_zeros((len(frames), max_len))
|
||||
else:
|
||||
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
|
||||
for i, v in enumerate(frames):
|
||||
out[i, : v.size(0)] = v
|
||||
return out
|
||||
|
||||
|
||||
def _is_int_or_np_int(n):
|
||||
return isinstance(n, int) or (
|
||||
isinstance(n, np.generic) and isinstance(n.item(), int)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechToTextDatasetItem(object):
|
||||
index: int
|
||||
source: torch.Tensor
|
||||
target: Optional[torch.Tensor] = None
|
||||
speaker_id: Optional[int] = None
|
||||
|
||||
|
||||
class SpeechToTextDataset(FairseqDataset):
|
||||
LANG_TAG_TEMPLATE = "<lang:{}>"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split: str,
|
||||
is_train_split: bool,
|
||||
cfg: S2TDataConfig,
|
||||
audio_paths: List[str],
|
||||
n_frames: List[int],
|
||||
src_texts: Optional[List[str]] = None,
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
speakers: Optional[List[str]] = None,
|
||||
src_langs: Optional[List[str]] = None,
|
||||
tgt_langs: Optional[List[str]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
tgt_dict: Optional[Dictionary] = None,
|
||||
pre_tokenizer=None,
|
||||
bpe_tokenizer=None,
|
||||
n_frames_per_step=1,
|
||||
speaker_to_id=None,
|
||||
append_eos=True,
|
||||
):
|
||||
self.split, self.is_train_split = split, is_train_split
|
||||
self.cfg = cfg
|
||||
self.audio_paths, self.n_frames = audio_paths, n_frames
|
||||
self.n_samples = len(audio_paths)
|
||||
assert len(n_frames) == self.n_samples > 0
|
||||
assert src_texts is None or len(src_texts) == self.n_samples
|
||||
assert tgt_texts is None or len(tgt_texts) == self.n_samples
|
||||
assert speakers is None or len(speakers) == self.n_samples
|
||||
assert src_langs is None or len(src_langs) == self.n_samples
|
||||
assert tgt_langs is None or len(tgt_langs) == self.n_samples
|
||||
assert ids is None or len(ids) == self.n_samples
|
||||
assert (tgt_dict is None and tgt_texts is None) or (
|
||||
tgt_dict is not None and tgt_texts is not None
|
||||
)
|
||||
self.src_texts, self.tgt_texts = src_texts, tgt_texts
|
||||
self.src_langs, self.tgt_langs = src_langs, tgt_langs
|
||||
self.speakers = speakers
|
||||
self.tgt_dict = tgt_dict
|
||||
self.check_tgt_lang_tag()
|
||||
self.ids = ids
|
||||
self.shuffle = cfg.shuffle if is_train_split else False
|
||||
|
||||
self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
|
||||
self.cfg.get_feature_transforms(split, is_train_split)
|
||||
)
|
||||
self.waveform_transforms = CompositeAudioWaveformTransform.from_config_dict(
|
||||
self.cfg.get_waveform_transforms(split, is_train_split)
|
||||
)
|
||||
# TODO: add these to data_cfg.py
|
||||
self.dataset_transforms = CompositeAudioDatasetTransform.from_config_dict(
|
||||
self.cfg.get_dataset_transforms(split, is_train_split)
|
||||
)
|
||||
|
||||
# check proper usage of transforms
|
||||
if self.feature_transforms and self.cfg.use_audio_input:
|
||||
logger.warning(
|
||||
"Feature transforms will not be applied. To use feature transforms, "
|
||||
"set use_audio_input as False in config."
|
||||
)
|
||||
|
||||
self.pre_tokenizer = pre_tokenizer
|
||||
self.bpe_tokenizer = bpe_tokenizer
|
||||
self.n_frames_per_step = n_frames_per_step
|
||||
self.speaker_to_id = speaker_to_id
|
||||
|
||||
self.tgt_lens = self.get_tgt_lens_and_check_oov()
|
||||
self.append_eos = append_eos
|
||||
|
||||
logger.info(self.__repr__())
|
||||
|
||||
def get_tgt_lens_and_check_oov(self):
|
||||
if self.tgt_texts is None:
|
||||
return [0 for _ in range(self.n_samples)]
|
||||
tgt_lens = []
|
||||
n_tokens, n_oov_tokens = 0, 0
|
||||
for i in range(self.n_samples):
|
||||
tokenized = self.get_tokenized_tgt_text(i).split(" ")
|
||||
oov_tokens = [
|
||||
t
|
||||
for t in tokenized
|
||||
if self.tgt_dict.index(t) == self.tgt_dict.unk_index
|
||||
]
|
||||
n_tokens += len(tokenized)
|
||||
n_oov_tokens += len(oov_tokens)
|
||||
tgt_lens.append(len(tokenized))
|
||||
logger.info(f"'{self.split}' has {n_oov_tokens / n_tokens * 100:.2f}% OOV")
|
||||
return tgt_lens
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ f'(split="{self.split}", n_samples={self.n_samples:_}, '
|
||||
f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, "
|
||||
f"n_frames_per_step={self.n_frames_per_step}, "
|
||||
f"shuffle={self.shuffle}, "
|
||||
f"feature_transforms={self.feature_transforms}, "
|
||||
f"waveform_transforms={self.waveform_transforms}, "
|
||||
f"dataset_transforms={self.dataset_transforms})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lang_tag(cls, token):
|
||||
pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
|
||||
return re.match(pattern, token)
|
||||
|
||||
def check_tgt_lang_tag(self):
|
||||
if self.cfg.prepend_tgt_lang_tag:
|
||||
assert self.tgt_langs is not None and self.tgt_dict is not None
|
||||
tgt_lang_tags = [
|
||||
self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
|
||||
]
|
||||
assert all(t in self.tgt_dict for t in tgt_lang_tags)
|
||||
|
||||
@classmethod
|
||||
def tokenize(cls, tokenizer, text: str):
|
||||
return text if tokenizer is None else tokenizer.encode(text)
|
||||
|
||||
def get_tokenized_tgt_text(self, index: Union[int, List[int]]):
|
||||
if _is_int_or_np_int(index):
|
||||
text = self.tgt_texts[index]
|
||||
else:
|
||||
text = " ".join([self.tgt_texts[i] for i in index])
|
||||
|
||||
text = self.tokenize(self.pre_tokenizer, text)
|
||||
text = self.tokenize(self.bpe_tokenizer, text)
|
||||
return text
|
||||
|
||||
def pack_frames(self, feature: torch.Tensor):
|
||||
if self.n_frames_per_step == 1:
|
||||
return feature
|
||||
n_packed_frames = feature.shape[0] // self.n_frames_per_step
|
||||
feature = feature[: self.n_frames_per_step * n_packed_frames]
|
||||
return feature.reshape(n_packed_frames, -1)
|
||||
|
||||
@classmethod
|
||||
def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary):
|
||||
lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang))
|
||||
assert lang_tag_idx != dictionary.unk()
|
||||
return lang_tag_idx
|
||||
|
||||
def _get_source_audio(self, index: Union[int, List[int]]) -> torch.Tensor:
|
||||
"""
|
||||
Gives source audio for given index with any relevant transforms
|
||||
applied. For ConcatAug, source audios for given indices are
|
||||
concatenated in given order.
|
||||
Args:
|
||||
index (int or List[int]): index—or in the case of ConcatAug,
|
||||
indices—to pull the source audio for
|
||||
Returns:
|
||||
source audios concatenated for given indices with
|
||||
relevant transforms appplied
|
||||
"""
|
||||
if _is_int_or_np_int(index):
|
||||
source = get_features_or_waveform(
|
||||
self.audio_paths[index],
|
||||
need_waveform=self.cfg.use_audio_input,
|
||||
use_sample_rate=self.cfg.use_sample_rate,
|
||||
waveform_transforms=self.waveform_transforms,
|
||||
)
|
||||
else:
|
||||
source = np.concatenate(
|
||||
[
|
||||
get_features_or_waveform(
|
||||
self.audio_paths[i],
|
||||
need_waveform=self.cfg.use_audio_input,
|
||||
use_sample_rate=self.cfg.use_sample_rate,
|
||||
waveform_transforms=self.waveform_transforms,
|
||||
)
|
||||
for i in index
|
||||
]
|
||||
)
|
||||
if self.cfg.use_audio_input:
|
||||
source = torch.from_numpy(source).float()
|
||||
if self.cfg.standardize_audio:
|
||||
with torch.no_grad():
|
||||
source = F.layer_norm(source, source.shape)
|
||||
else:
|
||||
if self.feature_transforms is not None:
|
||||
source = self.feature_transforms(source)
|
||||
source = torch.from_numpy(source).float()
|
||||
return source
|
||||
|
||||
def __getitem__(self, index: int) -> SpeechToTextDatasetItem:
|
||||
has_concat = self.dataset_transforms.has_transform(ConcatAugment)
|
||||
if has_concat:
|
||||
concat = self.dataset_transforms.get_transform(ConcatAugment)
|
||||
indices = concat.find_indices(index, self.n_frames, self.n_samples)
|
||||
|
||||
source = self._get_source_audio(indices if has_concat else index)
|
||||
source = self.pack_frames(source)
|
||||
|
||||
target = None
|
||||
if self.tgt_texts is not None:
|
||||
tokenized = self.get_tokenized_tgt_text(indices if has_concat else index)
|
||||
target = self.tgt_dict.encode_line(
|
||||
tokenized, add_if_not_exist=False, append_eos=self.append_eos
|
||||
).long()
|
||||
if self.cfg.prepend_tgt_lang_tag:
|
||||
lang_tag_idx = self.get_lang_tag_idx(
|
||||
self.tgt_langs[index], self.tgt_dict
|
||||
)
|
||||
target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
|
||||
|
||||
if self.cfg.prepend_bos_and_append_tgt_lang_tag:
|
||||
bos = torch.LongTensor([self.tgt_dict.bos()])
|
||||
lang_tag_idx = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
|
||||
assert lang_tag_idx != self.tgt_dict.unk()
|
||||
lang_tag_idx = torch.LongTensor([lang_tag_idx])
|
||||
target = torch.cat((bos, target, lang_tag_idx), 0)
|
||||
|
||||
speaker_id = None
|
||||
if self.speaker_to_id is not None:
|
||||
speaker_id = self.speaker_to_id[self.speakers[index]]
|
||||
return SpeechToTextDatasetItem(
|
||||
index=index, source=source, target=target, speaker_id=speaker_id
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.n_samples
|
||||
|
||||
def collater(
|
||||
self, samples: List[SpeechToTextDatasetItem], return_order: bool = False
|
||||
) -> Dict:
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
indices = torch.tensor([x.index for x in samples], dtype=torch.long)
|
||||
|
||||
sources = [x.source for x in samples]
|
||||
has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment)
|
||||
if has_NOAug and self.cfg.use_audio_input:
|
||||
NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment)
|
||||
sources = NOAug(sources)
|
||||
|
||||
frames = _collate_frames(sources, self.cfg.use_audio_input)
|
||||
# sort samples by descending number of frames
|
||||
n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long)
|
||||
n_frames, order = n_frames.sort(descending=True)
|
||||
indices = indices.index_select(0, order)
|
||||
frames = frames.index_select(0, order)
|
||||
|
||||
target, target_lengths = None, None
|
||||
prev_output_tokens = None
|
||||
ntokens = None
|
||||
if self.tgt_texts is not None:
|
||||
target = fairseq_data_utils.collate_tokens(
|
||||
[x.target for x in samples],
|
||||
self.tgt_dict.pad(),
|
||||
self.tgt_dict.eos(),
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
)
|
||||
target = target.index_select(0, order)
|
||||
target_lengths = torch.tensor(
|
||||
[x.target.size(0) for x in samples], dtype=torch.long
|
||||
).index_select(0, order)
|
||||
prev_output_tokens = fairseq_data_utils.collate_tokens(
|
||||
[x.target for x in samples],
|
||||
self.tgt_dict.pad(),
|
||||
eos_idx=None,
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=True,
|
||||
)
|
||||
prev_output_tokens = prev_output_tokens.index_select(0, order)
|
||||
ntokens = sum(x.target.size(0) for x in samples)
|
||||
|
||||
speaker = None
|
||||
if self.speaker_to_id is not None:
|
||||
speaker = (
|
||||
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
|
||||
.index_select(0, order)
|
||||
.view(-1, 1)
|
||||
)
|
||||
|
||||
net_input = {
|
||||
"src_tokens": frames,
|
||||
"src_lengths": n_frames,
|
||||
"prev_output_tokens": prev_output_tokens,
|
||||
}
|
||||
out = {
|
||||
"id": indices,
|
||||
"net_input": net_input,
|
||||
"speaker": speaker,
|
||||
"target": target,
|
||||
"target_lengths": target_lengths,
|
||||
"ntokens": ntokens,
|
||||
"nsentences": len(samples),
|
||||
}
|
||||
if return_order:
|
||||
out["order"] = order
|
||||
return out
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.n_frames[index]
|
||||
|
||||
def size(self, index):
|
||||
return self.n_frames[index], self.tgt_lens[index]
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return np.array(self.n_frames)
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return True
|
||||
|
||||
def ordered_indices(self):
|
||||
if self.shuffle:
|
||||
order = [np.random.permutation(len(self))]
|
||||
else:
|
||||
order = [np.arange(len(self))]
|
||||
# first by descending order of # of frames then by original/random order
|
||||
order.append([-n for n in self.n_frames])
|
||||
return np.lexsort(order)
|
||||
|
||||
def prefetch(self, indices):
|
||||
raise False
|
||||
|
||||
|
||||
class TextTargetMultitaskData(object):
|
||||
# mandatory columns
|
||||
KEY_ID, KEY_TEXT = "id", "tgt_text"
|
||||
LANG_TAG_TEMPLATE = "<lang:{}>"
|
||||
|
||||
def __init__(self, args, split, tgt_dict):
|
||||
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split)
|
||||
self.data = {s[self.KEY_ID]: s[self.KEY_TEXT] for s in samples}
|
||||
self.dict = tgt_dict
|
||||
self.append_eos = args.decoder_type != "ctc"
|
||||
self.pre_tokenizer = self.build_tokenizer(args)
|
||||
self.bpe_tokenizer = self.build_bpe(args)
|
||||
self.prepend_bos_and_append_tgt_lang_tag = (
|
||||
args.prepend_bos_and_append_tgt_lang_tag
|
||||
)
|
||||
self.eos_token = args.eos_token
|
||||
self.lang_tag_mapping = args.get_lang_tag_mapping
|
||||
|
||||
@classmethod
|
||||
def is_lang_tag(cls, token):
|
||||
pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
|
||||
return re.match(pattern, token)
|
||||
|
||||
@classmethod
|
||||
def tokenize(cls, tokenizer, text: str):
|
||||
return text if tokenizer is None else tokenizer.encode(text)
|
||||
|
||||
def get_tokenized_tgt_text(self, index: int):
|
||||
text = self.tokenize(self.pre_tokenizer, self.data[index])
|
||||
text = self.tokenize(self.bpe_tokenizer, text)
|
||||
return text
|
||||
|
||||
def get_lang_tag_idx(self, lang: str, dictionary: Dictionary):
|
||||
lang_tag = self.LANG_TAG_TEMPLATE.format(lang)
|
||||
lang_tag = self.lang_tag_mapping.get(lang_tag, lang_tag)
|
||||
lang_tag_idx = dictionary.index(lang_tag)
|
||||
assert lang_tag_idx != dictionary.unk(), (lang, lang_tag)
|
||||
return lang_tag_idx
|
||||
|
||||
def build_tokenizer(self, args):
|
||||
pre_tokenizer = args.config.get("pre_tokenizer")
|
||||
if pre_tokenizer is not None:
|
||||
logger.info(f"pre-tokenizer: {pre_tokenizer}")
|
||||
return encoders.build_tokenizer(Namespace(**pre_tokenizer))
|
||||
else:
|
||||
return None
|
||||
|
||||
def build_bpe(self, args):
|
||||
bpe_tokenizer = args.config.get("bpe_tokenizer")
|
||||
if bpe_tokenizer is not None:
|
||||
logger.info(f"tokenizer: {bpe_tokenizer}")
|
||||
return encoders.build_bpe(Namespace(**bpe_tokenizer))
|
||||
else:
|
||||
return None
|
||||
|
||||
def get(self, sample_id, tgt_lang=None):
|
||||
if sample_id in self.data:
|
||||
tokenized = self.get_tokenized_tgt_text(sample_id)
|
||||
target = self.dict.encode_line(
|
||||
tokenized,
|
||||
add_if_not_exist=False,
|
||||
append_eos=self.append_eos,
|
||||
)
|
||||
if self.prepend_bos_and_append_tgt_lang_tag:
|
||||
bos = torch.LongTensor([self.dict.bos()])
|
||||
lang_tag_idx = self.get_lang_tag_idx(tgt_lang, self.dict)
|
||||
assert lang_tag_idx != self.dict.unk()
|
||||
lang_tag_idx = torch.LongTensor([lang_tag_idx])
|
||||
target = torch.cat((bos, target, lang_tag_idx), 0)
|
||||
return target
|
||||
else:
|
||||
logger.warning(f"no target for {sample_id}")
|
||||
return torch.IntTensor([])
|
||||
|
||||
def collater(self, samples: List[torch.Tensor]) -> torch.Tensor:
|
||||
out = fairseq_data_utils.collate_tokens(
|
||||
samples,
|
||||
self.dict.pad(),
|
||||
eos_idx=None,
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
).long()
|
||||
|
||||
prev_out = fairseq_data_utils.collate_tokens(
|
||||
samples,
|
||||
self.dict.pad(),
|
||||
eos_idx=None,
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=True,
|
||||
).long()
|
||||
|
||||
target_lengths = torch.tensor([t.size(0) for t in samples], dtype=torch.long)
|
||||
ntokens = sum(t.size(0) for t in samples)
|
||||
|
||||
output = {
|
||||
"prev_output_tokens": prev_out,
|
||||
"target": out,
|
||||
"target_lengths": target_lengths,
|
||||
"ntokens": ntokens,
|
||||
}
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class SpeechToTextMultitaskDataset(SpeechToTextDataset):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.multitask_data = {}
|
||||
|
||||
def add_multitask_dataset(self, task_name, task_data):
|
||||
self.multitask_data[task_name] = task_data
|
||||
|
||||
def __getitem__(
|
||||
self, index: int
|
||||
) -> Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]:
|
||||
s2t_data = super().__getitem__(index)
|
||||
|
||||
multitask_target = {}
|
||||
sample_id = self.ids[index]
|
||||
tgt_lang = self.tgt_langs[index]
|
||||
for task_name, task_dataset in self.multitask_data.items():
|
||||
multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
|
||||
|
||||
return s2t_data, multitask_target
|
||||
|
||||
def collater(
|
||||
self, samples: List[Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]]
|
||||
) -> Dict:
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
out = super().collater([s for s, _ in samples], return_order=True)
|
||||
order = out["order"]
|
||||
del out["order"]
|
||||
|
||||
for task_name, task_dataset in self.multitask_data.items():
|
||||
if "multitask" not in out:
|
||||
out["multitask"] = {}
|
||||
d = [s[task_name] for _, s in samples]
|
||||
task_target = task_dataset.collater(d)
|
||||
out["multitask"][task_name] = {
|
||||
"target": task_target["target"].index_select(0, order),
|
||||
"target_lengths": task_target["target_lengths"].index_select(0, order),
|
||||
"ntokens": task_target["ntokens"],
|
||||
}
|
||||
out["multitask"][task_name]["net_input"] = {
|
||||
"prev_output_tokens": task_target["prev_output_tokens"].index_select(
|
||||
0, order
|
||||
),
|
||||
}
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SpeechToTextDatasetCreator(object):
|
||||
# mandatory columns
|
||||
KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames"
|
||||
KEY_TGT_TEXT = "tgt_text"
|
||||
# optional columns
|
||||
KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text"
|
||||
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
|
||||
# default values
|
||||
DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = ""
|
||||
|
||||
@classmethod
|
||||
def _from_list(
|
||||
cls,
|
||||
split_name: str,
|
||||
is_train_split,
|
||||
samples: List[Dict],
|
||||
cfg: S2TDataConfig,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
n_frames_per_step,
|
||||
speaker_to_id,
|
||||
multitask: Optional[Dict] = None,
|
||||
) -> SpeechToTextDataset:
|
||||
audio_root = Path(cfg.audio_root)
|
||||
ids = [s[cls.KEY_ID] for s in samples]
|
||||
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
|
||||
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
|
||||
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
|
||||
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
|
||||
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
|
||||
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
|
||||
has_multitask = multitask is not None and len(multitask.keys()) > 0
|
||||
dataset_cls = (
|
||||
SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset
|
||||
)
|
||||
|
||||
ds = dataset_cls(
|
||||
split=split_name,
|
||||
is_train_split=is_train_split,
|
||||
cfg=cfg,
|
||||
audio_paths=audio_paths,
|
||||
n_frames=n_frames,
|
||||
src_texts=src_texts,
|
||||
tgt_texts=tgt_texts,
|
||||
speakers=speakers,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
ids=ids,
|
||||
tgt_dict=tgt_dict,
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
speaker_to_id=speaker_to_id,
|
||||
)
|
||||
|
||||
if has_multitask:
|
||||
for task_name, task_obj in multitask.items():
|
||||
task_data = TextTargetMultitaskData(
|
||||
task_obj.args, split_name, task_obj.target_dictionary
|
||||
)
|
||||
ds.add_multitask_dataset(task_name, task_data)
|
||||
return ds
|
||||
|
||||
@classmethod
|
||||
def get_size_ratios(
|
||||
cls, datasets: List[SpeechToTextDataset], alpha: float = 1.0
|
||||
) -> List[float]:
|
||||
"""Size ratios for temperature-based sampling
|
||||
(https://arxiv.org/abs/1907.05019)"""
|
||||
|
||||
id_to_lp, lp_to_sz = {}, defaultdict(int)
|
||||
for ds in datasets:
|
||||
lang_pairs = {f"{s}->{t}" for s, t in zip(ds.src_langs, ds.tgt_langs)}
|
||||
assert len(lang_pairs) == 1
|
||||
lang_pair = list(lang_pairs)[0]
|
||||
id_to_lp[ds.split] = lang_pair
|
||||
lp_to_sz[lang_pair] += sum(ds.n_frames)
|
||||
|
||||
sz_sum = sum(v for v in lp_to_sz.values())
|
||||
lp_to_prob = {k: v / sz_sum for k, v in lp_to_sz.items()}
|
||||
lp_to_tgt_prob = {k: v**alpha for k, v in lp_to_prob.items()}
|
||||
prob_sum = sum(v for v in lp_to_tgt_prob.values())
|
||||
lp_to_tgt_prob = {k: v / prob_sum for k, v in lp_to_tgt_prob.items()}
|
||||
lp_to_sz_ratio = {
|
||||
k: (lp_to_tgt_prob[k] * sz_sum) / v for k, v in lp_to_sz.items()
|
||||
}
|
||||
size_ratio = [lp_to_sz_ratio[id_to_lp[ds.split]] for ds in datasets]
|
||||
|
||||
p_formatted = {
|
||||
k: f"{lp_to_prob[k]:.3f}->{lp_to_tgt_prob[k]:.3f}" for k in lp_to_sz
|
||||
}
|
||||
logger.info(f"sampling probability balancing: {p_formatted}")
|
||||
sr_formatted = {ds.split: f"{r:.3f}" for ds, r in zip(datasets, size_ratio)}
|
||||
logger.info(f"balanced sampling size ratio: {sr_formatted}")
|
||||
return size_ratio
|
||||
|
||||
@classmethod
|
||||
def _load_samples_from_tsv(cls, root: str, split: str):
|
||||
tsv_path = Path(root) / f"{split}.tsv"
|
||||
if not tsv_path.is_file():
|
||||
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
|
||||
with open(tsv_path) as f:
|
||||
reader = csv.DictReader(
|
||||
f,
|
||||
delimiter="\t",
|
||||
quotechar=None,
|
||||
doublequote=False,
|
||||
lineterminator="\n",
|
||||
quoting=csv.QUOTE_NONE,
|
||||
)
|
||||
samples = [dict(e) for e in reader]
|
||||
if len(samples) == 0:
|
||||
raise ValueError(f"Empty manifest: {tsv_path}")
|
||||
return samples
|
||||
|
||||
@classmethod
|
||||
def _from_tsv(
|
||||
cls,
|
||||
root: str,
|
||||
cfg: S2TDataConfig,
|
||||
split: str,
|
||||
tgt_dict,
|
||||
is_train_split: bool,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
n_frames_per_step,
|
||||
speaker_to_id,
|
||||
multitask: Optional[Dict] = None,
|
||||
) -> SpeechToTextDataset:
|
||||
samples = cls._load_samples_from_tsv(root, split)
|
||||
return cls._from_list(
|
||||
split,
|
||||
is_train_split,
|
||||
samples,
|
||||
cfg,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
n_frames_per_step,
|
||||
speaker_to_id,
|
||||
multitask,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tsv(
|
||||
cls,
|
||||
root: str,
|
||||
cfg: S2TDataConfig,
|
||||
splits: str,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
is_train_split: bool,
|
||||
epoch: int,
|
||||
seed: int,
|
||||
n_frames_per_step: int = 1,
|
||||
speaker_to_id=None,
|
||||
multitask: Optional[Dict] = None,
|
||||
) -> SpeechToTextDataset:
|
||||
datasets = [
|
||||
cls._from_tsv(
|
||||
root=root,
|
||||
cfg=cfg,
|
||||
split=split,
|
||||
tgt_dict=tgt_dict,
|
||||
is_train_split=is_train_split,
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
speaker_to_id=speaker_to_id,
|
||||
multitask=multitask,
|
||||
)
|
||||
for split in splits.split(",")
|
||||
]
|
||||
|
||||
if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
|
||||
# temperature-based sampling
|
||||
size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
|
||||
datasets = [
|
||||
ResamplingDataset(
|
||||
d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
|
||||
)
|
||||
for r, d in zip(size_ratios, datasets)
|
||||
]
|
||||
|
||||
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
||||
@@ -0,0 +1,359 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq.data import ConcatDataset, Dictionary, ResamplingDataset
|
||||
from fairseq.data import data_utils as fairseq_data_utils
|
||||
from fairseq.data.audio.speech_to_text_dataset import (
|
||||
S2TDataConfig,
|
||||
SpeechToTextDataset,
|
||||
SpeechToTextDatasetCreator,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class S2TJointDataConfig(S2TDataConfig):
|
||||
"""Wrapper class for data config YAML"""
|
||||
|
||||
@property
|
||||
def src_vocab_filename(self):
|
||||
"""fairseq vocabulary file under data root"""
|
||||
return self.config.get("src_vocab_filename", "src_dict.txt")
|
||||
|
||||
@property
|
||||
def src_pre_tokenizer(self) -> Dict:
|
||||
"""Pre-tokenizer to apply before subword tokenization. Returning
|
||||
a dictionary with `tokenizer` providing the tokenizer name and
|
||||
the other items providing the tokenizer-specific arguments.
|
||||
Tokenizers are defined in `fairseq.data.encoders.*`"""
|
||||
return self.config.get("src_pre_tokenizer", {"tokenizer": None})
|
||||
|
||||
@property
|
||||
def src_bpe_tokenizer(self) -> Dict:
|
||||
"""Subword tokenizer to apply on source text after pre-tokenization.
|
||||
Returning a dictionary with `bpe` providing the tokenizer name and
|
||||
the other items providing the tokenizer-specific arguments.
|
||||
Tokenizers are defined in `fairseq.data.encoders.*`"""
|
||||
return self.config.get("src_bpe_tokenizer", {"bpe": None})
|
||||
|
||||
@property
|
||||
def prepend_tgt_lang_tag_no_change(self) -> bool:
|
||||
"""Prepend target lang ID token as the prev_output_tokens BOS (e.g. for
|
||||
to-many multilingual setting). No change needed during inference.
|
||||
This option is deprecated and replaced by prepend_tgt_lang_tag_as_bos.
|
||||
"""
|
||||
value = self.config.get("prepend_tgt_lang_tag_no_change", None)
|
||||
if value is None:
|
||||
return self.config.get("prepend_tgt_lang_tag_as_bos", False)
|
||||
return value
|
||||
|
||||
@property
|
||||
def sampling_text_alpha(self):
|
||||
"""Hyper-parameter alpha = 1/T for temperature-based resampling. (text
|
||||
input only) (alpha = 1 for no resampling)"""
|
||||
return self.config.get("sampling_text_alpha", 1.0)
|
||||
|
||||
|
||||
class SpeechToTextJointDatasetItem(NamedTuple):
|
||||
index: int
|
||||
source: torch.Tensor
|
||||
target: Optional[torch.Tensor] = None
|
||||
src_txt_tokens: Optional[torch.Tensor] = None
|
||||
tgt_lang_tag: Optional[int] = None
|
||||
src_lang_tag: Optional[int] = None
|
||||
tgt_alignment: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
# use_src_lang_id:
|
||||
# 0: don't use src_lang_id
|
||||
# 1: attach src_lang_id to the src_txt_tokens as eos
|
||||
class SpeechToTextJointDataset(SpeechToTextDataset):
|
||||
def __init__(
|
||||
self,
|
||||
split: str,
|
||||
is_train_split: bool,
|
||||
cfg: S2TJointDataConfig,
|
||||
audio_paths: List[str],
|
||||
n_frames: List[int],
|
||||
src_texts: Optional[List[str]] = None,
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
speakers: Optional[List[str]] = None,
|
||||
src_langs: Optional[List[str]] = None,
|
||||
tgt_langs: Optional[List[str]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
tgt_dict: Optional[Dictionary] = None,
|
||||
src_dict: Optional[Dictionary] = None,
|
||||
pre_tokenizer=None,
|
||||
bpe_tokenizer=None,
|
||||
src_pre_tokenizer=None,
|
||||
src_bpe_tokenizer=None,
|
||||
append_eos: Optional[bool] = True,
|
||||
alignment: Optional[List[str]] = None,
|
||||
use_src_lang_id: Optional[int] = 0,
|
||||
):
|
||||
super().__init__(
|
||||
split,
|
||||
is_train_split,
|
||||
cfg,
|
||||
audio_paths,
|
||||
n_frames,
|
||||
src_texts=src_texts,
|
||||
tgt_texts=tgt_texts,
|
||||
speakers=speakers,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
ids=ids,
|
||||
tgt_dict=tgt_dict,
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
append_eos=append_eos,
|
||||
)
|
||||
|
||||
self.src_dict = src_dict
|
||||
self.src_pre_tokenizer = src_pre_tokenizer
|
||||
self.src_bpe_tokenizer = src_bpe_tokenizer
|
||||
self.alignment = None
|
||||
self.use_src_lang_id = use_src_lang_id
|
||||
if alignment is not None:
|
||||
self.alignment = [
|
||||
[float(s) for s in sample.split()] for sample in alignment
|
||||
]
|
||||
|
||||
def get_tokenized_src_text(self, index: int):
|
||||
text = self.tokenize(self.src_pre_tokenizer, self.src_texts[index])
|
||||
text = self.tokenize(self.src_bpe_tokenizer, text)
|
||||
return text
|
||||
|
||||
def __getitem__(self, index: int) -> SpeechToTextJointDatasetItem:
|
||||
s2t_dataset_item = super().__getitem__(index)
|
||||
src_tokens = None
|
||||
src_lang_tag = None
|
||||
if self.src_texts is not None and self.src_dict is not None:
|
||||
src_tokens = self.get_tokenized_src_text(index)
|
||||
src_tokens = self.src_dict.encode_line(
|
||||
src_tokens, add_if_not_exist=False, append_eos=True
|
||||
).long()
|
||||
if self.use_src_lang_id > 0:
|
||||
src_lang_tag = self.get_lang_tag_idx(
|
||||
self.src_langs[index], self.src_dict
|
||||
)
|
||||
tgt_lang_tag = None
|
||||
if self.cfg.prepend_tgt_lang_tag_no_change:
|
||||
# prepend_tgt_lang_tag_no_change: modify prev_output_tokens instead
|
||||
tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
|
||||
ali = None
|
||||
if self.alignment is not None:
|
||||
ali = torch.Tensor(self.alignment[index]).float()
|
||||
|
||||
return SpeechToTextJointDatasetItem(
|
||||
index=index,
|
||||
source=s2t_dataset_item.source,
|
||||
target=s2t_dataset_item.target,
|
||||
src_txt_tokens=src_tokens,
|
||||
tgt_lang_tag=tgt_lang_tag,
|
||||
src_lang_tag=src_lang_tag,
|
||||
tgt_alignment=ali,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.n_samples
|
||||
|
||||
def collater(self, samples: List[SpeechToTextJointDatasetItem]) -> Dict:
|
||||
s2t_out = super().collater(samples, return_order=True)
|
||||
if s2t_out == {}:
|
||||
return s2t_out
|
||||
net_input, order = s2t_out["net_input"], s2t_out["order"]
|
||||
|
||||
if self.src_texts is not None and self.src_dict is not None:
|
||||
src_txt_tokens = fairseq_data_utils.collate_tokens(
|
||||
[x.src_txt_tokens for x in samples],
|
||||
self.src_dict.pad(),
|
||||
self.src_dict.eos(),
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
)
|
||||
src_txt_lengths = torch.tensor(
|
||||
[x.src_txt_tokens.size()[0] for x in samples], dtype=torch.long
|
||||
)
|
||||
if self.use_src_lang_id > 0:
|
||||
src_lang_idxs = torch.tensor(
|
||||
[s.src_lang_tag for s in samples], dtype=src_txt_tokens.dtype
|
||||
)
|
||||
if self.use_src_lang_id == 1: # replace eos with lang_id
|
||||
eos_idx = src_txt_lengths - 1
|
||||
src_txt_tokens.scatter_(
|
||||
1, eos_idx.view(-1, 1), src_lang_idxs.view(-1, 1)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Implementation is required")
|
||||
|
||||
src_txt_tokens = src_txt_tokens.index_select(0, order)
|
||||
src_txt_lengths = src_txt_lengths.index_select(0, order)
|
||||
net_input["src_txt_tokens"] = src_txt_tokens
|
||||
net_input["src_txt_lengths"] = src_txt_lengths
|
||||
|
||||
net_input["alignment"] = None
|
||||
if self.alignment is not None:
|
||||
max_len = max([s.tgt_alignment.size(0) for s in samples])
|
||||
alignment = torch.ones(len(samples), max_len).float()
|
||||
for i, s in enumerate(samples):
|
||||
cur_len = s.tgt_alignment.size(0)
|
||||
alignment[i][:cur_len].copy_(s.tgt_alignment)
|
||||
net_input["alignment"] = alignment.index_select(0, order)
|
||||
|
||||
if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
|
||||
for i in range(len(samples)):
|
||||
net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
|
||||
|
||||
out = {
|
||||
"id": s2t_out["id"],
|
||||
"net_input": net_input,
|
||||
"target": s2t_out["target"],
|
||||
"target_lengths": s2t_out["target_lengths"],
|
||||
"ntokens": s2t_out["ntokens"],
|
||||
"nsentences": len(samples),
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
class SpeechToTextJointDatasetCreator(SpeechToTextDatasetCreator):
|
||||
KEY_ALIGN = "align"
|
||||
|
||||
@classmethod
|
||||
def _from_list(
|
||||
cls,
|
||||
split_name: str,
|
||||
is_train_split,
|
||||
samples: List[Dict],
|
||||
cfg: S2TJointDataConfig,
|
||||
tgt_dict,
|
||||
src_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
src_pre_tokenizer,
|
||||
src_bpe_tokenizer,
|
||||
append_eos,
|
||||
use_src_lang_id,
|
||||
) -> SpeechToTextJointDataset:
|
||||
audio_root = Path(cfg.audio_root)
|
||||
ids = [s[cls.KEY_ID] for s in samples]
|
||||
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
|
||||
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
|
||||
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
|
||||
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
|
||||
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
|
||||
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
tgt_alignment = None
|
||||
if cls.KEY_ALIGN in samples[0].keys():
|
||||
tgt_alignment = [s[cls.KEY_ALIGN] for s in samples]
|
||||
return SpeechToTextJointDataset(
|
||||
split_name,
|
||||
is_train_split,
|
||||
cfg,
|
||||
audio_paths,
|
||||
n_frames,
|
||||
src_texts=src_texts,
|
||||
tgt_texts=tgt_texts,
|
||||
speakers=speakers,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
ids=ids,
|
||||
tgt_dict=tgt_dict,
|
||||
src_dict=src_dict,
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
src_pre_tokenizer=src_pre_tokenizer,
|
||||
src_bpe_tokenizer=src_bpe_tokenizer,
|
||||
append_eos=append_eos,
|
||||
alignment=tgt_alignment,
|
||||
use_src_lang_id=use_src_lang_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_tsv(
|
||||
cls,
|
||||
root: str,
|
||||
cfg: S2TJointDataConfig,
|
||||
split: str,
|
||||
tgt_dict,
|
||||
src_dict,
|
||||
is_train_split: bool,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
src_pre_tokenizer,
|
||||
src_bpe_tokenizer,
|
||||
append_eos: bool,
|
||||
use_src_lang_id: int,
|
||||
) -> SpeechToTextJointDataset:
|
||||
samples = cls._load_samples_from_tsv(root, split)
|
||||
return cls._from_list(
|
||||
split,
|
||||
is_train_split,
|
||||
samples,
|
||||
cfg,
|
||||
tgt_dict,
|
||||
src_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
src_pre_tokenizer,
|
||||
src_bpe_tokenizer,
|
||||
append_eos,
|
||||
use_src_lang_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tsv(
|
||||
cls,
|
||||
root: str,
|
||||
cfg: S2TJointDataConfig,
|
||||
splits: str,
|
||||
tgt_dict,
|
||||
src_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
src_pre_tokenizer,
|
||||
src_bpe_tokenizer,
|
||||
is_train_split: bool,
|
||||
epoch: int,
|
||||
seed: int,
|
||||
append_eos: Optional[bool] = True,
|
||||
use_src_lang_id: Optional[int] = 0,
|
||||
) -> SpeechToTextJointDataset:
|
||||
datasets = [
|
||||
cls._from_tsv(
|
||||
root,
|
||||
cfg,
|
||||
split,
|
||||
tgt_dict,
|
||||
src_dict,
|
||||
is_train_split,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
src_pre_tokenizer,
|
||||
src_bpe_tokenizer,
|
||||
append_eos=append_eos,
|
||||
use_src_lang_id=use_src_lang_id,
|
||||
)
|
||||
for split in splits.split(",")
|
||||
]
|
||||
|
||||
if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
|
||||
# temperature-based sampling
|
||||
size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
|
||||
datasets = [
|
||||
ResamplingDataset(
|
||||
d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
|
||||
)
|
||||
for r, d in zip(size_ratios, datasets)
|
||||
]
|
||||
|
||||
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
||||
@@ -0,0 +1,250 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.abs
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.data import data_utils as fairseq_data_utils
|
||||
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
||||
from fairseq.data.audio.speech_to_text_dataset import (
|
||||
S2TDataConfig,
|
||||
SpeechToTextDataset,
|
||||
SpeechToTextDatasetCreator,
|
||||
_collate_frames,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextToSpeechDatasetItem(object):
|
||||
index: int
|
||||
source: torch.Tensor
|
||||
target: Optional[torch.Tensor] = None
|
||||
speaker_id: Optional[int] = None
|
||||
duration: Optional[torch.Tensor] = None
|
||||
pitch: Optional[torch.Tensor] = None
|
||||
energy: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class TextToSpeechDataset(SpeechToTextDataset):
|
||||
def __init__(
|
||||
self,
|
||||
split: str,
|
||||
is_train_split: bool,
|
||||
cfg: S2TDataConfig,
|
||||
audio_paths: List[str],
|
||||
n_frames: List[int],
|
||||
src_texts: Optional[List[str]] = None,
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
speakers: Optional[List[str]] = None,
|
||||
src_langs: Optional[List[str]] = None,
|
||||
tgt_langs: Optional[List[str]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
tgt_dict: Optional[Dictionary] = None,
|
||||
pre_tokenizer=None,
|
||||
bpe_tokenizer=None,
|
||||
n_frames_per_step=1,
|
||||
speaker_to_id=None,
|
||||
durations: Optional[List[List[int]]] = None,
|
||||
pitches: Optional[List[str]] = None,
|
||||
energies: Optional[List[str]] = None,
|
||||
):
|
||||
super(TextToSpeechDataset, self).__init__(
|
||||
split,
|
||||
is_train_split,
|
||||
cfg,
|
||||
audio_paths,
|
||||
n_frames,
|
||||
src_texts=src_texts,
|
||||
tgt_texts=tgt_texts,
|
||||
speakers=speakers,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
ids=ids,
|
||||
tgt_dict=tgt_dict,
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
speaker_to_id=speaker_to_id,
|
||||
)
|
||||
self.durations = durations
|
||||
self.pitches = pitches
|
||||
self.energies = energies
|
||||
|
||||
def __getitem__(self, index: int) -> TextToSpeechDatasetItem:
|
||||
s2t_item = super().__getitem__(index)
|
||||
|
||||
duration, pitch, energy = None, None, None
|
||||
if self.durations is not None:
|
||||
duration = torch.tensor(
|
||||
self.durations[index] + [0], dtype=torch.long # pad 0 for EOS
|
||||
)
|
||||
if self.pitches is not None:
|
||||
pitch = get_features_or_waveform(self.pitches[index])
|
||||
pitch = torch.from_numpy(
|
||||
np.concatenate((pitch, [0])) # pad 0 for EOS
|
||||
).float()
|
||||
if self.energies is not None:
|
||||
energy = get_features_or_waveform(self.energies[index])
|
||||
energy = torch.from_numpy(
|
||||
np.concatenate((energy, [0])) # pad 0 for EOS
|
||||
).float()
|
||||
return TextToSpeechDatasetItem(
|
||||
index=index,
|
||||
source=s2t_item.source,
|
||||
target=s2t_item.target,
|
||||
speaker_id=s2t_item.speaker_id,
|
||||
duration=duration,
|
||||
pitch=pitch,
|
||||
energy=energy,
|
||||
)
|
||||
|
||||
def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]:
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
src_lengths, order = torch.tensor(
|
||||
[s.target.shape[0] for s in samples], dtype=torch.long
|
||||
).sort(descending=True)
|
||||
id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select(
|
||||
0, order
|
||||
)
|
||||
feat = _collate_frames(
|
||||
[s.source for s in samples], self.cfg.use_audio_input
|
||||
).index_select(0, order)
|
||||
target_lengths = torch.tensor(
|
||||
[s.source.shape[0] for s in samples], dtype=torch.long
|
||||
).index_select(0, order)
|
||||
|
||||
src_tokens = fairseq_data_utils.collate_tokens(
|
||||
[s.target for s in samples],
|
||||
self.tgt_dict.pad(),
|
||||
self.tgt_dict.eos(),
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
).index_select(0, order)
|
||||
|
||||
speaker = None
|
||||
if self.speaker_to_id is not None:
|
||||
speaker = (
|
||||
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
|
||||
.index_select(0, order)
|
||||
.view(-1, 1)
|
||||
)
|
||||
|
||||
bsz, _, d = feat.size()
|
||||
prev_output_tokens = torch.cat(
|
||||
(feat.new_zeros((bsz, 1, d)), feat[:, :-1, :]), dim=1
|
||||
)
|
||||
|
||||
durations, pitches, energies = None, None, None
|
||||
if self.durations is not None:
|
||||
durations = fairseq_data_utils.collate_tokens(
|
||||
[s.duration for s in samples], 0
|
||||
).index_select(0, order)
|
||||
assert src_tokens.shape[1] == durations.shape[1]
|
||||
if self.pitches is not None:
|
||||
pitches = _collate_frames([s.pitch for s in samples], True)
|
||||
pitches = pitches.index_select(0, order)
|
||||
assert src_tokens.shape[1] == pitches.shape[1]
|
||||
if self.energies is not None:
|
||||
energies = _collate_frames([s.energy for s in samples], True)
|
||||
energies = energies.index_select(0, order)
|
||||
assert src_tokens.shape[1] == energies.shape[1]
|
||||
src_texts = [self.tgt_dict.string(samples[i].target) for i in order]
|
||||
|
||||
return {
|
||||
"id": id_,
|
||||
"net_input": {
|
||||
"src_tokens": src_tokens,
|
||||
"src_lengths": src_lengths,
|
||||
"prev_output_tokens": prev_output_tokens,
|
||||
},
|
||||
"speaker": speaker,
|
||||
"target": feat,
|
||||
"durations": durations,
|
||||
"pitches": pitches,
|
||||
"energies": energies,
|
||||
"target_lengths": target_lengths,
|
||||
"ntokens": sum(target_lengths).item(),
|
||||
"nsentences": len(samples),
|
||||
"src_texts": src_texts,
|
||||
}
|
||||
|
||||
|
||||
class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator):
|
||||
KEY_DURATION = "duration"
|
||||
KEY_PITCH = "pitch"
|
||||
KEY_ENERGY = "energy"
|
||||
|
||||
@classmethod
|
||||
def _from_list(
|
||||
cls,
|
||||
split_name: str,
|
||||
is_train_split,
|
||||
samples: List[Dict],
|
||||
cfg: S2TDataConfig,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
n_frames_per_step,
|
||||
speaker_to_id,
|
||||
multitask=None,
|
||||
) -> TextToSpeechDataset:
|
||||
audio_root = Path(cfg.audio_root)
|
||||
ids = [s[cls.KEY_ID] for s in samples]
|
||||
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
|
||||
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
|
||||
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
|
||||
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
|
||||
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
|
||||
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
|
||||
durations = [s.get(cls.KEY_DURATION, None) for s in samples]
|
||||
durations = [
|
||||
None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations
|
||||
]
|
||||
durations = None if any(dd is None for dd in durations) else durations
|
||||
|
||||
pitches = [s.get(cls.KEY_PITCH, None) for s in samples]
|
||||
pitches = [
|
||||
None if pp is None else (audio_root / pp).as_posix() for pp in pitches
|
||||
]
|
||||
pitches = None if any(pp is None for pp in pitches) else pitches
|
||||
|
||||
energies = [s.get(cls.KEY_ENERGY, None) for s in samples]
|
||||
energies = [
|
||||
None if ee is None else (audio_root / ee).as_posix() for ee in energies
|
||||
]
|
||||
energies = None if any(ee is None for ee in energies) else energies
|
||||
|
||||
return TextToSpeechDataset(
|
||||
split_name,
|
||||
is_train_split,
|
||||
cfg,
|
||||
audio_paths,
|
||||
n_frames,
|
||||
src_texts,
|
||||
tgt_texts,
|
||||
speakers,
|
||||
src_langs,
|
||||
tgt_langs,
|
||||
ids,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
n_frames_per_step,
|
||||
speaker_to_id,
|
||||
durations,
|
||||
pitches,
|
||||
energies,
|
||||
)
|
||||
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
from fairseq.data.audio import (
|
||||
AudioTransform,
|
||||
CompositeAudioTransform,
|
||||
import_transforms,
|
||||
register_audio_transform,
|
||||
)
|
||||
|
||||
|
||||
class AudioWaveformTransform(AudioTransform):
|
||||
pass
|
||||
|
||||
|
||||
AUDIO_WAVEFORM_TRANSFORM_REGISTRY = {}
|
||||
AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES = set()
|
||||
|
||||
|
||||
def get_audio_waveform_transform(name):
|
||||
return AUDIO_WAVEFORM_TRANSFORM_REGISTRY[name]
|
||||
|
||||
|
||||
def register_audio_waveform_transform(name):
|
||||
return register_audio_transform(
|
||||
name,
|
||||
AudioWaveformTransform,
|
||||
AUDIO_WAVEFORM_TRANSFORM_REGISTRY,
|
||||
AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES,
|
||||
)
|
||||
|
||||
|
||||
import_transforms(os.path.dirname(__file__), "waveform")
|
||||
|
||||
|
||||
class CompositeAudioWaveformTransform(CompositeAudioTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
return super()._from_config_dict(
|
||||
cls,
|
||||
"waveform",
|
||||
get_audio_waveform_transform,
|
||||
CompositeAudioWaveformTransform,
|
||||
config,
|
||||
)
|
||||
|
||||
def __call__(self, x, sample_rate):
|
||||
for t in self.transforms:
|
||||
x, sample_rate = t(x, sample_rate)
|
||||
return x, sample_rate
|
||||
@@ -0,0 +1,201 @@
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from math import ceil
|
||||
|
||||
from fairseq.data.audio import rand_uniform
|
||||
from fairseq.data.audio.waveform_transforms import (
|
||||
AudioWaveformTransform,
|
||||
register_audio_waveform_transform,
|
||||
)
|
||||
|
||||
SNR_MIN = 5.0
|
||||
SNR_MAX = 15.0
|
||||
RATE = 0.25
|
||||
|
||||
NOISE_RATE = 1.0
|
||||
NOISE_LEN_MEAN = 0.2
|
||||
NOISE_LEN_STD = 0.05
|
||||
|
||||
|
||||
class NoiseAugmentTransform(AudioWaveformTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return cls(
|
||||
_config.get("samples_path", None),
|
||||
_config.get("snr_min", SNR_MIN),
|
||||
_config.get("snr_max", SNR_MAX),
|
||||
_config.get("rate", RATE),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samples_path: str,
|
||||
snr_min: float = SNR_MIN,
|
||||
snr_max: float = SNR_MAX,
|
||||
rate: float = RATE,
|
||||
):
|
||||
# Sanity checks
|
||||
assert (
|
||||
samples_path
|
||||
), "need to provide path to audio samples for noise augmentation"
|
||||
assert snr_max >= snr_min, f"empty signal-to-noise range ({snr_min}, {snr_max})"
|
||||
assert rate >= 0 and rate <= 1, "rate should be a float between 0 to 1"
|
||||
|
||||
self.paths = list(Path(samples_path).glob("**/*.wav")) # load music
|
||||
self.n_samples = len(self.paths)
|
||||
assert self.n_samples > 0, f"no audio files found in {samples_path}"
|
||||
|
||||
self.snr_min = snr_min
|
||||
self.snr_max = snr_max
|
||||
self.rate = rate
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ "("
|
||||
+ ", ".join(
|
||||
[
|
||||
f"n_samples={self.n_samples}",
|
||||
f"snr={self.snr_min}-{self.snr_max}dB",
|
||||
f"rate={self.rate}",
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
|
||||
def pick_sample(self, goal_shape, always_2d=False, use_sample_rate=None):
|
||||
from fairseq.data.audio.audio_utils import get_waveform
|
||||
|
||||
path = self.paths[np.random.randint(0, self.n_samples)]
|
||||
sample = get_waveform(
|
||||
path, always_2d=always_2d, output_sample_rate=use_sample_rate
|
||||
)[0]
|
||||
|
||||
# Check dimensions match, else silently skip adding noise to sample
|
||||
# NOTE: SHOULD THIS QUIT WITH AN ERROR?
|
||||
is_2d = len(goal_shape) == 2
|
||||
if len(goal_shape) != sample.ndim or (
|
||||
is_2d and goal_shape[0] != sample.shape[0]
|
||||
):
|
||||
return np.zeros(goal_shape)
|
||||
|
||||
# Cut/repeat sample to size
|
||||
len_dim = len(goal_shape) - 1
|
||||
n_repeat = ceil(goal_shape[len_dim] / sample.shape[len_dim])
|
||||
repeated = np.tile(sample, [1, n_repeat] if is_2d else n_repeat)
|
||||
start = np.random.randint(0, repeated.shape[len_dim] - goal_shape[len_dim] + 1)
|
||||
return (
|
||||
repeated[:, start : start + goal_shape[len_dim]]
|
||||
if is_2d
|
||||
else repeated[start : start + goal_shape[len_dim]]
|
||||
)
|
||||
|
||||
def _mix(self, source, noise, snr):
|
||||
get_power = lambda x: np.mean(x**2)
|
||||
if get_power(noise):
|
||||
scl = np.sqrt(
|
||||
get_power(source) / (np.power(10, snr / 10) * get_power(noise))
|
||||
)
|
||||
else:
|
||||
scl = 0
|
||||
return 1 * source + scl * noise
|
||||
|
||||
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
|
||||
return self.pick_sample(goal_shape, always_2d, use_sample_rate)
|
||||
|
||||
def __call__(self, source, sample_rate):
|
||||
if np.random.random() > self.rate:
|
||||
return source, sample_rate
|
||||
|
||||
noise = self._get_noise(
|
||||
source.shape, always_2d=True, use_sample_rate=sample_rate
|
||||
)
|
||||
|
||||
return (
|
||||
self._mix(source, noise, rand_uniform(self.snr_min, self.snr_max)),
|
||||
sample_rate,
|
||||
)
|
||||
|
||||
|
||||
@register_audio_waveform_transform("musicaugment")
|
||||
class MusicAugmentTransform(NoiseAugmentTransform):
|
||||
pass
|
||||
|
||||
|
||||
@register_audio_waveform_transform("backgroundnoiseaugment")
|
||||
class BackgroundNoiseAugmentTransform(NoiseAugmentTransform):
|
||||
pass
|
||||
|
||||
|
||||
@register_audio_waveform_transform("babbleaugment")
|
||||
class BabbleAugmentTransform(NoiseAugmentTransform):
|
||||
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
|
||||
for i in range(np.random.randint(3, 8)):
|
||||
speech = self.pick_sample(goal_shape, always_2d, use_sample_rate)
|
||||
if i == 0:
|
||||
agg_noise = speech
|
||||
else: # SNR scaled by i (how many noise signals already in agg_noise)
|
||||
agg_noise = self._mix(agg_noise, speech, i)
|
||||
return agg_noise
|
||||
|
||||
|
||||
@register_audio_waveform_transform("sporadicnoiseaugment")
|
||||
class SporadicNoiseAugmentTransform(NoiseAugmentTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return cls(
|
||||
_config.get("samples_path", None),
|
||||
_config.get("snr_min", SNR_MIN),
|
||||
_config.get("snr_max", SNR_MAX),
|
||||
_config.get("rate", RATE),
|
||||
_config.get("noise_rate", NOISE_RATE),
|
||||
_config.get("noise_len_mean", NOISE_LEN_MEAN),
|
||||
_config.get("noise_len_std", NOISE_LEN_STD),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samples_path: str,
|
||||
snr_min: float = SNR_MIN,
|
||||
snr_max: float = SNR_MAX,
|
||||
rate: float = RATE,
|
||||
noise_rate: float = NOISE_RATE, # noises per second
|
||||
noise_len_mean: float = NOISE_LEN_MEAN, # length of noises in seconds
|
||||
noise_len_std: float = NOISE_LEN_STD,
|
||||
):
|
||||
super().__init__(samples_path, snr_min, snr_max, rate)
|
||||
self.noise_rate = noise_rate
|
||||
self.noise_len_mean = noise_len_mean
|
||||
self.noise_len_std = noise_len_std
|
||||
|
||||
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
|
||||
agg_noise = np.zeros(goal_shape)
|
||||
len_dim = len(goal_shape) - 1
|
||||
is_2d = len(goal_shape) == 2
|
||||
|
||||
n_noises = round(self.noise_rate * goal_shape[len_dim] / use_sample_rate)
|
||||
start_pointers = [
|
||||
round(rand_uniform(0, goal_shape[len_dim])) for _ in range(n_noises)
|
||||
]
|
||||
|
||||
for start_pointer in start_pointers:
|
||||
noise_shape = list(goal_shape)
|
||||
len_seconds = np.random.normal(self.noise_len_mean, self.noise_len_std)
|
||||
noise_shape[len_dim] = round(max(0, len_seconds) * use_sample_rate)
|
||||
end_pointer = start_pointer + noise_shape[len_dim]
|
||||
if end_pointer >= goal_shape[len_dim]:
|
||||
continue
|
||||
|
||||
noise = self.pick_sample(noise_shape, always_2d, use_sample_rate)
|
||||
if is_2d:
|
||||
agg_noise[:, start_pointer:end_pointer] = (
|
||||
agg_noise[:, start_pointer:end_pointer] + noise
|
||||
)
|
||||
else:
|
||||
agg_noise[start_pointer:end_pointer] = (
|
||||
agg_noise[start_pointer:end_pointer] + noise
|
||||
)
|
||||
|
||||
return agg_noise
|
||||
165
modules/voice_conversion/fairseq/data/backtranslation_dataset.py
Normal file
165
modules/voice_conversion/fairseq/data/backtranslation_dataset.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from fairseq import utils
|
||||
|
||||
from . import FairseqDataset
|
||||
|
||||
|
||||
def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
|
||||
"""Backtranslate a list of samples.
|
||||
|
||||
Given an input (*samples*) of the form:
|
||||
|
||||
[{'id': 1, 'source': 'hallo welt'}]
|
||||
|
||||
this will return:
|
||||
|
||||
[{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}]
|
||||
|
||||
Args:
|
||||
samples (List[dict]): samples to backtranslate. Individual samples are
|
||||
expected to have a 'source' key, which will become the 'target'
|
||||
after backtranslation.
|
||||
collate_fn (callable): function to collate samples into a mini-batch
|
||||
generate_fn (callable): function to generate backtranslations
|
||||
cuda (bool): use GPU for generation (default: ``True``)
|
||||
|
||||
Returns:
|
||||
List[dict]: an updated list of samples with a backtranslated source
|
||||
"""
|
||||
collated_samples = collate_fn(samples)
|
||||
s = utils.move_to_cuda(collated_samples) if cuda else collated_samples
|
||||
generated_sources = generate_fn(s)
|
||||
|
||||
id_to_src = {sample["id"]: sample["source"] for sample in samples}
|
||||
|
||||
# Go through each tgt sentence in batch and its corresponding best
|
||||
# generated hypothesis and create a backtranslation data pair
|
||||
# {id: id, source: generated backtranslation, target: original tgt}
|
||||
return [
|
||||
{
|
||||
"id": id.item(),
|
||||
"target": id_to_src[id.item()],
|
||||
"source": hypos[0]["tokens"].cpu(),
|
||||
}
|
||||
for id, hypos in zip(collated_samples["id"], generated_sources)
|
||||
]
|
||||
|
||||
|
||||
class BacktranslationDataset(FairseqDataset):
|
||||
"""
|
||||
Sets up a backtranslation dataset which takes a tgt batch, generates
|
||||
a src using a tgt-src backtranslation function (*backtranslation_fn*),
|
||||
and returns the corresponding `{generated src, input tgt}` batch.
|
||||
|
||||
Args:
|
||||
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
|
||||
backtranslated. Only the source side of this dataset will be used.
|
||||
After backtranslation, the source sentences in this dataset will be
|
||||
returned as the targets.
|
||||
src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated
|
||||
sentences.
|
||||
tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of
|
||||
sentences to be backtranslated.
|
||||
backtranslation_fn (callable, optional): function to call to generate
|
||||
backtranslations. This is typically the `generate` method of a
|
||||
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
|
||||
Pass in None when it is not available at initialization time, and
|
||||
use set_backtranslation_fn function to set it when available.
|
||||
output_collater (callable, optional): function to call on the
|
||||
backtranslated samples to create the final batch
|
||||
(default: ``tgt_dataset.collater``).
|
||||
cuda: use GPU for generation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tgt_dataset,
|
||||
src_dict,
|
||||
tgt_dict=None,
|
||||
backtranslation_fn=None,
|
||||
output_collater=None,
|
||||
cuda=True,
|
||||
**kwargs
|
||||
):
|
||||
self.tgt_dataset = tgt_dataset
|
||||
self.backtranslation_fn = backtranslation_fn
|
||||
self.output_collater = (
|
||||
output_collater if output_collater is not None else tgt_dataset.collater
|
||||
)
|
||||
self.cuda = cuda if torch.cuda.is_available() else False
|
||||
self.src_dict = src_dict
|
||||
self.tgt_dict = tgt_dict
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Returns a single sample from *tgt_dataset*. Note that backtranslation is
|
||||
not applied in this step; use :func:`collater` instead to backtranslate
|
||||
a batch of samples.
|
||||
"""
|
||||
return self.tgt_dataset[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.tgt_dataset)
|
||||
|
||||
def set_backtranslation_fn(self, backtranslation_fn):
|
||||
self.backtranslation_fn = backtranslation_fn
|
||||
|
||||
def collater(self, samples):
|
||||
"""Merge and backtranslate a list of samples to form a mini-batch.
|
||||
|
||||
Using the samples from *tgt_dataset*, load a collated target sample to
|
||||
feed to the backtranslation model. Then take the backtranslation with
|
||||
the best score as the source and the original input as the target.
|
||||
|
||||
Note: we expect *tgt_dataset* to provide a function `collater()` that
|
||||
will collate samples into the format expected by *backtranslation_fn*.
|
||||
After backtranslation, we will feed the new list of samples (i.e., the
|
||||
`(backtranslated source, original source)` pairs) to *output_collater*
|
||||
and return the result.
|
||||
|
||||
Args:
|
||||
samples (List[dict]): samples to backtranslate and collate
|
||||
|
||||
Returns:
|
||||
dict: a mini-batch with keys coming from *output_collater*
|
||||
"""
|
||||
if samples[0].get("is_dummy", False):
|
||||
return samples
|
||||
samples = backtranslate_samples(
|
||||
samples=samples,
|
||||
collate_fn=self.tgt_dataset.collater,
|
||||
generate_fn=(lambda net_input: self.backtranslation_fn(net_input)),
|
||||
cuda=self.cuda,
|
||||
)
|
||||
return self.output_collater(samples)
|
||||
|
||||
def num_tokens(self, index):
|
||||
"""Just use the tgt dataset num_tokens"""
|
||||
return self.tgt_dataset.num_tokens(index)
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Just use the tgt dataset ordered_indices"""
|
||||
return self.tgt_dataset.ordered_indices()
|
||||
|
||||
def size(self, index):
|
||||
"""Return an example's size as a float or tuple. This value is used
|
||||
when filtering a dataset with ``--max-positions``.
|
||||
|
||||
Note: we use *tgt_dataset* to approximate the length of the source
|
||||
sentence, since we do not know the actual length until after
|
||||
backtranslation.
|
||||
"""
|
||||
tgt_size = self.tgt_dataset.size(index)[0]
|
||||
return (tgt_size, tgt_size)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return getattr(self.tgt_dataset, "supports_prefetch", False)
|
||||
|
||||
def prefetch(self, indices):
|
||||
return self.tgt_dataset.prefetch(indices)
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
from . import FairseqDataset
|
||||
|
||||
|
||||
class BaseWrapperDataset(FairseqDataset):
|
||||
def __init__(self, dataset):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.dataset[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def collater(self, samples):
|
||||
if hasattr(self.dataset, "collater"):
|
||||
return self.dataset.collater(samples)
|
||||
else:
|
||||
return default_collate(samples)
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self.dataset.sizes
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.dataset.num_tokens(index)
|
||||
|
||||
def size(self, index):
|
||||
return self.dataset.size(index)
|
||||
|
||||
def ordered_indices(self):
|
||||
return self.dataset.ordered_indices()
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return getattr(self.dataset, "supports_prefetch", False)
|
||||
|
||||
def attr(self, attr: str, index: int):
|
||||
return self.dataset.attr(attr, index)
|
||||
|
||||
def prefetch(self, indices):
|
||||
self.dataset.prefetch(indices)
|
||||
|
||||
def get_batch_shapes(self):
|
||||
return self.dataset.get_batch_shapes()
|
||||
|
||||
def batch_by_size(
|
||||
self,
|
||||
indices,
|
||||
max_tokens=None,
|
||||
max_sentences=None,
|
||||
required_batch_size_multiple=1,
|
||||
):
|
||||
return self.dataset.batch_by_size(
|
||||
indices,
|
||||
max_tokens=max_tokens,
|
||||
max_sentences=max_sentences,
|
||||
required_batch_size_multiple=required_batch_size_multiple,
|
||||
)
|
||||
|
||||
def filter_indices_by_size(self, indices, max_sizes):
|
||||
return self.dataset.filter_indices_by_size(indices, max_sizes)
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return self.dataset.can_reuse_epoch_itr_across_epochs
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
super().set_epoch(epoch)
|
||||
if hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(epoch)
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from fairseq.data import BaseWrapperDataset
|
||||
from fairseq.data.data_utils import get_buckets, get_bucketed_sizes
|
||||
|
||||
|
||||
class BucketPadLengthDataset(BaseWrapperDataset):
|
||||
"""
|
||||
Bucket and pad item lengths to the nearest bucket size. This can be used to
|
||||
reduce the number of unique batch shapes, which is important on TPUs since
|
||||
each new batch shape requires a recompilation.
|
||||
|
||||
Args:
|
||||
dataset (FairseqDatset): dataset to bucket
|
||||
sizes (List[int]): all item sizes
|
||||
num_buckets (int): number of buckets to create
|
||||
pad_idx (int): padding symbol
|
||||
left_pad (bool): if True, pad on the left; otherwise right pad
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
sizes,
|
||||
num_buckets,
|
||||
pad_idx,
|
||||
left_pad,
|
||||
tensor_key=None,
|
||||
):
|
||||
super().__init__(dataset)
|
||||
self.pad_idx = pad_idx
|
||||
self.left_pad = left_pad
|
||||
|
||||
assert num_buckets > 0
|
||||
self.buckets = get_buckets(sizes, num_buckets)
|
||||
self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
|
||||
self._tensor_key = tensor_key
|
||||
|
||||
def _set_tensor(self, item, val):
|
||||
if self._tensor_key is None:
|
||||
return val
|
||||
item[self._tensor_key] = val
|
||||
return item
|
||||
|
||||
def _get_tensor(self, item):
|
||||
if self._tensor_key is None:
|
||||
return item
|
||||
return item[self._tensor_key]
|
||||
|
||||
def _pad(self, tensor, bucket_size, dim=-1):
|
||||
num_pad = bucket_size - tensor.size(dim)
|
||||
return F.pad(
|
||||
tensor,
|
||||
(num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad),
|
||||
value=self.pad_idx,
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.dataset[index]
|
||||
bucket_size = self._bucketed_sizes[index]
|
||||
tensor = self._get_tensor(item)
|
||||
padded = self._pad(tensor, bucket_size)
|
||||
return self._set_tensor(item, padded)
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._bucketed_sizes
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self._bucketed_sizes[index]
|
||||
|
||||
def size(self, index):
|
||||
return self._bucketed_sizes[index]
|
||||
576
modules/voice_conversion/fairseq/data/codedataset.py
Normal file
576
modules/voice_conversion/fairseq/data/codedataset.py
Normal file
@@ -0,0 +1,576 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
|
||||
from . import data_utils
|
||||
from fairseq.data.fairseq_dataset import FairseqDataset
|
||||
|
||||
F0_FRAME_SPACE = 0.005 # sec
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExpressiveCodeDataConfig(object):
|
||||
def __init__(self, json_path):
|
||||
with open(json_path, "r") as f:
|
||||
self.config = json.load(f)
|
||||
self._manifests = self.config["manifests"]
|
||||
|
||||
@property
|
||||
def manifests(self):
|
||||
return self._manifests
|
||||
|
||||
@property
|
||||
def n_units(self):
|
||||
return self.config["n_units"]
|
||||
|
||||
@property
|
||||
def sampling_rate(self):
|
||||
return self.config["sampling_rate"]
|
||||
|
||||
@property
|
||||
def code_hop_size(self):
|
||||
return self.config["code_hop_size"]
|
||||
|
||||
@property
|
||||
def f0_stats(self):
|
||||
"""pre-computed f0 statistics path"""
|
||||
return self.config.get("f0_stats", None)
|
||||
|
||||
@property
|
||||
def f0_vq_type(self):
|
||||
"""naive or precomp"""
|
||||
return self.config["f0_vq_type"]
|
||||
|
||||
@property
|
||||
def f0_vq_name(self):
|
||||
return self.config["f0_vq_name"]
|
||||
|
||||
def get_f0_vq_naive_quantizer(self, log, norm_mean, norm_std):
|
||||
key = "log" if log else "linear"
|
||||
if norm_mean and norm_std:
|
||||
key += "_mean_std_norm"
|
||||
elif norm_mean:
|
||||
key += "_mean_norm"
|
||||
else:
|
||||
key += "_none_norm"
|
||||
return self.config["f0_vq_naive_quantizer"][key]
|
||||
|
||||
@property
|
||||
def f0_vq_n_units(self):
|
||||
return self.config["f0_vq_n_units"]
|
||||
|
||||
@property
|
||||
def multispkr(self):
|
||||
"""how to parse speaker label from audio path"""
|
||||
return self.config.get("multispkr", None)
|
||||
|
||||
|
||||
def get_f0(audio, rate=16000):
|
||||
try:
|
||||
import amfm_decompy.basic_tools as basic
|
||||
import amfm_decompy.pYAAPT as pYAAPT
|
||||
from librosa.util import normalize
|
||||
except ImportError:
|
||||
raise "Please install amfm_decompy (`pip install AMFM-decompy`) and librosa (`pip install librosa`)."
|
||||
|
||||
assert audio.ndim == 1
|
||||
frame_length = 20.0 # ms
|
||||
to_pad = int(frame_length / 1000 * rate) // 2
|
||||
|
||||
audio = normalize(audio) * 0.95
|
||||
audio = np.pad(audio, (to_pad, to_pad), "constant", constant_values=0)
|
||||
audio = basic.SignalObj(audio, rate)
|
||||
pitch = pYAAPT.yaapt(
|
||||
audio,
|
||||
frame_length=frame_length,
|
||||
frame_space=F0_FRAME_SPACE * 1000,
|
||||
nccf_thresh1=0.25,
|
||||
tda_frame_length=25.0,
|
||||
)
|
||||
f0 = pitch.samp_values
|
||||
return f0
|
||||
|
||||
|
||||
def interpolate_f0(f0):
|
||||
try:
|
||||
from scipy.interpolate import interp1d
|
||||
except ImportError:
|
||||
raise "Please install scipy (`pip install scipy`)"
|
||||
|
||||
orig_t = np.arange(f0.shape[0])
|
||||
f0_interp = f0[:]
|
||||
ii = f0_interp != 0
|
||||
if ii.sum() > 1:
|
||||
f0_interp = interp1d(
|
||||
orig_t[ii], f0_interp[ii], bounds_error=False, kind="linear", fill_value=0
|
||||
)(orig_t)
|
||||
f0_interp = torch.Tensor(f0_interp).type_as(f0).to(f0.device)
|
||||
return f0_interp
|
||||
|
||||
|
||||
def naive_quantize(x, edges):
|
||||
bin_idx = (x.view(-1, 1) > edges.view(1, -1)).long().sum(dim=1)
|
||||
return bin_idx
|
||||
|
||||
|
||||
def load_wav(full_path):
|
||||
try:
|
||||
import soundfile as sf
|
||||
except ImportError:
|
||||
raise "Please install soundfile (`pip install SoundFile`)"
|
||||
data, sampling_rate = sf.read(full_path)
|
||||
return data, sampling_rate
|
||||
|
||||
|
||||
def parse_code(code_str, dictionary, append_eos):
|
||||
code, duration = torch.unique_consecutive(
|
||||
torch.ShortTensor(list(map(int, code_str.split()))), return_counts=True
|
||||
)
|
||||
code = " ".join(map(str, code.tolist()))
|
||||
code = dictionary.encode_line(code, append_eos).short()
|
||||
|
||||
if append_eos:
|
||||
duration = torch.cat((duration, duration.new_zeros((1,))), dim=0) # eos
|
||||
duration = duration.short()
|
||||
return code, duration
|
||||
|
||||
|
||||
def parse_manifest(manifest, dictionary):
|
||||
audio_files = []
|
||||
codes = []
|
||||
durations = []
|
||||
speakers = []
|
||||
|
||||
with open(manifest) as info:
|
||||
for line in info.readlines():
|
||||
sample = eval(line.strip())
|
||||
if "cpc_km100" in sample:
|
||||
k = "cpc_km100"
|
||||
elif "hubert_km100" in sample:
|
||||
k = "hubert_km100"
|
||||
elif "phone" in sample:
|
||||
k = "phone"
|
||||
else:
|
||||
assert False, "unknown format"
|
||||
code = sample[k]
|
||||
code, duration = parse_code(code, dictionary, append_eos=True)
|
||||
|
||||
codes.append(code)
|
||||
durations.append(duration)
|
||||
audio_files.append(sample["audio"])
|
||||
speakers.append(sample.get("speaker", None))
|
||||
|
||||
return audio_files, codes, durations, speakers
|
||||
|
||||
|
||||
def parse_speaker(path, method):
|
||||
if type(path) == str:
|
||||
path = Path(path)
|
||||
|
||||
if method == "parent_name":
|
||||
return path.parent.name
|
||||
elif method == "parent_parent_name":
|
||||
return path.parent.parent.name
|
||||
elif method == "_":
|
||||
return path.name.split("_")[0]
|
||||
elif method == "single":
|
||||
return "A"
|
||||
elif callable(method):
|
||||
return method(path)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_f0_by_filename(filename, tgt_sampling_rate):
|
||||
audio, sampling_rate = load_wav(filename)
|
||||
if sampling_rate != tgt_sampling_rate:
|
||||
raise ValueError(
|
||||
"{} SR doesn't match target {} SR".format(sampling_rate, tgt_sampling_rate)
|
||||
)
|
||||
|
||||
# compute un-interpolated f0, and use Ann's interp in __getitem__ if set
|
||||
f0 = get_f0(audio, rate=tgt_sampling_rate)
|
||||
f0 = torch.from_numpy(f0.astype(np.float32))
|
||||
return f0
|
||||
|
||||
|
||||
def align_f0_to_durations(f0, durations, f0_code_ratio, tol=1):
|
||||
code_len = durations.sum()
|
||||
targ_len = int(f0_code_ratio * code_len)
|
||||
diff = f0.size(0) - targ_len
|
||||
assert abs(diff) <= tol, (
|
||||
f"Cannot subsample F0: |{f0.size(0)} - {f0_code_ratio}*{code_len}|"
|
||||
f" > {tol} (dur=\n{durations})"
|
||||
)
|
||||
if diff > 0:
|
||||
f0 = f0[:targ_len]
|
||||
elif diff < 0:
|
||||
f0 = torch.cat((f0, f0.new_full((-diff,), f0[-1])), 0)
|
||||
|
||||
f0_offset = 0.0
|
||||
seg_f0s = []
|
||||
for dur in durations:
|
||||
f0_dur = dur.item() * f0_code_ratio
|
||||
seg_f0 = f0[int(f0_offset) : int(f0_offset + f0_dur)]
|
||||
seg_f0 = seg_f0[seg_f0 != 0]
|
||||
if len(seg_f0) == 0:
|
||||
seg_f0 = torch.tensor(0).type(seg_f0.type())
|
||||
else:
|
||||
seg_f0 = seg_f0.mean()
|
||||
seg_f0s.append(seg_f0)
|
||||
f0_offset += f0_dur
|
||||
|
||||
assert int(f0_offset) == f0.size(0), f"{f0_offset} {f0.size()} {durations.sum()}"
|
||||
return torch.tensor(seg_f0s)
|
||||
|
||||
|
||||
class Paddings(object):
|
||||
def __init__(self, code_val, dur_val=0, f0_val=-2.0):
|
||||
self.code = code_val
|
||||
self.dur = dur_val
|
||||
self.f0 = f0_val
|
||||
|
||||
|
||||
class Shifts(object):
|
||||
def __init__(self, shifts_str, pads):
|
||||
self._shifts = list(map(int, shifts_str.split(",")))
|
||||
assert len(self._shifts) == 2, self._shifts
|
||||
assert all(s >= 0 for s in self._shifts)
|
||||
self.extra_length = max(s for s in self._shifts)
|
||||
self.pads = pads
|
||||
|
||||
@property
|
||||
def dur(self):
|
||||
return self._shifts[0]
|
||||
|
||||
@property
|
||||
def f0(self):
|
||||
return self._shifts[1]
|
||||
|
||||
@staticmethod
|
||||
def shift_one(seq, left_pad_num, right_pad_num, pad):
|
||||
assert seq.ndim == 1
|
||||
bos = seq.new_full((left_pad_num,), pad)
|
||||
eos = seq.new_full((right_pad_num,), pad)
|
||||
seq = torch.cat([bos, seq, eos])
|
||||
mask = torch.ones_like(seq).bool()
|
||||
mask[left_pad_num : len(seq) - right_pad_num] = 0
|
||||
return seq, mask
|
||||
|
||||
def __call__(self, code, dur, f0):
|
||||
if self.extra_length == 0:
|
||||
code_mask = torch.zeros_like(code).bool()
|
||||
dur_mask = torch.zeros_like(dur).bool()
|
||||
f0_mask = torch.zeros_like(f0).bool()
|
||||
return code, code_mask, dur, dur_mask, f0, f0_mask
|
||||
|
||||
code, code_mask = self.shift_one(code, 0, self.extra_length, self.pads.code)
|
||||
dur, dur_mask = self.shift_one(
|
||||
dur, self.dur, self.extra_length - self.dur, self.pads.dur
|
||||
)
|
||||
f0, f0_mask = self.shift_one(
|
||||
f0, self.f0, self.extra_length - self.f0, self.pads.f0
|
||||
)
|
||||
return code, code_mask, dur, dur_mask, f0, f0_mask
|
||||
|
||||
|
||||
class CodeDataset(FairseqDataset):
|
||||
def __init__(
|
||||
self,
|
||||
manifest,
|
||||
dictionary,
|
||||
dur_dictionary,
|
||||
f0_dictionary,
|
||||
config,
|
||||
discrete_dur,
|
||||
discrete_f0,
|
||||
log_f0,
|
||||
normalize_f0_mean,
|
||||
normalize_f0_std,
|
||||
interpolate_f0,
|
||||
return_filename=False,
|
||||
strip_filename=True,
|
||||
shifts="0,0",
|
||||
return_continuous_f0=False,
|
||||
):
|
||||
random.seed(1234)
|
||||
self.dictionary = dictionary
|
||||
self.dur_dictionary = dur_dictionary
|
||||
self.f0_dictionary = f0_dictionary
|
||||
self.config = config
|
||||
|
||||
# duration config
|
||||
self.discrete_dur = discrete_dur
|
||||
|
||||
# pitch config
|
||||
self.discrete_f0 = discrete_f0
|
||||
self.log_f0 = log_f0
|
||||
self.normalize_f0_mean = normalize_f0_mean
|
||||
self.normalize_f0_std = normalize_f0_std
|
||||
self.interpolate_f0 = interpolate_f0
|
||||
|
||||
self.return_filename = return_filename
|
||||
self.strip_filename = strip_filename
|
||||
self.f0_code_ratio = config.code_hop_size / (
|
||||
config.sampling_rate * F0_FRAME_SPACE
|
||||
)
|
||||
|
||||
# use lazy loading to avoid sharing file handlers across workers
|
||||
self.manifest = manifest
|
||||
self._codes = None
|
||||
self._durs = None
|
||||
self._f0s = None
|
||||
with open(f"{manifest}.leng.txt", "r") as f:
|
||||
lengs = [int(line.rstrip()) for line in f]
|
||||
edges = np.cumsum([0] + lengs)
|
||||
self.starts, self.ends = edges[:-1], edges[1:]
|
||||
with open(f"{manifest}.path.txt", "r") as f:
|
||||
self.file_names = [line.rstrip() for line in f]
|
||||
logger.info(f"num entries: {len(self.starts)}")
|
||||
|
||||
if os.path.exists(f"{manifest}.f0_stat.pt"):
|
||||
self.f0_stats = torch.load(f"{manifest}.f0_stat.pt")
|
||||
elif config.f0_stats:
|
||||
self.f0_stats = torch.load(config.f0_stats)
|
||||
|
||||
self.multispkr = config.multispkr
|
||||
if config.multispkr:
|
||||
with open(f"{manifest}.speaker.txt", "r") as f:
|
||||
self.spkrs = [line.rstrip() for line in f]
|
||||
self.id_to_spkr = sorted(self.spkrs)
|
||||
self.spkr_to_id = {k: v for v, k in enumerate(self.id_to_spkr)}
|
||||
|
||||
self.pads = Paddings(
|
||||
dictionary.pad(),
|
||||
0, # use 0 for duration padding
|
||||
f0_dictionary.pad() if discrete_f0 else -5.0,
|
||||
)
|
||||
self.shifts = Shifts(shifts, pads=self.pads)
|
||||
self.return_continuous_f0 = return_continuous_f0
|
||||
|
||||
def get_data_handlers(self):
|
||||
logging.info(f"loading data for {self.manifest}")
|
||||
self._codes = np.load(f"{self.manifest}.code.npy", mmap_mode="r")
|
||||
self._durs = np.load(f"{self.manifest}.dur.npy", mmap_mode="r")
|
||||
|
||||
if self.discrete_f0:
|
||||
if self.config.f0_vq_type == "precomp":
|
||||
self._f0s = np.load(
|
||||
f"{self.manifest}.{self.config.f0_vq_name}.npy", mmap_mode="r"
|
||||
)
|
||||
elif self.config.f0_vq_type == "naive":
|
||||
self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
|
||||
quantizers_path = self.config.get_f0_vq_naive_quantizer(
|
||||
self.log_f0, self.normalize_f0_mean, self.normalize_f0_std
|
||||
)
|
||||
quantizers = torch.load(quantizers_path)
|
||||
n_units = self.config.f0_vq_n_units
|
||||
self._f0_quantizer = torch.from_numpy(quantizers[n_units])
|
||||
else:
|
||||
raise ValueError(f"f0_vq_type {self.config.f0_vq_type} not supported")
|
||||
else:
|
||||
self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
|
||||
|
||||
def preprocess_f0(self, f0, stats):
|
||||
"""
|
||||
1. interpolate
|
||||
2. log transform (keep unvoiced frame 0)
|
||||
"""
|
||||
# TODO: change this to be dependent on config for naive quantizer
|
||||
f0 = f0.clone()
|
||||
if self.interpolate_f0:
|
||||
f0 = interpolate_f0(f0)
|
||||
|
||||
mask = f0 != 0 # only process voiced frames
|
||||
if self.log_f0:
|
||||
f0[mask] = f0[mask].log()
|
||||
if self.normalize_f0_mean:
|
||||
mean = stats["logf0_mean"] if self.log_f0 else stats["f0_mean"]
|
||||
f0[mask] = f0[mask] - mean
|
||||
if self.normalize_f0_std:
|
||||
std = stats["logf0_std"] if self.log_f0 else stats["f0_std"]
|
||||
f0[mask] = f0[mask] / std
|
||||
return f0
|
||||
|
||||
def _get_raw_item(self, index):
|
||||
start, end = self.starts[index], self.ends[index]
|
||||
if self._codes is None:
|
||||
self.get_data_handlers()
|
||||
code = torch.from_numpy(np.array(self._codes[start:end])).long()
|
||||
dur = torch.from_numpy(np.array(self._durs[start:end]))
|
||||
f0 = torch.from_numpy(np.array(self._f0s[start:end]))
|
||||
return code, dur, f0
|
||||
|
||||
def __getitem__(self, index):
|
||||
code, dur, f0 = self._get_raw_item(index)
|
||||
code = torch.cat([code.new([self.dictionary.bos()]), code])
|
||||
|
||||
# use 0 for eos and bos
|
||||
dur = torch.cat([dur.new([0]), dur])
|
||||
if self.discrete_dur:
|
||||
dur = self.dur_dictionary.encode_line(
|
||||
" ".join(map(str, dur.tolist())), append_eos=False
|
||||
).long()
|
||||
else:
|
||||
dur = dur.float()
|
||||
|
||||
# TODO: find a more elegant approach
|
||||
raw_f0 = None
|
||||
if self.discrete_f0:
|
||||
if self.config.f0_vq_type == "precomp":
|
||||
f0 = self.f0_dictionary.encode_line(
|
||||
" ".join(map(str, f0.tolist())), append_eos=False
|
||||
).long()
|
||||
else:
|
||||
f0 = f0.float()
|
||||
f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
|
||||
if self.return_continuous_f0:
|
||||
raw_f0 = f0
|
||||
raw_f0 = torch.cat([raw_f0.new([self.f0_dictionary.bos()]), raw_f0])
|
||||
f0 = naive_quantize(f0, self._f0_quantizer)
|
||||
f0 = torch.cat([f0.new([self.f0_dictionary.bos()]), f0])
|
||||
else:
|
||||
f0 = f0.float()
|
||||
if self.multispkr:
|
||||
f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
|
||||
else:
|
||||
f0 = self.preprocess_f0(f0, self.f0_stats)
|
||||
f0 = torch.cat([f0.new([0]), f0])
|
||||
|
||||
if raw_f0 is not None:
|
||||
*_, raw_f0, raw_f0_mask = self.shifts(code, dur, raw_f0)
|
||||
else:
|
||||
raw_f0_mask = None
|
||||
|
||||
code, code_mask, dur, dur_mask, f0, f0_mask = self.shifts(code, dur, f0)
|
||||
if raw_f0_mask is not None:
|
||||
assert (raw_f0_mask == f0_mask).all()
|
||||
|
||||
# is a padded frame if either input or output is padded
|
||||
feats = {
|
||||
"source": code[:-1],
|
||||
"target": code[1:],
|
||||
"mask": code_mask[1:].logical_or(code_mask[:-1]),
|
||||
"dur_source": dur[:-1],
|
||||
"dur_target": dur[1:],
|
||||
"dur_mask": dur_mask[1:].logical_or(dur_mask[:-1]),
|
||||
"f0_source": f0[:-1],
|
||||
"f0_target": f0[1:],
|
||||
"f0_mask": f0_mask[1:].logical_or(f0_mask[:-1]),
|
||||
}
|
||||
|
||||
if raw_f0 is not None:
|
||||
feats["raw_f0"] = raw_f0[1:]
|
||||
|
||||
if self.return_filename:
|
||||
fname = self.file_names[index]
|
||||
feats["filename"] = (
|
||||
fname if not self.strip_filename else Path(fname).with_suffix("").name
|
||||
)
|
||||
return feats
|
||||
|
||||
def __len__(self):
|
||||
return len(self.starts)
|
||||
|
||||
def size(self, index):
|
||||
return self.ends[index] - self.starts[index] + self.shifts.extra_length
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.size(index)
|
||||
|
||||
def collater(self, samples):
|
||||
pad_idx, eos_idx = self.dictionary.pad(), self.dictionary.eos()
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
src_tokens = data_utils.collate_tokens(
|
||||
[s["source"] for s in samples], pad_idx, eos_idx, left_pad=False
|
||||
)
|
||||
|
||||
tgt_tokens = data_utils.collate_tokens(
|
||||
[s["target"] for s in samples],
|
||||
pad_idx=pad_idx,
|
||||
eos_idx=pad_idx, # appending padding, eos is there already
|
||||
left_pad=False,
|
||||
)
|
||||
|
||||
src_durs, tgt_durs = [
|
||||
data_utils.collate_tokens(
|
||||
[s[k] for s in samples],
|
||||
pad_idx=self.pads.dur,
|
||||
eos_idx=self.pads.dur,
|
||||
left_pad=False,
|
||||
)
|
||||
for k in ["dur_source", "dur_target"]
|
||||
]
|
||||
|
||||
src_f0s, tgt_f0s = [
|
||||
data_utils.collate_tokens(
|
||||
[s[k] for s in samples],
|
||||
pad_idx=self.pads.f0,
|
||||
eos_idx=self.pads.f0,
|
||||
left_pad=False,
|
||||
)
|
||||
for k in ["f0_source", "f0_target"]
|
||||
]
|
||||
|
||||
mask, dur_mask, f0_mask = [
|
||||
data_utils.collate_tokens(
|
||||
[s[k] for s in samples],
|
||||
pad_idx=1,
|
||||
eos_idx=1,
|
||||
left_pad=False,
|
||||
)
|
||||
for k in ["mask", "dur_mask", "f0_mask"]
|
||||
]
|
||||
|
||||
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
|
||||
n_tokens = sum(len(s["source"]) for s in samples)
|
||||
|
||||
result = {
|
||||
"nsentences": len(samples),
|
||||
"ntokens": n_tokens,
|
||||
"net_input": {
|
||||
"src_tokens": src_tokens,
|
||||
"src_lengths": src_lengths,
|
||||
"dur_src": src_durs,
|
||||
"f0_src": src_f0s,
|
||||
},
|
||||
"target": tgt_tokens,
|
||||
"dur_target": tgt_durs,
|
||||
"f0_target": tgt_f0s,
|
||||
"mask": mask,
|
||||
"dur_mask": dur_mask,
|
||||
"f0_mask": f0_mask,
|
||||
}
|
||||
|
||||
if "filename" in samples[0]:
|
||||
result["filename"] = [s["filename"] for s in samples]
|
||||
|
||||
# TODO: remove this hack into the inference dataset
|
||||
if "prefix" in samples[0]:
|
||||
result["prefix"] = [s["prefix"] for s in samples]
|
||||
|
||||
if "raw_f0" in samples[0]:
|
||||
raw_f0s = data_utils.collate_tokens(
|
||||
[s["raw_f0"] for s in samples],
|
||||
pad_idx=self.pads.f0,
|
||||
eos_idx=self.pads.f0,
|
||||
left_pad=False,
|
||||
)
|
||||
result["raw_f0"] = raw_f0s
|
||||
return result
|
||||
25
modules/voice_conversion/fairseq/data/colorize_dataset.py
Normal file
25
modules/voice_conversion/fairseq/data/colorize_dataset.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from . import BaseWrapperDataset
|
||||
|
||||
|
||||
class ColorizeDataset(BaseWrapperDataset):
|
||||
"""Adds 'colors' property to net input that is obtained from the provided color getter for use by models"""
|
||||
|
||||
def __init__(self, dataset, color_getter):
|
||||
super().__init__(dataset)
|
||||
self.color_getter = color_getter
|
||||
|
||||
def collater(self, samples):
|
||||
base_collate = super().collater(samples)
|
||||
if len(base_collate) > 0:
|
||||
base_collate["net_input"]["colors"] = torch.tensor(
|
||||
list(self.color_getter(self.dataset, s["id"]) for s in samples),
|
||||
dtype=torch.long,
|
||||
)
|
||||
return base_collate
|
||||
124
modules/voice_conversion/fairseq/data/concat_dataset.py
Normal file
124
modules/voice_conversion/fairseq/data/concat_dataset.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import bisect
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
from . import FairseqDataset
|
||||
|
||||
|
||||
class ConcatDataset(FairseqDataset):
|
||||
@staticmethod
|
||||
def cumsum(sequence, sample_ratios):
|
||||
r, s = [], 0
|
||||
for e, ratio in zip(sequence, sample_ratios):
|
||||
curr_len = int(ratio * len(e))
|
||||
r.append(curr_len + s)
|
||||
s += curr_len
|
||||
return r
|
||||
|
||||
def __init__(self, datasets, sample_ratios=1):
|
||||
super(ConcatDataset, self).__init__()
|
||||
assert len(datasets) > 0, "datasets should not be an empty iterable"
|
||||
self.datasets = list(datasets)
|
||||
if isinstance(sample_ratios, int):
|
||||
sample_ratios = [sample_ratios] * len(self.datasets)
|
||||
self.sample_ratios = sample_ratios
|
||||
self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
|
||||
self.real_sizes = [len(d) for d in self.datasets]
|
||||
|
||||
def __len__(self):
|
||||
return self.cumulative_sizes[-1]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
||||
return self.datasets[dataset_idx][sample_idx]
|
||||
|
||||
def _get_dataset_and_sample_index(self, idx: int):
|
||||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
||||
if dataset_idx == 0:
|
||||
sample_idx = idx
|
||||
else:
|
||||
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
||||
sample_idx = sample_idx % self.real_sizes[dataset_idx]
|
||||
return dataset_idx, sample_idx
|
||||
|
||||
def collater(self, samples, **extra_args):
|
||||
# For now only supports datasets with same underlying collater implementations
|
||||
if hasattr(self.datasets[0], "collater"):
|
||||
return self.datasets[0].collater(samples, **extra_args)
|
||||
else:
|
||||
return default_collate(samples, **extra_args)
|
||||
|
||||
def size(self, idx: int):
|
||||
"""
|
||||
Return an example's size as a float or tuple.
|
||||
"""
|
||||
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
||||
return self.datasets[dataset_idx].size(sample_idx)
|
||||
|
||||
def num_tokens(self, index: int):
|
||||
return np.max(self.size(index))
|
||||
|
||||
def attr(self, attr: str, index: int):
|
||||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
|
||||
return getattr(self.datasets[dataset_idx], attr, None)
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
_dataset_sizes = []
|
||||
for ds, sr in zip(self.datasets, self.sample_ratios):
|
||||
if isinstance(ds.sizes, np.ndarray):
|
||||
_dataset_sizes.append(np.tile(ds.sizes, sr))
|
||||
else:
|
||||
# Only support underlying dataset with single size array.
|
||||
assert isinstance(ds.sizes, list)
|
||||
_dataset_sizes.append(np.tile(ds.sizes[0], sr))
|
||||
return np.concatenate(_dataset_sizes)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return all(d.supports_prefetch for d in self.datasets)
|
||||
|
||||
def ordered_indices(self):
|
||||
"""
|
||||
Returns indices sorted by length. So less padding is needed.
|
||||
"""
|
||||
if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1:
|
||||
# special handling for concatenating lang_pair_datasets
|
||||
indices = np.arange(len(self))
|
||||
sizes = self.sizes
|
||||
tgt_sizes = (
|
||||
sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
|
||||
)
|
||||
src_sizes = (
|
||||
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
|
||||
)
|
||||
# sort by target length, then source length
|
||||
if tgt_sizes is not None:
|
||||
indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
|
||||
return indices[np.argsort(src_sizes[indices], kind="mergesort")]
|
||||
else:
|
||||
return np.argsort(self.sizes)
|
||||
|
||||
def prefetch(self, indices):
|
||||
frm = 0
|
||||
for to, ds in zip(self.cumulative_sizes, self.datasets):
|
||||
real_size = len(ds)
|
||||
if getattr(ds, "supports_prefetch", False):
|
||||
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
|
||||
frm = to
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
super().set_epoch(epoch)
|
||||
for ds in self.datasets:
|
||||
if hasattr(ds, "set_epoch"):
|
||||
ds.set_epoch(epoch)
|
||||
@@ -0,0 +1,54 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from . import FairseqDataset
|
||||
|
||||
|
||||
class ConcatSentencesDataset(FairseqDataset):
|
||||
def __init__(self, *datasets):
|
||||
super().__init__()
|
||||
self.datasets = datasets
|
||||
assert all(
|
||||
len(ds) == len(datasets[0]) for ds in datasets
|
||||
), "datasets must have the same length"
|
||||
|
||||
def __getitem__(self, index):
|
||||
return torch.cat([ds[index] for ds in self.datasets])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.datasets[0])
|
||||
|
||||
def collater(self, samples):
|
||||
return self.datasets[0].collater(samples)
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return sum(ds.sizes for ds in self.datasets)
|
||||
|
||||
def num_tokens(self, index):
|
||||
return sum(ds.num_tokens(index) for ds in self.datasets)
|
||||
|
||||
def size(self, index):
|
||||
return sum(ds.size(index) for ds in self.datasets)
|
||||
|
||||
def ordered_indices(self):
|
||||
return self.datasets[0].ordered_indices()
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets)
|
||||
|
||||
def prefetch(self, indices):
|
||||
for ds in self.datasets:
|
||||
if getattr(ds, "supports_prefetch", False):
|
||||
ds.prefetch(indices)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
super().set_epoch(epoch)
|
||||
for ds in self.datasets:
|
||||
if hasattr(ds, "set_epoch"):
|
||||
ds.set_epoch(epoch)
|
||||
604
modules/voice_conversion/fairseq/data/data_utils.py
Normal file
604
modules/voice_conversion/fairseq/data/data_utils.py
Normal file
@@ -0,0 +1,604 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
try:
|
||||
from collections.abc import Iterable
|
||||
except ImportError:
|
||||
from collections import Iterable
|
||||
import contextlib
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq.file_io import PathManager
|
||||
from fairseq import utils
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def infer_language_pair(path):
|
||||
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
|
||||
src, dst = None, None
|
||||
for filename in PathManager.ls(path):
|
||||
parts = filename.split(".")
|
||||
if len(parts) >= 3 and len(parts[1].split("-")) == 2:
|
||||
return parts[1].split("-")
|
||||
return src, dst
|
||||
|
||||
|
||||
def collate_tokens(
|
||||
values,
|
||||
pad_idx,
|
||||
eos_idx=None,
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
pad_to_length=None,
|
||||
pad_to_multiple=1,
|
||||
pad_to_bsz=None,
|
||||
):
|
||||
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
||||
size = max(v.size(0) for v in values)
|
||||
size = size if pad_to_length is None else max(size, pad_to_length)
|
||||
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
|
||||
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
|
||||
|
||||
batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz)
|
||||
res = values[0].new(batch_size, size).fill_(pad_idx)
|
||||
|
||||
def copy_tensor(src, dst):
|
||||
assert dst.numel() == src.numel()
|
||||
if move_eos_to_beginning:
|
||||
if eos_idx is None:
|
||||
# if no eos_idx is specified, then use the last token in src
|
||||
dst[0] = src[-1]
|
||||
else:
|
||||
dst[0] = eos_idx
|
||||
dst[1:] = src[:-1]
|
||||
else:
|
||||
dst.copy_(src)
|
||||
|
||||
for i, v in enumerate(values):
|
||||
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
|
||||
return res
|
||||
|
||||
|
||||
def load_indexed_dataset(
|
||||
path, dictionary=None, dataset_impl=None, combine=False, default="cached"
|
||||
):
|
||||
"""A helper function for loading indexed datasets.
|
||||
|
||||
Args:
|
||||
path (str): path to indexed dataset (e.g., 'data-bin/train')
|
||||
dictionary (~fairseq.data.Dictionary): data dictionary
|
||||
dataset_impl (str, optional): which dataset implementation to use. If
|
||||
not provided, it will be inferred automatically. For legacy indexed
|
||||
data we use the 'cached' implementation by default.
|
||||
combine (bool, optional): automatically load and combine multiple
|
||||
datasets. For example, if *path* is 'data-bin/train', then we will
|
||||
combine 'data-bin/train', 'data-bin/train1', ... and return a
|
||||
single ConcatDataset instance.
|
||||
"""
|
||||
import fairseq.data.indexed_dataset as indexed_dataset
|
||||
from fairseq.data.concat_dataset import ConcatDataset
|
||||
|
||||
datasets = []
|
||||
for k in itertools.count():
|
||||
path_k = path + (str(k) if k > 0 else "")
|
||||
try:
|
||||
path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
|
||||
except Exception as e:
|
||||
if "StorageException: [404] Path not found" in str(e):
|
||||
logger.warning(f"path_k: {e} not found")
|
||||
else:
|
||||
raise e
|
||||
|
||||
dataset_impl_k = dataset_impl
|
||||
if dataset_impl_k is None:
|
||||
dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
|
||||
dataset = indexed_dataset.make_dataset(
|
||||
path_k,
|
||||
impl=dataset_impl_k or default,
|
||||
fix_lua_indexing=True,
|
||||
dictionary=dictionary,
|
||||
)
|
||||
if dataset is None:
|
||||
break
|
||||
logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k))
|
||||
datasets.append(dataset)
|
||||
if not combine:
|
||||
break
|
||||
if len(datasets) == 0:
|
||||
return None
|
||||
elif len(datasets) == 1:
|
||||
return datasets[0]
|
||||
else:
|
||||
return ConcatDataset(datasets)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def numpy_seed(seed, *addl_seeds):
|
||||
"""Context manager which seeds the NumPy PRNG with the specified seed and
|
||||
restores the state afterward"""
|
||||
if seed is None:
|
||||
yield
|
||||
return
|
||||
if len(addl_seeds) > 0:
|
||||
seed = int(hash((seed, *addl_seeds)) % 1e6)
|
||||
state = np.random.get_state()
|
||||
np.random.seed(seed)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
np.random.set_state(state)
|
||||
|
||||
|
||||
def collect_filtered(function, iterable, filtered):
|
||||
"""
|
||||
Similar to :func:`filter` but collects filtered elements in ``filtered``.
|
||||
|
||||
Args:
|
||||
function (callable): function that returns ``False`` for elements that
|
||||
should be filtered
|
||||
iterable (iterable): iterable to filter
|
||||
filtered (list): list to store filtered elements
|
||||
"""
|
||||
for el in iterable:
|
||||
if function(el):
|
||||
yield el
|
||||
else:
|
||||
filtered.append(el)
|
||||
|
||||
|
||||
def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
|
||||
def compare_leq(a, b):
|
||||
return a <= b if not isinstance(a, tuple) else max(a) <= b
|
||||
|
||||
def check_size(idx):
|
||||
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
||||
return size_fn(idx) <= max_positions
|
||||
elif isinstance(max_positions, dict):
|
||||
idx_size = size_fn(idx)
|
||||
assert isinstance(idx_size, dict)
|
||||
intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
|
||||
return all(
|
||||
all(
|
||||
a is None or b is None or a <= b
|
||||
for a, b in zip(idx_size[key], max_positions[key])
|
||||
)
|
||||
for key in intersect_keys
|
||||
)
|
||||
else:
|
||||
# For MultiCorpusSampledDataset, will generalize it later
|
||||
if not isinstance(size_fn(idx), Iterable):
|
||||
return all(size_fn(idx) <= b for b in max_positions)
|
||||
return all(
|
||||
a is None or b is None or a <= b
|
||||
for a, b in zip(size_fn(idx), max_positions)
|
||||
)
|
||||
|
||||
ignored = []
|
||||
itr = collect_filtered(check_size, indices, ignored)
|
||||
indices = np.fromiter(itr, dtype=np.int64, count=-1)
|
||||
return indices, ignored
|
||||
|
||||
|
||||
def filter_by_size(indices, dataset, max_positions, raise_exception=False):
|
||||
"""
|
||||
[deprecated] Filter indices based on their size.
|
||||
Use `FairseqDataset::filter_indices_by_size` instead.
|
||||
|
||||
Args:
|
||||
indices (List[int]): ordered list of dataset indices
|
||||
dataset (FairseqDataset): fairseq dataset instance
|
||||
max_positions (tuple): filter elements larger than this size.
|
||||
Comparisons are done component-wise.
|
||||
raise_exception (bool, optional): if ``True``, raise an exception if
|
||||
any elements are filtered (default: False).
|
||||
"""
|
||||
warnings.warn(
|
||||
"data_utils.filter_by_size is deprecated. "
|
||||
"Use `FairseqDataset::filter_indices_by_size` instead.",
|
||||
stacklevel=2,
|
||||
)
|
||||
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
||||
if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
|
||||
ignored = indices[dataset.sizes[indices] > max_positions].tolist()
|
||||
indices = indices[dataset.sizes[indices] <= max_positions]
|
||||
elif (
|
||||
hasattr(dataset, "sizes")
|
||||
and isinstance(dataset.sizes, list)
|
||||
and len(dataset.sizes) == 1
|
||||
):
|
||||
ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
|
||||
indices = indices[dataset.sizes[0][indices] <= max_positions]
|
||||
else:
|
||||
indices, ignored = _filter_by_size_dynamic(
|
||||
indices, dataset.size, max_positions
|
||||
)
|
||||
else:
|
||||
indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
|
||||
|
||||
if len(ignored) > 0 and raise_exception:
|
||||
raise Exception(
|
||||
(
|
||||
"Size of sample #{} is invalid (={}) since max_positions={}, "
|
||||
"skip this example with --skip-invalid-size-inputs-valid-test"
|
||||
).format(ignored[0], dataset.size(ignored[0]), max_positions)
|
||||
)
|
||||
if len(ignored) > 0:
|
||||
logger.warning(
|
||||
(
|
||||
"{} samples have invalid sizes and will be skipped, "
|
||||
"max_positions={}, first few sample ids={}"
|
||||
).format(len(ignored), max_positions, ignored[:10])
|
||||
)
|
||||
return indices
|
||||
|
||||
|
||||
def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
|
||||
"""Filter a list of sample indices. Remove those that are longer
|
||||
than specified in max_sizes.
|
||||
|
||||
Args:
|
||||
indices (np.array): original array of sample indices
|
||||
max_sizes (int or list[int] or tuple[int]): max sample size,
|
||||
can be defined separately for src and tgt (then list or tuple)
|
||||
|
||||
Returns:
|
||||
np.array: filtered sample array
|
||||
list: list of removed indices
|
||||
"""
|
||||
if max_sizes is None:
|
||||
return indices, []
|
||||
if type(max_sizes) in (int, float):
|
||||
max_src_size, max_tgt_size = max_sizes, max_sizes
|
||||
else:
|
||||
max_src_size, max_tgt_size = max_sizes
|
||||
if tgt_sizes is None:
|
||||
ignored = indices[src_sizes[indices] > max_src_size]
|
||||
else:
|
||||
ignored = indices[
|
||||
(src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
|
||||
]
|
||||
if len(ignored) > 0:
|
||||
if tgt_sizes is None:
|
||||
indices = indices[src_sizes[indices] <= max_src_size]
|
||||
else:
|
||||
indices = indices[
|
||||
(src_sizes[indices] <= max_src_size)
|
||||
& (tgt_sizes[indices] <= max_tgt_size)
|
||||
]
|
||||
return indices, ignored.tolist()
|
||||
|
||||
|
||||
def batch_by_size(
|
||||
indices,
|
||||
num_tokens_fn,
|
||||
num_tokens_vec=None,
|
||||
max_tokens=None,
|
||||
max_sentences=None,
|
||||
required_batch_size_multiple=1,
|
||||
fixed_shapes=None,
|
||||
):
|
||||
"""
|
||||
Yield mini-batches of indices bucketed by size. Batches may contain
|
||||
sequences of different lengths.
|
||||
|
||||
Args:
|
||||
indices (List[int]): ordered list of dataset indices
|
||||
num_tokens_fn (callable): function that returns the number of tokens at
|
||||
a given index
|
||||
num_tokens_vec (List[int], optional): precomputed vector of the number
|
||||
of tokens for each index in indices (to enable faster batch generation)
|
||||
max_tokens (int, optional): max number of tokens in each batch
|
||||
(default: None).
|
||||
max_sentences (int, optional): max number of sentences in each
|
||||
batch (default: None).
|
||||
required_batch_size_multiple (int, optional): require batch size to
|
||||
be less than N or a multiple of N (default: 1).
|
||||
fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
|
||||
only be created with the given shapes. *max_sentences* and
|
||||
*required_batch_size_multiple* will be ignored (default: None).
|
||||
"""
|
||||
try:
|
||||
from fairseq.data.data_utils_fast import (
|
||||
batch_by_size_fn,
|
||||
batch_by_size_vec,
|
||||
batch_fixed_shapes_fast,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please build Cython components with: "
|
||||
"`python setup.py build_ext --inplace`"
|
||||
)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`."
|
||||
)
|
||||
|
||||
# added int() to avoid TypeError: an integer is required
|
||||
max_tokens = int(max_tokens) if max_tokens is not None else -1
|
||||
max_sentences = max_sentences if max_sentences is not None else -1
|
||||
bsz_mult = required_batch_size_multiple
|
||||
|
||||
if not isinstance(indices, np.ndarray):
|
||||
indices = np.fromiter(indices, dtype=np.int64, count=-1)
|
||||
|
||||
if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray):
|
||||
num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1)
|
||||
|
||||
if fixed_shapes is None:
|
||||
if num_tokens_vec is None:
|
||||
return batch_by_size_fn(
|
||||
indices,
|
||||
num_tokens_fn,
|
||||
max_tokens,
|
||||
max_sentences,
|
||||
bsz_mult,
|
||||
)
|
||||
else:
|
||||
return batch_by_size_vec(
|
||||
indices,
|
||||
num_tokens_vec,
|
||||
max_tokens,
|
||||
max_sentences,
|
||||
bsz_mult,
|
||||
)
|
||||
|
||||
else:
|
||||
fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
|
||||
sort_order = np.lexsort(
|
||||
[
|
||||
fixed_shapes[:, 1].argsort(), # length
|
||||
fixed_shapes[:, 0].argsort(), # bsz
|
||||
]
|
||||
)
|
||||
fixed_shapes_sorted = fixed_shapes[sort_order]
|
||||
return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
|
||||
|
||||
|
||||
def post_process(sentence: str, symbol: str):
|
||||
if symbol == "sentencepiece":
|
||||
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
|
||||
elif symbol == "wordpiece":
|
||||
sentence = sentence.replace(" ", "").replace("_", " ").strip()
|
||||
elif symbol == "letter":
|
||||
sentence = sentence.replace(" ", "").replace("|", " ").strip()
|
||||
elif symbol == "silence":
|
||||
import re
|
||||
|
||||
sentence = sentence.replace("<SIL>", "")
|
||||
sentence = re.sub(" +", " ", sentence).strip()
|
||||
elif symbol == "_EOW":
|
||||
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
|
||||
elif symbol in {"subword_nmt", "@@ ", "@@"}:
|
||||
if symbol == "subword_nmt":
|
||||
symbol = "@@ "
|
||||
sentence = (sentence + " ").replace(symbol, "").rstrip()
|
||||
elif symbol == "none":
|
||||
pass
|
||||
elif symbol is not None:
|
||||
raise NotImplementedError(f"Unknown post_process option: {symbol}")
|
||||
return sentence
|
||||
|
||||
|
||||
def compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
padding_mask: Optional[torch.Tensor],
|
||||
mask_prob: float,
|
||||
mask_length: int,
|
||||
mask_type: str = "static",
|
||||
mask_other: float = 0.0,
|
||||
min_masks: int = 0,
|
||||
no_overlap: bool = False,
|
||||
min_space: int = 0,
|
||||
require_same_masks: bool = True,
|
||||
mask_dropout: float = 0.0,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes random mask spans for a given shape
|
||||
|
||||
Args:
|
||||
shape: the the shape for which to compute masks.
|
||||
should be of size 2 where first element is batch size and 2nd is timesteps
|
||||
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
||||
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
||||
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
||||
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||||
mask_type: how to compute mask lengths
|
||||
static = fixed size
|
||||
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
||||
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
||||
poisson = sample from possion distribution with lambda = mask length
|
||||
min_masks: minimum number of masked spans
|
||||
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
||||
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
||||
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
|
||||
mask_dropout: randomly dropout this percentage of masks in each example
|
||||
"""
|
||||
|
||||
bsz, all_sz = shape
|
||||
mask = np.full((bsz, all_sz), False)
|
||||
|
||||
all_num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * all_sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
|
||||
all_num_mask = max(min_masks, all_num_mask)
|
||||
|
||||
mask_idcs = []
|
||||
for i in range(bsz):
|
||||
if padding_mask is not None:
|
||||
sz = all_sz - padding_mask[i].long().sum().item()
|
||||
num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
num_mask = max(min_masks, num_mask)
|
||||
else:
|
||||
sz = all_sz
|
||||
num_mask = all_num_mask
|
||||
|
||||
if mask_type == "static":
|
||||
lengths = np.full(num_mask, mask_length)
|
||||
elif mask_type == "uniform":
|
||||
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
||||
elif mask_type == "normal":
|
||||
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
||||
lengths = [max(1, int(round(x))) for x in lengths]
|
||||
elif mask_type == "poisson":
|
||||
lengths = np.random.poisson(mask_length, size=num_mask)
|
||||
lengths = [int(round(x)) for x in lengths]
|
||||
else:
|
||||
raise Exception("unknown mask selection " + mask_type)
|
||||
|
||||
if sum(lengths) == 0:
|
||||
lengths[0] = min(mask_length, sz - 1)
|
||||
|
||||
if no_overlap:
|
||||
mask_idc = []
|
||||
|
||||
def arrange(s, e, length, keep_length):
|
||||
span_start = np.random.randint(s, e - length)
|
||||
mask_idc.extend(span_start + i for i in range(length))
|
||||
|
||||
new_parts = []
|
||||
if span_start - s - min_space >= keep_length:
|
||||
new_parts.append((s, span_start - min_space + 1))
|
||||
if e - span_start - length - min_space > keep_length:
|
||||
new_parts.append((span_start + length + min_space, e))
|
||||
return new_parts
|
||||
|
||||
parts = [(0, sz)]
|
||||
min_length = min(lengths)
|
||||
for length in sorted(lengths, reverse=True):
|
||||
lens = np.fromiter(
|
||||
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
||||
np.int,
|
||||
)
|
||||
l_sum = np.sum(lens)
|
||||
if l_sum == 0:
|
||||
break
|
||||
probs = lens / np.sum(lens)
|
||||
c = np.random.choice(len(parts), p=probs)
|
||||
s, e = parts.pop(c)
|
||||
parts.extend(arrange(s, e, length, min_length))
|
||||
mask_idc = np.asarray(mask_idc)
|
||||
else:
|
||||
min_len = min(lengths)
|
||||
if sz - min_len <= num_mask:
|
||||
min_len = sz - num_mask - 1
|
||||
|
||||
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
||||
|
||||
mask_idc = np.asarray(
|
||||
[
|
||||
mask_idc[j] + offset
|
||||
for j in range(len(mask_idc))
|
||||
for offset in range(lengths[j])
|
||||
]
|
||||
)
|
||||
|
||||
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
||||
|
||||
min_len = min([len(m) for m in mask_idcs])
|
||||
for i, mask_idc in enumerate(mask_idcs):
|
||||
if len(mask_idc) > min_len and require_same_masks:
|
||||
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
||||
if mask_dropout > 0:
|
||||
num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int)
|
||||
mask_idc = np.random.choice(
|
||||
mask_idc, len(mask_idc) - num_holes, replace=False
|
||||
)
|
||||
|
||||
mask[i, mask_idc] = True
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def get_mem_usage():
|
||||
try:
|
||||
import psutil
|
||||
|
||||
mb = 1024 * 1024
|
||||
return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
|
||||
except ImportError:
|
||||
return "N/A"
|
||||
|
||||
|
||||
# lens: torch.LongTensor
|
||||
# returns: torch.BoolTensor
|
||||
def lengths_to_padding_mask(lens):
|
||||
bsz, max_lens = lens.size(0), torch.max(lens).item()
|
||||
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
|
||||
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
|
||||
return mask
|
||||
|
||||
|
||||
# lens: torch.LongTensor
|
||||
# returns: torch.BoolTensor
|
||||
def lengths_to_mask(lens):
|
||||
return ~lengths_to_padding_mask(lens)
|
||||
|
||||
|
||||
def get_buckets(sizes, num_buckets):
|
||||
buckets = np.unique(
|
||||
np.percentile(
|
||||
sizes,
|
||||
np.linspace(0, 100, num_buckets + 1),
|
||||
interpolation="lower",
|
||||
)[1:]
|
||||
)
|
||||
return buckets
|
||||
|
||||
|
||||
def get_bucketed_sizes(orig_sizes, buckets):
|
||||
sizes = np.copy(orig_sizes)
|
||||
assert np.min(sizes) >= 0
|
||||
start_val = -1
|
||||
for end_val in buckets:
|
||||
mask = (sizes > start_val) & (sizes <= end_val)
|
||||
sizes[mask] = end_val
|
||||
start_val = end_val
|
||||
return sizes
|
||||
|
||||
|
||||
def _find_extra_valid_paths(dataset_path: str) -> set:
|
||||
paths = utils.split_paths(dataset_path)
|
||||
all_valid_paths = set()
|
||||
for sub_dir in paths:
|
||||
contents = PathManager.ls(sub_dir)
|
||||
valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
|
||||
all_valid_paths |= {os.path.basename(p) for p in valid_paths}
|
||||
# Remove .bin, .idx etc
|
||||
roots = {os.path.splitext(p)[0] for p in all_valid_paths}
|
||||
return roots
|
||||
|
||||
|
||||
def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
|
||||
"""Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored."""
|
||||
if (
|
||||
train_cfg.dataset.ignore_unused_valid_subsets
|
||||
or train_cfg.dataset.combine_valid_subsets
|
||||
or train_cfg.dataset.disable_validation
|
||||
or not hasattr(train_cfg.task, "data")
|
||||
):
|
||||
return
|
||||
other_paths = _find_extra_valid_paths(train_cfg.task.data)
|
||||
specified_subsets = train_cfg.dataset.valid_subset.split(",")
|
||||
ignored_paths = [p for p in other_paths if p not in specified_subsets]
|
||||
if ignored_paths:
|
||||
advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
|
||||
msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
|
||||
raise ValueError(msg)
|
||||
178
modules/voice_conversion/fairseq/data/data_utils_fast.pyx
Normal file
178
modules/voice_conversion/fairseq/data/data_utils_fast.pyx
Normal file
@@ -0,0 +1,178 @@
|
||||
# cython: language_level=3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
|
||||
cimport cython
|
||||
cimport numpy as np
|
||||
|
||||
from libc.stdint cimport int32_t, int64_t
|
||||
from libcpp cimport bool as bool_t
|
||||
|
||||
ctypedef int64_t DTYPE_t
|
||||
|
||||
@cython.cdivision(True)
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cpdef list batch_by_size_vec(
|
||||
np.ndarray[int64_t, ndim=1] indices,
|
||||
np.ndarray[int64_t, ndim=1] num_tokens_vec,
|
||||
int64_t max_tokens,
|
||||
int64_t max_sentences,
|
||||
int32_t bsz_mult,
|
||||
):
|
||||
if indices.shape[0] == 0:
|
||||
return []
|
||||
|
||||
assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, (
|
||||
f"Sentences lengths should not exceed max_tokens={max_tokens}"
|
||||
)
|
||||
|
||||
cdef int32_t indices_len = indices.shape[0]
|
||||
cdef np.ndarray[int32_t, ndim=1] batches_ends = \
|
||||
np.zeros(indices_len, dtype=np.int32)
|
||||
cdef int32_t[:] batches_ends_view = batches_ends
|
||||
cdef int64_t[:] num_tokens_view = num_tokens_vec
|
||||
|
||||
cdef int32_t pos = 0
|
||||
cdef int32_t new_batch_end = 0
|
||||
|
||||
cdef int64_t new_batch_max_tokens = 0
|
||||
cdef int32_t new_batch_sentences = 0
|
||||
cdef int64_t new_batch_num_tokens = 0
|
||||
|
||||
cdef bool_t overflow = False
|
||||
cdef bool_t size_matches_with_bsz_mult = False
|
||||
|
||||
cdef int32_t batches_count = 0
|
||||
cdef int32_t batch_start = 0
|
||||
cdef int64_t tail_max_tokens = 0
|
||||
cdef int64_t batch_max_tokens = 0
|
||||
|
||||
for pos in range(indices_len):
|
||||
# At every pos we keep stats about the last complete batch [batch_start:batch_end),
|
||||
# and tail [batch_end:pos].
|
||||
# 1) Every time when (batch + tail) forms a valid batch
|
||||
# (according to max_tokens, max_sentences and bsz_mult) we append tail to batch.
|
||||
# 2) When (batch+tail) violates max_tokens or max_sentences constraints
|
||||
# we finalize running batch, and tail becomes a new batch.
|
||||
# 3) There is a corner case when tail also violates constraints.
|
||||
# In that situation [batch_end:pos-1] (tail without the current pos)
|
||||
# gets added to the finalized batches, while [pos:pos] becomes a new tail.
|
||||
#
|
||||
# Important: For the sake of performance try to avoid using function calls within this loop.
|
||||
|
||||
tail_max_tokens = tail_max_tokens \
|
||||
if tail_max_tokens > num_tokens_view[pos] \
|
||||
else num_tokens_view[pos]
|
||||
new_batch_end = pos + 1
|
||||
new_batch_max_tokens = batch_max_tokens \
|
||||
if batch_max_tokens > tail_max_tokens \
|
||||
else tail_max_tokens
|
||||
new_batch_sentences = new_batch_end - batch_start
|
||||
new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens
|
||||
|
||||
overflow = (new_batch_sentences > max_sentences > 0 or
|
||||
new_batch_num_tokens > max_tokens > 0)
|
||||
size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or
|
||||
new_batch_sentences % bsz_mult == 0)
|
||||
|
||||
if overflow:
|
||||
tail_num_tokens = tail_max_tokens * \
|
||||
(new_batch_end - batches_ends_view[batches_count])
|
||||
tail_overflow = tail_num_tokens > max_tokens > 0
|
||||
# In case of a tail overflow finalize two batches
|
||||
if tail_overflow:
|
||||
batches_count += 1
|
||||
batches_ends_view[batches_count] = pos
|
||||
tail_max_tokens = num_tokens_view[pos]
|
||||
batch_start = batches_ends_view[batches_count]
|
||||
batches_count += 1
|
||||
new_batch_max_tokens = tail_max_tokens
|
||||
|
||||
if overflow or size_matches_with_bsz_mult:
|
||||
batches_ends_view[batches_count] = new_batch_end
|
||||
batch_max_tokens = new_batch_max_tokens
|
||||
tail_max_tokens = 0
|
||||
if batches_ends_view[batches_count] != indices_len:
|
||||
batches_count += 1
|
||||
# Memory and time-efficient split
|
||||
return np.split(indices, batches_ends[:batches_count])
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cpdef list batch_by_size_fn(
|
||||
np.ndarray[DTYPE_t, ndim=1] indices,
|
||||
num_tokens_fn,
|
||||
int64_t max_tokens,
|
||||
int64_t max_sentences,
|
||||
int32_t bsz_mult,
|
||||
):
|
||||
cdef int32_t indices_len = indices.shape[0]
|
||||
cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len,
|
||||
dtype=np.int64)
|
||||
cdef DTYPE_t[:] indices_view = indices
|
||||
cdef DTYPE_t[:] num_tokens_vec_view = num_tokens_vec
|
||||
cdef int64_t pos
|
||||
for pos in range(indices_len):
|
||||
num_tokens_vec[pos] = num_tokens_fn(indices_view[pos])
|
||||
return batch_by_size_vec(indices, num_tokens_vec, max_tokens,
|
||||
max_sentences, bsz_mult,)
|
||||
|
||||
|
||||
cdef _find_valid_shape(
|
||||
DTYPE_t[:, :] shapes_view,
|
||||
int64_t num_sentences,
|
||||
int64_t num_tokens,
|
||||
):
|
||||
"""Return index of first valid shape of -1 if none is found."""
|
||||
for i in range(shapes_view.shape[0]):
|
||||
if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]:
|
||||
return i
|
||||
return -1
|
||||
|
||||
|
||||
@cython.cdivision(True)
|
||||
cpdef list batch_fixed_shapes_fast(
|
||||
np.ndarray[DTYPE_t, ndim=1] indices,
|
||||
num_tokens_fn,
|
||||
np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted,
|
||||
):
|
||||
cdef int64_t sample_len = 0
|
||||
cdef list sample_lens = []
|
||||
cdef list batch = []
|
||||
cdef list batches = []
|
||||
cdef int64_t mod_len
|
||||
cdef int64_t i
|
||||
cdef int64_t idx
|
||||
cdef int64_t num_tokens
|
||||
cdef DTYPE_t[:] indices_view = indices
|
||||
cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted
|
||||
|
||||
for i in range(len(indices_view)):
|
||||
idx = indices_view[i]
|
||||
num_tokens = num_tokens_fn(idx)
|
||||
sample_lens.append(num_tokens)
|
||||
sample_len = max(sample_len, num_tokens)
|
||||
|
||||
shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len)
|
||||
if shape_idx == -1:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
sample_lens = []
|
||||
sample_len = 0
|
||||
shapes_view = fixed_shapes_sorted
|
||||
elif shape_idx > 0:
|
||||
# small optimization for the next call to _find_valid_shape
|
||||
shapes_view = shapes_view[shape_idx:]
|
||||
|
||||
batch.append(idx)
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
|
||||
return batches
|
||||
443
modules/voice_conversion/fairseq/data/denoising_dataset.py
Normal file
443
modules/voice_conversion/fairseq/data/denoising_dataset.py
Normal file
@@ -0,0 +1,443 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import FairseqDataset, data_utils
|
||||
|
||||
|
||||
def collate(
|
||||
samples,
|
||||
pad_idx,
|
||||
eos_idx,
|
||||
vocab,
|
||||
left_pad_source=False,
|
||||
left_pad_target=False,
|
||||
input_feeding=True,
|
||||
pad_to_length=None,
|
||||
):
|
||||
assert input_feeding
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
|
||||
return data_utils.collate_tokens(
|
||||
[s[key] for s in samples],
|
||||
pad_idx,
|
||||
eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
|
||||
left_pad=left_pad,
|
||||
move_eos_to_beginning=move_eos_to_beginning,
|
||||
pad_to_length=pad_to_length,
|
||||
)
|
||||
|
||||
id = torch.LongTensor([s["id"] for s in samples])
|
||||
src_tokens = merge(
|
||||
"source",
|
||||
left_pad=left_pad_source,
|
||||
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
|
||||
)
|
||||
# sort by descending source length
|
||||
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
|
||||
src_lengths, sort_order = src_lengths.sort(descending=True)
|
||||
id = id.index_select(0, sort_order)
|
||||
src_tokens = src_tokens.index_select(0, sort_order)
|
||||
|
||||
prev_output_tokens = None
|
||||
target = None
|
||||
if samples[0].get("target", None) is not None:
|
||||
target = merge(
|
||||
"target",
|
||||
left_pad=left_pad_target,
|
||||
pad_to_length=pad_to_length["target"]
|
||||
if pad_to_length is not None
|
||||
else None,
|
||||
)
|
||||
target = target.index_select(0, sort_order)
|
||||
ntokens = sum(len(s["target"]) for s in samples)
|
||||
|
||||
if input_feeding:
|
||||
# we create a shifted version of targets for feeding the
|
||||
# previous output token(s) into the next decoder step
|
||||
prev_output_tokens = merge(
|
||||
"target",
|
||||
left_pad=left_pad_target,
|
||||
move_eos_to_beginning=True,
|
||||
pad_to_length=pad_to_length["target"]
|
||||
if pad_to_length is not None
|
||||
else None,
|
||||
)
|
||||
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
|
||||
else:
|
||||
ntokens = sum(len(s["source"]) for s in samples)
|
||||
|
||||
batch = {
|
||||
"id": id,
|
||||
"ntokens": ntokens,
|
||||
"net_input": {
|
||||
"src_tokens": src_tokens,
|
||||
"src_lengths": src_lengths,
|
||||
},
|
||||
"target": target,
|
||||
"nsentences": samples[0]["source"].size(0),
|
||||
"sort_order": sort_order,
|
||||
}
|
||||
if prev_output_tokens is not None:
|
||||
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class DenoisingDataset(FairseqDataset):
|
||||
"""
|
||||
A wrapper around TokenBlockDataset for BART dataset.
|
||||
|
||||
Args:
|
||||
dataset (TokenBlockDataset): dataset to wrap
|
||||
sizes (List[int]): sentence lengths
|
||||
vocab (~fairseq.data.Dictionary): vocabulary
|
||||
mask_idx (int): dictionary index used for masked token
|
||||
mask_whole_words: only mask whole words. This should be a byte mask
|
||||
over vocab indices, indicating whether it is the beginning of a
|
||||
word. We will extend any mask to encompass the whole word.
|
||||
shuffle (bool, optional): shuffle the elements before batching.
|
||||
Default: ``True``
|
||||
seed: Seed for random number generator for reproducibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
sizes,
|
||||
vocab,
|
||||
mask_idx,
|
||||
mask_whole_words,
|
||||
shuffle,
|
||||
seed,
|
||||
mask,
|
||||
mask_random,
|
||||
insert,
|
||||
rotate,
|
||||
permute_sentences,
|
||||
bpe,
|
||||
replace_length,
|
||||
mask_length,
|
||||
poisson_lambda,
|
||||
eos=None,
|
||||
item_transform_func=None,
|
||||
):
|
||||
self.dataset = dataset
|
||||
|
||||
self.sizes = sizes
|
||||
|
||||
self.vocab = vocab
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self.mask_idx = mask_idx
|
||||
self.mask_whole_word = mask_whole_words
|
||||
self.mask_ratio = mask
|
||||
self.random_ratio = mask_random
|
||||
self.insert_ratio = insert
|
||||
self.rotate_ratio = rotate
|
||||
self.permute_sentence_ratio = permute_sentences
|
||||
self.eos = eos if eos is not None else vocab.eos()
|
||||
self.item_transform_func = item_transform_func
|
||||
|
||||
if bpe != "gpt2":
|
||||
self.full_stop_index = self.vocab.eos()
|
||||
else:
|
||||
assert bpe == "gpt2"
|
||||
self.full_stop_index = self.vocab.index("13")
|
||||
|
||||
self.replace_length = replace_length
|
||||
if self.replace_length not in [-1, 0, 1]:
|
||||
raise ValueError(f"invalid arg: replace_length={self.replace_length}")
|
||||
if mask_length not in ["subword", "word", "span-poisson"]:
|
||||
raise ValueError(f"invalid arg: mask-length={mask_length}")
|
||||
if mask_length == "subword" and replace_length not in [0, 1]:
|
||||
raise ValueError(f"if using subwords, use replace-length=1 or 0")
|
||||
|
||||
self.mask_span_distribution = None
|
||||
if mask_length == "span-poisson":
|
||||
_lambda = poisson_lambda
|
||||
|
||||
lambda_to_the_k = 1
|
||||
e_to_the_minus_lambda = math.exp(-_lambda)
|
||||
k_factorial = 1
|
||||
ps = []
|
||||
for k in range(0, 128):
|
||||
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
|
||||
lambda_to_the_k *= _lambda
|
||||
k_factorial *= k + 1
|
||||
if ps[-1] < 0.0000001:
|
||||
break
|
||||
ps = torch.FloatTensor(ps)
|
||||
self.mask_span_distribution = torch.distributions.Categorical(ps)
|
||||
|
||||
self.epoch = 0
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return True # only the noise changes, not item sizes
|
||||
|
||||
def set_epoch(self, epoch, **unused):
|
||||
self.epoch = epoch
|
||||
|
||||
def __getitem__(self, index):
|
||||
with data_utils.numpy_seed(self.seed, self.epoch, index):
|
||||
tokens = self.dataset[index]
|
||||
assert tokens[-1] == self.eos
|
||||
source, target = tokens, tokens.clone()
|
||||
|
||||
if self.permute_sentence_ratio > 0.0:
|
||||
source = self.permute_sentences(source, self.permute_sentence_ratio)
|
||||
|
||||
if self.mask_ratio > 0:
|
||||
source = self.add_whole_word_mask(source, self.mask_ratio)
|
||||
|
||||
if self.insert_ratio > 0:
|
||||
source = self.add_insertion_noise(source, self.insert_ratio)
|
||||
|
||||
if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio:
|
||||
source = self.add_rolling_noise(source)
|
||||
# there can additional changes to make:
|
||||
if self.item_transform_func is not None:
|
||||
source, target = self.item_transform_func(source, target)
|
||||
|
||||
assert (source >= 0).all()
|
||||
assert (source[1:-1] >= 1).all()
|
||||
assert (source <= len(self.vocab)).all()
|
||||
assert source[0] == self.vocab.bos()
|
||||
assert source[-1] == self.eos
|
||||
return {
|
||||
"id": index,
|
||||
"source": source,
|
||||
"target": target,
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def permute_sentences(self, source, p=1.0):
|
||||
full_stops = source == self.full_stop_index
|
||||
# Pretend it ends with a full stop so last span is a sentence
|
||||
full_stops[-2] = 1
|
||||
|
||||
# Tokens that are full stops, where the previous token is not
|
||||
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2
|
||||
result = source.clone()
|
||||
|
||||
num_sentences = sentence_ends.size(0)
|
||||
num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0)
|
||||
substitutions = torch.randperm(num_sentences)[:num_to_permute]
|
||||
ordering = torch.arange(0, num_sentences)
|
||||
ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]
|
||||
|
||||
# Ignore <bos> at start
|
||||
index = 1
|
||||
for i in ordering:
|
||||
sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]]
|
||||
result[index : index + sentence.size(0)] = sentence
|
||||
index += sentence.size(0)
|
||||
return result
|
||||
|
||||
def word_starts(self, source):
|
||||
if self.mask_whole_word is not None:
|
||||
is_word_start = self.mask_whole_word.gather(0, source)
|
||||
else:
|
||||
is_word_start = torch.ones(source.size())
|
||||
is_word_start[0] = 0
|
||||
is_word_start[-1] = 0
|
||||
return is_word_start
|
||||
|
||||
def add_whole_word_mask(self, source, p):
|
||||
is_word_start = self.word_starts(source)
|
||||
num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
|
||||
num_inserts = 0
|
||||
if num_to_mask == 0:
|
||||
return source
|
||||
|
||||
if self.mask_span_distribution is not None:
|
||||
lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
|
||||
|
||||
# Make sure we have enough to mask
|
||||
cum_length = torch.cumsum(lengths, 0)
|
||||
while cum_length[-1] < num_to_mask:
|
||||
lengths = torch.cat(
|
||||
[
|
||||
lengths,
|
||||
self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
cum_length = torch.cumsum(lengths, 0)
|
||||
|
||||
# Trim to masking budget
|
||||
i = 0
|
||||
while cum_length[i] < num_to_mask:
|
||||
i += 1
|
||||
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
|
||||
num_to_mask = i + 1
|
||||
lengths = lengths[:num_to_mask]
|
||||
|
||||
# Handle 0-length mask (inserts) separately
|
||||
lengths = lengths[lengths > 0]
|
||||
num_inserts = num_to_mask - lengths.size(0)
|
||||
num_to_mask -= num_inserts
|
||||
if num_to_mask == 0:
|
||||
return self.add_insertion_noise(source, num_inserts / source.size(0))
|
||||
|
||||
assert (lengths > 0).all()
|
||||
else:
|
||||
lengths = torch.ones((num_to_mask,)).long()
|
||||
assert is_word_start[-1] == 0
|
||||
word_starts = is_word_start.nonzero(as_tuple=False)
|
||||
indices = word_starts[
|
||||
torch.randperm(word_starts.size(0))[:num_to_mask]
|
||||
].squeeze(1)
|
||||
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
|
||||
|
||||
source_length = source.size(0)
|
||||
assert source_length - 1 not in indices
|
||||
to_keep = torch.ones(source_length, dtype=torch.bool)
|
||||
is_word_start[
|
||||
-1
|
||||
] = 255 # acts as a long length, so spans don't go over the end of doc
|
||||
if self.replace_length == 0:
|
||||
to_keep[indices] = 0
|
||||
else:
|
||||
# keep index, but replace it with [MASK]
|
||||
source[indices] = self.mask_idx
|
||||
source[indices[mask_random]] = torch.randint(
|
||||
1, len(self.vocab), size=(mask_random.sum(),)
|
||||
)
|
||||
|
||||
if self.mask_span_distribution is not None:
|
||||
assert len(lengths.size()) == 1
|
||||
assert lengths.size() == indices.size()
|
||||
lengths -= 1
|
||||
while indices.size(0) > 0:
|
||||
assert lengths.size() == indices.size()
|
||||
lengths -= is_word_start[indices + 1].long()
|
||||
uncompleted = lengths >= 0
|
||||
indices = indices[uncompleted] + 1
|
||||
mask_random = mask_random[uncompleted]
|
||||
lengths = lengths[uncompleted]
|
||||
if self.replace_length != -1:
|
||||
# delete token
|
||||
to_keep[indices] = 0
|
||||
else:
|
||||
# keep index, but replace it with [MASK]
|
||||
source[indices] = self.mask_idx
|
||||
source[indices[mask_random]] = torch.randint(
|
||||
1, len(self.vocab), size=(mask_random.sum(),)
|
||||
)
|
||||
else:
|
||||
# A bit faster when all lengths are 1
|
||||
while indices.size(0) > 0:
|
||||
uncompleted = is_word_start[indices + 1] == 0
|
||||
indices = indices[uncompleted] + 1
|
||||
mask_random = mask_random[uncompleted]
|
||||
if self.replace_length != -1:
|
||||
# delete token
|
||||
to_keep[indices] = 0
|
||||
else:
|
||||
# keep index, but replace it with [MASK]
|
||||
source[indices] = self.mask_idx
|
||||
source[indices[mask_random]] = torch.randint(
|
||||
1, len(self.vocab), size=(mask_random.sum(),)
|
||||
)
|
||||
|
||||
assert source_length - 1 not in indices
|
||||
|
||||
source = source[to_keep]
|
||||
|
||||
if num_inserts > 0:
|
||||
source = self.add_insertion_noise(source, num_inserts / source.size(0))
|
||||
|
||||
return source
|
||||
|
||||
def add_permuted_noise(self, tokens, p):
|
||||
num_words = len(tokens)
|
||||
num_to_permute = math.ceil(((num_words * 2) * p) / 2.0)
|
||||
substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1
|
||||
tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]]
|
||||
return tokens
|
||||
|
||||
def add_rolling_noise(self, tokens):
|
||||
offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1)
|
||||
tokens = torch.cat(
|
||||
(tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]),
|
||||
dim=0,
|
||||
)
|
||||
return tokens
|
||||
|
||||
def add_insertion_noise(self, tokens, p):
|
||||
if p == 0.0:
|
||||
return tokens
|
||||
|
||||
num_tokens = len(tokens)
|
||||
n = int(math.ceil(num_tokens * p))
|
||||
|
||||
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
|
||||
noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
|
||||
noise_mask[noise_indices] = 1
|
||||
result = torch.LongTensor(n + len(tokens)).fill_(-1)
|
||||
|
||||
num_random = int(math.ceil(n * self.random_ratio))
|
||||
result[noise_indices[num_random:]] = self.mask_idx
|
||||
result[noise_indices[:num_random]] = torch.randint(
|
||||
low=1, high=len(self.vocab), size=(num_random,)
|
||||
)
|
||||
|
||||
result[~noise_mask] = tokens
|
||||
|
||||
assert (result >= 0).all()
|
||||
return result
|
||||
|
||||
def collater(self, samples, pad_to_length=None):
|
||||
"""Merge a list of samples to form a mini-batch.
|
||||
Args:
|
||||
samples (List[dict]): samples to collate
|
||||
Returns:
|
||||
dict: a mini-batch of data
|
||||
"""
|
||||
return collate(
|
||||
samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length
|
||||
)
|
||||
|
||||
def num_tokens(self, index):
|
||||
"""Return the number of tokens in a sample. This value is used to
|
||||
enforce ``--max-tokens`` during batching."""
|
||||
return self.sizes[index]
|
||||
|
||||
def size(self, index):
|
||||
"""Return an example's size as a float or tuple. This value is used when
|
||||
filtering a dataset with ``--max-positions``."""
|
||||
return self.sizes[index]
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
if self.shuffle:
|
||||
indices = np.random.permutation(len(self))
|
||||
else:
|
||||
indices = np.arange(len(self))
|
||||
return indices[np.argsort(self.sizes[indices], kind="mergesort")]
|
||||
|
||||
def prefetch(self, indices):
|
||||
self.src.prefetch(indices)
|
||||
self.tgt.prefetch(indices)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return (
|
||||
hasattr(self.src, "supports_prefetch")
|
||||
and self.src.supports_prefetch
|
||||
and hasattr(self.tgt, "supports_prefetch")
|
||||
and self.tgt.supports_prefetch
|
||||
)
|
||||
401
modules/voice_conversion/fairseq/data/dictionary.py
Normal file
401
modules/voice_conversion/fairseq/data/dictionary.py
Normal file
@@ -0,0 +1,401 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from collections import Counter
|
||||
from multiprocessing import Pool
|
||||
|
||||
import torch
|
||||
from fairseq import utils
|
||||
from fairseq.data import data_utils
|
||||
from fairseq.file_chunker_utils import Chunker, find_offsets
|
||||
from fairseq.file_io import PathManager
|
||||
from fairseq.tokenizer import tokenize_line
|
||||
|
||||
|
||||
class Dictionary:
|
||||
"""A mapping from symbols to consecutive integers"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*, # begin keyword-only arguments
|
||||
bos="<s>",
|
||||
pad="<pad>",
|
||||
eos="</s>",
|
||||
unk="<unk>",
|
||||
extra_special_symbols=None,
|
||||
):
|
||||
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
|
||||
self.symbols = []
|
||||
self.count = []
|
||||
self.indices = {}
|
||||
self.bos_index = self.add_symbol(bos)
|
||||
self.pad_index = self.add_symbol(pad)
|
||||
self.eos_index = self.add_symbol(eos)
|
||||
self.unk_index = self.add_symbol(unk)
|
||||
if extra_special_symbols:
|
||||
for s in extra_special_symbols:
|
||||
self.add_symbol(s)
|
||||
self.nspecial = len(self.symbols)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.indices == other.indices
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx < len(self.symbols):
|
||||
return self.symbols[idx]
|
||||
return self.unk_word
|
||||
|
||||
def get_count(self, idx):
|
||||
return self.count[idx]
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the number of symbols in the dictionary"""
|
||||
return len(self.symbols)
|
||||
|
||||
def __contains__(self, sym):
|
||||
return sym in self.indices
|
||||
|
||||
def index(self, sym):
|
||||
"""Returns the index of the specified symbol"""
|
||||
assert isinstance(sym, str)
|
||||
if sym in self.indices:
|
||||
return self.indices[sym]
|
||||
return self.unk_index
|
||||
|
||||
def string(
|
||||
self,
|
||||
tensor,
|
||||
bpe_symbol=None,
|
||||
escape_unk=False,
|
||||
extra_symbols_to_ignore=None,
|
||||
unk_string=None,
|
||||
include_eos=False,
|
||||
separator=" ",
|
||||
):
|
||||
"""Helper for converting a tensor of token indices to a string.
|
||||
|
||||
Can optionally remove BPE symbols or escape <unk> words.
|
||||
"""
|
||||
if torch.is_tensor(tensor) and tensor.dim() == 2:
|
||||
return "\n".join(
|
||||
self.string(
|
||||
t,
|
||||
bpe_symbol,
|
||||
escape_unk,
|
||||
extra_symbols_to_ignore,
|
||||
include_eos=include_eos,
|
||||
)
|
||||
for t in tensor
|
||||
)
|
||||
|
||||
extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
|
||||
if not include_eos:
|
||||
extra_symbols_to_ignore.add(self.eos())
|
||||
|
||||
def token_string(i):
|
||||
if i == self.unk():
|
||||
if unk_string is not None:
|
||||
return unk_string
|
||||
else:
|
||||
return self.unk_string(escape_unk)
|
||||
else:
|
||||
return self[i]
|
||||
|
||||
if hasattr(self, "bos_index"):
|
||||
extra_symbols_to_ignore.add(self.bos())
|
||||
|
||||
sent = separator.join(
|
||||
token_string(i)
|
||||
for i in tensor
|
||||
if utils.item(i) not in extra_symbols_to_ignore
|
||||
)
|
||||
|
||||
return data_utils.post_process(sent, bpe_symbol)
|
||||
|
||||
def unk_string(self, escape=False):
|
||||
"""Return unknown string, optionally escaped as: <<unk>>"""
|
||||
if escape:
|
||||
return "<{}>".format(self.unk_word)
|
||||
else:
|
||||
return self.unk_word
|
||||
|
||||
def add_symbol(self, word, n=1, overwrite=False):
|
||||
"""Adds a word to the dictionary"""
|
||||
if word in self.indices and not overwrite:
|
||||
idx = self.indices[word]
|
||||
self.count[idx] = self.count[idx] + n
|
||||
return idx
|
||||
else:
|
||||
idx = len(self.symbols)
|
||||
self.indices[word] = idx
|
||||
self.symbols.append(word)
|
||||
self.count.append(n)
|
||||
return idx
|
||||
|
||||
def update(self, new_dict):
|
||||
"""Updates counts from new dictionary."""
|
||||
for word in new_dict.symbols:
|
||||
idx2 = new_dict.indices[word]
|
||||
if word in self.indices:
|
||||
idx = self.indices[word]
|
||||
self.count[idx] = self.count[idx] + new_dict.count[idx2]
|
||||
else:
|
||||
idx = len(self.symbols)
|
||||
self.indices[word] = idx
|
||||
self.symbols.append(word)
|
||||
self.count.append(new_dict.count[idx2])
|
||||
|
||||
def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
|
||||
"""Sort symbols by frequency in descending order, ignoring special ones.
|
||||
|
||||
Args:
|
||||
- threshold defines the minimum word count
|
||||
- nwords defines the total number of words in the final dictionary,
|
||||
including special symbols
|
||||
- padding_factor can be used to pad the dictionary size to be a
|
||||
multiple of 8, which is important on some hardware (e.g., Nvidia
|
||||
Tensor Cores).
|
||||
"""
|
||||
if nwords <= 0:
|
||||
nwords = len(self)
|
||||
|
||||
new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial)))
|
||||
new_symbols = self.symbols[: self.nspecial]
|
||||
new_count = self.count[: self.nspecial]
|
||||
|
||||
c = Counter(
|
||||
dict(
|
||||
sorted(zip(self.symbols[self.nspecial :], self.count[self.nspecial :]))
|
||||
)
|
||||
)
|
||||
for symbol, count in c.most_common(nwords - self.nspecial):
|
||||
if count >= threshold:
|
||||
new_indices[symbol] = len(new_symbols)
|
||||
new_symbols.append(symbol)
|
||||
new_count.append(count)
|
||||
else:
|
||||
break
|
||||
|
||||
assert len(new_symbols) == len(new_indices)
|
||||
|
||||
self.count = list(new_count)
|
||||
self.symbols = list(new_symbols)
|
||||
self.indices = new_indices
|
||||
|
||||
self.pad_to_multiple_(padding_factor)
|
||||
|
||||
def pad_to_multiple_(self, padding_factor):
|
||||
"""Pad Dictionary size to be a multiple of *padding_factor*."""
|
||||
if padding_factor > 1:
|
||||
i = 0
|
||||
while len(self) % padding_factor != 0:
|
||||
symbol = "madeupword{:04d}".format(i)
|
||||
self.add_symbol(symbol, n=0)
|
||||
i += 1
|
||||
|
||||
def bos(self):
|
||||
"""Helper to get index of beginning-of-sentence symbol"""
|
||||
return self.bos_index
|
||||
|
||||
def pad(self):
|
||||
"""Helper to get index of pad symbol"""
|
||||
return self.pad_index
|
||||
|
||||
def eos(self):
|
||||
"""Helper to get index of end-of-sentence symbol"""
|
||||
return self.eos_index
|
||||
|
||||
def unk(self):
|
||||
"""Helper to get index of unk symbol"""
|
||||
return self.unk_index
|
||||
|
||||
@classmethod
|
||||
def load(cls, f):
|
||||
"""Loads the dictionary from a text file with the format:
|
||||
|
||||
```
|
||||
<symbol0> <count0>
|
||||
<symbol1> <count1>
|
||||
...
|
||||
```
|
||||
"""
|
||||
d = cls()
|
||||
d.add_from_file(f)
|
||||
return d
|
||||
|
||||
def add_from_file(self, f):
|
||||
"""
|
||||
Loads a pre-existing dictionary from a text file and adds its symbols
|
||||
to this instance.
|
||||
"""
|
||||
if isinstance(f, str):
|
||||
try:
|
||||
with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd:
|
||||
self.add_from_file(fd)
|
||||
except FileNotFoundError as fnfe:
|
||||
raise fnfe
|
||||
except UnicodeError:
|
||||
raise Exception(
|
||||
"Incorrect encoding detected in {}, please "
|
||||
"rebuild the dataset".format(f)
|
||||
)
|
||||
return
|
||||
|
||||
lines = f.readlines()
|
||||
indices_start_line = self._load_meta(lines)
|
||||
|
||||
for line in lines[indices_start_line:]:
|
||||
try:
|
||||
line, field = line.rstrip().rsplit(" ", 1)
|
||||
if field == "#fairseq:overwrite":
|
||||
overwrite = True
|
||||
line, field = line.rsplit(" ", 1)
|
||||
else:
|
||||
overwrite = False
|
||||
count = int(field)
|
||||
word = line
|
||||
if word in self and not overwrite:
|
||||
raise RuntimeError(
|
||||
"Duplicate word found when loading Dictionary: '{}'. "
|
||||
"Duplicate words can overwrite earlier ones by adding the "
|
||||
"#fairseq:overwrite flag at the end of the corresponding row "
|
||||
"in the dictionary file. If using the Camembert model, please "
|
||||
"download an updated copy of the model file.".format(word)
|
||||
)
|
||||
self.add_symbol(word, n=count, overwrite=overwrite)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Incorrect dictionary format, expected '<token> <cnt> [flags]': \"{line}\""
|
||||
)
|
||||
|
||||
def _save(self, f, kv_iterator):
|
||||
if isinstance(f, str):
|
||||
PathManager.mkdirs(os.path.dirname(f))
|
||||
with PathManager.open(f, "w", encoding="utf-8") as fd:
|
||||
return self.save(fd)
|
||||
for k, v in kv_iterator:
|
||||
print("{} {}".format(k, v), file=f)
|
||||
|
||||
def _get_meta(self):
|
||||
return [], []
|
||||
|
||||
def _load_meta(self, lines):
|
||||
return 0
|
||||
|
||||
def save(self, f):
|
||||
"""Stores dictionary into a text file"""
|
||||
ex_keys, ex_vals = self._get_meta()
|
||||
self._save(
|
||||
f,
|
||||
zip(
|
||||
ex_keys + self.symbols[self.nspecial :],
|
||||
ex_vals + self.count[self.nspecial :],
|
||||
),
|
||||
)
|
||||
|
||||
def dummy_sentence(self, length):
|
||||
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
|
||||
t[-1] = self.eos()
|
||||
return t
|
||||
|
||||
def encode_line(
|
||||
self,
|
||||
line,
|
||||
line_tokenizer=tokenize_line,
|
||||
add_if_not_exist=True,
|
||||
consumer=None,
|
||||
append_eos=True,
|
||||
reverse_order=False,
|
||||
) -> torch.IntTensor:
|
||||
words = line_tokenizer(line)
|
||||
if reverse_order:
|
||||
words = list(reversed(words))
|
||||
nwords = len(words)
|
||||
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
|
||||
|
||||
for i, word in enumerate(words):
|
||||
if add_if_not_exist:
|
||||
idx = self.add_symbol(word)
|
||||
else:
|
||||
idx = self.index(word)
|
||||
if consumer is not None:
|
||||
consumer(word, idx)
|
||||
ids[i] = idx
|
||||
if append_eos:
|
||||
ids[nwords] = self.eos_index
|
||||
return ids
|
||||
|
||||
@staticmethod
|
||||
def _add_file_to_dictionary_single_worker(
|
||||
filename,
|
||||
tokenize,
|
||||
eos_word,
|
||||
start_offset,
|
||||
end_offset,
|
||||
):
|
||||
counter = Counter()
|
||||
with Chunker(filename, start_offset, end_offset) as line_iterator:
|
||||
for line in line_iterator:
|
||||
for word in tokenize(line):
|
||||
counter.update([word])
|
||||
counter.update([eos_word])
|
||||
return counter
|
||||
|
||||
@staticmethod
|
||||
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
|
||||
def merge_result(counter):
|
||||
for w, c in sorted(counter.items()):
|
||||
dict.add_symbol(w, c)
|
||||
|
||||
local_file = PathManager.get_local_path(filename)
|
||||
offsets = find_offsets(local_file, num_workers)
|
||||
if num_workers > 1:
|
||||
chunks = zip(offsets, offsets[1:])
|
||||
pool = Pool(processes=num_workers)
|
||||
results = []
|
||||
for (start_offset, end_offset) in chunks:
|
||||
results.append(
|
||||
pool.apply_async(
|
||||
Dictionary._add_file_to_dictionary_single_worker,
|
||||
(
|
||||
local_file,
|
||||
tokenize,
|
||||
dict.eos_word,
|
||||
start_offset,
|
||||
end_offset,
|
||||
),
|
||||
)
|
||||
)
|
||||
pool.close()
|
||||
pool.join()
|
||||
for r in results:
|
||||
merge_result(r.get())
|
||||
else:
|
||||
merge_result(
|
||||
Dictionary._add_file_to_dictionary_single_worker(
|
||||
local_file, tokenize, dict.eos_word, offsets[0], offsets[1]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TruncatedDictionary(object):
|
||||
def __init__(self, wrapped_dict, length):
|
||||
self.__class__ = type(
|
||||
wrapped_dict.__class__.__name__,
|
||||
(self.__class__, wrapped_dict.__class__),
|
||||
{},
|
||||
)
|
||||
self.__dict__ = wrapped_dict.__dict__
|
||||
self.wrapped_dict = wrapped_dict
|
||||
self.length = min(len(self.wrapped_dict), length)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, i):
|
||||
if i < self.length:
|
||||
return self.wrapped_dict[i]
|
||||
return self.wrapped_dict.unk()
|
||||
29
modules/voice_conversion/fairseq/data/encoders/__init__.py
Normal file
29
modules/voice_conversion/fairseq/data/encoders/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
||||
from fairseq import registry
|
||||
|
||||
|
||||
build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry(
|
||||
"--tokenizer",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
build_bpe, register_bpe, BPE_REGISTRY, _ = registry.setup_registry(
|
||||
"--bpe",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
# automatically import any Python files in the encoders/ directory
|
||||
for file in sorted(os.listdir(os.path.dirname(__file__))):
|
||||
if file.endswith(".py") and not file.startswith("_"):
|
||||
module = file[: file.find(".py")]
|
||||
importlib.import_module("fairseq.data.encoders." + module)
|
||||
48
modules/voice_conversion/fairseq/data/encoders/byte_bpe.py
Normal file
48
modules/voice_conversion/fairseq/data/encoders/byte_bpe.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq import file_utils
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.data.encoders.byte_utils import (
|
||||
SPACE,
|
||||
SPACE_ESCAPE,
|
||||
byte_encode,
|
||||
smart_byte_decode,
|
||||
)
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ByteBpeConfig(FairseqDataclass):
|
||||
sentencepiece_model_path: str = field(
|
||||
default="???", metadata={"help": "path to sentencepiece model"}
|
||||
)
|
||||
|
||||
|
||||
@register_bpe("byte_bpe", dataclass=ByteBpeConfig)
|
||||
class ByteBPE(object):
|
||||
def __init__(self, cfg):
|
||||
vocab = file_utils.cached_path(cfg.sentencepiece_model_path)
|
||||
try:
|
||||
import sentencepiece as spm
|
||||
|
||||
self.sp = spm.SentencePieceProcessor()
|
||||
self.sp.Load(vocab)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install sentencepiece with: pip install sentencepiece"
|
||||
)
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
byte_encoded = byte_encode(x)
|
||||
return SPACE.join(self.sp.EncodeAsPieces(byte_encoded))
|
||||
|
||||
@staticmethod
|
||||
def decode(x: str) -> str:
|
||||
unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
|
||||
return smart_byte_decode(unescaped)
|
||||
51
modules/voice_conversion/fairseq/data/encoders/byte_utils.py
Normal file
51
modules/voice_conversion/fairseq/data/encoders/byte_utils.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import re
|
||||
|
||||
|
||||
WHITESPACE_NORMALIZER = re.compile(r"\s+")
|
||||
SPACE = chr(32)
|
||||
SPACE_ESCAPE = chr(9601)
|
||||
# excluding non-breaking space (160) here
|
||||
PRINTABLE_LATIN = set(
|
||||
list(range(32, 126 + 1)) + list(range(161, 172 + 1)) + list(range(174, 255 + 1))
|
||||
)
|
||||
BYTE_TO_BCHAR = {
|
||||
b: chr(b) if b in PRINTABLE_LATIN else chr(256 + b) for b in range(256)
|
||||
}
|
||||
BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
|
||||
|
||||
|
||||
def byte_encode(x: str) -> str:
|
||||
normalized = WHITESPACE_NORMALIZER.sub(SPACE, x)
|
||||
return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")])
|
||||
|
||||
|
||||
def byte_decode(x: str) -> str:
|
||||
try:
|
||||
return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8")
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
|
||||
def smart_byte_decode(x: str) -> str:
|
||||
output = byte_decode(x)
|
||||
if output == "":
|
||||
# DP the best recovery (max valid chars) if it's broken
|
||||
n_bytes = len(x)
|
||||
f = [0 for _ in range(n_bytes + 1)]
|
||||
pt = [0 for _ in range(n_bytes + 1)]
|
||||
for i in range(1, n_bytes + 1):
|
||||
f[i], pt[i] = f[i - 1], i - 1
|
||||
for j in range(1, min(4, i) + 1):
|
||||
if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0:
|
||||
f[i], pt[i] = f[i - j] + 1, i - j
|
||||
cur_pt = n_bytes
|
||||
while cur_pt > 0:
|
||||
if f[cur_pt] == f[pt[cur_pt]] + 1:
|
||||
output = byte_decode(x[pt[cur_pt] : cur_pt]) + output
|
||||
cur_pt = pt[cur_pt]
|
||||
return output
|
||||
34
modules/voice_conversion/fairseq/data/encoders/bytes.py
Normal file
34
modules/voice_conversion/fairseq/data/encoders/bytes.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.data.encoders.byte_utils import (
|
||||
SPACE,
|
||||
SPACE_ESCAPE,
|
||||
byte_encode,
|
||||
smart_byte_decode,
|
||||
)
|
||||
|
||||
|
||||
@register_bpe("bytes")
|
||||
class Bytes(object):
|
||||
def __init__(self, *unused):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def encode(x: str) -> str:
|
||||
encoded = byte_encode(x)
|
||||
escaped = encoded.replace(SPACE, SPACE_ESCAPE)
|
||||
return SPACE.join(list(escaped))
|
||||
|
||||
@staticmethod
|
||||
def decode(x: str) -> str:
|
||||
unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
|
||||
return smart_byte_decode(unescaped)
|
||||
30
modules/voice_conversion/fairseq/data/encoders/characters.py
Normal file
30
modules/voice_conversion/fairseq/data/encoders/characters.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from fairseq.data.encoders import register_bpe
|
||||
|
||||
|
||||
SPACE = chr(32)
|
||||
SPACE_ESCAPE = chr(9601)
|
||||
|
||||
|
||||
@register_bpe("characters")
|
||||
class Characters(object):
|
||||
def __init__(self, *unused):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def encode(x: str) -> str:
|
||||
escaped = x.replace(SPACE, SPACE_ESCAPE)
|
||||
return SPACE.join(list(escaped))
|
||||
|
||||
@staticmethod
|
||||
def decode(x: str) -> str:
|
||||
return x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
|
||||
36
modules/voice_conversion/fairseq/data/encoders/fastbpe.py
Normal file
36
modules/voice_conversion/fairseq/data/encoders/fastbpe.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq import file_utils
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class fastBPEConfig(FairseqDataclass):
|
||||
bpe_codes: str = field(default="???", metadata={"help": "path to fastBPE BPE"})
|
||||
|
||||
|
||||
@register_bpe("fastbpe", dataclass=fastBPEConfig)
|
||||
class fastBPE(object):
|
||||
def __init__(self, cfg):
|
||||
if cfg.bpe_codes is None:
|
||||
raise ValueError("--bpe-codes is required for --bpe=fastbpe")
|
||||
codes = file_utils.cached_path(cfg.bpe_codes)
|
||||
try:
|
||||
import fastBPE
|
||||
|
||||
self.bpe = fastBPE.fastBPE(codes)
|
||||
self.bpe_symbol = "@@ "
|
||||
except ImportError:
|
||||
raise ImportError("Please install fastBPE with: pip install fastBPE")
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
return self.bpe.apply([x])[0]
|
||||
|
||||
def decode(self, x: str) -> str:
|
||||
return (x + " ").replace(self.bpe_symbol, "").rstrip()
|
||||
45
modules/voice_conversion/fairseq/data/encoders/gpt2_bpe.py
Normal file
45
modules/voice_conversion/fairseq/data/encoders/gpt2_bpe.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq import file_utils
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
from .gpt2_bpe_utils import get_encoder
|
||||
|
||||
|
||||
DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
|
||||
DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPT2BPEConfig(FairseqDataclass):
|
||||
gpt2_encoder_json: str = field(
|
||||
default=DEFAULT_ENCODER_JSON, metadata={"help": "path to encoder.json"}
|
||||
)
|
||||
gpt2_vocab_bpe: str = field(
|
||||
default=DEFAULT_VOCAB_BPE, metadata={"help": "path to vocab.bpe"}
|
||||
)
|
||||
|
||||
|
||||
@register_bpe("gpt2", dataclass=GPT2BPEConfig)
|
||||
class GPT2BPE(object):
|
||||
def __init__(self, cfg):
|
||||
encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json)
|
||||
vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe)
|
||||
self.bpe = get_encoder(encoder_json, vocab_bpe)
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
return " ".join(map(str, self.bpe.encode(x)))
|
||||
|
||||
def decode(self, x: str) -> str:
|
||||
return self.bpe.decode(
|
||||
[int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]
|
||||
)
|
||||
|
||||
def is_beginning_of_word(self, x: str) -> bool:
|
||||
return self.decode(x).startswith(" ")
|
||||
140
modules/voice_conversion/fairseq/data/encoders/gpt2_bpe_utils.py
Normal file
140
modules/voice_conversion/fairseq/data/encoders/gpt2_bpe_utils.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Byte pair encoding utilities from GPT-2.
|
||||
|
||||
Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
||||
Original license: MIT
|
||||
"""
|
||||
|
||||
import json
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = (
|
||||
list(range(ord("!"), ord("~") + 1))
|
||||
+ list(range(ord("¡"), ord("¬") + 1))
|
||||
+ list(range(ord("®"), ord("ÿ") + 1))
|
||||
)
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8 + n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
class Encoder:
|
||||
def __init__(self, encoder, bpe_merges, errors="replace"):
|
||||
self.encoder = encoder
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.errors = errors # how to handle errors in decoding
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||
self.cache = {}
|
||||
|
||||
try:
|
||||
import regex as re
|
||||
|
||||
self.re = re
|
||||
except ImportError:
|
||||
raise ImportError("Please install regex with: pip install regex")
|
||||
|
||||
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
||||
self.pat = self.re.compile(
|
||||
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
||||
)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = " ".join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
for token in self.re.findall(self.pat, text):
|
||||
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
||||
bpe_tokens.extend(
|
||||
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
|
||||
)
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = "".join([self.decoder.get(token, token) for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode(
|
||||
"utf-8", errors=self.errors
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
def get_encoder(encoder_json_path, vocab_bpe_path):
|
||||
with open(encoder_json_path, "r") as f:
|
||||
encoder = json.load(f)
|
||||
with open(vocab_bpe_path, "r", encoding="utf-8") as f:
|
||||
bpe_data = f.read()
|
||||
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
|
||||
return Encoder(
|
||||
encoder=encoder,
|
||||
bpe_merges=bpe_merges,
|
||||
)
|
||||
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BertBPEConfig(FairseqDataclass):
|
||||
bpe_cased: bool = field(default=False, metadata={"help": "set for cased BPE"})
|
||||
bpe_vocab_file: Optional[str] = field(
|
||||
default=None, metadata={"help": "bpe vocab file"}
|
||||
)
|
||||
|
||||
|
||||
@register_bpe("bert", dataclass=BertBPEConfig)
|
||||
class BertBPE(object):
|
||||
def __init__(self, cfg):
|
||||
try:
|
||||
from transformers import BertTokenizer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install transformers with: pip install transformers"
|
||||
)
|
||||
|
||||
if cfg.bpe_vocab_file:
|
||||
self.bert_tokenizer = BertTokenizer(
|
||||
cfg.bpe_vocab_file, do_lower_case=not cfg.bpe_cased
|
||||
)
|
||||
else:
|
||||
vocab_file_name = (
|
||||
"bert-base-cased" if cfg.bpe_cased else "bert-base-uncased"
|
||||
)
|
||||
self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name)
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
return " ".join(self.bert_tokenizer.tokenize(x))
|
||||
|
||||
def decode(self, x: str) -> str:
|
||||
return self.bert_tokenizer.clean_up_tokenization(
|
||||
self.bert_tokenizer.convert_tokens_to_string(x.split(" "))
|
||||
)
|
||||
|
||||
def is_beginning_of_word(self, x: str) -> bool:
|
||||
return not x.startswith("##")
|
||||
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq import file_utils
|
||||
|
||||
|
||||
@dataclass
|
||||
class HuggingFaceByteLevelBPEConfig(FairseqDataclass):
|
||||
bpe_merges: str = field(default="???", metadata={"help": "path to merges.txt"})
|
||||
bpe_vocab: str = field(default="???", metadata={"help": "path to vocab.json"})
|
||||
bpe_add_prefix_space: bool = field(
|
||||
default=False, metadata={"help": "add prefix space before encoding"}
|
||||
)
|
||||
|
||||
|
||||
@register_bpe("hf_byte_bpe", dataclass=HuggingFaceByteLevelBPEConfig)
|
||||
class HuggingFaceByteLevelBPE(object):
|
||||
def __init__(self, cfg):
|
||||
try:
|
||||
from tokenizers import ByteLevelBPETokenizer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install huggingface/tokenizers with: " "pip install tokenizers"
|
||||
)
|
||||
|
||||
bpe_vocab = file_utils.cached_path(cfg.bpe_vocab)
|
||||
bpe_merges = file_utils.cached_path(cfg.bpe_merges)
|
||||
|
||||
self.bpe = ByteLevelBPETokenizer(
|
||||
bpe_vocab,
|
||||
bpe_merges,
|
||||
add_prefix_space=cfg.bpe_add_prefix_space,
|
||||
)
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
return " ".join(map(str, self.bpe.encode(x).ids))
|
||||
|
||||
def decode(self, x: str) -> str:
|
||||
return self.bpe.decode(
|
||||
[int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]
|
||||
)
|
||||
|
||||
def is_beginning_of_word(self, x: str) -> bool:
|
||||
return self.decode(x).startswith(" ")
|
||||
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq.data.encoders import register_tokenizer
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MosesTokenizerConfig(FairseqDataclass):
|
||||
source_lang: str = field(default="en", metadata={"help": "source language"})
|
||||
target_lang: str = field(default="en", metadata={"help": "target language"})
|
||||
moses_no_dash_splits: bool = field(
|
||||
default=False, metadata={"help": "don't apply dash split rules"}
|
||||
)
|
||||
moses_no_escape: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "don't perform HTML escaping on apostrophe, quotes, etc."},
|
||||
)
|
||||
|
||||
|
||||
@register_tokenizer("moses", dataclass=MosesTokenizerConfig)
|
||||
class MosesTokenizer(object):
|
||||
def __init__(self, cfg: MosesTokenizerConfig):
|
||||
self.cfg = cfg
|
||||
|
||||
try:
|
||||
from sacremoses import MosesTokenizer, MosesDetokenizer
|
||||
|
||||
self.tok = MosesTokenizer(cfg.source_lang)
|
||||
self.detok = MosesDetokenizer(cfg.target_lang)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install Moses tokenizer with: pip install sacremoses"
|
||||
)
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
return self.tok.tokenize(
|
||||
x,
|
||||
aggressive_dash_splits=(not self.cfg.moses_no_dash_splits),
|
||||
return_str=True,
|
||||
escape=(not self.cfg.moses_no_escape),
|
||||
)
|
||||
|
||||
def decode(self, x: str) -> str:
|
||||
return self.detok.detokenize(x.split())
|
||||
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from fairseq.data.encoders import register_tokenizer
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@register_tokenizer("nltk", dataclass=FairseqDataclass)
|
||||
class NLTKTokenizer(object):
|
||||
def __init__(self, *unused):
|
||||
try:
|
||||
from nltk.tokenize import word_tokenize
|
||||
|
||||
self.word_tokenize = word_tokenize
|
||||
except ImportError:
|
||||
raise ImportError("Please install nltk with: pip install nltk")
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
return " ".join(self.word_tokenize(x))
|
||||
|
||||
def decode(self, x: str) -> str:
|
||||
return x
|
||||
@@ -0,0 +1,65 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from fairseq import file_utils
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SentencepieceConfig(FairseqDataclass):
|
||||
sentencepiece_model: str = field(
|
||||
default="???", metadata={"help": "path to sentencepiece model"}
|
||||
)
|
||||
sentencepiece_enable_sampling: bool = field(
|
||||
default=False, metadata={"help": "enable sampling"}
|
||||
)
|
||||
sentencepiece_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "soothing parameter for unigram sampling, "
|
||||
"and merge probability for BPE-dropout"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@register_bpe("sentencepiece", dataclass=SentencepieceConfig)
|
||||
class SentencepieceBPE(object):
|
||||
def __init__(self, cfg):
|
||||
self.enable_sampling = cfg.sentencepiece_enable_sampling
|
||||
self.alpha = cfg.sentencepiece_alpha
|
||||
sentencepiece_model = file_utils.cached_path(cfg.sentencepiece_model)
|
||||
try:
|
||||
import sentencepiece as spm
|
||||
|
||||
self.sp = spm.SentencePieceProcessor()
|
||||
self.sp.Load(sentencepiece_model)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install sentencepiece with: pip install sentencepiece"
|
||||
)
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
return " ".join(
|
||||
self.sp.Encode(
|
||||
x, out_type=str, enable_sampling=self.enable_sampling, alpha=self.alpha
|
||||
)
|
||||
)
|
||||
|
||||
def decode(self, x: str) -> str:
|
||||
return x.replace(" ", "").replace("\u2581", " ").strip()
|
||||
|
||||
def is_beginning_of_word(self, x: str) -> bool:
|
||||
if x in ["<unk>", "<s>", "</s>", "<pad>"]:
|
||||
# special elements are always considered beginnings
|
||||
# HACK: this logic is already present in fairseq/tasks/masked_lm.py
|
||||
# but these special tokens are also contained in the sentencepiece
|
||||
# vocabulary which causes duplicate special tokens. This hack makes
|
||||
# sure that they are all taken into account.
|
||||
return True
|
||||
return x.startswith("\u2581")
|
||||
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import re
|
||||
|
||||
from fairseq.data.encoders import register_tokenizer
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@register_tokenizer("space", dataclass=FairseqDataclass)
|
||||
class SpaceTokenizer(object):
|
||||
def __init__(self, *unused):
|
||||
self.space_tok = re.compile(r"\s+")
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
return self.space_tok.sub(" ", x)
|
||||
|
||||
def decode(self, x: str) -> str:
|
||||
return x
|
||||
@@ -0,0 +1,54 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq import file_utils
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubwordNMTBPEConfig(FairseqDataclass):
|
||||
bpe_codes: str = field(default="???", metadata={"help": "path to subword NMT BPE"})
|
||||
bpe_separator: str = field(default="@@", metadata={"help": "BPE separator"})
|
||||
|
||||
|
||||
@register_bpe("subword_nmt", dataclass=SubwordNMTBPEConfig)
|
||||
class SubwordNMTBPE(object):
|
||||
def __init__(self, cfg):
|
||||
if cfg.bpe_codes is None:
|
||||
raise ValueError("--bpe-codes is required for --bpe=subword_nmt")
|
||||
codes = file_utils.cached_path(cfg.bpe_codes)
|
||||
try:
|
||||
from subword_nmt import apply_bpe
|
||||
|
||||
bpe_parser = apply_bpe.create_parser()
|
||||
bpe_args = bpe_parser.parse_args(
|
||||
[
|
||||
"--codes",
|
||||
codes,
|
||||
"--separator",
|
||||
cfg.bpe_separator,
|
||||
]
|
||||
)
|
||||
self.bpe = apply_bpe.BPE(
|
||||
bpe_args.codes,
|
||||
bpe_args.merges,
|
||||
bpe_args.separator,
|
||||
None,
|
||||
bpe_args.glossaries,
|
||||
)
|
||||
self.bpe_symbol = bpe_args.separator + " "
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install subword_nmt with: pip install subword-nmt"
|
||||
)
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
return self.bpe.process_line(x)
|
||||
|
||||
def decode(self, x: str) -> str:
|
||||
return (x + " ").replace(self.bpe_symbol, "").rstrip()
|
||||
30
modules/voice_conversion/fairseq/data/encoders/utils.py
Normal file
30
modules/voice_conversion/fairseq/data/encoders/utils.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from fairseq.data import encoders
|
||||
|
||||
|
||||
def get_whole_word_mask(args, dictionary):
|
||||
bpe = encoders.build_bpe(args)
|
||||
if bpe is not None:
|
||||
|
||||
def is_beginning_of_word(i):
|
||||
if i < dictionary.nspecial:
|
||||
# special elements are always considered beginnings
|
||||
return True
|
||||
tok = dictionary[i]
|
||||
if tok.startswith("madeupword"):
|
||||
return True
|
||||
try:
|
||||
return bpe.is_beginning_of_word(tok)
|
||||
except ValueError:
|
||||
return True
|
||||
|
||||
mask_whole_words = torch.ByteTensor(
|
||||
list(map(is_beginning_of_word, range(len(dictionary))))
|
||||
)
|
||||
return mask_whole_words
|
||||
return None
|
||||
205
modules/voice_conversion/fairseq/data/fairseq_dataset.py
Normal file
205
modules/voice_conversion/fairseq/data/fairseq_dataset.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch.utils.data
|
||||
from fairseq.data import data_utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpochListening:
|
||||
"""Mixin for receiving updates whenever the epoch increments."""
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
"""
|
||||
Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for
|
||||
this dataset across epochs.
|
||||
|
||||
This needs to return ``False`` if the sample sizes can change across
|
||||
epochs, in which case we may need to regenerate batches at each epoch.
|
||||
If your dataset relies in ``set_epoch`` then you should consider setting
|
||||
this to ``False``.
|
||||
"""
|
||||
return True
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
"""Will receive the updated epoch number at the beginning of the epoch."""
|
||||
pass
|
||||
|
||||
|
||||
class FairseqDataset(torch.utils.data.Dataset, EpochListening):
|
||||
"""A dataset that provides helpers for batching."""
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def collater(self, samples):
|
||||
"""Merge a list of samples to form a mini-batch.
|
||||
|
||||
Args:
|
||||
samples (List[dict]): samples to collate
|
||||
|
||||
Returns:
|
||||
dict: a mini-batch suitable for forwarding with a Model
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def num_tokens(self, index):
|
||||
"""Return the number of tokens in a sample. This value is used to
|
||||
enforce ``--max-tokens`` during batching."""
|
||||
raise NotImplementedError
|
||||
|
||||
def num_tokens_vec(self, indices):
|
||||
"""Return the number of tokens for a set of positions defined by indices.
|
||||
This value is used to enforce ``--max-tokens`` during batching."""
|
||||
raise NotImplementedError
|
||||
|
||||
def size(self, index):
|
||||
"""Return an example's size as a float or tuple. This value is used when
|
||||
filtering a dataset with ``--max-positions``."""
|
||||
raise NotImplementedError
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
return np.arange(len(self), dtype=np.int64)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
"""Whether this dataset supports prefetching."""
|
||||
return False
|
||||
|
||||
def attr(self, attr: str, index: int):
|
||||
return getattr(self, attr, None)
|
||||
|
||||
def prefetch(self, indices):
|
||||
"""Prefetch the data required for this epoch."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_batch_shapes(self):
|
||||
"""
|
||||
Return a list of valid batch shapes, for example::
|
||||
|
||||
[(8, 512), (16, 256), (32, 128)]
|
||||
|
||||
The first dimension of each tuple is the batch size and can be ``None``
|
||||
to automatically infer the max batch size based on ``--max-tokens``.
|
||||
The second dimension of each tuple is the max supported length as given
|
||||
by :func:`fairseq.data.FairseqDataset.num_tokens`.
|
||||
|
||||
This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size`
|
||||
to restrict batch shapes. This is useful on TPUs to avoid too many
|
||||
dynamic shapes (and recompilations).
|
||||
"""
|
||||
return None
|
||||
|
||||
def batch_by_size(
|
||||
self,
|
||||
indices,
|
||||
max_tokens=None,
|
||||
max_sentences=None,
|
||||
required_batch_size_multiple=1,
|
||||
):
|
||||
"""
|
||||
Given an ordered set of indices, return batches according to
|
||||
*max_tokens*, *max_sentences* and *required_batch_size_multiple*.
|
||||
"""
|
||||
from fairseq.data import data_utils
|
||||
|
||||
fixed_shapes = self.get_batch_shapes()
|
||||
if fixed_shapes is not None:
|
||||
|
||||
def adjust_bsz(bsz, num_tokens):
|
||||
if bsz is None:
|
||||
assert max_tokens is not None, "Must specify --max-tokens"
|
||||
bsz = max_tokens // num_tokens
|
||||
if max_sentences is not None:
|
||||
bsz = min(bsz, max_sentences)
|
||||
elif (
|
||||
bsz >= required_batch_size_multiple
|
||||
and bsz % required_batch_size_multiple != 0
|
||||
):
|
||||
bsz -= bsz % required_batch_size_multiple
|
||||
return bsz
|
||||
|
||||
fixed_shapes = np.array(
|
||||
[
|
||||
[adjust_bsz(bsz, num_tokens), num_tokens]
|
||||
for (bsz, num_tokens) in fixed_shapes
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
num_tokens_vec = self.num_tokens_vec(indices).astype("int64")
|
||||
except NotImplementedError:
|
||||
num_tokens_vec = None
|
||||
|
||||
return data_utils.batch_by_size(
|
||||
indices,
|
||||
num_tokens_fn=self.num_tokens,
|
||||
num_tokens_vec=num_tokens_vec,
|
||||
max_tokens=max_tokens,
|
||||
max_sentences=max_sentences,
|
||||
required_batch_size_multiple=required_batch_size_multiple,
|
||||
fixed_shapes=fixed_shapes,
|
||||
)
|
||||
|
||||
def filter_indices_by_size(self, indices, max_sizes):
|
||||
"""
|
||||
Filter a list of sample indices. Remove those that are longer than
|
||||
specified in *max_sizes*.
|
||||
|
||||
WARNING: don't update, override method in child classes
|
||||
|
||||
Args:
|
||||
indices (np.array): original array of sample indices
|
||||
max_sizes (int or list[int] or tuple[int]): max sample size,
|
||||
can be defined separately for src and tgt (then list or tuple)
|
||||
|
||||
Returns:
|
||||
np.array: filtered sample array
|
||||
list: list of removed indices
|
||||
"""
|
||||
if isinstance(max_sizes, float) or isinstance(max_sizes, int):
|
||||
if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray):
|
||||
ignored = indices[self.sizes[indices] > max_sizes].tolist()
|
||||
indices = indices[self.sizes[indices] <= max_sizes]
|
||||
elif (
|
||||
hasattr(self, "sizes")
|
||||
and isinstance(self.sizes, list)
|
||||
and len(self.sizes) == 1
|
||||
):
|
||||
ignored = indices[self.sizes[0][indices] > max_sizes].tolist()
|
||||
indices = indices[self.sizes[0][indices] <= max_sizes]
|
||||
else:
|
||||
indices, ignored = data_utils._filter_by_size_dynamic(
|
||||
indices, self.size, max_sizes
|
||||
)
|
||||
else:
|
||||
indices, ignored = data_utils._filter_by_size_dynamic(
|
||||
indices, self.size, max_sizes
|
||||
)
|
||||
return indices, ignored
|
||||
|
||||
@property
|
||||
def supports_fetch_outside_dataloader(self):
|
||||
"""Whether this dataset supports fetching outside the workers of the dataloader."""
|
||||
return True
|
||||
|
||||
|
||||
class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening):
|
||||
"""
|
||||
For datasets that need to be read sequentially, usually because the data is
|
||||
being streamed or otherwise can't be manipulated on a single machine.
|
||||
"""
|
||||
|
||||
def __iter__(self):
|
||||
raise NotImplementedError
|
||||
107
modules/voice_conversion/fairseq/data/fasta_dataset.py
Normal file
107
modules/voice_conversion/fairseq/data/fasta_dataset.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def fasta_file_path(prefix_path):
|
||||
return prefix_path + ".fasta"
|
||||
|
||||
|
||||
class FastaDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
For loading protein sequence datasets in the common FASTA data format
|
||||
"""
|
||||
|
||||
def __init__(self, path: str, cache_indices=False):
|
||||
self.fn = fasta_file_path(path)
|
||||
self.threadlocal = threading.local()
|
||||
self.cache = Path(f"{path}.fasta.idx.npy")
|
||||
if cache_indices:
|
||||
if self.cache.exists():
|
||||
self.offsets, self.sizes = np.load(self.cache)
|
||||
else:
|
||||
self.offsets, self.sizes = self._build_index(path)
|
||||
np.save(self.cache, np.stack([self.offsets, self.sizes]))
|
||||
else:
|
||||
self.offsets, self.sizes = self._build_index(path)
|
||||
|
||||
def _get_file(self):
|
||||
if not hasattr(self.threadlocal, "f"):
|
||||
self.threadlocal.f = open(self.fn, "r")
|
||||
return self.threadlocal.f
|
||||
|
||||
def __getitem__(self, idx):
|
||||
f = self._get_file()
|
||||
f.seek(self.offsets[idx])
|
||||
desc = f.readline().strip()
|
||||
line = f.readline()
|
||||
seq = ""
|
||||
while line != "" and line[0] != ">":
|
||||
seq += line.strip()
|
||||
line = f.readline()
|
||||
return desc, seq
|
||||
|
||||
def __len__(self):
|
||||
return self.offsets.size
|
||||
|
||||
def _build_index(self, path: str):
|
||||
# Use grep and awk to get 100M/s on local SSD.
|
||||
# Should process your enormous 100G fasta in ~10 min single core...
|
||||
path = fasta_file_path(path)
|
||||
bytes_offsets = subprocess.check_output(
|
||||
f"cat {path} | tqdm --bytes --total $(wc -c < {path})"
|
||||
"| grep --byte-offset '^>' -o | cut -d: -f1",
|
||||
shell=True,
|
||||
)
|
||||
fasta_lengths = subprocess.check_output(
|
||||
f"cat {path} | tqdm --bytes --total $(wc -c < {path})"
|
||||
"| awk '/^>/ {print \"\";next;} { printf(\"%s\",$0);}' | tail -n+2 | awk '{print length($1)}'",
|
||||
shell=True,
|
||||
)
|
||||
bytes_np = np.fromstring(bytes_offsets, dtype=np.int64, sep=" ")
|
||||
sizes_np = np.fromstring(fasta_lengths, dtype=np.int64, sep=" ")
|
||||
return bytes_np, sizes_np
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
self.threadlocal = threading.local()
|
||||
|
||||
def __getstate__(self):
|
||||
d = {}
|
||||
for i, v in self.__dict__.items():
|
||||
if i != "threadlocal":
|
||||
d[i] = v
|
||||
return d
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self.threadlocal, "f"):
|
||||
self.threadlocal.f.close()
|
||||
del self.threadlocal.f
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return os.path.exists(fasta_file_path(path))
|
||||
|
||||
|
||||
class EncodedFastaDataset(FastaDataset):
|
||||
"""
|
||||
The FastaDataset returns raw sequences - this allows us to return
|
||||
indices with a dictionary instead.
|
||||
"""
|
||||
|
||||
def __init__(self, path, dictionary):
|
||||
super().__init__(path, cache_indices=True)
|
||||
self.dictionary = dictionary
|
||||
|
||||
def __getitem__(self, idx):
|
||||
desc, seq = super().__getitem__(idx)
|
||||
return self.dictionary.encode_line(seq, line_tokenizer=list).long()
|
||||
21
modules/voice_conversion/fairseq/data/huffman/__init__.py
Normal file
21
modules/voice_conversion/fairseq/data/huffman/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .huffman_coder import HuffmanCodeBuilder, HuffmanCoder
|
||||
from .huffman_mmap_indexed_dataset import (
|
||||
HuffmanMMapIndex,
|
||||
HuffmanMMapIndexedDataset,
|
||||
HuffmanMMapIndexedDatasetBuilder,
|
||||
vocab_file_path,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"HuffmanCoder",
|
||||
"HuffmanCodeBuilder",
|
||||
"HuffmanMMapIndexedDatasetBuilder",
|
||||
"HuffmanMMapIndexedDataset",
|
||||
"HuffmanMMapIndex",
|
||||
"vocab_file_path",
|
||||
]
|
||||
267
modules/voice_conversion/fairseq/data/huffman/huffman_coder.py
Normal file
267
modules/voice_conversion/fairseq/data/huffman/huffman_coder.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import re
|
||||
import typing as tp
|
||||
from collections import Counter, deque
|
||||
from dataclasses import dataclass
|
||||
|
||||
from bitarray import bitarray, util
|
||||
from fairseq.data import Dictionary
|
||||
|
||||
# basically we have to write to addressable bytes for the memory mapped
|
||||
# dataset loader. Sentences that get encoded to a length that is not a
|
||||
# multiple of BLOCKSIZE (a byte) will be padded to fit. (see _pad in the coder)
|
||||
BLOCKSIZE = 8
|
||||
|
||||
|
||||
class HuffmanCoder:
|
||||
def __init__(
|
||||
self, root: "HuffmanNode", bos="<s>", pad="<pad>", eos="</s>", unk="<unk>"
|
||||
):
|
||||
self.root = root
|
||||
self.table = root.code_table()
|
||||
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
|
||||
|
||||
def _pad(self, a: bitarray) -> bitarray:
|
||||
"""
|
||||
bitpadding, 1 then 0.
|
||||
|
||||
If the array is already a multiple of blocksize, we add a full block.
|
||||
"""
|
||||
pad_len = BLOCKSIZE - (len(a) % BLOCKSIZE) - 1
|
||||
padding = bitarray("1" + "0" * pad_len)
|
||||
return a + padding
|
||||
|
||||
def _unpad(self, a: bitarray) -> bitarray:
|
||||
"""
|
||||
remove the bitpadding.
|
||||
|
||||
There will be a set of 0s preceded by a 1 at the end of the bitarray, we remove that
|
||||
"""
|
||||
# count the 0 padding at the end until we find the first 1
|
||||
# we want to remove the one too
|
||||
remove_cnt = util.rindex(a, 1)
|
||||
return a[:remove_cnt]
|
||||
|
||||
def encode(self, iter: tp.List[str]) -> bytes:
|
||||
"""
|
||||
encode a list of tokens a return bytes. We use bitpadding to make sure the encoded bits fit in bytes.
|
||||
"""
|
||||
a = bitarray()
|
||||
for token in iter:
|
||||
code = self.get_code(token)
|
||||
if code is None:
|
||||
if self.unk_word is None:
|
||||
raise Exception(f"unknown token {token} cannot be encoded.")
|
||||
else:
|
||||
token = self.unk_word
|
||||
a = a + self.get_code(token)
|
||||
return self._pad(a).tobytes()
|
||||
|
||||
def decode(self, bits: bytes) -> tp.Iterator["HuffmanNode"]:
|
||||
"""
|
||||
take bitpadded bytes and decode it to a set of leaves. You can then use each node to find the symbol/id
|
||||
"""
|
||||
a = bitarray()
|
||||
a.frombytes(bits)
|
||||
return self.root.decode(self._unpad(a))
|
||||
|
||||
def get_code(self, symbol: str) -> tp.Optional[bitarray]:
|
||||
node = self.get_node(symbol)
|
||||
return None if node is None else node.code
|
||||
|
||||
def get_node(self, symbol: str) -> "HuffmanNode":
|
||||
return self.table.get(symbol)
|
||||
|
||||
@classmethod
|
||||
def from_file(
|
||||
cls,
|
||||
filename: str,
|
||||
bos="<s>",
|
||||
pad="<pad>",
|
||||
eos="</s>",
|
||||
unk="<unk>",
|
||||
) -> "HuffmanCoder":
|
||||
builder = HuffmanCodeBuilder.from_file(filename)
|
||||
return builder.build_code(bos=bos, pad=pad, eos=eos, unk=unk)
|
||||
|
||||
def to_file(self, filename, sep="\t"):
|
||||
nodes = list(self.table.values())
|
||||
nodes.sort(key=lambda n: n.id)
|
||||
with open(filename, "w", encoding="utf-8") as output:
|
||||
for n in nodes:
|
||||
output.write(f"{n.symbol}{sep}{n.count}\n")
|
||||
|
||||
def __iter__(self):
|
||||
for n in self.table.values():
|
||||
yield n
|
||||
|
||||
def merge(self, other_coder: "HuffmanCoder") -> "HuffmanCoder":
|
||||
builder = HuffmanCodeBuilder()
|
||||
for n in self:
|
||||
builder.increment(n.symbol, n.count)
|
||||
for n in other_coder:
|
||||
builder.increment(n.symbol, n.count)
|
||||
return builder.build_code()
|
||||
|
||||
def __eq__(self, other: "HuffmanCoder") -> bool:
|
||||
return self.table == other.table
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.table)
|
||||
|
||||
def __contains__(self, sym: str) -> bool:
|
||||
return sym in self.table
|
||||
|
||||
def to_dictionary(self) -> Dictionary:
|
||||
dictionary = Dictionary(bos=self.bos, unk=self.unk, pad=self.pad, eos=self.eos)
|
||||
for n in self:
|
||||
dictionary.add_symbol(n.symbol, n=n.count)
|
||||
dictionary.finalize()
|
||||
return dictionary
|
||||
|
||||
|
||||
@dataclass
|
||||
class HuffmanNode:
|
||||
"""
|
||||
a node in a Huffman tree
|
||||
"""
|
||||
|
||||
id: int
|
||||
count: int
|
||||
symbol: tp.Optional[str] = None
|
||||
left: tp.Optional["HuffmanNode"] = None
|
||||
right: tp.Optional["HuffmanNode"] = None
|
||||
code: tp.Optional[bitarray] = None
|
||||
|
||||
def is_leaf(self) -> bool:
|
||||
return self.left is None and self.right is None
|
||||
|
||||
def code_table(
|
||||
self, prefix: tp.Optional[bitarray] = None
|
||||
) -> tp.Dict[str, "HuffmanNode"]:
|
||||
defaulted_prefix = prefix if prefix is not None else bitarray()
|
||||
if self.is_leaf():
|
||||
self.code = (
|
||||
defaulted_prefix if len(defaulted_prefix) > 0 else bitarray("0")
|
||||
) # leaf could be the root if there is only one symbol
|
||||
return {self.symbol: self}
|
||||
|
||||
codes_right = self.right.code_table(defaulted_prefix + bitarray([0]))
|
||||
codes_left = self.left.code_table(defaulted_prefix + bitarray([1]))
|
||||
return {**codes_left, **codes_right}
|
||||
|
||||
def decode(self, bits: bitarray) -> tp.Iterator["HuffmanNode"]:
|
||||
current_node = self
|
||||
for bit in bits:
|
||||
if bit == 0: # go right
|
||||
current_node = current_node.right
|
||||
else: # go left
|
||||
current_node = current_node.left
|
||||
if current_node is None:
|
||||
# we shouldn't be on a leaf here
|
||||
raise Exception("fell off a leaf")
|
||||
if current_node.is_leaf():
|
||||
yield current_node
|
||||
current_node = self
|
||||
if current_node != self:
|
||||
raise Exception("couldn't decode all the bits")
|
||||
|
||||
|
||||
class HuffmanCodeBuilder:
|
||||
"""
|
||||
build a dictionary with occurence count and then build the Huffman code for it.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.symbols = Counter()
|
||||
|
||||
def add_symbols(self, *syms) -> None:
|
||||
self.symbols.update(syms)
|
||||
|
||||
def increment(self, symbol: str, cnt: int) -> None:
|
||||
self.symbols[symbol] += cnt
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, filename):
|
||||
c = cls()
|
||||
with open(filename, "r", encoding="utf-8") as input:
|
||||
for line in input:
|
||||
split = re.split(r"[\s]+", line)
|
||||
c.increment(split[0], int(split[1]))
|
||||
return c
|
||||
|
||||
def to_file(self, filename, sep="\t"):
|
||||
with open(filename, "w", encoding="utf-8") as output:
|
||||
for (tok, cnt) in self.symbols.most_common():
|
||||
output.write(f"{tok}{sep}{cnt}\n")
|
||||
|
||||
def _smallest(self, q1: deque, q2: deque) -> HuffmanNode:
|
||||
if len(q1) == 0:
|
||||
return q2.pop()
|
||||
|
||||
if len(q2) == 0:
|
||||
return q1.pop()
|
||||
|
||||
if q1[-1].count < q2[-1].count:
|
||||
return q1.pop()
|
||||
|
||||
return q2.pop()
|
||||
|
||||
def __add__(self, c: "HuffmanCodeBuilder") -> "HuffmanCodeBuilder":
|
||||
new_c = self.symbols + c.symbols
|
||||
new_b = HuffmanCodeBuilder()
|
||||
new_b.symbols = new_c
|
||||
return new_b
|
||||
|
||||
def build_code(
|
||||
self,
|
||||
bos="<s>",
|
||||
pad="<pad>",
|
||||
eos="</s>",
|
||||
unk="<unk>",
|
||||
) -> HuffmanCoder:
|
||||
assert len(self.symbols) > 0, "cannot build code from empty list of symbols"
|
||||
|
||||
if self.symbols[bos] == 0:
|
||||
self.add_symbols(bos)
|
||||
if self.symbols[pad] == 0:
|
||||
self.add_symbols(pad)
|
||||
if self.symbols[eos] == 0:
|
||||
self.add_symbols(eos)
|
||||
if self.symbols[unk] == 0:
|
||||
self.add_symbols(unk)
|
||||
|
||||
node_id = 0
|
||||
leaves_queue = deque(
|
||||
[
|
||||
HuffmanNode(symbol=symbol, count=count, id=idx)
|
||||
for idx, (symbol, count) in enumerate(self.symbols.most_common())
|
||||
]
|
||||
) # left are the most common, right are the least common
|
||||
|
||||
if len(leaves_queue) == 1:
|
||||
root = leaves_queue.pop()
|
||||
root.id = 0
|
||||
return HuffmanCoder(root)
|
||||
|
||||
nodes_queue = deque()
|
||||
|
||||
while len(leaves_queue) > 0 or len(nodes_queue) != 1:
|
||||
# get the lowest two nodes at the head of each queue
|
||||
node1 = self._smallest(leaves_queue, nodes_queue)
|
||||
node2 = self._smallest(leaves_queue, nodes_queue)
|
||||
|
||||
# add new node
|
||||
nodes_queue.appendleft(
|
||||
HuffmanNode(
|
||||
count=node1.count + node2.count, left=node1, right=node2, id=node_id
|
||||
)
|
||||
)
|
||||
node_id += 1
|
||||
|
||||
# we are left with the root
|
||||
return HuffmanCoder(nodes_queue.pop(), bos=bos, pad=pad, eos=eos, unk=unk)
|
||||
@@ -0,0 +1,287 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import mmap
|
||||
import os
|
||||
import shutil
|
||||
import struct
|
||||
import typing as tp
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import indexed_dataset
|
||||
from fairseq.data.huffman import HuffmanCoder
|
||||
from fairseq.file_io import PathManager
|
||||
|
||||
|
||||
class HuffmanMMapIndex:
|
||||
"""
|
||||
keep an index of the offsets in the huffman binary file.
|
||||
First a header, then the list of sizes (num tokens) for each instance and finally
|
||||
the addresses of each instance.
|
||||
"""
|
||||
|
||||
_HDR_MAGIC = b"HUFFIDX\x00\x00"
|
||||
_VERSION = 1
|
||||
|
||||
@classmethod
|
||||
def writer(cls, path: str, data_len: int):
|
||||
class _Writer:
|
||||
def __enter__(self):
|
||||
self._file = open(path, "wb")
|
||||
|
||||
# write header (magic + version)
|
||||
self._file.write(cls._HDR_MAGIC)
|
||||
self._file.write(struct.pack("<Q", cls._VERSION))
|
||||
self._file.write(struct.pack("<Q", data_len))
|
||||
|
||||
return self
|
||||
|
||||
def write(self, sizes, pointers):
|
||||
# add number of items in the index to the header
|
||||
self._file.write(struct.pack("<Q", len(sizes)))
|
||||
|
||||
# write sizes
|
||||
sizes = np.array(sizes, dtype=np.int32)
|
||||
self._file.write(sizes.tobytes(order="C"))
|
||||
del sizes
|
||||
|
||||
# write address pointers
|
||||
pointers = np.array(pointers, dtype=np.int64)
|
||||
self._file.write(pointers.tobytes(order="C"))
|
||||
del pointers
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._file.close()
|
||||
|
||||
return _Writer()
|
||||
|
||||
def __init__(self, path):
|
||||
with open(path, "rb") as stream:
|
||||
# read headers
|
||||
magic_test = stream.read(9)
|
||||
assert self._HDR_MAGIC == magic_test, (
|
||||
"Index file doesn't match expected format. "
|
||||
"Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
(version,) = struct.unpack("<Q", stream.read(8))
|
||||
assert (
|
||||
self._VERSION == version
|
||||
), f"Unexpected file version{version} != code version {self._VERSION}"
|
||||
|
||||
# read length of data file
|
||||
(self._data_len,) = struct.unpack("<Q", stream.read(8))
|
||||
# read number of items in data file/index
|
||||
(self._len,) = struct.unpack("<Q", stream.read(8))
|
||||
offset = stream.tell()
|
||||
|
||||
indexed_dataset._warmup_mmap_file(path)
|
||||
|
||||
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
self._sizes = np.frombuffer(
|
||||
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
|
||||
)
|
||||
self._pointers = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._len,
|
||||
offset=offset + self._sizes.nbytes,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(self._len):
|
||||
yield self[i]
|
||||
|
||||
@property
|
||||
def data_len(self):
|
||||
return self._data_len
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
return self._pointers[i], self._sizes[i]
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
|
||||
def vocab_file_path(prefix_path):
|
||||
return prefix_path + ".vocab"
|
||||
|
||||
|
||||
class HuffmanMMapIndexedDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
an indexed dataset that use mmap and memoryview to access data from disk
|
||||
that was compressed with a HuffmanCoder.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix_path):
|
||||
super().__init__()
|
||||
|
||||
self._prefix_path = None
|
||||
self._index = None
|
||||
self._bin_buffer = None
|
||||
self._coder = None
|
||||
self._file = None
|
||||
|
||||
self._bin_buffer_mmap = None
|
||||
|
||||
self._do_init(prefix_path)
|
||||
|
||||
def __getstate__(self):
|
||||
return self._prefix_path
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._do_init(state)
|
||||
|
||||
def _do_init(self, prefix_path):
|
||||
self._prefix_path = prefix_path
|
||||
self._index = HuffmanMMapIndex(
|
||||
indexed_dataset.index_file_path(self._prefix_path)
|
||||
)
|
||||
self._coder = HuffmanCoder.from_file(vocab_file_path(self._prefix_path))
|
||||
|
||||
indexed_dataset._warmup_mmap_file(
|
||||
indexed_dataset.data_file_path(self._prefix_path)
|
||||
)
|
||||
self._file = os.open(
|
||||
indexed_dataset.data_file_path(self._prefix_path), os.O_RDONLY
|
||||
)
|
||||
self._bin_buffer_mmap = mmap.mmap(
|
||||
self._file,
|
||||
self._index.data_len,
|
||||
access=mmap.ACCESS_READ,
|
||||
)
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
|
||||
def __del__(self):
|
||||
del self._bin_buffer
|
||||
if self._file:
|
||||
os.close(self._file)
|
||||
del self._index
|
||||
|
||||
def __len__(self):
|
||||
return len(self._index)
|
||||
|
||||
def _decode(self, i):
|
||||
ptr, _ = self._index[i]
|
||||
if i == 0:
|
||||
raw_bytes = self._bin_buffer[:ptr]
|
||||
else:
|
||||
(prev_ptr, _) = self._index[i - 1]
|
||||
raw_bytes = self._bin_buffer[prev_ptr:ptr]
|
||||
|
||||
return self._coder.decode(raw_bytes.tobytes())
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
nodes = self._decode(i)
|
||||
return torch.tensor([n.id for n in nodes], dtype=torch.int64)
|
||||
|
||||
def __iter__(self):
|
||||
for idx in range(len(self)):
|
||||
yield self[idx]
|
||||
|
||||
def get_symbols(self, i):
|
||||
nodes = self._decode(i)
|
||||
for n in nodes:
|
||||
yield n.symbol
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._index.sizes
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def coder(self):
|
||||
return self._coder
|
||||
|
||||
@staticmethod
|
||||
def exists(prefix_path):
|
||||
return (
|
||||
PathManager.exists(indexed_dataset.index_file_path(prefix_path))
|
||||
and PathManager.exists(indexed_dataset.data_file_path(prefix_path))
|
||||
and PathManager.exists(vocab_file_path(prefix_path))
|
||||
)
|
||||
|
||||
|
||||
class HuffmanMMapIndexedDatasetBuilder:
|
||||
"""
|
||||
Helper to build a memory mapped datasets with a huffman encoder.
|
||||
You can either open/close this manually or use it as a ContextManager.
|
||||
Provide your own coder, it will then be stored alongside the dataset.
|
||||
The builder will first write the vocab file, then open the binary file so you can stream
|
||||
into it, finally the index will be written when the builder is closed (your index should fit in memory).
|
||||
"""
|
||||
|
||||
def __init__(self, path_prefix: str, coder: HuffmanCoder) -> None:
|
||||
self._path_prefix = path_prefix
|
||||
self._coder = coder
|
||||
self._sizes = []
|
||||
self._ptrs = []
|
||||
self._data_len = 0
|
||||
|
||||
def open(self):
|
||||
self._coder.to_file(vocab_file_path(self._path_prefix))
|
||||
self._data_file = open(indexed_dataset.data_file_path(self._path_prefix), "wb")
|
||||
|
||||
def __enter__(self) -> "HuffmanMMapIndexedDatasetBuilder":
|
||||
self.open()
|
||||
return self
|
||||
|
||||
def add_item(self, tokens: tp.List[str]) -> None:
|
||||
"""
|
||||
add a list of tokens to the dataset, they will compressed with the
|
||||
provided coder before being written to file.
|
||||
"""
|
||||
encoded = self._coder.encode(tokens)
|
||||
code_len = len(encoded)
|
||||
last_ptr = 0
|
||||
if len(self._ptrs) > 0:
|
||||
last_ptr = self._ptrs[-1]
|
||||
self._sizes.append(len(tokens))
|
||||
self._ptrs.append(last_ptr + code_len)
|
||||
self._data_len += code_len
|
||||
self._data_file.write(encoded)
|
||||
|
||||
def append(self, other_dataset_path_prefix: str) -> None:
|
||||
"""
|
||||
append an existing dataset.
|
||||
Beware, if it wasn't built with the same coder, you are in trouble.
|
||||
"""
|
||||
other_index = HuffmanMMapIndex(
|
||||
indexed_dataset.index_file_path(other_dataset_path_prefix)
|
||||
)
|
||||
for (ptr, size) in other_index:
|
||||
self._ptrs.append(ptr + self._data_len)
|
||||
self._sizes.append(size)
|
||||
|
||||
# Concatenate data
|
||||
with open(indexed_dataset.data_file_path(other_dataset_path_prefix), "rb") as f:
|
||||
shutil.copyfileobj(f, self._data_file)
|
||||
|
||||
self._data_len += other_index.data_len
|
||||
|
||||
def close(self):
|
||||
self._data_file.close()
|
||||
with HuffmanMMapIndex.writer(
|
||||
indexed_dataset.index_file_path(self._path_prefix), self._data_len
|
||||
) as index:
|
||||
index.write(self._sizes, self._ptrs)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
self.close()
|
||||
19
modules/voice_conversion/fairseq/data/id_dataset.py
Normal file
19
modules/voice_conversion/fairseq/data/id_dataset.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from . import FairseqDataset
|
||||
|
||||
|
||||
class IdDataset(FairseqDataset):
|
||||
def __getitem__(self, index):
|
||||
return index
|
||||
|
||||
def __len__(self):
|
||||
return 0
|
||||
|
||||
def collater(self, samples):
|
||||
return torch.tensor(samples)
|
||||
587
modules/voice_conversion/fairseq/data/indexed_dataset.py
Normal file
587
modules/voice_conversion/fairseq/data/indexed_dataset.py
Normal file
@@ -0,0 +1,587 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import shutil
|
||||
import struct
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.dataclass.constants import DATASET_IMPL_CHOICES
|
||||
from fairseq.data.fasta_dataset import FastaDataset
|
||||
from fairseq.file_io import PathManager
|
||||
from fairseq.data.huffman import HuffmanMMapIndexedDataset, HuffmanMMapIndex
|
||||
|
||||
from . import FairseqDataset
|
||||
|
||||
from typing import Union
|
||||
|
||||
|
||||
def best_fitting_int_dtype(
|
||||
max_int_to_represent,
|
||||
) -> Union[np.uint16, np.uint32, np.int64]:
|
||||
|
||||
if max_int_to_represent is None:
|
||||
return np.uint32 # Safe guess
|
||||
elif max_int_to_represent < 65500:
|
||||
return np.uint16
|
||||
elif max_int_to_represent < 4294967295:
|
||||
return np.uint32
|
||||
else:
|
||||
return np.int64
|
||||
# we avoid np.uint64 because it doesn't save space and its type promotion behaves unexpectedly
|
||||
# https://github.com/numpy/numpy/issues/5745
|
||||
|
||||
|
||||
def get_available_dataset_impl():
|
||||
return list(map(str, DATASET_IMPL_CHOICES))
|
||||
|
||||
|
||||
def infer_dataset_impl(path):
|
||||
if IndexedRawTextDataset.exists(path):
|
||||
return "raw"
|
||||
elif IndexedDataset.exists(path):
|
||||
with open(index_file_path(path), "rb") as f:
|
||||
magic = f.read(8)
|
||||
if magic == IndexedDataset._HDR_MAGIC:
|
||||
return "cached"
|
||||
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
|
||||
return "mmap"
|
||||
elif magic == HuffmanMMapIndex._HDR_MAGIC[:8]:
|
||||
return "huffman"
|
||||
else:
|
||||
return None
|
||||
elif FastaDataset.exists(path):
|
||||
return "fasta"
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def make_builder(out_file, impl, vocab_size=None):
|
||||
if impl == "mmap":
|
||||
return MMapIndexedDatasetBuilder(
|
||||
out_file, dtype=best_fitting_int_dtype(vocab_size)
|
||||
)
|
||||
elif impl == "fasta":
|
||||
raise NotImplementedError
|
||||
elif impl == "huffman":
|
||||
raise ValueError(
|
||||
"Use HuffmanCodeBuilder directly as it has a different interface."
|
||||
)
|
||||
else:
|
||||
return IndexedDatasetBuilder(out_file)
|
||||
|
||||
|
||||
def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None):
|
||||
if impl == "raw" and IndexedRawTextDataset.exists(path):
|
||||
assert dictionary is not None
|
||||
return IndexedRawTextDataset(path, dictionary)
|
||||
elif impl == "lazy" and IndexedDataset.exists(path):
|
||||
return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing)
|
||||
elif impl == "cached" and IndexedDataset.exists(path):
|
||||
return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing)
|
||||
elif impl == "mmap" and MMapIndexedDataset.exists(path):
|
||||
return MMapIndexedDataset(path)
|
||||
elif impl == "fasta" and FastaDataset.exists(path):
|
||||
from fairseq.data.fasta_dataset import EncodedFastaDataset
|
||||
|
||||
return EncodedFastaDataset(path, dictionary)
|
||||
elif impl == "huffman" and HuffmanMMapIndexedDataset.exists(path):
|
||||
return HuffmanMMapIndexedDataset(path)
|
||||
return None
|
||||
|
||||
|
||||
def dataset_exists(path, impl):
|
||||
if impl == "raw":
|
||||
return IndexedRawTextDataset.exists(path)
|
||||
elif impl == "mmap":
|
||||
return MMapIndexedDataset.exists(path)
|
||||
elif impl == "huffman":
|
||||
return HuffmanMMapIndexedDataset.exists(path)
|
||||
else:
|
||||
return IndexedDataset.exists(path)
|
||||
|
||||
|
||||
def read_longs(f, n):
|
||||
a = np.empty(n, dtype=np.int64)
|
||||
f.readinto(a)
|
||||
return a
|
||||
|
||||
|
||||
def write_longs(f, a):
|
||||
f.write(np.array(a, dtype=np.int64))
|
||||
|
||||
|
||||
_code_to_dtype = {
|
||||
1: np.uint8,
|
||||
2: np.int8,
|
||||
3: np.int16,
|
||||
4: np.int32,
|
||||
5: np.int64,
|
||||
6: np.float64,
|
||||
7: np.double,
|
||||
8: np.uint16,
|
||||
9: np.uint32,
|
||||
10: np.uint64,
|
||||
}
|
||||
|
||||
|
||||
def _dtype_header_code(dtype) -> int:
|
||||
for k in _code_to_dtype.keys():
|
||||
if _code_to_dtype[k] == dtype:
|
||||
return k
|
||||
raise ValueError(dtype)
|
||||
|
||||
|
||||
def index_file_path(prefix_path):
|
||||
return prefix_path + ".idx"
|
||||
|
||||
|
||||
def data_file_path(prefix_path):
|
||||
return prefix_path + ".bin"
|
||||
|
||||
|
||||
class IndexedDataset(FairseqDataset):
|
||||
"""Loader for TorchNet IndexedDataset"""
|
||||
|
||||
_HDR_MAGIC = b"TNTIDX\x00\x00"
|
||||
|
||||
def __init__(self, path, fix_lua_indexing=False):
|
||||
super().__init__()
|
||||
self.path = path
|
||||
self.fix_lua_indexing = fix_lua_indexing
|
||||
self.data_file = None
|
||||
self.read_index(path)
|
||||
|
||||
def read_index(self, path):
|
||||
with open(index_file_path(path), "rb") as f:
|
||||
magic = f.read(8)
|
||||
assert magic == self._HDR_MAGIC, (
|
||||
"Index file doesn't match expected format. "
|
||||
"Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
version = f.read(8)
|
||||
assert struct.unpack("<Q", version) == (1,)
|
||||
code, self.element_size = struct.unpack("<QQ", f.read(16))
|
||||
self.dtype = _code_to_dtype[code]
|
||||
self._len, self.s = struct.unpack("<QQ", f.read(16))
|
||||
self.dim_offsets = read_longs(f, self._len + 1)
|
||||
self.data_offsets = read_longs(f, self._len + 1)
|
||||
self.sizes = read_longs(f, self.s)
|
||||
|
||||
def read_data(self, path):
|
||||
self.data_file = open(data_file_path(path), "rb", buffering=0)
|
||||
|
||||
def check_index(self, i):
|
||||
if i < 0 or i >= self._len:
|
||||
raise IndexError("index out of range")
|
||||
|
||||
def __del__(self):
|
||||
if self.data_file:
|
||||
self.data_file.close()
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i) -> torch.Tensor:
|
||||
if not self.data_file:
|
||||
self.read_data(self.path)
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
item = torch.from_numpy(a).long()
|
||||
if self.fix_lua_indexing:
|
||||
item -= 1 # subtract 1 for 0-based indexing
|
||||
return item
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
def size(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return PathManager.exists(index_file_path(path)) and PathManager.exists(
|
||||
data_file_path(path)
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False # avoid prefetching to save memory
|
||||
|
||||
|
||||
class IndexedCachedDataset(IndexedDataset):
|
||||
def __init__(self, path, fix_lua_indexing=False):
|
||||
super().__init__(path, fix_lua_indexing=fix_lua_indexing)
|
||||
self.cache = None
|
||||
self.cache_index = {}
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return True
|
||||
|
||||
def prefetch(self, indices):
|
||||
if all(i in self.cache_index for i in indices):
|
||||
return
|
||||
if not self.data_file:
|
||||
self.read_data(self.path)
|
||||
indices = sorted(set(indices))
|
||||
total_size = 0
|
||||
for i in indices:
|
||||
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
|
||||
self.cache = np.empty(total_size, dtype=self.dtype)
|
||||
ptx = 0
|
||||
self.cache_index.clear()
|
||||
for i in indices:
|
||||
self.cache_index[i] = ptx
|
||||
size = self.data_offsets[i + 1] - self.data_offsets[i]
|
||||
a = self.cache[ptx : ptx + size]
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
ptx += size
|
||||
if self.data_file:
|
||||
# close and delete data file after prefetch so we can pickle
|
||||
self.data_file.close()
|
||||
self.data_file = None
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
ptx = self.cache_index[i]
|
||||
np.copyto(a, self.cache[ptx : ptx + a.size])
|
||||
item = torch.from_numpy(a).long()
|
||||
if self.fix_lua_indexing:
|
||||
item -= 1 # subtract 1 for 0-based indexing
|
||||
return item
|
||||
|
||||
|
||||
class IndexedRawTextDataset(FairseqDataset):
|
||||
"""Takes a text file as input and binarizes it in memory at instantiation.
|
||||
Original lines are also kept in memory"""
|
||||
|
||||
def __init__(self, path, dictionary, append_eos=True, reverse_order=False):
|
||||
self.tokens_list = []
|
||||
self.lines = []
|
||||
self.sizes = []
|
||||
self.append_eos = append_eos
|
||||
self.reverse_order = reverse_order
|
||||
self.read_data(path, dictionary)
|
||||
self.size = len(self.tokens_list)
|
||||
|
||||
def read_data(self, path, dictionary):
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
self.lines.append(line.strip("\n"))
|
||||
tokens = dictionary.encode_line(
|
||||
line,
|
||||
add_if_not_exist=False,
|
||||
append_eos=self.append_eos,
|
||||
reverse_order=self.reverse_order,
|
||||
).long()
|
||||
self.tokens_list.append(tokens)
|
||||
self.sizes.append(len(tokens))
|
||||
self.sizes = np.array(self.sizes)
|
||||
|
||||
def check_index(self, i):
|
||||
if i < 0 or i >= self.size:
|
||||
raise IndexError("index out of range")
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
self.check_index(i)
|
||||
return self.tokens_list[i]
|
||||
|
||||
def get_original_text(self, i):
|
||||
self.check_index(i)
|
||||
return self.lines[i]
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
def size(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return PathManager.exists(path)
|
||||
|
||||
|
||||
class IndexedDatasetBuilder:
|
||||
element_sizes = {
|
||||
np.uint8: 1,
|
||||
np.int8: 1,
|
||||
np.int16: 2,
|
||||
np.int32: 4,
|
||||
np.int64: 8,
|
||||
np.float64: 4,
|
||||
np.double: 8,
|
||||
}
|
||||
|
||||
def __init__(self, out_file, dtype=np.int32):
|
||||
self.out_file = open(out_file, "wb")
|
||||
self.dtype = dtype
|
||||
self.data_offsets = [0]
|
||||
self.dim_offsets = [0]
|
||||
self.sizes = []
|
||||
self.element_size = self.element_sizes[self.dtype]
|
||||
|
||||
def add_item(self, tensor):
|
||||
# +1 for Lua compatibility
|
||||
bytes = self.out_file.write(np.array(tensor.numpy() + 1, dtype=self.dtype))
|
||||
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
|
||||
for s in tensor.size():
|
||||
self.sizes.append(s)
|
||||
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
|
||||
|
||||
def merge_file_(self, another_file):
|
||||
index = IndexedDataset(another_file)
|
||||
assert index.dtype == self.dtype
|
||||
|
||||
begin = self.data_offsets[-1]
|
||||
for offset in index.data_offsets[1:]:
|
||||
self.data_offsets.append(begin + offset)
|
||||
self.sizes.extend(index.sizes)
|
||||
begin = self.dim_offsets[-1]
|
||||
for dim_offset in index.dim_offsets[1:]:
|
||||
self.dim_offsets.append(begin + dim_offset)
|
||||
|
||||
with open(data_file_path(another_file), "rb") as f:
|
||||
while True:
|
||||
data = f.read(1024)
|
||||
if data:
|
||||
self.out_file.write(data)
|
||||
else:
|
||||
break
|
||||
|
||||
def finalize(self, index_file):
|
||||
self.out_file.close()
|
||||
index = open(index_file, "wb")
|
||||
index.write(b"TNTIDX\x00\x00")
|
||||
index.write(struct.pack("<Q", 1))
|
||||
index.write(
|
||||
struct.pack("<QQ", _dtype_header_code(self.dtype), self.element_size)
|
||||
)
|
||||
index.write(struct.pack("<QQ", len(self.data_offsets) - 1, len(self.sizes)))
|
||||
write_longs(index, self.dim_offsets)
|
||||
write_longs(index, self.data_offsets)
|
||||
write_longs(index, self.sizes)
|
||||
index.close()
|
||||
|
||||
|
||||
def _warmup_mmap_file(path):
|
||||
with open(path, "rb") as stream:
|
||||
while stream.read(100 * 1024 * 1024):
|
||||
pass
|
||||
|
||||
|
||||
class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
class Index:
|
||||
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
||||
|
||||
@classmethod
|
||||
def writer(cls, path, dtype):
|
||||
class _Writer:
|
||||
def __enter__(self):
|
||||
self._file = open(path, "wb")
|
||||
|
||||
self._file.write(cls._HDR_MAGIC)
|
||||
self._file.write(struct.pack("<Q", 1))
|
||||
self._file.write(struct.pack("<B", _dtype_header_code(dtype)))
|
||||
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _get_pointers(sizes):
|
||||
dtype_size = dtype().itemsize
|
||||
address = 0
|
||||
pointers = []
|
||||
|
||||
for size in sizes:
|
||||
pointers.append(address)
|
||||
address += size * dtype_size
|
||||
|
||||
return pointers
|
||||
|
||||
def write(self, sizes):
|
||||
pointers = self._get_pointers(sizes)
|
||||
|
||||
self._file.write(struct.pack("<Q", len(sizes)))
|
||||
|
||||
sizes = np.array(sizes, dtype=np.int32)
|
||||
self._file.write(sizes.tobytes(order="C"))
|
||||
del sizes
|
||||
|
||||
pointers = np.array(pointers, dtype=np.int64)
|
||||
self._file.write(pointers.tobytes(order="C"))
|
||||
del pointers
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._file.close()
|
||||
|
||||
return _Writer()
|
||||
|
||||
def __init__(self, path):
|
||||
with open(path, "rb") as stream:
|
||||
magic_test = stream.read(9)
|
||||
assert self._HDR_MAGIC == magic_test, (
|
||||
"Index file doesn't match expected format. "
|
||||
"Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
version = struct.unpack("<Q", stream.read(8))
|
||||
assert (1,) == version
|
||||
|
||||
(dtype_code,) = struct.unpack("<B", stream.read(1))
|
||||
self._dtype = _code_to_dtype[dtype_code]
|
||||
self._dtype_size = self._dtype().itemsize
|
||||
|
||||
self._len = struct.unpack("<Q", stream.read(8))[0]
|
||||
offset = stream.tell()
|
||||
|
||||
_warmup_mmap_file(path)
|
||||
|
||||
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
self._sizes = np.frombuffer(
|
||||
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
|
||||
)
|
||||
self._pointers = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._len,
|
||||
offset=offset + self._sizes.nbytes,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
return self._pointers[i], self._sizes[i]
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def __init__(self, path):
|
||||
super().__init__()
|
||||
|
||||
self._path = None
|
||||
self._index = None
|
||||
self._bin_buffer = None
|
||||
|
||||
self._do_init(path)
|
||||
|
||||
def __getstate__(self):
|
||||
return self._path
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._do_init(state)
|
||||
|
||||
def _do_init(self, path):
|
||||
self._path = path
|
||||
self._index = self.Index(index_file_path(self._path))
|
||||
|
||||
_warmup_mmap_file(data_file_path(self._path))
|
||||
self._bin_buffer_mmap = np.memmap(
|
||||
data_file_path(self._path), mode="r", order="C"
|
||||
)
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
del self._index
|
||||
|
||||
def __len__(self):
|
||||
return len(self._index)
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
ptr, size = self._index[i]
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
||||
)
|
||||
if self._index.dtype != np.int64:
|
||||
np_array = np_array.astype(np.int64)
|
||||
|
||||
return torch.from_numpy(np_array)
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._index.sizes
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return PathManager.exists(index_file_path(path)) and PathManager.exists(
|
||||
data_file_path(path)
|
||||
)
|
||||
|
||||
|
||||
def get_indexed_dataset_to_local(path) -> str:
|
||||
local_index_path = PathManager.get_local_path(index_file_path(path))
|
||||
local_data_path = PathManager.get_local_path(data_file_path(path))
|
||||
|
||||
assert local_index_path.endswith(".idx") and local_data_path.endswith(".bin"), (
|
||||
"PathManager.get_local_path does not return files with expected patterns: "
|
||||
f"{local_index_path} and {local_data_path}"
|
||||
)
|
||||
|
||||
local_path = local_data_path[:-4] # stripping surfix ".bin"
|
||||
assert local_path == local_index_path[:-4] # stripping surfix ".idx"
|
||||
return local_path
|
||||
|
||||
|
||||
class MMapIndexedDatasetBuilder:
|
||||
def __init__(self, out_file, dtype=np.int64):
|
||||
self._data_file = open(out_file, "wb")
|
||||
self._dtype = dtype
|
||||
self._sizes = []
|
||||
|
||||
def add_item(self, tensor):
|
||||
np_array = np.array(tensor.numpy(), dtype=self._dtype)
|
||||
self._data_file.write(np_array.tobytes(order="C"))
|
||||
self._sizes.append(np_array.size)
|
||||
|
||||
def merge_file_(self, another_file):
|
||||
# Concatenate index
|
||||
index = MMapIndexedDataset.Index(index_file_path(another_file))
|
||||
assert index.dtype == self._dtype
|
||||
|
||||
for size in index.sizes:
|
||||
self._sizes.append(size)
|
||||
|
||||
# Concatenate data
|
||||
with open(data_file_path(another_file), "rb") as f:
|
||||
shutil.copyfileobj(f, self._data_file)
|
||||
|
||||
def finalize(self, index_file):
|
||||
self._data_file.close()
|
||||
|
||||
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
|
||||
index.write(self._sizes)
|
||||
883
modules/voice_conversion/fairseq/data/iterators.py
Normal file
883
modules/voice_conversion/fairseq/data/iterators.py
Normal file
@@ -0,0 +1,883 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import operator
|
||||
import os
|
||||
import queue
|
||||
import time
|
||||
from threading import Thread
|
||||
from typing import Iterator, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import data_utils
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Object used by _background_consumer to signal the source is exhausted
|
||||
# to the main thread.
|
||||
_sentinel = object()
|
||||
|
||||
|
||||
class CountingIterator(object):
|
||||
"""Wrapper around an iterable that maintains the iteration count.
|
||||
|
||||
Args:
|
||||
iterable (iterable): iterable to wrap
|
||||
start (int): starting iteration count. Note that this doesn't
|
||||
actually advance the iterator.
|
||||
total (int): override the iterator length returned by ``__len``.
|
||||
This can be used to truncate *iterator*.
|
||||
|
||||
Attributes:
|
||||
n (int): number of elements consumed from this iterator
|
||||
"""
|
||||
|
||||
def __init__(self, iterable, start=None, total=None):
|
||||
self._itr = iter(iterable)
|
||||
self.n = start or getattr(iterable, "n", 0)
|
||||
self.total = total if total is not None else self.n + len(iterable)
|
||||
|
||||
def __len__(self):
|
||||
return self.total
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if not self.has_next():
|
||||
raise StopIteration
|
||||
try:
|
||||
x = next(self._itr)
|
||||
except StopIteration:
|
||||
raise IndexError(
|
||||
f"Iterator expected to have length {self.total}, "
|
||||
f"but exhausted at position {self.n}."
|
||||
)
|
||||
self.n += 1
|
||||
return x
|
||||
|
||||
def has_next(self):
|
||||
"""Whether the iterator has been exhausted."""
|
||||
return self.n < self.total
|
||||
|
||||
def skip(self, n):
|
||||
"""Fast-forward the iterator by skipping n elements."""
|
||||
for _ in range(n):
|
||||
next(self)
|
||||
return self
|
||||
|
||||
def take(self, n):
|
||||
"""Truncate the iterator to n elements at most."""
|
||||
self.total = min(self.total, n)
|
||||
# Propagate this change to the underlying iterator
|
||||
if hasattr(self._itr, "take"):
|
||||
self._itr.take(max(n - self.n, 0))
|
||||
return self
|
||||
|
||||
|
||||
class EpochBatchIterating(object):
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def next_epoch_idx(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def next_epoch_itr(
|
||||
self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True
|
||||
):
|
||||
"""Return a new iterator over the dataset.
|
||||
|
||||
Args:
|
||||
shuffle (bool, optional): shuffle batches before returning the
|
||||
iterator (default: True).
|
||||
fix_batches_to_gpus (bool, optional): ensure that batches are always
|
||||
allocated to the same shards across epochs. Requires
|
||||
that :attr:`dataset` supports prefetching (default: False).
|
||||
set_dataset_epoch (bool, optional): update the wrapped Dataset with
|
||||
the new epoch number (default: True).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def end_of_epoch(self) -> bool:
|
||||
"""Returns whether the most recent epoch iterator has been exhausted"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def iterations_in_epoch(self) -> int:
|
||||
"""The number of consumed batches in the current epoch."""
|
||||
raise NotImplementedError
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns a dictionary containing a whole state of the iterator."""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Copies the state of the iterator from the given *state_dict*."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def first_batch(self):
|
||||
return "DUMMY"
|
||||
|
||||
|
||||
class StreamingEpochBatchIterator(EpochBatchIterating):
|
||||
"""A steaming-style iterator over a :class:`torch.utils.data.IterableDataset`.
|
||||
|
||||
Args:
|
||||
dataset (~torch.utils.data.Dataset): dataset from which to load the data
|
||||
max_sentences: batch size
|
||||
collate_fn (callable): merges a list of samples to form a mini-batch
|
||||
num_workers (int, optional): how many subprocesses to use for data
|
||||
loading. 0 means the data will be loaded in the main process
|
||||
(default: 0).
|
||||
epoch (int, optional): the epoch to start the iterator from
|
||||
(default: 1).
|
||||
buffer_size (int, optional): the number of batches to keep ready in the
|
||||
queue. Helps speeding up dataloading. When buffer_size is zero, the
|
||||
default torch.utils.data.DataLoader preloading is used.
|
||||
timeout (int, optional): if positive, the timeout value for collecting a batch
|
||||
from workers. Should always be non-negative (default: ``0``).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
max_sentences=1,
|
||||
collate_fn=None,
|
||||
epoch=1,
|
||||
num_workers=0,
|
||||
buffer_size=0,
|
||||
timeout=0,
|
||||
persistent_workers=False,
|
||||
):
|
||||
assert isinstance(dataset, torch.utils.data.IterableDataset)
|
||||
self.dataset = dataset
|
||||
self.max_sentences = max_sentences
|
||||
self.collate_fn = collate_fn
|
||||
self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
|
||||
self.num_workers = num_workers
|
||||
# This upper limit here is to prevent people from abusing this feature
|
||||
# in a shared computing environment.
|
||||
self.buffer_size = min(buffer_size, 20)
|
||||
self.timeout = timeout
|
||||
self.persistent_workers = persistent_workers
|
||||
|
||||
self._current_epoch_iterator = None
|
||||
|
||||
@property
|
||||
def next_epoch_idx(self):
|
||||
"""Return the epoch index after *next_epoch_itr* is called."""
|
||||
if self._current_epoch_iterator is not None and self.end_of_epoch():
|
||||
return self.epoch + 1
|
||||
else:
|
||||
return self.epoch
|
||||
|
||||
def next_epoch_itr(
|
||||
self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True
|
||||
):
|
||||
self.epoch = self.next_epoch_idx
|
||||
if set_dataset_epoch and hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(self.epoch)
|
||||
self._current_epoch_iterator = self._get_iterator_for_epoch(self.epoch, shuffle)
|
||||
return self._current_epoch_iterator
|
||||
|
||||
def end_of_epoch(self) -> bool:
|
||||
return not self._current_epoch_iterator.has_next()
|
||||
|
||||
@property
|
||||
def iterations_in_epoch(self) -> int:
|
||||
if self._current_epoch_iterator is not None:
|
||||
return self._current_epoch_iterator.n
|
||||
return 0
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
"epoch": self.epoch,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.epoch = state_dict["epoch"]
|
||||
|
||||
def _get_iterator_for_epoch(self, epoch, shuffle, offset=0):
|
||||
if self.num_workers > 0:
|
||||
os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"
|
||||
|
||||
# Create data loader
|
||||
worker_init_fn = getattr(self.dataset, "worker_init_fn", None)
|
||||
itr = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.max_sentences,
|
||||
collate_fn=self.collate_fn,
|
||||
num_workers=self.num_workers,
|
||||
timeout=self.timeout,
|
||||
worker_init_fn=worker_init_fn,
|
||||
pin_memory=True,
|
||||
persistent_workers=self.persistent_workers,
|
||||
)
|
||||
|
||||
# Wrap with a BufferedIterator if needed
|
||||
if self.buffer_size > 0:
|
||||
itr = BufferedIterator(self.buffer_size, itr)
|
||||
|
||||
# Wrap with CountingIterator
|
||||
itr = CountingIterator(itr, start=offset)
|
||||
|
||||
return itr
|
||||
|
||||
|
||||
class FrozenBatchSampler:
|
||||
def __init__(
|
||||
self,
|
||||
ordered_batches,
|
||||
epoch,
|
||||
fix_batches_to_gpus,
|
||||
shuffle,
|
||||
initial_offset,
|
||||
):
|
||||
self.ordered_batches = ordered_batches
|
||||
self.fix_batches_to_gpus = fix_batches_to_gpus
|
||||
self.shuffle = shuffle
|
||||
self.make_batches_for_epoch(epoch, initial_offset)
|
||||
|
||||
def make_batches_for_epoch(self, epoch, offset=0):
|
||||
self.batches = self.ordered_batches(
|
||||
epoch, self.fix_batches_to_gpus, self.shuffle
|
||||
)
|
||||
if offset > 0:
|
||||
self.batches = self.batches[offset:]
|
||||
|
||||
def __iter__(self) -> Iterator[List[int]]:
|
||||
return iter(self.batches)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.batches)
|
||||
|
||||
|
||||
class EpochBatchIterator(EpochBatchIterating):
|
||||
"""A multi-epoch iterator over a :class:`torch.utils.data.Dataset`.
|
||||
|
||||
Compared to :class:`torch.utils.data.DataLoader`, this iterator:
|
||||
|
||||
- can be reused across multiple epochs with the :func:`next_epoch_itr`
|
||||
method (optionally shuffled between epochs)
|
||||
- can be serialized/deserialized with the :func:`state_dict` and
|
||||
:func:`load_state_dict` methods
|
||||
- supports sharding with the *num_shards* and *shard_id* arguments
|
||||
|
||||
Args:
|
||||
dataset (~torch.utils.data.Dataset): dataset from which to load the data
|
||||
collate_fn (callable): merges a list of samples to form a mini-batch
|
||||
batch_sampler (~torch.utils.data.Sampler or a callable): an iterator over batches of
|
||||
indices, or a callable to create such an iterator (~torch.utils.data.Sampler).
|
||||
A callable batch_sampler will be called for each epoch to enable per epoch dynamic
|
||||
batch iterators defined by this callable batch_sampler.
|
||||
seed (int, optional): seed for random number generator for
|
||||
reproducibility (default: 1).
|
||||
num_shards (int, optional): shard the data iterator into N
|
||||
shards (default: 1).
|
||||
shard_id (int, optional): which shard of the data iterator to
|
||||
return (default: 0).
|
||||
num_workers (int, optional): how many subprocesses to use for data
|
||||
loading. 0 means the data will be loaded in the main process
|
||||
(default: 0).
|
||||
epoch (int, optional): the epoch to start the iterator from
|
||||
(default: 1).
|
||||
buffer_size (int, optional): the number of batches to keep ready in the
|
||||
queue. Helps speeding up dataloading. When buffer_size is zero, the
|
||||
default torch.utils.data.DataLoader preloading is used.
|
||||
timeout (int, optional): if positive, the timeout value for collecting a batch
|
||||
from workers. Should always be non-negative (default: ``0``).
|
||||
disable_shuffling (bool, optional): force disable shuffling
|
||||
(default: ``False``).
|
||||
skip_remainder_batch (bool, optional): if set, discard the last batch in an epoch
|
||||
for the sake of training stability, as the last batch is usually smaller than
|
||||
local_batch_size * distributed_word_size (default: ``False``).
|
||||
grouped_shuffling (bool, optional): enable shuffling batches in groups
|
||||
of num_shards. Ensures that each GPU receives similar length sequences when
|
||||
batches are sorted by length.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
collate_fn,
|
||||
batch_sampler,
|
||||
seed=1,
|
||||
num_shards=1,
|
||||
shard_id=0,
|
||||
num_workers=0,
|
||||
epoch=1,
|
||||
buffer_size=0,
|
||||
timeout=0,
|
||||
disable_shuffling=False,
|
||||
skip_remainder_batch=False,
|
||||
grouped_shuffling=False,
|
||||
reuse_dataloader=False,
|
||||
persistent_workers=False,
|
||||
):
|
||||
assert isinstance(dataset, torch.utils.data.Dataset)
|
||||
self.dataset = dataset
|
||||
self.collate_fn = collate_fn
|
||||
self.batch_sampler = batch_sampler
|
||||
self._frozen_batches = (
|
||||
tuple(batch_sampler) if not callable(batch_sampler) else None
|
||||
)
|
||||
self.seed = seed
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
self.num_workers = num_workers
|
||||
# This upper limit here is to prevent people from abusing this feature
|
||||
# in a shared computing environment.
|
||||
self.buffer_size = min(buffer_size, 20)
|
||||
self.timeout = timeout
|
||||
self.disable_shuffling = disable_shuffling
|
||||
self.skip_remainder_batch = skip_remainder_batch
|
||||
self.grouped_shuffling = grouped_shuffling
|
||||
|
||||
self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
|
||||
self.shuffle = not disable_shuffling
|
||||
self._cur_epoch_itr = None
|
||||
self._next_epoch_itr = None
|
||||
self._supports_prefetch = getattr(dataset, "supports_prefetch", False)
|
||||
|
||||
self.dataloader = None
|
||||
self.reuse_dataloader = reuse_dataloader
|
||||
self.persistent_workers = persistent_workers
|
||||
|
||||
@property
|
||||
def frozen_batches(self):
|
||||
if self._frozen_batches is None:
|
||||
self._frozen_batches = tuple(self.batch_sampler(self.dataset, self.epoch))
|
||||
return self._frozen_batches
|
||||
|
||||
@property
|
||||
def first_batch(self):
|
||||
if len(self.frozen_batches) == 0:
|
||||
raise Exception(
|
||||
"The dataset is empty. This could indicate "
|
||||
"that all elements in the dataset have been skipped. "
|
||||
"Try increasing the max number of allowed tokens or using "
|
||||
"a larger dataset."
|
||||
)
|
||||
|
||||
if getattr(self.dataset, "supports_fetch_outside_dataloader", True):
|
||||
return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0]])
|
||||
else:
|
||||
return "DUMMY"
|
||||
|
||||
def __len__(self):
|
||||
return int(math.ceil(len(self.frozen_batches) / float(self.num_shards)))
|
||||
|
||||
@property
|
||||
def n(self):
|
||||
return self.iterations_in_epoch
|
||||
|
||||
@property
|
||||
def next_epoch_idx(self):
|
||||
"""Return the epoch index after *next_epoch_itr* is called."""
|
||||
if self._next_epoch_itr is not None:
|
||||
return self.epoch
|
||||
elif self._cur_epoch_itr is not None and self.end_of_epoch():
|
||||
return self.epoch + 1
|
||||
else:
|
||||
return self.epoch
|
||||
|
||||
def next_epoch_itr(
|
||||
self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True
|
||||
):
|
||||
"""Return a new iterator over the dataset.
|
||||
|
||||
Args:
|
||||
shuffle (bool, optional): shuffle batches before returning the
|
||||
iterator (default: True).
|
||||
fix_batches_to_gpus (bool, optional): ensure that batches are always
|
||||
allocated to the same shards across epochs. Requires
|
||||
that :attr:`dataset` supports prefetching (default: False).
|
||||
set_dataset_epoch (bool, optional): update the wrapped Dataset with
|
||||
the new epoch number (default: True).
|
||||
"""
|
||||
if self.disable_shuffling:
|
||||
shuffle = False
|
||||
prev_epoch = self.epoch
|
||||
self.epoch = self.next_epoch_idx
|
||||
if set_dataset_epoch and hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(self.epoch)
|
||||
if self._next_epoch_itr is not None:
|
||||
self._cur_epoch_itr = self._next_epoch_itr
|
||||
self._next_epoch_itr = None
|
||||
else:
|
||||
if callable(self.batch_sampler) and prev_epoch != self.epoch:
|
||||
# reset _frozen_batches to refresh the next epoch
|
||||
self._frozen_batches = None
|
||||
self._cur_epoch_itr = self._get_iterator_for_epoch(
|
||||
self.epoch,
|
||||
shuffle,
|
||||
fix_batches_to_gpus=fix_batches_to_gpus,
|
||||
)
|
||||
self.shuffle = shuffle
|
||||
return self._cur_epoch_itr
|
||||
|
||||
def end_of_epoch(self) -> bool:
|
||||
"""Returns whether the most recent epoch iterator has been exhausted"""
|
||||
return not self._cur_epoch_itr.has_next()
|
||||
|
||||
@property
|
||||
def iterations_in_epoch(self):
|
||||
"""The number of consumed batches in the current epoch."""
|
||||
if self._cur_epoch_itr is not None:
|
||||
return self._cur_epoch_itr.n
|
||||
elif self._next_epoch_itr is not None:
|
||||
return self._next_epoch_itr.n
|
||||
return 0
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns a dictionary containing a whole state of the iterator."""
|
||||
if self.end_of_epoch():
|
||||
epoch = self.epoch + 1
|
||||
iter_in_epoch = 0
|
||||
else:
|
||||
epoch = self.epoch
|
||||
iter_in_epoch = self.iterations_in_epoch
|
||||
return {
|
||||
"version": 2,
|
||||
"epoch": epoch,
|
||||
"iterations_in_epoch": iter_in_epoch,
|
||||
"shuffle": self.shuffle,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Copies the state of the iterator from the given *state_dict*."""
|
||||
self.epoch = state_dict["epoch"]
|
||||
itr_pos = state_dict.get("iterations_in_epoch", 0)
|
||||
version = state_dict.get("version", 1)
|
||||
if itr_pos > 0:
|
||||
# fast-forward epoch iterator
|
||||
self._next_epoch_itr = self._get_iterator_for_epoch(
|
||||
self.epoch,
|
||||
shuffle=state_dict.get("shuffle", True),
|
||||
offset=itr_pos,
|
||||
)
|
||||
if self._next_epoch_itr is None:
|
||||
if version == 1:
|
||||
# legacy behavior: we finished the epoch, increment epoch counter
|
||||
self.epoch += 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Cannot resume training due to dataloader mismatch, please "
|
||||
"report this to the fairseq developers. You can relaunch "
|
||||
"training with `--reset-dataloader` and it should work."
|
||||
)
|
||||
else:
|
||||
self._next_epoch_itr = None
|
||||
|
||||
def _get_iterator_for_epoch(
|
||||
self, epoch, shuffle, fix_batches_to_gpus=False, offset=0
|
||||
):
|
||||
if self.reuse_dataloader and self.dataloader is not None:
|
||||
self.epoch_batch_sampler.make_batches_for_epoch(epoch, offset)
|
||||
itr = self.dataloader
|
||||
else:
|
||||
self.epoch_batch_sampler = FrozenBatchSampler(
|
||||
self.ordered_batches,
|
||||
epoch,
|
||||
fix_batches_to_gpus,
|
||||
shuffle,
|
||||
initial_offset=offset,
|
||||
)
|
||||
|
||||
if offset > 0 and len(self.epoch_batch_sampler) == 0:
|
||||
return None
|
||||
|
||||
if self.num_workers > 0:
|
||||
os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"
|
||||
|
||||
# Create data loader
|
||||
itr = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_sampler=self.epoch_batch_sampler,
|
||||
num_workers=self.num_workers,
|
||||
timeout=self.timeout,
|
||||
pin_memory=True,
|
||||
persistent_workers=self.persistent_workers,
|
||||
)
|
||||
|
||||
if self.reuse_dataloader:
|
||||
self.dataloader = itr
|
||||
|
||||
# Wrap with a BufferedIterator if needed
|
||||
if self.buffer_size > 0:
|
||||
itr = BufferedIterator(self.buffer_size, itr)
|
||||
|
||||
# Wrap with CountingIterator
|
||||
itr = CountingIterator(itr, start=offset)
|
||||
|
||||
if self.skip_remainder_batch:
|
||||
# TODO: Below is a lazy implementation which discard the final batch regardless
|
||||
# of whether it is a full batch or not.
|
||||
|
||||
total_num_itrs = len(self.epoch_batch_sampler) - 1
|
||||
itr.take(total_num_itrs)
|
||||
logger.info(f"skip final residual batch, total_num_itrs = {total_num_itrs}")
|
||||
|
||||
return itr
|
||||
|
||||
def ordered_batches(self, epoch, fix_batches_to_gpus, shuffle):
|
||||
def shuffle_batches(batches, seed):
|
||||
with data_utils.numpy_seed(seed):
|
||||
|
||||
if self.grouped_shuffling:
|
||||
grouped_batches = [
|
||||
batches[(i * self.num_shards) : ((i + 1) * self.num_shards)]
|
||||
for i in range((len(batches) // self.num_shards))
|
||||
]
|
||||
np.random.shuffle(grouped_batches)
|
||||
batches = list(itertools.chain(*grouped_batches))
|
||||
else:
|
||||
np.random.shuffle(batches)
|
||||
|
||||
return batches
|
||||
|
||||
if self._supports_prefetch:
|
||||
batches = self.frozen_batches
|
||||
|
||||
if shuffle and not fix_batches_to_gpus:
|
||||
batches = shuffle_batches(list(batches), self.seed + epoch)
|
||||
|
||||
batches = list(
|
||||
ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
|
||||
)
|
||||
self.dataset.prefetch([i for s in batches for i in s])
|
||||
|
||||
if shuffle and fix_batches_to_gpus:
|
||||
batches = shuffle_batches(batches, self.seed + epoch + self.shard_id)
|
||||
else:
|
||||
if shuffle:
|
||||
batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch)
|
||||
else:
|
||||
batches = self.frozen_batches
|
||||
batches = list(
|
||||
ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
|
||||
)
|
||||
return batches
|
||||
|
||||
|
||||
class GroupedIterator(CountingIterator):
|
||||
"""Wrapper around an iterable that returns groups (chunks) of items.
|
||||
|
||||
Args:
|
||||
iterable (iterable): iterable to wrap
|
||||
chunk_size (int): size of each chunk
|
||||
skip_remainder_batch (bool, optional): if set, discard the last grouped batch in
|
||||
each training epoch, as the last grouped batch is usually smaller than
|
||||
local_batch_size * distributed_word_size * chunk_size (default: ``False``).
|
||||
Attributes:
|
||||
n (int): number of elements consumed from this iterator
|
||||
"""
|
||||
|
||||
def __init__(self, iterable, chunk_size, skip_remainder_batch=False):
|
||||
if skip_remainder_batch:
|
||||
total_num_itrs = int(math.floor(len(iterable) / float(chunk_size)))
|
||||
logger.info(
|
||||
f"skip final residual batch, grouped total_num_itrs = {total_num_itrs}"
|
||||
)
|
||||
else:
|
||||
total_num_itrs = int(math.ceil(len(iterable) / float(chunk_size)))
|
||||
logger.info(f"grouped total_num_itrs = {total_num_itrs}")
|
||||
|
||||
itr = _chunk_iterator(iterable, chunk_size, skip_remainder_batch)
|
||||
super().__init__(
|
||||
itr,
|
||||
start=int(math.ceil(getattr(iterable, "n", 0) / float(chunk_size))),
|
||||
total=total_num_itrs,
|
||||
)
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
if skip_remainder_batch:
|
||||
self.take(total_num_itrs)
|
||||
# TODO: [Hack] Here the grouped iterator modifies the base iterator size so that
|
||||
# training can move into the next epoch once the grouped iterator is exhausted.
|
||||
# Double-check this implementation in case unexpected behavior occurs.
|
||||
iterable.take(total_num_itrs * chunk_size)
|
||||
|
||||
|
||||
def _chunk_iterator(itr, chunk_size, skip_remainder_batch=False):
|
||||
chunk = []
|
||||
for x in itr:
|
||||
chunk.append(x)
|
||||
if len(chunk) == chunk_size:
|
||||
yield chunk
|
||||
chunk = []
|
||||
if not skip_remainder_batch and len(chunk) > 0:
|
||||
yield chunk
|
||||
|
||||
|
||||
class ShardedIterator(CountingIterator):
|
||||
"""A sharded wrapper around an iterable, padded to length.
|
||||
|
||||
Args:
|
||||
iterable (iterable): iterable to wrap
|
||||
num_shards (int): number of shards to split the iterable into
|
||||
shard_id (int): which shard to iterator over
|
||||
fill_value (Any, optional): padding value when the iterable doesn't
|
||||
evenly divide *num_shards* (default: None).
|
||||
|
||||
Attributes:
|
||||
n (int): number of elements consumed from this iterator
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, iterable, num_shards, shard_id, fill_value=None, skip_remainder_batch=None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
skip_remainder_batch: ignored"""
|
||||
if shard_id < 0 or shard_id >= num_shards:
|
||||
raise ValueError("shard_id must be between 0 and num_shards")
|
||||
sharded_len = int(math.ceil(len(iterable) / float(num_shards)))
|
||||
itr = map(
|
||||
operator.itemgetter(1),
|
||||
itertools.zip_longest(
|
||||
range(sharded_len),
|
||||
itertools.islice(iterable, shard_id, len(iterable), num_shards),
|
||||
fillvalue=fill_value,
|
||||
),
|
||||
)
|
||||
super().__init__(
|
||||
itr,
|
||||
start=int(math.ceil(getattr(iterable, "n", 0) / float(num_shards))),
|
||||
total=sharded_len,
|
||||
)
|
||||
|
||||
|
||||
class BackgroundConsumer(Thread):
|
||||
def __init__(self, queue, source, max_len, cuda_device):
|
||||
Thread.__init__(self)
|
||||
|
||||
self._queue = queue
|
||||
self._source = source
|
||||
self._max_len = max_len
|
||||
self.count = 0
|
||||
self.cuda_device = cuda_device
|
||||
|
||||
def run(self):
|
||||
# set_device to avoid creation of GPU0 context when using pin_memory
|
||||
if self.cuda_device is not None:
|
||||
torch.cuda.set_device(self.cuda_device)
|
||||
|
||||
try:
|
||||
for item in self._source:
|
||||
self._queue.put(item)
|
||||
|
||||
# Stop if we reached the maximum length
|
||||
self.count += 1
|
||||
if self._max_len is not None and self.count >= self._max_len:
|
||||
break
|
||||
|
||||
# Signal the consumer we are done.
|
||||
self._queue.put(_sentinel)
|
||||
except Exception as e:
|
||||
self._queue.put(e)
|
||||
|
||||
|
||||
class BufferedIterator(object):
|
||||
def __init__(self, size, iterable):
|
||||
self._queue = queue.Queue(size)
|
||||
self._iterable = iterable
|
||||
self._consumer = None
|
||||
|
||||
self.start_time = time.time()
|
||||
self.warning_time = None
|
||||
|
||||
self.total = len(iterable)
|
||||
|
||||
def _create_consumer(self):
|
||||
self._consumer = BackgroundConsumer(
|
||||
self._queue,
|
||||
self._iterable,
|
||||
self.total,
|
||||
torch.cuda.current_device() if torch.cuda.is_available() else None,
|
||||
)
|
||||
self._consumer.daemon = True
|
||||
self._consumer.start()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return self.total
|
||||
|
||||
def take(self, n):
|
||||
self.total = min(self.total, n)
|
||||
# Propagate this change to the underlying iterator
|
||||
if hasattr(self._iterable, "take"):
|
||||
self._iterable.take(n)
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
# Create consumer if not created yet
|
||||
if self._consumer is None:
|
||||
self._create_consumer()
|
||||
|
||||
# Notify the user if there is a data loading bottleneck
|
||||
if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)):
|
||||
if time.time() - self.start_time > 5 * 60:
|
||||
if (
|
||||
self.warning_time is None
|
||||
or time.time() - self.warning_time > 15 * 60
|
||||
):
|
||||
logger.debug(
|
||||
"Data loading buffer is empty or nearly empty. This may "
|
||||
"indicate a data loading bottleneck, and increasing the "
|
||||
"number of workers (--num-workers) may help."
|
||||
)
|
||||
self.warning_time = time.time()
|
||||
|
||||
# Get next example
|
||||
item = self._queue.get(True)
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
if item is _sentinel:
|
||||
raise StopIteration()
|
||||
return item
|
||||
|
||||
|
||||
class GroupedEpochBatchIterator(EpochBatchIterator):
|
||||
"""Grouped version of EpochBatchIterator
|
||||
It takes several samplers from different datasets.
|
||||
Each epoch shuffle the dataset wise sampler individually with different
|
||||
random seed. The those sub samplers are combined with into
|
||||
one big samplers with deterministic permutation to mix batches from
|
||||
different datasets. It will act like EpochBatchIterator but make sure
|
||||
1) data from one data set each time
|
||||
2) for different workers, they use the same order to fetch the data
|
||||
so they will use data from the same dataset everytime
|
||||
mult_rate is used for update_freq > 1 case where we want to make sure update_freq
|
||||
mini-batches come from same source
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
collate_fn,
|
||||
batch_samplers,
|
||||
seed=1,
|
||||
num_shards=1,
|
||||
shard_id=0,
|
||||
num_workers=0,
|
||||
epoch=0,
|
||||
mult_rate=1,
|
||||
buffer_size=0,
|
||||
skip_remainder_batch=False,
|
||||
reuse_dataloader=False,
|
||||
persistent_workers=False,
|
||||
):
|
||||
super().__init__(
|
||||
dataset,
|
||||
collate_fn,
|
||||
batch_samplers,
|
||||
seed,
|
||||
num_shards,
|
||||
shard_id,
|
||||
num_workers,
|
||||
epoch,
|
||||
buffer_size,
|
||||
skip_remainder_batch=skip_remainder_batch,
|
||||
reuse_dataloader=reuse_dataloader,
|
||||
persistent_workers=persistent_workers,
|
||||
)
|
||||
# level 0: sub-samplers 1: batch_idx 2: batches
|
||||
self._frozen_batches = tuple([tuple(sub_batch) for sub_batch in batch_samplers])
|
||||
self.step_size = mult_rate * num_shards
|
||||
|
||||
self.lengths = [
|
||||
(len(x) // self.step_size) * self.step_size for x in self.frozen_batches
|
||||
]
|
||||
|
||||
def __len__(self):
|
||||
return sum(self.lengths)
|
||||
|
||||
@property
|
||||
def first_batch(self):
|
||||
if len(self.frozen_batches) == 0:
|
||||
raise Exception(
|
||||
"The dataset is empty. This could indicate "
|
||||
"that all elements in the dataset have been skipped. "
|
||||
"Try increasing the max number of allowed tokens or using "
|
||||
"a larger dataset."
|
||||
)
|
||||
|
||||
if self.dataset.supports_fetch_outside_dataloader:
|
||||
return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0][0]])
|
||||
else:
|
||||
return "DUMMY"
|
||||
|
||||
def _get_iterator_for_epoch(
|
||||
self, epoch, shuffle, fix_batches_to_gpus=False, offset=0
|
||||
):
|
||||
def shuffle_batches(batches, seed):
|
||||
with data_utils.numpy_seed(seed):
|
||||
np.random.shuffle(batches)
|
||||
return batches
|
||||
|
||||
def return_full_batches(batch_sets, seed, shuffle):
|
||||
if shuffle:
|
||||
batch_sets = [shuffle_batches(list(x), seed) for x in batch_sets]
|
||||
|
||||
batch_sets = [
|
||||
batch_sets[i][: self.lengths[i]] for i in range(len(batch_sets))
|
||||
]
|
||||
batches = list(itertools.chain.from_iterable(batch_sets))
|
||||
|
||||
if shuffle:
|
||||
with data_utils.numpy_seed(seed):
|
||||
idx = np.random.permutation(len(batches) // self.step_size)
|
||||
if len(idx) * self.step_size != len(batches):
|
||||
raise ValueError(
|
||||
"ERROR: %d %d %d %d"
|
||||
% (len(idx), self.step_size, len(batches), self.shard_id),
|
||||
":".join(["%d" % x for x in self.lengths]),
|
||||
)
|
||||
mini_shards = [
|
||||
batches[i * self.step_size : (i + 1) * self.step_size]
|
||||
for i in idx
|
||||
]
|
||||
batches = list(itertools.chain.from_iterable(mini_shards))
|
||||
|
||||
return batches
|
||||
|
||||
if self._supports_prefetch:
|
||||
raise NotImplementedError("To be implemented")
|
||||
else:
|
||||
batches = return_full_batches(
|
||||
self.frozen_batches, self.seed + epoch, shuffle
|
||||
)
|
||||
batches = list(
|
||||
ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
|
||||
)
|
||||
|
||||
if offset > 0 and offset >= len(batches):
|
||||
return None
|
||||
|
||||
if self.num_workers > 0:
|
||||
os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"
|
||||
|
||||
itr = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_sampler=batches[offset:],
|
||||
num_workers=self.num_workers,
|
||||
persistent_workers=self.persistent_workers,
|
||||
)
|
||||
if self.buffer_size > 0:
|
||||
itr = BufferedIterator(self.buffer_size, itr)
|
||||
|
||||
return CountingIterator(itr, start=offset)
|
||||
477
modules/voice_conversion/fairseq/data/language_pair_dataset.py
Normal file
477
modules/voice_conversion/fairseq/data/language_pair_dataset.py
Normal file
@@ -0,0 +1,477 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import FairseqDataset, data_utils
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def collate(
|
||||
samples,
|
||||
pad_idx,
|
||||
eos_idx,
|
||||
left_pad_source=True,
|
||||
left_pad_target=False,
|
||||
input_feeding=True,
|
||||
pad_to_length=None,
|
||||
pad_to_multiple=1,
|
||||
):
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
|
||||
return data_utils.collate_tokens(
|
||||
[s[key] for s in samples],
|
||||
pad_idx,
|
||||
eos_idx,
|
||||
left_pad,
|
||||
move_eos_to_beginning,
|
||||
pad_to_length=pad_to_length,
|
||||
pad_to_multiple=pad_to_multiple,
|
||||
)
|
||||
|
||||
def check_alignment(alignment, src_len, tgt_len):
|
||||
if alignment is None or len(alignment) == 0:
|
||||
return False
|
||||
if (
|
||||
alignment[:, 0].max().item() >= src_len - 1
|
||||
or alignment[:, 1].max().item() >= tgt_len - 1
|
||||
):
|
||||
logger.warning("alignment size mismatch found, skipping alignment!")
|
||||
return False
|
||||
return True
|
||||
|
||||
def compute_alignment_weights(alignments):
|
||||
"""
|
||||
Given a tensor of shape [:, 2] containing the source-target indices
|
||||
corresponding to the alignments, a weight vector containing the
|
||||
inverse frequency of each target index is computed.
|
||||
For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
|
||||
a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
|
||||
index 3 is repeated twice)
|
||||
"""
|
||||
align_tgt = alignments[:, 1]
|
||||
_, align_tgt_i, align_tgt_c = torch.unique(
|
||||
align_tgt, return_inverse=True, return_counts=True
|
||||
)
|
||||
align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
|
||||
return 1.0 / align_weights.float()
|
||||
|
||||
id = torch.LongTensor([s["id"] for s in samples])
|
||||
src_tokens = merge(
|
||||
"source",
|
||||
left_pad=left_pad_source,
|
||||
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
|
||||
)
|
||||
# sort by descending source length
|
||||
src_lengths = torch.LongTensor(
|
||||
[s["source"].ne(pad_idx).long().sum() for s in samples]
|
||||
)
|
||||
src_lengths, sort_order = src_lengths.sort(descending=True)
|
||||
id = id.index_select(0, sort_order)
|
||||
src_tokens = src_tokens.index_select(0, sort_order)
|
||||
|
||||
prev_output_tokens = None
|
||||
target = None
|
||||
if samples[0].get("target", None) is not None:
|
||||
target = merge(
|
||||
"target",
|
||||
left_pad=left_pad_target,
|
||||
pad_to_length=pad_to_length["target"]
|
||||
if pad_to_length is not None
|
||||
else None,
|
||||
)
|
||||
target = target.index_select(0, sort_order)
|
||||
tgt_lengths = torch.LongTensor(
|
||||
[s["target"].ne(pad_idx).long().sum() for s in samples]
|
||||
).index_select(0, sort_order)
|
||||
ntokens = tgt_lengths.sum().item()
|
||||
|
||||
if samples[0].get("prev_output_tokens", None) is not None:
|
||||
prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target)
|
||||
elif input_feeding:
|
||||
# we create a shifted version of targets for feeding the
|
||||
# previous output token(s) into the next decoder step
|
||||
prev_output_tokens = merge(
|
||||
"target",
|
||||
left_pad=left_pad_target,
|
||||
move_eos_to_beginning=True,
|
||||
pad_to_length=pad_to_length["target"]
|
||||
if pad_to_length is not None
|
||||
else None,
|
||||
)
|
||||
else:
|
||||
ntokens = src_lengths.sum().item()
|
||||
|
||||
batch = {
|
||||
"id": id,
|
||||
"nsentences": len(samples),
|
||||
"ntokens": ntokens,
|
||||
"net_input": {
|
||||
"src_tokens": src_tokens,
|
||||
"src_lengths": src_lengths,
|
||||
},
|
||||
"target": target,
|
||||
}
|
||||
if prev_output_tokens is not None:
|
||||
batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(
|
||||
0, sort_order
|
||||
)
|
||||
|
||||
if samples[0].get("alignment", None) is not None:
|
||||
bsz, tgt_sz = batch["target"].shape
|
||||
src_sz = batch["net_input"]["src_tokens"].shape[1]
|
||||
|
||||
offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
|
||||
offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz
|
||||
if left_pad_source:
|
||||
offsets[:, 0] += src_sz - src_lengths
|
||||
if left_pad_target:
|
||||
offsets[:, 1] += tgt_sz - tgt_lengths
|
||||
|
||||
alignments = [
|
||||
alignment + offset
|
||||
for align_idx, offset, src_len, tgt_len in zip(
|
||||
sort_order, offsets, src_lengths, tgt_lengths
|
||||
)
|
||||
for alignment in [samples[align_idx]["alignment"].view(-1, 2)]
|
||||
if check_alignment(alignment, src_len, tgt_len)
|
||||
]
|
||||
|
||||
if len(alignments) > 0:
|
||||
alignments = torch.cat(alignments, dim=0)
|
||||
align_weights = compute_alignment_weights(alignments)
|
||||
|
||||
batch["alignments"] = alignments
|
||||
batch["align_weights"] = align_weights
|
||||
|
||||
if samples[0].get("constraints", None) is not None:
|
||||
# Collate the packed constraints across the samples, padding to
|
||||
# the length of the longest sample.
|
||||
lens = [sample.get("constraints").size(0) for sample in samples]
|
||||
max_len = max(lens)
|
||||
constraints = torch.zeros((len(samples), max(lens))).long()
|
||||
for i, sample in enumerate(samples):
|
||||
constraints[i, 0 : lens[i]] = samples[i].get("constraints")
|
||||
batch["constraints"] = constraints.index_select(0, sort_order)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class LanguagePairDataset(FairseqDataset):
|
||||
"""
|
||||
A pair of torch.utils.data.Datasets.
|
||||
|
||||
Args:
|
||||
src (torch.utils.data.Dataset): source dataset to wrap
|
||||
src_sizes (List[int]): source sentence lengths
|
||||
src_dict (~fairseq.data.Dictionary): source vocabulary
|
||||
tgt (torch.utils.data.Dataset, optional): target dataset to wrap
|
||||
tgt_sizes (List[int], optional): target sentence lengths
|
||||
tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
|
||||
left_pad_source (bool, optional): pad source tensors on the left side
|
||||
(default: True).
|
||||
left_pad_target (bool, optional): pad target tensors on the left side
|
||||
(default: False).
|
||||
shuffle (bool, optional): shuffle dataset elements before batching
|
||||
(default: True).
|
||||
input_feeding (bool, optional): create a shifted version of the targets
|
||||
to be passed into the model for teacher forcing (default: True).
|
||||
remove_eos_from_source (bool, optional): if set, removes eos from end
|
||||
of source if it's present (default: False).
|
||||
append_eos_to_target (bool, optional): if set, appends eos to end of
|
||||
target if it's absent (default: False).
|
||||
align_dataset (torch.utils.data.Dataset, optional): dataset
|
||||
containing alignments.
|
||||
constraints (Tensor, optional): 2d tensor with a concatenated, zero-
|
||||
delimited list of constraints for each sentence.
|
||||
append_bos (bool, optional): if set, appends bos to the beginning of
|
||||
source/target sentence.
|
||||
num_buckets (int, optional): if set to a value greater than 0, then
|
||||
batches will be bucketed into the given number of batch shapes.
|
||||
src_lang_id (int, optional): source language ID, if set, the collated batch
|
||||
will contain a field 'src_lang_id' in 'net_input' which indicates the
|
||||
source language of the samples.
|
||||
tgt_lang_id (int, optional): target language ID, if set, the collated batch
|
||||
will contain a field 'tgt_lang_id' which indicates the target language
|
||||
of the samples.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
src,
|
||||
src_sizes,
|
||||
src_dict,
|
||||
tgt=None,
|
||||
tgt_sizes=None,
|
||||
tgt_dict=None,
|
||||
left_pad_source=True,
|
||||
left_pad_target=False,
|
||||
shuffle=True,
|
||||
input_feeding=True,
|
||||
remove_eos_from_source=False,
|
||||
append_eos_to_target=False,
|
||||
align_dataset=None,
|
||||
constraints=None,
|
||||
append_bos=False,
|
||||
eos=None,
|
||||
num_buckets=0,
|
||||
src_lang_id=None,
|
||||
tgt_lang_id=None,
|
||||
pad_to_multiple=1,
|
||||
):
|
||||
if tgt_dict is not None:
|
||||
assert src_dict.pad() == tgt_dict.pad()
|
||||
assert src_dict.eos() == tgt_dict.eos()
|
||||
assert src_dict.unk() == tgt_dict.unk()
|
||||
if tgt is not None:
|
||||
assert len(src) == len(
|
||||
tgt
|
||||
), "Source and target must contain the same number of examples"
|
||||
self.src = src
|
||||
self.tgt = tgt
|
||||
self.src_sizes = np.array(src_sizes)
|
||||
self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
|
||||
self.sizes = (
|
||||
np.vstack((self.src_sizes, self.tgt_sizes)).T
|
||||
if self.tgt_sizes is not None
|
||||
else self.src_sizes
|
||||
)
|
||||
self.src_dict = src_dict
|
||||
self.tgt_dict = tgt_dict
|
||||
self.left_pad_source = left_pad_source
|
||||
self.left_pad_target = left_pad_target
|
||||
self.shuffle = shuffle
|
||||
self.input_feeding = input_feeding
|
||||
self.remove_eos_from_source = remove_eos_from_source
|
||||
self.append_eos_to_target = append_eos_to_target
|
||||
self.align_dataset = align_dataset
|
||||
if self.align_dataset is not None:
|
||||
assert (
|
||||
self.tgt_sizes is not None
|
||||
), "Both source and target needed when alignments are provided"
|
||||
self.constraints = constraints
|
||||
self.append_bos = append_bos
|
||||
self.eos = eos if eos is not None else src_dict.eos()
|
||||
self.src_lang_id = src_lang_id
|
||||
self.tgt_lang_id = tgt_lang_id
|
||||
if num_buckets > 0:
|
||||
from fairseq.data import BucketPadLengthDataset
|
||||
|
||||
self.src = BucketPadLengthDataset(
|
||||
self.src,
|
||||
sizes=self.src_sizes,
|
||||
num_buckets=num_buckets,
|
||||
pad_idx=self.src_dict.pad(),
|
||||
left_pad=self.left_pad_source,
|
||||
)
|
||||
self.src_sizes = self.src.sizes
|
||||
logger.info("bucketing source lengths: {}".format(list(self.src.buckets)))
|
||||
if self.tgt is not None:
|
||||
self.tgt = BucketPadLengthDataset(
|
||||
self.tgt,
|
||||
sizes=self.tgt_sizes,
|
||||
num_buckets=num_buckets,
|
||||
pad_idx=self.tgt_dict.pad(),
|
||||
left_pad=self.left_pad_target,
|
||||
)
|
||||
self.tgt_sizes = self.tgt.sizes
|
||||
logger.info(
|
||||
"bucketing target lengths: {}".format(list(self.tgt.buckets))
|
||||
)
|
||||
|
||||
# determine bucket sizes using self.num_tokens, which will return
|
||||
# the padded lengths (thanks to BucketPadLengthDataset)
|
||||
num_tokens = np.vectorize(self.num_tokens, otypes=[np.compat.long])
|
||||
self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
|
||||
self.buckets = [
|
||||
(None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens)
|
||||
]
|
||||
else:
|
||||
self.buckets = None
|
||||
self.pad_to_multiple = pad_to_multiple
|
||||
|
||||
def get_batch_shapes(self):
|
||||
return self.buckets
|
||||
|
||||
def __getitem__(self, index):
|
||||
tgt_item = self.tgt[index] if self.tgt is not None else None
|
||||
src_item = self.src[index]
|
||||
# Append EOS to end of tgt sentence if it does not have an EOS and remove
|
||||
# EOS from end of src sentence if it exists. This is useful when we use
|
||||
# use existing datasets for opposite directions i.e., when we want to
|
||||
# use tgt_dataset as src_dataset and vice versa
|
||||
if self.append_eos_to_target:
|
||||
eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
|
||||
if self.tgt and self.tgt[index][-1] != eos:
|
||||
tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
|
||||
|
||||
if self.append_bos:
|
||||
bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
|
||||
if self.tgt and self.tgt[index][0] != bos:
|
||||
tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])
|
||||
|
||||
bos = self.src_dict.bos()
|
||||
if self.src[index][0] != bos:
|
||||
src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])
|
||||
|
||||
if self.remove_eos_from_source:
|
||||
eos = self.src_dict.eos()
|
||||
if self.src[index][-1] == eos:
|
||||
src_item = self.src[index][:-1]
|
||||
|
||||
example = {
|
||||
"id": index,
|
||||
"source": src_item,
|
||||
"target": tgt_item,
|
||||
}
|
||||
if self.align_dataset is not None:
|
||||
example["alignment"] = self.align_dataset[index]
|
||||
if self.constraints is not None:
|
||||
example["constraints"] = self.constraints[index]
|
||||
return example
|
||||
|
||||
def __len__(self):
|
||||
return len(self.src)
|
||||
|
||||
def collater(self, samples, pad_to_length=None):
|
||||
"""Merge a list of samples to form a mini-batch.
|
||||
|
||||
Args:
|
||||
samples (List[dict]): samples to collate
|
||||
pad_to_length (dict, optional): a dictionary of
|
||||
{'source': source_pad_to_length, 'target': target_pad_to_length}
|
||||
to indicate the max length to pad to in source and target respectively.
|
||||
|
||||
Returns:
|
||||
dict: a mini-batch with the following keys:
|
||||
|
||||
- `id` (LongTensor): example IDs in the original input order
|
||||
- `ntokens` (int): total number of tokens in the batch
|
||||
- `net_input` (dict): the input to the Model, containing keys:
|
||||
|
||||
- `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
|
||||
the source sentence of shape `(bsz, src_len)`. Padding will
|
||||
appear on the left if *left_pad_source* is ``True``.
|
||||
- `src_lengths` (LongTensor): 1D Tensor of the unpadded
|
||||
lengths of each source sentence of shape `(bsz)`
|
||||
- `prev_output_tokens` (LongTensor): a padded 2D Tensor of
|
||||
tokens in the target sentence, shifted right by one
|
||||
position for teacher forcing, of shape `(bsz, tgt_len)`.
|
||||
This key will not be present if *input_feeding* is
|
||||
``False``. Padding will appear on the left if
|
||||
*left_pad_target* is ``True``.
|
||||
- `src_lang_id` (LongTensor): a long Tensor which contains source
|
||||
language IDs of each sample in the batch
|
||||
|
||||
- `target` (LongTensor): a padded 2D Tensor of tokens in the
|
||||
target sentence of shape `(bsz, tgt_len)`. Padding will appear
|
||||
on the left if *left_pad_target* is ``True``.
|
||||
- `tgt_lang_id` (LongTensor): a long Tensor which contains target language
|
||||
IDs of each sample in the batch
|
||||
"""
|
||||
res = collate(
|
||||
samples,
|
||||
pad_idx=self.src_dict.pad(),
|
||||
eos_idx=self.eos,
|
||||
left_pad_source=self.left_pad_source,
|
||||
left_pad_target=self.left_pad_target,
|
||||
input_feeding=self.input_feeding,
|
||||
pad_to_length=pad_to_length,
|
||||
pad_to_multiple=self.pad_to_multiple,
|
||||
)
|
||||
if self.src_lang_id is not None or self.tgt_lang_id is not None:
|
||||
src_tokens = res["net_input"]["src_tokens"]
|
||||
bsz = src_tokens.size(0)
|
||||
if self.src_lang_id is not None:
|
||||
res["net_input"]["src_lang_id"] = (
|
||||
torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens)
|
||||
)
|
||||
if self.tgt_lang_id is not None:
|
||||
res["tgt_lang_id"] = (
|
||||
torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens)
|
||||
)
|
||||
return res
|
||||
|
||||
def num_tokens(self, index):
|
||||
"""Return the number of tokens in a sample. This value is used to
|
||||
enforce ``--max-tokens`` during batching."""
|
||||
return max(
|
||||
self.src_sizes[index],
|
||||
self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
|
||||
)
|
||||
|
||||
def num_tokens_vec(self, indices):
|
||||
"""Return the number of tokens for a set of positions defined by indices.
|
||||
This value is used to enforce ``--max-tokens`` during batching."""
|
||||
sizes = self.src_sizes[indices]
|
||||
if self.tgt_sizes is not None:
|
||||
sizes = np.maximum(sizes, self.tgt_sizes[indices])
|
||||
return sizes
|
||||
|
||||
def size(self, index):
|
||||
"""Return an example's size as a float or tuple. This value is used when
|
||||
filtering a dataset with ``--max-positions``."""
|
||||
return (
|
||||
self.src_sizes[index],
|
||||
self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
|
||||
)
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
if self.shuffle:
|
||||
indices = np.random.permutation(len(self)).astype(np.int64)
|
||||
else:
|
||||
indices = np.arange(len(self), dtype=np.int64)
|
||||
if self.buckets is None:
|
||||
# sort by target length, then source length
|
||||
if self.tgt_sizes is not None:
|
||||
indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
|
||||
return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]
|
||||
else:
|
||||
# sort by bucketed_num_tokens, which is:
|
||||
# max(padded_src_len, padded_tgt_len)
|
||||
return indices[
|
||||
np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")
|
||||
]
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return getattr(self.src, "supports_prefetch", False) and (
|
||||
getattr(self.tgt, "supports_prefetch", False) or self.tgt is None
|
||||
)
|
||||
|
||||
def prefetch(self, indices):
|
||||
self.src.prefetch(indices)
|
||||
if self.tgt is not None:
|
||||
self.tgt.prefetch(indices)
|
||||
if self.align_dataset is not None:
|
||||
self.align_dataset.prefetch(indices)
|
||||
|
||||
def filter_indices_by_size(self, indices, max_sizes):
|
||||
"""Filter a list of sample indices. Remove those that are longer
|
||||
than specified in max_sizes.
|
||||
|
||||
Args:
|
||||
indices (np.array): original array of sample indices
|
||||
max_sizes (int or list[int] or tuple[int]): max sample size,
|
||||
can be defined separately for src and tgt (then list or tuple)
|
||||
|
||||
Returns:
|
||||
np.array: filtered sample array
|
||||
list: list of removed indices
|
||||
"""
|
||||
return data_utils.filter_paired_dataset_indices_by_size(
|
||||
self.src_sizes,
|
||||
self.tgt_sizes,
|
||||
indices,
|
||||
max_sizes,
|
||||
)
|
||||
16
modules/voice_conversion/fairseq/data/legacy/__init__.py
Normal file
16
modules/voice_conversion/fairseq/data/legacy/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .block_pair_dataset import BlockPairDataset
|
||||
from .masked_lm_dataset import MaskedLMDataset
|
||||
from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BertDictionary",
|
||||
"BlockPairDataset",
|
||||
"MaskedLMDataset",
|
||||
"MaskedLMDictionary",
|
||||
]
|
||||
@@ -0,0 +1,311 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import FairseqDataset
|
||||
|
||||
|
||||
class BlockPairDataset(FairseqDataset):
|
||||
"""Break a Dataset of tokens into sentence pair blocks for next sentence
|
||||
prediction as well as masked language model.
|
||||
|
||||
High-level logics are:
|
||||
1. break input tensor to tensor blocks
|
||||
2. pair the blocks with 50% next sentence and 50% random sentence
|
||||
3. return paired blocks as well as related segment labels
|
||||
|
||||
Args:
|
||||
dataset (~torch.utils.data.Dataset): dataset to break into blocks
|
||||
sizes: array of sentence lengths
|
||||
dictionary: dictionary for the task
|
||||
block_size: maximum block size
|
||||
break_mode: mode for breaking copurs into block pairs. currently we support
|
||||
2 modes
|
||||
doc: respect document boundaries and each part of the pair should belong to on document
|
||||
none: don't respect any boundary and cut tokens evenly
|
||||
short_seq_prob: probability for generating shorter block pairs
|
||||
doc_break_size: Size for empty line separating documents. Typically 1 if
|
||||
the sentences have eos, 0 otherwise.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
dictionary,
|
||||
sizes,
|
||||
block_size,
|
||||
break_mode="doc",
|
||||
short_seq_prob=0.1,
|
||||
doc_break_size=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
self.pad = dictionary.pad()
|
||||
self.eos = dictionary.eos()
|
||||
self.cls = dictionary.cls()
|
||||
self.mask = dictionary.mask()
|
||||
self.sep = dictionary.sep()
|
||||
self.break_mode = break_mode
|
||||
self.dictionary = dictionary
|
||||
self.short_seq_prob = short_seq_prob
|
||||
self.block_indices = []
|
||||
|
||||
assert len(dataset) == len(sizes)
|
||||
|
||||
if break_mode == "doc":
|
||||
cur_doc = []
|
||||
for sent_id, sz in enumerate(sizes):
|
||||
assert doc_break_size == 0 or sz != 0, (
|
||||
"when doc_break_size is non-zero, we expect documents to be"
|
||||
"separated by a blank line with a single eos."
|
||||
)
|
||||
# empty line as document separator
|
||||
if sz == doc_break_size:
|
||||
if len(cur_doc) == 0:
|
||||
continue
|
||||
self.block_indices.append(cur_doc)
|
||||
cur_doc = []
|
||||
else:
|
||||
cur_doc.append(sent_id)
|
||||
max_num_tokens = block_size - 3 # Account for [CLS], [SEP], [SEP]
|
||||
self.sent_pairs = []
|
||||
self.sizes = []
|
||||
for doc_id, doc in enumerate(self.block_indices):
|
||||
self._generate_sentence_pair(doc, doc_id, max_num_tokens, sizes)
|
||||
elif break_mode is None or break_mode == "none":
|
||||
# each block should have half of the block size since we are constructing block pair
|
||||
sent_length = (block_size - 3) // 2
|
||||
total_len = sum(dataset.sizes)
|
||||
length = math.ceil(total_len / sent_length)
|
||||
|
||||
def block_at(i):
|
||||
start = i * sent_length
|
||||
end = min(start + sent_length, total_len)
|
||||
return (start, end)
|
||||
|
||||
sent_indices = np.array([block_at(i) for i in range(length)])
|
||||
sent_sizes = np.array([e - s for s, e in sent_indices])
|
||||
dataset_index = self._sent_to_dataset_index(sent_sizes)
|
||||
|
||||
# pair sentences
|
||||
self._pair_sentences(dataset_index)
|
||||
else:
|
||||
raise ValueError("Invalid break_mode: " + break_mode)
|
||||
|
||||
def _pair_sentences(self, dataset_index):
|
||||
"""
|
||||
Give a list of evenly cut blocks/sentences, pair these sentences with 50%
|
||||
consecutive sentences and 50% random sentences.
|
||||
This is used for none break mode
|
||||
"""
|
||||
# pair sentences
|
||||
for sent_id, sent in enumerate(dataset_index):
|
||||
next_sent_label = (
|
||||
1 if np.random.rand() > 0.5 and sent_id != len(dataset_index) - 1 else 0
|
||||
)
|
||||
if next_sent_label:
|
||||
next_sent = dataset_index[sent_id + 1]
|
||||
else:
|
||||
next_sent = dataset_index[
|
||||
self._skip_sampling(len(dataset_index), [sent_id, sent_id + 1])
|
||||
]
|
||||
self.sent_pairs.append((sent, next_sent, next_sent_label))
|
||||
|
||||
# The current blocks don't include the special tokens but the
|
||||
# sizes already account for this
|
||||
self.sizes.append(3 + sent[3] + next_sent[3])
|
||||
|
||||
def _sent_to_dataset_index(self, sent_sizes):
|
||||
"""
|
||||
Build index mapping block indices to the underlying dataset indices
|
||||
"""
|
||||
dataset_index = []
|
||||
ds_idx, ds_remaining = -1, 0
|
||||
for to_consume in sent_sizes:
|
||||
sent_size = to_consume
|
||||
if ds_remaining == 0:
|
||||
ds_idx += 1
|
||||
ds_remaining = sent_sizes[ds_idx]
|
||||
start_ds_idx = ds_idx
|
||||
start_offset = sent_sizes[ds_idx] - ds_remaining
|
||||
while to_consume > ds_remaining:
|
||||
to_consume -= ds_remaining
|
||||
ds_idx += 1
|
||||
ds_remaining = sent_sizes[ds_idx]
|
||||
ds_remaining -= to_consume
|
||||
dataset_index.append(
|
||||
(
|
||||
start_ds_idx, # starting index in dataset
|
||||
start_offset, # starting offset within starting index
|
||||
ds_idx, # ending index in dataset
|
||||
sent_size, # sentence length
|
||||
)
|
||||
)
|
||||
assert ds_remaining == 0
|
||||
assert ds_idx == len(self.dataset) - 1
|
||||
return dataset_index
|
||||
|
||||
def _generate_sentence_pair(self, doc, doc_id, max_num_tokens, sizes):
|
||||
"""
|
||||
Go through a single document and genrate sentence paris from it
|
||||
"""
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
curr = 0
|
||||
# To provide more randomness, we decrease target seq length for parts of
|
||||
# samples (10% by default). Note that max_num_tokens is the hard threshold
|
||||
# for batching and will never be changed.
|
||||
target_seq_length = max_num_tokens
|
||||
if np.random.random() < self.short_seq_prob:
|
||||
target_seq_length = np.random.randint(2, max_num_tokens)
|
||||
# loop through all sentences in document
|
||||
while curr < len(doc):
|
||||
sent_id = doc[curr]
|
||||
current_chunk.append(sent_id)
|
||||
current_length = sum(sizes[current_chunk])
|
||||
# split chunk and generate pair when exceed target_seq_length or
|
||||
# finish the loop
|
||||
if curr == len(doc) - 1 or current_length >= target_seq_length:
|
||||
# split the chunk into 2 parts
|
||||
a_end = 1
|
||||
if len(current_chunk) > 2:
|
||||
a_end = np.random.randint(1, len(current_chunk) - 1)
|
||||
sent_a = current_chunk[:a_end]
|
||||
len_a = sum(sizes[sent_a])
|
||||
# generate next sentence label, note that if there is only 1 sentence
|
||||
# in current chunk, label is always 0
|
||||
next_sent_label = (
|
||||
1 if np.random.rand() > 0.5 and len(current_chunk) != 1 else 0
|
||||
)
|
||||
if not next_sent_label:
|
||||
# if next sentence label is 0, sample sent_b from a random doc
|
||||
target_b_length = target_seq_length - len_a
|
||||
rand_doc_id = self._skip_sampling(len(self.block_indices), [doc_id])
|
||||
random_doc = self.block_indices[rand_doc_id]
|
||||
random_start = np.random.randint(0, len(random_doc))
|
||||
sent_b = []
|
||||
len_b = 0
|
||||
for j in range(random_start, len(random_doc)):
|
||||
sent_b.append(random_doc[j])
|
||||
len_b = sum(sizes[sent_b])
|
||||
if len_b >= target_b_length:
|
||||
break
|
||||
# return the second part of the chunk since it's not used
|
||||
num_unused_segments = len(current_chunk) - a_end
|
||||
curr -= num_unused_segments
|
||||
else:
|
||||
# if next sentence label is 1, use the second part of chunk as sent_B
|
||||
sent_b = current_chunk[a_end:]
|
||||
len_b = sum(sizes[sent_b])
|
||||
# currently sent_a and sent_B may be longer than max_num_tokens,
|
||||
# truncate them and return block idx and offsets for them
|
||||
sent_a, sent_b = self._truncate_sentences(
|
||||
sent_a, sent_b, max_num_tokens
|
||||
)
|
||||
self.sent_pairs.append((sent_a, sent_b, next_sent_label))
|
||||
self.sizes.append(3 + sent_a[3] + sent_b[3])
|
||||
current_chunk = []
|
||||
curr += 1
|
||||
|
||||
def _skip_sampling(self, total, skip_ids):
|
||||
"""
|
||||
Generate a random integer which is not in skip_ids. Sample range is [0, total)
|
||||
TODO: ids in skip_ids should be consecutive, we can extend it to more generic version later
|
||||
"""
|
||||
rand_id = np.random.randint(total - len(skip_ids))
|
||||
return rand_id if rand_id < min(skip_ids) else rand_id + len(skip_ids)
|
||||
|
||||
def _truncate_sentences(self, sent_a, sent_b, max_num_tokens):
|
||||
"""
|
||||
Trancate a pair of sentence to limit total length under max_num_tokens
|
||||
Logics:
|
||||
1. Truncate longer sentence
|
||||
2. Tokens to be truncated could be at the beginning or the end of the sentnce
|
||||
Returns:
|
||||
Truncated sentences represented by dataset idx
|
||||
"""
|
||||
len_a, len_b = sum(self.dataset.sizes[sent_a]), sum(self.dataset.sizes[sent_b])
|
||||
front_cut_a = front_cut_b = end_cut_a = end_cut_b = 0
|
||||
|
||||
while True:
|
||||
total_length = (
|
||||
len_a + len_b - front_cut_a - front_cut_b - end_cut_a - end_cut_b
|
||||
)
|
||||
if total_length <= max_num_tokens:
|
||||
break
|
||||
|
||||
if len_a - front_cut_a - end_cut_a > len_b - front_cut_b - end_cut_b:
|
||||
if np.random.rand() < 0.5:
|
||||
front_cut_a += 1
|
||||
else:
|
||||
end_cut_a += 1
|
||||
else:
|
||||
if np.random.rand() < 0.5:
|
||||
front_cut_b += 1
|
||||
else:
|
||||
end_cut_b += 1
|
||||
|
||||
# calculate ds indices as well as offsets and return
|
||||
truncated_sent_a = self._cut_sentence(sent_a, front_cut_a, end_cut_a)
|
||||
truncated_sent_b = self._cut_sentence(sent_b, front_cut_b, end_cut_b)
|
||||
return truncated_sent_a, truncated_sent_b
|
||||
|
||||
def _cut_sentence(self, sent, front_cut, end_cut):
|
||||
"""
|
||||
Cut a sentence based on the numbers of tokens to be cut from beginning and end
|
||||
Represent the sentence as dataset idx and return
|
||||
"""
|
||||
start_ds_idx, end_ds_idx, offset = sent[0], sent[-1], 0
|
||||
target_len = sum(self.dataset.sizes[sent]) - front_cut - end_cut
|
||||
while front_cut > 0:
|
||||
if self.dataset.sizes[start_ds_idx] > front_cut:
|
||||
offset += front_cut
|
||||
break
|
||||
else:
|
||||
front_cut -= self.dataset.sizes[start_ds_idx]
|
||||
start_ds_idx += 1
|
||||
while end_cut > 0:
|
||||
if self.dataset.sizes[end_ds_idx] > end_cut:
|
||||
break
|
||||
else:
|
||||
end_cut -= self.dataset.sizes[end_ds_idx]
|
||||
end_ds_idx -= 1
|
||||
return start_ds_idx, offset, end_ds_idx, target_len
|
||||
|
||||
def _fetch_block(self, start_ds_idx, offset, end_ds_idx, length):
|
||||
"""
|
||||
Fetch a block of tokens based on its dataset idx
|
||||
"""
|
||||
buffer = torch.cat(
|
||||
[self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)]
|
||||
)
|
||||
s, e = offset, offset + length
|
||||
return buffer[s:e]
|
||||
|
||||
def __getitem__(self, index):
|
||||
block1, block2, next_sent_label = self.sent_pairs[index]
|
||||
block1 = self._fetch_block(*block1)
|
||||
block2 = self._fetch_block(*block2)
|
||||
return block1, block2, next_sent_label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sizes)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return getattr(self.dataset, "supports_prefetch", False)
|
||||
|
||||
def prefetch(self, indices):
|
||||
prefetch_idx = set()
|
||||
for index in indices:
|
||||
for block1, block2, _ in [self.sent_pairs[index]]:
|
||||
for ds_idx in range(block1[0], block1[2] + 1):
|
||||
prefetch_idx.add(ds_idx)
|
||||
for ds_idx in range(block2[0], block2[2] + 1):
|
||||
prefetch_idx.add(ds_idx)
|
||||
self.dataset.prefetch(prefetch_idx)
|
||||
@@ -0,0 +1,303 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import Dictionary, FairseqDataset, data_utils
|
||||
from fairseq.data.concat_dataset import ConcatDataset
|
||||
from fairseq.data.legacy.block_pair_dataset import BlockPairDataset
|
||||
from fairseq.data.token_block_dataset import TokenBlockDataset
|
||||
|
||||
|
||||
class MaskedLMDataset(FairseqDataset):
|
||||
"""
|
||||
A wrapper Dataset for masked language modelling. The dataset
|
||||
wraps around TokenBlockDataset or BlockedPairDataset and creates a batch
|
||||
where the input blocks are masked according to the specified masking
|
||||
probability. Additionally the batch can also contain sentence level targets
|
||||
if this is specified.
|
||||
|
||||
Args:
|
||||
dataset: Dataset which generates blocks of data. Only BlockPairDataset
|
||||
and TokenBlockDataset are supported.
|
||||
sizes: Sentence lengths
|
||||
vocab: Dictionary with the vocabulary and special tokens.
|
||||
pad_idx: Id of padding token in dictionary
|
||||
mask_idx: Id of mask token in dictionary
|
||||
classif_token_idx: Id of classification token in dictionary. This is the
|
||||
token associated with the sentence embedding (Eg: CLS for BERT)
|
||||
sep_token_idx: Id of separator token in dictionary
|
||||
(Eg: SEP in BERT)
|
||||
seed: Seed for random number generator for reproducibility.
|
||||
shuffle: Shuffle the elements before batching.
|
||||
has_pairs: Specifies whether the underlying dataset
|
||||
generates a pair of blocks along with a sentence_target or not.
|
||||
Setting it to True assumes that the underlying dataset generates a
|
||||
label for the pair of sentences which is surfaced as
|
||||
sentence_target. The default value assumes a single block with no
|
||||
sentence target.
|
||||
segment_id: An optional segment id for filling in the segment labels
|
||||
when we are in the single block setting (Eg: XLM). Default is 0.
|
||||
masking_ratio: specifies what percentage of the blocks should be masked.
|
||||
masking_prob: specifies the probability of a given token being
|
||||
replaced with the "MASK" token.
|
||||
random_token_prob: specifies the probability of a given token being
|
||||
replaced by a random token from the vocabulary.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: FairseqDataset,
|
||||
sizes: np.ndarray,
|
||||
vocab: Dictionary,
|
||||
pad_idx: int,
|
||||
mask_idx: int,
|
||||
classif_token_idx: int,
|
||||
sep_token_idx: int,
|
||||
seed: int = 1,
|
||||
shuffle: bool = True,
|
||||
has_pairs: bool = True,
|
||||
segment_id: int = 0,
|
||||
masking_ratio: float = 0.15,
|
||||
masking_prob: float = 0.8,
|
||||
random_token_prob: float = 0.1,
|
||||
):
|
||||
# Make sure the input datasets are the ones supported
|
||||
assert (
|
||||
isinstance(dataset, TokenBlockDataset)
|
||||
or isinstance(dataset, BlockPairDataset)
|
||||
or isinstance(dataset, ConcatDataset)
|
||||
), (
|
||||
"MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset or "
|
||||
"ConcatDataset"
|
||||
)
|
||||
|
||||
self.dataset = dataset
|
||||
self.sizes = np.array(sizes)
|
||||
self.vocab = vocab
|
||||
self.pad_idx = pad_idx
|
||||
self.mask_idx = mask_idx
|
||||
self.classif_token_idx = classif_token_idx
|
||||
self.sep_token_idx = sep_token_idx
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self.has_pairs = has_pairs
|
||||
self.segment_id = segment_id
|
||||
self.masking_ratio = masking_ratio
|
||||
self.masking_prob = masking_prob
|
||||
self.random_token_prob = random_token_prob
|
||||
|
||||
# If we have only one block then sizes needs to be updated to include
|
||||
# the classification token
|
||||
if not has_pairs:
|
||||
self.sizes = self.sizes + 1
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
# if has_pairs, then expect 2 blocks and a sentence target
|
||||
if self.has_pairs:
|
||||
(block_one, block_two, sentence_target) = self.dataset[index]
|
||||
else:
|
||||
block_one = self.dataset[index]
|
||||
|
||||
return {
|
||||
"id": index,
|
||||
"block_one": block_one,
|
||||
"block_two": block_two if self.has_pairs else None,
|
||||
"sentence_target": sentence_target if self.has_pairs else None,
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def _mask_block(
|
||||
self,
|
||||
sentence: np.ndarray,
|
||||
mask_idx: int,
|
||||
pad_idx: int,
|
||||
dictionary_token_range: Tuple,
|
||||
):
|
||||
"""
|
||||
Mask tokens for Masked Language Model training
|
||||
Samples mask_ratio tokens that will be predicted by LM.
|
||||
|
||||
Note:This function may not be efficient enough since we had multiple
|
||||
conversions between np and torch, we can replace them with torch
|
||||
operators later.
|
||||
|
||||
Args:
|
||||
sentence: 1d tensor to be masked
|
||||
mask_idx: index to use for masking the sentence
|
||||
pad_idx: index to use for masking the target for tokens we aren't
|
||||
predicting
|
||||
dictionary_token_range: range of indices in dictionary which can
|
||||
be used for random word replacement
|
||||
(e.g. without special characters)
|
||||
Return:
|
||||
masked_sent: masked sentence
|
||||
target: target with words which we are not predicting replaced
|
||||
by pad_idx
|
||||
"""
|
||||
masked_sent = np.copy(sentence)
|
||||
sent_length = len(sentence)
|
||||
mask_num = math.ceil(sent_length * self.masking_ratio)
|
||||
mask = np.random.choice(sent_length, mask_num, replace=False)
|
||||
target = np.copy(sentence)
|
||||
|
||||
for i in range(sent_length):
|
||||
if i in mask:
|
||||
rand = np.random.random()
|
||||
|
||||
# replace with mask if probability is less than masking_prob
|
||||
# (Eg: 0.8)
|
||||
if rand < self.masking_prob:
|
||||
masked_sent[i] = mask_idx
|
||||
|
||||
# replace with random token if probability is less than
|
||||
# masking_prob + random_token_prob (Eg: 0.9)
|
||||
elif rand < (self.masking_prob + self.random_token_prob):
|
||||
# sample random token from dictionary
|
||||
masked_sent[i] = np.random.randint(
|
||||
dictionary_token_range[0], dictionary_token_range[1]
|
||||
)
|
||||
else:
|
||||
target[i] = pad_idx
|
||||
|
||||
return masked_sent, target
|
||||
|
||||
def _collate(self, samples: List[Dict], pad_idx: int, eos_idx: int):
|
||||
"""
|
||||
Does the heavy lifting for creating a batch from the input list of
|
||||
examples. The logic is as follows:
|
||||
1. Mask the input blocks. In case has_pair is True then we have 2
|
||||
blocks to mask.
|
||||
2. Prepend the first masked block tensor with the special token
|
||||
used as sentence embedding. Eg: CLS in BERT. This happens
|
||||
irrespective of the value of has_pair.
|
||||
3. If has_pair is True, then append the first masked block with the
|
||||
special separator token (eg: SEP for BERT) and compute segment
|
||||
label accordingly. In this case, also append the second masked
|
||||
block with this special separator token and compute its segment
|
||||
label.
|
||||
4. For the targets tensor, prepend and append with padding index
|
||||
accordingly.
|
||||
5. Concatenate all tensors.
|
||||
"""
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
# To ensure determinism, we reset the state of the PRNG after every
|
||||
# batch based on the seed and the first id of the batch. This ensures
|
||||
# that across epochs we get the same mask for the same example. This
|
||||
# is needed for reproducibility and is how BERT does masking
|
||||
# TODO: Can we add deteminism without this constraint?
|
||||
with data_utils.numpy_seed(self.seed + samples[0]["id"]):
|
||||
for s in samples:
|
||||
|
||||
# token range is needed for replacing with random token during
|
||||
# masking
|
||||
token_range = (self.vocab.nspecial, len(self.vocab))
|
||||
|
||||
# mask according to specified probabilities.
|
||||
masked_blk_one, masked_tgt_one = self._mask_block(
|
||||
s["block_one"],
|
||||
self.mask_idx,
|
||||
self.pad_idx,
|
||||
token_range,
|
||||
)
|
||||
|
||||
tokens = np.concatenate([[self.classif_token_idx], masked_blk_one])
|
||||
targets = np.concatenate([[self.pad_idx], masked_tgt_one])
|
||||
segments = np.ones(len(tokens)) * self.segment_id
|
||||
|
||||
# if has_pairs is True then we need to add the SEP token to both
|
||||
# the blocks after masking and re-compute segments based on the new
|
||||
# lengths.
|
||||
if self.has_pairs:
|
||||
tokens_one = np.concatenate([tokens, [self.sep_token_idx]])
|
||||
targets_one = np.concatenate([targets, [self.pad_idx]])
|
||||
|
||||
masked_blk_two, masked_tgt_two = self._mask_block(
|
||||
s["block_two"], self.mask_idx, self.pad_idx, token_range
|
||||
)
|
||||
tokens_two = np.concatenate([masked_blk_two, [self.sep_token_idx]])
|
||||
targets_two = np.concatenate([masked_tgt_two, [self.pad_idx]])
|
||||
|
||||
# block + 1 sep + 1 special (CLS)
|
||||
segments_one = np.zeros(len(tokens_one))
|
||||
# block + 1 sep
|
||||
segments_two = np.ones(len(tokens_two))
|
||||
|
||||
tokens = np.concatenate([tokens_one, tokens_two])
|
||||
targets = np.concatenate([targets_one, targets_two])
|
||||
segments = np.concatenate([segments_one, segments_two])
|
||||
|
||||
s["source"] = torch.LongTensor(tokens)
|
||||
s["segment_labels"] = torch.LongTensor(segments)
|
||||
s["lm_target"] = torch.LongTensor(targets)
|
||||
|
||||
def merge(key):
|
||||
return data_utils.collate_tokens(
|
||||
[s[key] for s in samples], pad_idx, eos_idx, left_pad=False
|
||||
)
|
||||
|
||||
return {
|
||||
"id": torch.LongTensor([s["id"] for s in samples]),
|
||||
"ntokens": sum(len(s["source"]) for s in samples),
|
||||
"net_input": {
|
||||
"src_tokens": merge("source"),
|
||||
"segment_labels": merge("segment_labels"),
|
||||
},
|
||||
"lm_target": merge("lm_target"),
|
||||
"sentence_target": torch.LongTensor([s["sentence_target"] for s in samples])
|
||||
if self.has_pairs
|
||||
else None,
|
||||
"nsentences": len(samples),
|
||||
}
|
||||
|
||||
def collater(self, samples: List[Dict]):
|
||||
"""Merge a list of samples to form a mini-batch.
|
||||
|
||||
Args:
|
||||
samples (List[dict]): samples to collate
|
||||
|
||||
Returns:
|
||||
dict: a mini-batch of data
|
||||
"""
|
||||
return self._collate(samples, self.vocab.pad(), self.vocab.eos())
|
||||
|
||||
def num_tokens(self, index: int):
|
||||
"""
|
||||
Return the number of tokens in a sample. This value is used to
|
||||
enforce max-tokens during batching.
|
||||
"""
|
||||
return self.sizes[index]
|
||||
|
||||
def size(self, index: int):
|
||||
"""
|
||||
Return an example's size as a float or tuple. This value is used when
|
||||
filtering a dataset with max-positions.
|
||||
"""
|
||||
return self.sizes[index]
|
||||
|
||||
def ordered_indices(self):
|
||||
"""
|
||||
Return an ordered list of indices. Batches will be constructed based
|
||||
on this order.
|
||||
"""
|
||||
if self.shuffle:
|
||||
return np.random.permutation(len(self))
|
||||
else:
|
||||
order = [np.arange(len(self))]
|
||||
order.append(self.sizes)
|
||||
return np.lexsort(order)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return getattr(self.dataset, "supports_prefetch", False)
|
||||
|
||||
def prefetch(self, indices):
|
||||
self.dataset.prefetch(indices)
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from fairseq.data import Dictionary
|
||||
|
||||
|
||||
class MaskedLMDictionary(Dictionary):
|
||||
"""
|
||||
Dictionary for Masked Language Modelling tasks. This extends Dictionary by
|
||||
adding the mask symbol.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pad="<pad>",
|
||||
eos="</s>",
|
||||
unk="<unk>",
|
||||
mask="<mask>",
|
||||
):
|
||||
super().__init__(pad=pad, eos=eos, unk=unk)
|
||||
self.mask_word = mask
|
||||
self.mask_index = self.add_symbol(mask)
|
||||
self.nspecial = len(self.symbols)
|
||||
|
||||
def mask(self):
|
||||
"""Helper to get index of mask symbol"""
|
||||
return self.mask_index
|
||||
|
||||
|
||||
class BertDictionary(MaskedLMDictionary):
|
||||
"""
|
||||
Dictionary for BERT task. This extends MaskedLMDictionary by adding support
|
||||
for cls and sep symbols.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pad="<pad>",
|
||||
eos="</s>",
|
||||
unk="<unk>",
|
||||
mask="<mask>",
|
||||
cls="<cls>",
|
||||
sep="<sep>",
|
||||
):
|
||||
super().__init__(pad=pad, eos=eos, unk=unk, mask=mask)
|
||||
self.cls_word = cls
|
||||
self.sep_word = sep
|
||||
self.cls_index = self.add_symbol(cls)
|
||||
self.sep_index = self.add_symbol(sep)
|
||||
self.nspecial = len(self.symbols)
|
||||
|
||||
def cls(self):
|
||||
"""Helper to get index of cls symbol"""
|
||||
return self.cls_index
|
||||
|
||||
def sep(self):
|
||||
"""Helper to get index of sep symbol"""
|
||||
return self.sep_index
|
||||
32
modules/voice_conversion/fairseq/data/list_dataset.py
Normal file
32
modules/voice_conversion/fairseq/data/list_dataset.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from . import BaseWrapperDataset
|
||||
|
||||
|
||||
class ListDataset(BaseWrapperDataset):
|
||||
def __init__(self, dataset, sizes=None):
|
||||
super().__init__(dataset)
|
||||
self._sizes = sizes
|
||||
|
||||
def __iter__(self):
|
||||
for x in self.dataset:
|
||||
yield x
|
||||
|
||||
def collater(self, samples):
|
||||
return samples
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
def size(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
pass
|
||||
@@ -0,0 +1,97 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Dict
|
||||
|
||||
from fairseq.data.monolingual_dataset import MonolingualDataset
|
||||
|
||||
from . import FairseqDataset
|
||||
|
||||
|
||||
class LMContextWindowDataset(FairseqDataset):
|
||||
"""
|
||||
Wraps a MonolingualDataset and provides more context for evaluation.
|
||||
|
||||
Each item in the new dataset will have a maximum size of
|
||||
``tokens_per_sample + context_window``.
|
||||
|
||||
Args:
|
||||
dataset: dataset to wrap
|
||||
tokens_per_sample (int): the max number of tokens in each dataset item
|
||||
context_window (int): the number of accumulated tokens to add to each
|
||||
dataset item
|
||||
pad_idx (int): padding symbol
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: MonolingualDataset,
|
||||
tokens_per_sample: int,
|
||||
context_window: int,
|
||||
pad_idx: int,
|
||||
):
|
||||
assert context_window > 0
|
||||
self.dataset = dataset
|
||||
self.tokens_per_sample = tokens_per_sample
|
||||
self.context_window = context_window
|
||||
self.pad_idx = pad_idx
|
||||
self.prev_tokens = np.empty([0])
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.dataset[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def collater(self, samples) -> Dict:
|
||||
sample = self.dataset.collater(samples)
|
||||
|
||||
pad = self.pad_idx
|
||||
max_sample_len = self.tokens_per_sample + self.context_window
|
||||
|
||||
bsz, tsz = sample["net_input"]["src_tokens"].shape
|
||||
start_idxs = [0] * bsz
|
||||
toks = sample["net_input"]["src_tokens"]
|
||||
lengths = sample["net_input"]["src_lengths"]
|
||||
tgt = sample["target"]
|
||||
new_toks = np.empty([bsz, tsz + self.context_window], dtype=np.int64)
|
||||
new_tgt = np.full([bsz, tsz + self.context_window], pad, dtype=np.int64)
|
||||
sample_lens = toks.ne(pad).long().sum(dim=1).cpu()
|
||||
for i in range(bsz):
|
||||
sample_len = sample_lens[i]
|
||||
extra = len(self.prev_tokens) + sample_len - max_sample_len
|
||||
if extra > 0:
|
||||
self.prev_tokens = self.prev_tokens[extra:]
|
||||
pads = np.full(self.context_window - len(self.prev_tokens), pad)
|
||||
new_toks[i] = np.concatenate([self.prev_tokens, toks[i].numpy(), pads])
|
||||
new_tgt[
|
||||
i, len(self.prev_tokens) : len(self.prev_tokens) + len(tgt[i])
|
||||
] = tgt[i]
|
||||
start_idxs[i] = len(self.prev_tokens)
|
||||
lengths[i] += len(self.prev_tokens)
|
||||
self.prev_tokens = new_toks[i][new_toks[i] != pad][-self.context_window :]
|
||||
sample["net_input"]["src_tokens"] = torch.from_numpy(new_toks)
|
||||
sample["target"] = torch.from_numpy(new_tgt)
|
||||
sample["start_indices"] = start_idxs
|
||||
return sample
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.dataset.num_tokens(index)
|
||||
|
||||
def size(self, index):
|
||||
return self.dataset.size(index)
|
||||
|
||||
def ordered_indices(self):
|
||||
# NOTE we don't shuffle the data to retain access to the previous dataset elements
|
||||
return np.arange(len(self.dataset))
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return getattr(self.dataset, "supports_prefetch", False)
|
||||
|
||||
def prefetch(self, indices):
|
||||
return self.dataset.prefetch(indices)
|
||||
21
modules/voice_conversion/fairseq/data/lru_cache_dataset.py
Normal file
21
modules/voice_conversion/fairseq/data/lru_cache_dataset.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
from . import BaseWrapperDataset
|
||||
|
||||
|
||||
class LRUCacheDataset(BaseWrapperDataset):
|
||||
def __init__(self, dataset, token=None):
|
||||
super().__init__(dataset)
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, index):
|
||||
return self.dataset[index]
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def collater(self, samples):
|
||||
return self.dataset.collater(samples)
|
||||
220
modules/voice_conversion/fairseq/data/mask_tokens_dataset.py
Normal file
220
modules/voice_conversion/fairseq/data/mask_tokens_dataset.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import Dictionary, data_utils
|
||||
|
||||
from . import BaseWrapperDataset, LRUCacheDataset
|
||||
|
||||
|
||||
class MaskTokensDataset(BaseWrapperDataset):
|
||||
"""
|
||||
A wrapper Dataset for masked language modeling.
|
||||
|
||||
Input items are masked according to the specified masking probability.
|
||||
|
||||
Args:
|
||||
dataset: Dataset to wrap.
|
||||
sizes: Sentence lengths
|
||||
vocab: Dictionary with the vocabulary and special tokens.
|
||||
pad_idx: Id of pad token in vocab
|
||||
mask_idx: Id of mask token in vocab
|
||||
return_masked_tokens: controls whether to return the non-masked tokens
|
||||
(the default) or to return a tensor with the original masked token
|
||||
IDs (and *pad_idx* elsewhere). The latter is useful as targets for
|
||||
masked LM training.
|
||||
seed: Seed for random number generator for reproducibility.
|
||||
mask_prob: probability of replacing a token with *mask_idx*.
|
||||
leave_unmasked_prob: probability that a masked token is unmasked.
|
||||
random_token_prob: probability of replacing a masked token with a
|
||||
random token from the vocabulary.
|
||||
freq_weighted_replacement: sample random replacement words based on
|
||||
word frequencies in the vocab.
|
||||
mask_whole_words: only mask whole words. This should be a byte mask
|
||||
over vocab indices, indicating whether it is the beginning of a
|
||||
word. We will extend any mask to encompass the whole word.
|
||||
bpe: BPE to use for whole-word masking.
|
||||
mask_multiple_length : repeat each mask index multiple times. Default
|
||||
value is 1.
|
||||
mask_stdev : standard deviation of masks distribution in case of
|
||||
multiple masking. Default value is 0.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def apply_mask(cls, dataset: torch.utils.data.Dataset, *args, **kwargs):
|
||||
"""Return the source and target datasets for masked LM training."""
|
||||
dataset = LRUCacheDataset(dataset)
|
||||
return (
|
||||
LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=False)),
|
||||
LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=True)),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: torch.utils.data.Dataset,
|
||||
vocab: Dictionary,
|
||||
pad_idx: int,
|
||||
mask_idx: int,
|
||||
return_masked_tokens: bool = False,
|
||||
seed: int = 1,
|
||||
mask_prob: float = 0.15,
|
||||
leave_unmasked_prob: float = 0.1,
|
||||
random_token_prob: float = 0.1,
|
||||
freq_weighted_replacement: bool = False,
|
||||
mask_whole_words: torch.Tensor = None,
|
||||
mask_multiple_length: int = 1,
|
||||
mask_stdev: float = 0.0,
|
||||
):
|
||||
assert 0.0 < mask_prob < 1.0
|
||||
assert 0.0 <= random_token_prob <= 1.0
|
||||
assert 0.0 <= leave_unmasked_prob <= 1.0
|
||||
assert random_token_prob + leave_unmasked_prob <= 1.0
|
||||
assert mask_multiple_length >= 1
|
||||
assert mask_stdev >= 0.0
|
||||
|
||||
self.dataset = dataset
|
||||
self.vocab = vocab
|
||||
self.pad_idx = pad_idx
|
||||
self.mask_idx = mask_idx
|
||||
self.return_masked_tokens = return_masked_tokens
|
||||
self.seed = seed
|
||||
self.mask_prob = mask_prob
|
||||
self.leave_unmasked_prob = leave_unmasked_prob
|
||||
self.random_token_prob = random_token_prob
|
||||
self.mask_whole_words = mask_whole_words
|
||||
self.mask_multiple_length = mask_multiple_length
|
||||
self.mask_stdev = mask_stdev
|
||||
|
||||
if random_token_prob > 0.0:
|
||||
if freq_weighted_replacement:
|
||||
weights = np.array(self.vocab.count)
|
||||
else:
|
||||
weights = np.ones(len(self.vocab))
|
||||
weights[: self.vocab.nspecial] = 0
|
||||
self.weights = weights / weights.sum()
|
||||
|
||||
self.epoch = 0
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return True # only the noise changes, not item sizes
|
||||
|
||||
def set_epoch(self, epoch, **unused):
|
||||
super().set_epoch(epoch)
|
||||
self.epoch = epoch
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
return self.__getitem_cached__(self.seed, self.epoch, index)
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem_cached__(self, seed: int, epoch: int, index: int):
|
||||
with data_utils.numpy_seed(self.seed, self.epoch, index):
|
||||
item = self.dataset[index]
|
||||
sz = len(item)
|
||||
|
||||
assert (
|
||||
self.mask_idx not in item
|
||||
), "Dataset contains mask_idx (={}), this is not expected!".format(
|
||||
self.mask_idx,
|
||||
)
|
||||
|
||||
if self.mask_whole_words is not None:
|
||||
word_begins_mask = self.mask_whole_words.gather(0, item)
|
||||
word_begins_idx = word_begins_mask.nonzero().view(-1)
|
||||
sz = len(word_begins_idx)
|
||||
words = np.split(word_begins_mask, word_begins_idx)[1:]
|
||||
assert len(words) == sz
|
||||
word_lens = list(map(len, words))
|
||||
|
||||
# decide elements to mask
|
||||
mask = np.full(sz, False)
|
||||
num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
self.mask_prob * sz / float(self.mask_multiple_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
|
||||
# multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453)
|
||||
mask_idc = np.random.choice(sz, num_mask, replace=False)
|
||||
if self.mask_stdev > 0.0:
|
||||
lengths = np.random.normal(
|
||||
self.mask_multiple_length, self.mask_stdev, size=num_mask
|
||||
)
|
||||
lengths = [max(0, int(round(x))) for x in lengths]
|
||||
mask_idc = np.asarray(
|
||||
[
|
||||
mask_idc[j] + offset
|
||||
for j in range(len(mask_idc))
|
||||
for offset in range(lengths[j])
|
||||
],
|
||||
dtype=np.int64,
|
||||
)
|
||||
else:
|
||||
mask_idc = np.concatenate(
|
||||
[mask_idc + i for i in range(self.mask_multiple_length)]
|
||||
)
|
||||
mask_idc = mask_idc[mask_idc < len(mask)]
|
||||
try:
|
||||
mask[mask_idc] = True
|
||||
except: # something wrong
|
||||
print(
|
||||
"Assigning mask indexes {} to mask {} failed!".format(
|
||||
mask_idc, mask
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
||||
if self.return_masked_tokens:
|
||||
# exit early if we're just returning the masked tokens
|
||||
# (i.e., the targets for masked LM training)
|
||||
if self.mask_whole_words is not None:
|
||||
mask = np.repeat(mask, word_lens)
|
||||
new_item = np.full(len(mask), self.pad_idx)
|
||||
new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1]
|
||||
return torch.from_numpy(new_item)
|
||||
|
||||
# decide unmasking and random replacement
|
||||
rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
|
||||
if rand_or_unmask_prob > 0.0:
|
||||
rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
|
||||
if self.random_token_prob == 0.0:
|
||||
unmask = rand_or_unmask
|
||||
rand_mask = None
|
||||
elif self.leave_unmasked_prob == 0.0:
|
||||
unmask = None
|
||||
rand_mask = rand_or_unmask
|
||||
else:
|
||||
unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
|
||||
decision = np.random.rand(sz) < unmask_prob
|
||||
unmask = rand_or_unmask & decision
|
||||
rand_mask = rand_or_unmask & (~decision)
|
||||
else:
|
||||
unmask = rand_mask = None
|
||||
|
||||
if unmask is not None:
|
||||
mask = mask ^ unmask
|
||||
|
||||
if self.mask_whole_words is not None:
|
||||
mask = np.repeat(mask, word_lens)
|
||||
|
||||
new_item = np.copy(item)
|
||||
new_item[mask] = self.mask_idx
|
||||
if rand_mask is not None:
|
||||
num_rand = rand_mask.sum()
|
||||
if num_rand > 0:
|
||||
if self.mask_whole_words is not None:
|
||||
rand_mask = np.repeat(rand_mask, word_lens)
|
||||
num_rand = rand_mask.sum()
|
||||
|
||||
new_item[rand_mask] = np.random.choice(
|
||||
len(self.vocab),
|
||||
num_rand,
|
||||
p=self.weights,
|
||||
)
|
||||
|
||||
return torch.from_numpy(new_item)
|
||||
253
modules/voice_conversion/fairseq/data/monolingual_dataset.py
Normal file
253
modules/voice_conversion/fairseq/data/monolingual_dataset.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import FairseqDataset, data_utils
|
||||
|
||||
|
||||
def collate(samples, pad_idx, eos_idx, fixed_pad_length=None, pad_to_bsz=None):
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
def merge(key, is_list=False):
|
||||
if is_list:
|
||||
res = []
|
||||
for i in range(len(samples[0][key])):
|
||||
res.append(
|
||||
data_utils.collate_tokens(
|
||||
[s[key][i] for s in samples],
|
||||
pad_idx,
|
||||
eos_idx,
|
||||
left_pad=False,
|
||||
pad_to_length=fixed_pad_length,
|
||||
pad_to_bsz=pad_to_bsz,
|
||||
)
|
||||
)
|
||||
return res
|
||||
else:
|
||||
return data_utils.collate_tokens(
|
||||
[s[key] for s in samples],
|
||||
pad_idx,
|
||||
eos_idx,
|
||||
left_pad=False,
|
||||
pad_to_length=fixed_pad_length,
|
||||
pad_to_bsz=pad_to_bsz,
|
||||
)
|
||||
|
||||
src_tokens = merge("source")
|
||||
if samples[0]["target"] is not None:
|
||||
is_target_list = isinstance(samples[0]["target"], list)
|
||||
target = merge("target", is_target_list)
|
||||
else:
|
||||
target = src_tokens
|
||||
|
||||
return {
|
||||
"id": torch.LongTensor([s["id"] for s in samples]),
|
||||
"nsentences": len(samples),
|
||||
"ntokens": sum(len(s["source"]) for s in samples),
|
||||
"net_input": {
|
||||
"src_tokens": src_tokens,
|
||||
"src_lengths": torch.LongTensor([s["source"].numel() for s in samples]),
|
||||
},
|
||||
"target": target,
|
||||
}
|
||||
|
||||
|
||||
class MonolingualDataset(FairseqDataset):
|
||||
"""
|
||||
A wrapper around torch.utils.data.Dataset for monolingual data.
|
||||
|
||||
Args:
|
||||
dataset (torch.utils.data.Dataset): dataset to wrap
|
||||
sizes (List[int]): sentence lengths
|
||||
vocab (~fairseq.data.Dictionary): vocabulary
|
||||
shuffle (bool, optional): shuffle the elements before batching
|
||||
(default: True).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
sizes,
|
||||
src_vocab,
|
||||
tgt_vocab=None,
|
||||
add_eos_for_other_targets=False,
|
||||
shuffle=False,
|
||||
targets=None,
|
||||
add_bos_token=False,
|
||||
fixed_pad_length=None,
|
||||
pad_to_bsz=None,
|
||||
src_lang_idx=None,
|
||||
tgt_lang_idx=None,
|
||||
):
|
||||
self.dataset = dataset
|
||||
self.sizes = np.array(sizes)
|
||||
self.vocab = src_vocab
|
||||
self.tgt_vocab = tgt_vocab or src_vocab
|
||||
self.add_eos_for_other_targets = add_eos_for_other_targets
|
||||
self.shuffle = shuffle
|
||||
self.add_bos_token = add_bos_token
|
||||
self.fixed_pad_length = fixed_pad_length
|
||||
self.pad_to_bsz = pad_to_bsz
|
||||
self.src_lang_idx = src_lang_idx
|
||||
self.tgt_lang_idx = tgt_lang_idx
|
||||
|
||||
assert targets is None or all(
|
||||
t in {"self", "future", "past"} for t in targets
|
||||
), "targets must be none or one of 'self', 'future', 'past'"
|
||||
if targets is not None and len(targets) == 0:
|
||||
targets = None
|
||||
self.targets = targets
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.targets is not None:
|
||||
# *future_target* is the original sentence
|
||||
# *source* is shifted right by 1 (maybe left-padded with eos)
|
||||
# *past_target* is shifted right by 2 (left-padded as needed)
|
||||
#
|
||||
# Left-to-right language models should condition on *source* and
|
||||
# predict *future_target*.
|
||||
# Right-to-left language models should condition on *source* and
|
||||
# predict *past_target*.
|
||||
source, future_target, past_target = self.dataset[index]
|
||||
source, target = self._make_source_target(
|
||||
source, future_target, past_target
|
||||
)
|
||||
else:
|
||||
source = self.dataset[index]
|
||||
target = None
|
||||
source, target = self._maybe_add_bos(source, target)
|
||||
return {"id": index, "source": source, "target": target}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def _make_source_target(self, source, future_target, past_target):
|
||||
if self.targets is not None:
|
||||
target = []
|
||||
|
||||
if (
|
||||
self.add_eos_for_other_targets
|
||||
and (("self" in self.targets) or ("past" in self.targets))
|
||||
and source[-1] != self.vocab.eos()
|
||||
):
|
||||
# append eos at the end of source
|
||||
source = torch.cat([source, source.new([self.vocab.eos()])])
|
||||
|
||||
if "future" in self.targets:
|
||||
future_target = torch.cat(
|
||||
[future_target, future_target.new([self.vocab.pad()])]
|
||||
)
|
||||
if "past" in self.targets:
|
||||
# first token is before the start of sentence which is only used in "none" break mode when
|
||||
# add_eos_for_other_targets is False
|
||||
past_target = torch.cat(
|
||||
[
|
||||
past_target.new([self.vocab.pad()]),
|
||||
past_target[1:],
|
||||
source[-2, None],
|
||||
]
|
||||
)
|
||||
|
||||
for t in self.targets:
|
||||
if t == "self":
|
||||
target.append(source)
|
||||
elif t == "future":
|
||||
target.append(future_target)
|
||||
elif t == "past":
|
||||
target.append(past_target)
|
||||
else:
|
||||
raise Exception("invalid target " + t)
|
||||
|
||||
if len(target) == 1:
|
||||
target = target[0]
|
||||
else:
|
||||
target = future_target
|
||||
|
||||
return source, self._filter_vocab(target)
|
||||
|
||||
def _maybe_add_bos(self, source, target):
|
||||
if self.add_bos_token:
|
||||
source = torch.cat([source.new([self.vocab.bos()]), source])
|
||||
if target is not None:
|
||||
target = torch.cat([target.new([self.tgt_vocab.bos()]), target])
|
||||
return source, target
|
||||
|
||||
def num_tokens_vec(self, indices):
|
||||
"""Return the number of tokens for a set of positions defined by indices.
|
||||
This value is used to enforce ``--max-tokens`` during batching."""
|
||||
return self.sizes[indices]
|
||||
|
||||
def _filter_vocab(self, target):
|
||||
if len(self.tgt_vocab) != len(self.vocab):
|
||||
|
||||
def _filter(target):
|
||||
mask = target.ge(len(self.tgt_vocab))
|
||||
if mask.any():
|
||||
target[mask] = self.tgt_vocab.unk()
|
||||
return target
|
||||
|
||||
if isinstance(target, list):
|
||||
return [_filter(t) for t in target]
|
||||
return _filter(target)
|
||||
return target
|
||||
|
||||
def collater(self, samples):
|
||||
"""Merge a list of samples to form a mini-batch.
|
||||
|
||||
Args:
|
||||
samples (List[dict]): samples to collate
|
||||
|
||||
Returns:
|
||||
dict: a mini-batch with the following keys:
|
||||
|
||||
- `id` (LongTensor): example IDs in the original input order
|
||||
- `ntokens` (int): total number of tokens in the batch
|
||||
- `net_input` (dict): the input to the Model, containing keys:
|
||||
|
||||
- `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
|
||||
the source sentence of shape `(bsz, src_len)`. Padding will
|
||||
appear on the right.
|
||||
|
||||
- `target` (LongTensor): a padded 2D Tensor of tokens in the
|
||||
target sentence of shape `(bsz, tgt_len)`. Padding will appear
|
||||
on the right.
|
||||
"""
|
||||
return collate(
|
||||
samples,
|
||||
self.vocab.pad(),
|
||||
self.vocab.eos(),
|
||||
self.fixed_pad_length,
|
||||
self.pad_to_bsz,
|
||||
)
|
||||
|
||||
def num_tokens(self, index):
|
||||
"""Return the number of tokens in a sample. This value is used to
|
||||
enforce ``--max-tokens`` during batching."""
|
||||
return self.sizes[index]
|
||||
|
||||
def size(self, index):
|
||||
"""Return an example's size as a float or tuple. This value is used when
|
||||
filtering a dataset with ``--max-positions``."""
|
||||
return self.sizes[index]
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
if self.shuffle:
|
||||
order = [np.random.permutation(len(self))]
|
||||
else:
|
||||
order = [np.arange(len(self))]
|
||||
order.append(self.sizes)
|
||||
return np.lexsort(order)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return getattr(self.dataset, "supports_prefetch", False)
|
||||
|
||||
def prefetch(self, indices):
|
||||
self.dataset.prefetch(indices)
|
||||
285
modules/voice_conversion/fairseq/data/multi_corpus_dataset.py
Normal file
285
modules/voice_conversion/fairseq/data/multi_corpus_dataset.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fairseq.data import data_utils
|
||||
|
||||
from . import FairseqDataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultiCorpusDataset(FairseqDataset):
|
||||
"""
|
||||
Stores multiple instances of FairseqDataset together.
|
||||
Unless batch_sample=True, requires each instance
|
||||
to be the same dataset, as the collate method needs to work on batches with
|
||||
samples from each dataset.
|
||||
|
||||
Allows specifying a distribution over the datasets to use. Note that unlike
|
||||
MultiCorpusSampledDataset, this distribution allows sampling for each item,
|
||||
rather than on a batch level. Note that datasets with sampling probabilty
|
||||
of 0 will be skipped.
|
||||
|
||||
Each time ordered_indices() is called, a new sample is generated with
|
||||
the specified distribution.
|
||||
|
||||
Args:
|
||||
datasets: a OrderedDict of FairseqDataset instances.
|
||||
distribution: a List containing the probability of getting an utterance from
|
||||
corresponding dataset
|
||||
seed: random seed for sampling the datsets
|
||||
sort_indices: if true, will sort the ordered indices by size
|
||||
batch_sample: if true, will ensure each batch is from a single dataset
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datasets: Dict[str, FairseqDataset],
|
||||
distribution: List[float],
|
||||
seed: int,
|
||||
sort_indices: bool = False,
|
||||
batch_sample: bool = False,
|
||||
distributed_rank: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(datasets, OrderedDict)
|
||||
assert len(datasets) == len(distribution)
|
||||
assert sum(distribution) == 1
|
||||
self.datasets = datasets
|
||||
self.distribution = distribution
|
||||
self.seed = seed
|
||||
self.sort_indices = sort_indices
|
||||
self.batch_sample = batch_sample
|
||||
self.distributed_rank = distributed_rank
|
||||
|
||||
# Avoid repeated conversions to list later
|
||||
self.dataset_list = list(datasets.values())
|
||||
self.total_num_instances = 0
|
||||
|
||||
first_dataset = self.dataset_list[0]
|
||||
|
||||
self.num_instances_per_dataset = []
|
||||
self.dataset_offsets = []
|
||||
for i, dataset in enumerate(self.dataset_list):
|
||||
assert isinstance(dataset, FairseqDataset)
|
||||
assert type(dataset) is type(first_dataset)
|
||||
self.num_instances_per_dataset.append(
|
||||
0 if self.distribution[i] == 0 else len(dataset)
|
||||
)
|
||||
self.dataset_offsets.append(self.total_num_instances)
|
||||
self.total_num_instances += self.num_instances_per_dataset[i]
|
||||
|
||||
def ordered_indices(self):
|
||||
start = time.time()
|
||||
with data_utils.numpy_seed(self.seed, self.epoch):
|
||||
logger.info(
|
||||
f"sampling new dataset with seed {self.seed} epoch {self.epoch}"
|
||||
)
|
||||
sampled_indices = []
|
||||
num_selected_instances = 0
|
||||
|
||||
# For each dataset i, sample self.distribution[i] * self.total_num_instances
|
||||
for i, key in enumerate(self.datasets):
|
||||
if self.distribution[i] == 0:
|
||||
# skip dataset if sampling probability is 0
|
||||
continue
|
||||
|
||||
if i < len(self.datasets) - 1:
|
||||
num_instances = int(self.distribution[i] * self.total_num_instances)
|
||||
high = self.dataset_offsets[i + 1]
|
||||
else:
|
||||
num_instances = self.total_num_instances - num_selected_instances
|
||||
high = self.total_num_instances
|
||||
|
||||
logger.info(f"sampling {num_instances} from {key} dataset")
|
||||
num_selected_instances += num_instances
|
||||
|
||||
# First, add k copies of the dataset where k = num_instances // len(dataset).
|
||||
# This ensures an equal distribution of the data points as much as possible.
|
||||
# For the remaining entries randomly sample them
|
||||
dataset_size = len(self.datasets[key])
|
||||
num_copies = num_instances // dataset_size
|
||||
dataset_indices = (
|
||||
np.random.permutation(high - self.dataset_offsets[i])
|
||||
+ self.dataset_offsets[i]
|
||||
)[: num_instances - num_copies * dataset_size]
|
||||
if num_copies > 0:
|
||||
sampled_indices += list(
|
||||
np.concatenate(
|
||||
(
|
||||
np.repeat(
|
||||
np.arange(self.dataset_offsets[i], high), num_copies
|
||||
),
|
||||
dataset_indices,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
sampled_indices += list(dataset_indices)
|
||||
|
||||
assert (
|
||||
len(sampled_indices) == self.total_num_instances
|
||||
), f"{len(sampled_indices)} vs {self.total_num_instances}"
|
||||
|
||||
np.random.shuffle(sampled_indices)
|
||||
if self.sort_indices:
|
||||
sampled_indices.sort(key=lambda i: self.num_tokens(i))
|
||||
|
||||
logger.info(
|
||||
"multi_corpus_dataset ordered_indices took {}s".format(
|
||||
time.time() - start
|
||||
)
|
||||
)
|
||||
return np.array(sampled_indices, dtype=np.int64)
|
||||
|
||||
def _map_index(self, index: int):
|
||||
"""
|
||||
If dataset A has length N and dataset B has length M
|
||||
then index 1 maps to index 1 of dataset A, and index N + 1
|
||||
maps to index 1 of B.
|
||||
"""
|
||||
counter = 0
|
||||
for num_instances, key in zip(self.num_instances_per_dataset, self.datasets):
|
||||
if index < counter + num_instances:
|
||||
return index - counter, key
|
||||
counter += num_instances
|
||||
raise ValueError(
|
||||
"Invalid index: {}, max: {}".format(index, self.total_num_instances)
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Length of this dataset is the sum of individual datasets
|
||||
"""
|
||||
return self.total_num_instances
|
||||
|
||||
async def getitem(self, index):
|
||||
new_index, key = self._map_index(index)
|
||||
try:
|
||||
if hasattr(self.datasets[key], "getitem"):
|
||||
item = await self.datasets[key].getitem(new_index)
|
||||
else:
|
||||
item = self.datasets[key][new_index]
|
||||
item["full_id"] = index
|
||||
return item
|
||||
except Exception as e:
|
||||
e.args = (f"Error from {key} dataset", *e.args)
|
||||
raise
|
||||
|
||||
def __getitem__(self, index):
|
||||
return asyncio.run(self.getitem(index))
|
||||
|
||||
async def getitems(self, indices):
|
||||
# initialize a bunch of everstore read operations
|
||||
# wait in the end to reduce overhead
|
||||
# very helpful if io is latency bounded
|
||||
|
||||
max_concurrency = 32
|
||||
sem = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def controlled_getitem(index):
|
||||
async with sem:
|
||||
return await self.getitem(index)
|
||||
|
||||
coroutines = []
|
||||
for index in indices:
|
||||
coroutines.append(controlled_getitem(index))
|
||||
results = await asyncio.gather(*coroutines)
|
||||
return results
|
||||
|
||||
def __getitems__(self, indices):
|
||||
return asyncio.run(self.getitems(indices))
|
||||
|
||||
def collater(self, samples):
|
||||
"""
|
||||
If we are doing batch sampling, then pick the right collater to use.
|
||||
|
||||
Otherwise we assume all collaters are the same.
|
||||
"""
|
||||
if len(samples) == 0:
|
||||
return None
|
||||
if "full_id" in samples[0]:
|
||||
_, key = self._map_index(samples[0]["full_id"])
|
||||
try:
|
||||
batch = self.datasets[key].collater(samples)
|
||||
except Exception:
|
||||
print(f"Collating failed for key {key}", flush=True)
|
||||
raise
|
||||
return batch
|
||||
else:
|
||||
# Subclasses may override __getitem__ to not specify full_id
|
||||
return list(self.datasets.values())[0].collater(samples)
|
||||
|
||||
def num_tokens(self, index: int):
|
||||
index, key = self._map_index(index)
|
||||
return self.datasets[key].num_tokens(index)
|
||||
|
||||
def size(self, index: int):
|
||||
index, key = self._map_index(index)
|
||||
return self.datasets[key].size(index)
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return False
|
||||
|
||||
def set_epoch(self, epoch, **unused):
|
||||
super().set_epoch(epoch)
|
||||
logger.info(f"setting epoch of multi_corpus_dataset to {epoch}")
|
||||
self.epoch = epoch
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_fetch_outside_dataloader(self):
|
||||
return all(
|
||||
self.datasets[key].supports_fetch_outside_dataloader
|
||||
for key in self.datasets
|
||||
)
|
||||
|
||||
def batch_by_size(
|
||||
self,
|
||||
indices,
|
||||
max_tokens=None,
|
||||
max_sentences=None,
|
||||
required_batch_size_multiple=1,
|
||||
):
|
||||
if not self.batch_sample:
|
||||
return super().batch_by_size(
|
||||
indices, max_tokens, max_sentences, required_batch_size_multiple
|
||||
)
|
||||
|
||||
dataset_indices = {key: [] for key in self.datasets}
|
||||
for i in indices:
|
||||
_, key = self._map_index(i)
|
||||
dataset_indices[key].append(i)
|
||||
|
||||
batches = []
|
||||
for key in dataset_indices:
|
||||
cur_batches = super().batch_by_size(
|
||||
np.array(dataset_indices[key], dtype=np.int64),
|
||||
max_tokens,
|
||||
max_sentences,
|
||||
required_batch_size_multiple,
|
||||
)
|
||||
logger.info(f"Created {len(cur_batches)} batches for dataset {key}")
|
||||
batches += cur_batches
|
||||
|
||||
# If this dataset is used in a distributed training setup,
|
||||
# then shuffle such that the order is seeded by the distributed rank
|
||||
# as well
|
||||
if self.distributed_rank is not None:
|
||||
with data_utils.numpy_seed(self.seed, self.epoch, self.distributed_rank):
|
||||
np.random.shuffle(batches)
|
||||
return batches
|
||||
@@ -0,0 +1,152 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import FairseqDataset
|
||||
|
||||
|
||||
def uniform_sampler(x):
|
||||
# Sample from uniform distribution
|
||||
return np.random.choice(x, 1).item()
|
||||
|
||||
|
||||
class MultiCorpusSampledDataset(FairseqDataset):
|
||||
"""
|
||||
Stores multiple instances of FairseqDataset together and in every iteration
|
||||
creates a batch by first sampling a dataset according to a specified
|
||||
probability distribution and then getting instances from that dataset.
|
||||
|
||||
Args:
|
||||
datasets: an OrderedDict of FairseqDataset instances.
|
||||
sampling_func: A function for sampling over list of dataset keys.
|
||||
The default strategy is to sample uniformly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datasets: Dict[str, FairseqDataset],
|
||||
sampling_func: Callable[[List], int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(datasets, OrderedDict)
|
||||
self.datasets = datasets
|
||||
if sampling_func is None:
|
||||
sampling_func = uniform_sampler
|
||||
self.sampling_func = sampling_func
|
||||
|
||||
self.total_num_instances = 0
|
||||
for _, dataset in datasets.items():
|
||||
assert isinstance(dataset, FairseqDataset)
|
||||
self.total_num_instances += len(dataset)
|
||||
|
||||
self._ordered_indices = None
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Length of this dataset is the sum of individual datasets
|
||||
"""
|
||||
return self.total_num_instances
|
||||
|
||||
def ordered_indices(self):
|
||||
"""
|
||||
Ordered indices for batching. Here we call the underlying
|
||||
dataset's ordered_indices() so that we get the same random ordering
|
||||
as we would have from using the underlying dataset directly.
|
||||
"""
|
||||
if self._ordered_indices is None:
|
||||
self._ordered_indices = OrderedDict(
|
||||
[
|
||||
(key, dataset.ordered_indices())
|
||||
for key, dataset in self.datasets.items()
|
||||
]
|
||||
)
|
||||
return np.arange(len(self))
|
||||
|
||||
def _map_index_to_dataset(self, key: int, index: int):
|
||||
"""
|
||||
Different underlying datasets have different lengths. In order to ensure
|
||||
we are not accessing an index outside the range of the current dataset
|
||||
size, we wrap around. This function should be called after we have
|
||||
created an ordering for this and all underlying datasets.
|
||||
"""
|
||||
assert (
|
||||
self._ordered_indices is not None
|
||||
), "Must call MultiCorpusSampledDataset.ordered_indices() first"
|
||||
mapped_index = index % len(self.datasets[key])
|
||||
return self._ordered_indices[key][mapped_index]
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
"""
|
||||
Get the item associated with index from each underlying dataset.
|
||||
Since index is in the range of [0, TotalNumInstances], we need to
|
||||
map the index to the dataset before retrieving the item.
|
||||
"""
|
||||
return OrderedDict(
|
||||
[
|
||||
(key, dataset[self._map_index_to_dataset(key, index)])
|
||||
for key, dataset in self.datasets.items()
|
||||
]
|
||||
)
|
||||
|
||||
def collater(self, samples: List[Dict]):
|
||||
"""
|
||||
Generate a mini-batch for this dataset.
|
||||
To convert this into a regular mini-batch we use the following
|
||||
logic:
|
||||
1. Select a dataset using the specified probability distribution.
|
||||
2. Call the collater function of the selected dataset.
|
||||
"""
|
||||
if len(samples) == 0:
|
||||
return None
|
||||
|
||||
selected_key = self.sampling_func(list(self.datasets.keys()))
|
||||
selected_samples = [sample[selected_key] for sample in samples]
|
||||
return self.datasets[selected_key].collater(selected_samples)
|
||||
|
||||
def num_tokens(self, index: int):
|
||||
"""
|
||||
Return an example's length (number of tokens), used for batching. Here
|
||||
we return the max across all examples at index across all underlying
|
||||
datasets.
|
||||
"""
|
||||
return max(
|
||||
dataset.num_tokens(self._map_index_to_dataset(key, index))
|
||||
for key, dataset in self.datasets.items()
|
||||
)
|
||||
|
||||
def size(self, index: int):
|
||||
"""
|
||||
Return an example's size as a float or tuple. Here we return the max
|
||||
across all underlying datasets. This value is used when filtering a
|
||||
dataset with max-positions.
|
||||
"""
|
||||
return max(
|
||||
dataset.size(self._map_index_to_dataset(key, index))
|
||||
for key, dataset in self.datasets.items()
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return all(
|
||||
getattr(dataset, "supports_prefetch", False)
|
||||
for dataset in self.datasets.values()
|
||||
)
|
||||
|
||||
def prefetch(self, indices):
|
||||
for key, dataset in self.datasets.items():
|
||||
dataset.prefetch(
|
||||
[self._map_index_to_dataset(key, index) for index in indices]
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_fetch_outside_dataloader(self):
|
||||
return all(
|
||||
self.datasets[key].supports_fetch_outside_dataloader
|
||||
for key in self.datasets
|
||||
)
|
||||
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,63 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from fairseq.data import Dictionary
|
||||
|
||||
|
||||
class EncoderLangtok(Enum):
|
||||
"""
|
||||
Prepend to the beginning of source sentence either the
|
||||
source or target language token. (src/tgt).
|
||||
"""
|
||||
|
||||
src = "src"
|
||||
tgt = "tgt"
|
||||
|
||||
|
||||
class LangTokSpec(Enum):
|
||||
main = "main"
|
||||
mono_dae = "mono_dae"
|
||||
|
||||
|
||||
class LangTokStyle(Enum):
|
||||
multilingual = "multilingual"
|
||||
mbart = "mbart"
|
||||
|
||||
|
||||
@torch.jit.export
|
||||
def get_lang_tok(
|
||||
lang: str, lang_tok_style: str, spec: str = LangTokSpec.main.value
|
||||
) -> str:
|
||||
# TOKEN_STYLES can't be defined outside this fn since it needs to be
|
||||
# TorchScriptable.
|
||||
TOKEN_STYLES: Dict[str, str] = {
|
||||
LangTokStyle.mbart.value: "[{}]",
|
||||
LangTokStyle.multilingual.value: "__{}__",
|
||||
}
|
||||
|
||||
if spec.endswith("dae"):
|
||||
lang = f"{lang}_dae"
|
||||
elif spec.endswith("mined"):
|
||||
lang = f"{lang}_mined"
|
||||
style = TOKEN_STYLES[lang_tok_style]
|
||||
return style.format(lang)
|
||||
|
||||
|
||||
def augment_dictionary(
|
||||
dictionary: Dictionary,
|
||||
language_list: List[str],
|
||||
lang_tok_style: str,
|
||||
langtoks_specs: Sequence[str] = (LangTokSpec.main.value,),
|
||||
extra_data: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
for spec in langtoks_specs:
|
||||
for language in language_list:
|
||||
dictionary.add_symbol(
|
||||
get_lang_tok(lang=language, lang_tok_style=lang_tok_style, spec=spec)
|
||||
)
|
||||
|
||||
if lang_tok_style == LangTokStyle.mbart.value or (
|
||||
extra_data is not None and LangTokSpec.mono_dae.value in extra_data
|
||||
):
|
||||
dictionary.add_symbol("<mask>")
|
||||
@@ -0,0 +1,468 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import datetime
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from bisect import bisect_right
|
||||
from collections import OrderedDict, defaultdict
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq.data import FairseqDataset, data_utils
|
||||
from fairseq.distributed import utils as distributed_utils
|
||||
|
||||
|
||||
def get_time_gap(s, e):
|
||||
return (
|
||||
datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)
|
||||
).__str__()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def default_virtual_size_func(datasets, ratios, max_scale_up=1.5):
|
||||
sizes = [len(d) for d in datasets]
|
||||
if ratios is None:
|
||||
return sum(sizes)
|
||||
largest_idx = np.argmax(sizes)
|
||||
largest_r = ratios[largest_idx]
|
||||
largest_s = sizes[largest_idx]
|
||||
# set virtual sizes relative to the largest dataset
|
||||
virtual_sizes = [(r / largest_r) * largest_s for r in ratios]
|
||||
vsize = sum(virtual_sizes)
|
||||
max_size = sum(sizes) * max_scale_up
|
||||
return int(vsize if vsize < max_size else max_size)
|
||||
|
||||
|
||||
class CollateFormat(Enum):
|
||||
single = 1
|
||||
ordered_dict = 2
|
||||
|
||||
|
||||
class SampledMultiDataset(FairseqDataset):
|
||||
"""Samples from multiple sub-datasets according to given sampling ratios.
|
||||
Args:
|
||||
datasets (
|
||||
List[~torch.utils.data.Dataset]
|
||||
or OrderedDict[str, ~torch.utils.data.Dataset]
|
||||
): datasets
|
||||
sampling_ratios (List[float]): list of probability of each dataset to be sampled
|
||||
(default: None, which corresponds to concatenating all dataset together).
|
||||
seed (int): RNG seed to use (default: 2).
|
||||
epoch (int): starting epoch number (default: 1).
|
||||
eval_key (str, optional): a key used at evaluation time that causes
|
||||
this instance to pass-through batches from *datasets[eval_key]*.
|
||||
collate_format (CollateFormat): collater output format, either CollateFormat.ordered_dict or
|
||||
CollateFormat.single (default: CollateFormat.single) where CollateFormat.single configures
|
||||
the collater to output batches of data mixed from all sub-datasets,
|
||||
and CollateFormat.ordered_dict configures the collater to output a dictionary of batches indexed by keys
|
||||
of sub-datasets.
|
||||
Note that not all sub-datasets will present in a single batch in both formats.
|
||||
virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func).
|
||||
split (str): the split of the data, e.g. 'train', 'valid' or 'test'.
|
||||
shared_collater (bool): whether or not to all sub-datasets have the same collater.
|
||||
shuffle (bool): whether or not to shuffle data (default: True).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datasets,
|
||||
sampling_ratios=None,
|
||||
seed=2,
|
||||
epoch=1,
|
||||
eval_key=None,
|
||||
collate_format=CollateFormat.single,
|
||||
virtual_size=default_virtual_size_func,
|
||||
split="",
|
||||
shared_collater=False,
|
||||
shuffle=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.shared_collater = shared_collater
|
||||
self.shuffle = shuffle
|
||||
|
||||
if isinstance(datasets, OrderedDict):
|
||||
self.keys = list(datasets.keys())
|
||||
datasets = list(datasets.values())
|
||||
elif isinstance(datasets, List):
|
||||
self.keys = list(range(len(datasets)))
|
||||
else:
|
||||
raise AssertionError()
|
||||
self.datasets = datasets
|
||||
self.split = split
|
||||
|
||||
self.eval_key = eval_key
|
||||
if self.eval_key is not None:
|
||||
self.collate_format = CollateFormat.single
|
||||
else:
|
||||
self.collate_format = collate_format
|
||||
|
||||
self.seed = seed
|
||||
self._cur_epoch = None
|
||||
|
||||
self.cumulated_sizes = None
|
||||
# self.datasets[k][self._cur_indices[i]] is the data item i in this sampled dataset
|
||||
# namely, data item i is sampled from the kth sub-dataset self.datasets[k]
|
||||
# where self.cumulated_sizes[k-1] <= i < self.cumulated_sizes[k]
|
||||
self._cur_indices = None
|
||||
|
||||
self._sizes = None
|
||||
self.virtual_size_per_dataset = None
|
||||
# caching properties
|
||||
self._reset_cached_properties()
|
||||
self.setup_sampling(sampling_ratios, virtual_size)
|
||||
self.set_epoch(epoch)
|
||||
|
||||
def _clean_if_not_none(self, var_list):
|
||||
for v in var_list:
|
||||
if v is not None:
|
||||
del v
|
||||
|
||||
def _reset_cached_properties(self):
|
||||
self._clean_if_not_none([self._sizes, self._cur_indices])
|
||||
self._sizes = None
|
||||
self._cur_indices = None
|
||||
|
||||
def setup_sampling(self, sample_ratios, virtual_size):
|
||||
sizes = [len(d) for d in self.datasets]
|
||||
if sample_ratios is None:
|
||||
# default back to concating datasets
|
||||
self.sample_ratios = None
|
||||
self.virtual_size = sum(sizes)
|
||||
else:
|
||||
if not isinstance(sample_ratios, np.ndarray):
|
||||
sample_ratios = np.array(sample_ratios)
|
||||
self.sample_ratios = sample_ratios
|
||||
virtual_size = (
|
||||
default_virtual_size_func if virtual_size is None else virtual_size
|
||||
)
|
||||
self.virtual_size = (
|
||||
virtual_size(self.datasets, self.sample_ratios)
|
||||
if callable(virtual_size)
|
||||
else virtual_size
|
||||
)
|
||||
|
||||
def adjust_sampling(self, epoch, sampling_ratios, virtual_size):
|
||||
if sampling_ratios is not None:
|
||||
sampling_ratios = self._sync_sample_ratios(sampling_ratios)
|
||||
self.setup_sampling(sampling_ratios, virtual_size)
|
||||
|
||||
def _sync_sample_ratios(self, ratios):
|
||||
# in case the ratios are not precisely the same across processes
|
||||
# also to ensure every procresses update the ratios in the same pace
|
||||
ratios = torch.DoubleTensor(ratios)
|
||||
if torch.distributed.is_initialized():
|
||||
if torch.cuda.is_available():
|
||||
distributed_utils.all_reduce(
|
||||
ratios.cuda(), group=distributed_utils.get_data_parallel_group()
|
||||
)
|
||||
else:
|
||||
distributed_utils.all_reduce(
|
||||
ratios, group=distributed_utils.get_data_parallel_group()
|
||||
)
|
||||
ret = ratios.cpu()
|
||||
ret = ret.numpy()
|
||||
return ret
|
||||
|
||||
def random_choice_in_dataset(self, rng, dataset, choice_size):
|
||||
if hasattr(dataset, "random_choice_in_dataset"):
|
||||
return dataset.random_choice_in_dataset(rng, choice_size)
|
||||
dataset_size = len(dataset)
|
||||
return rng.choice(
|
||||
dataset_size, choice_size, replace=(choice_size > dataset_size)
|
||||
)
|
||||
|
||||
def get_virtual_indices(self, rng, datasets, sample_ratios, virtual_size):
|
||||
def get_counts(sample_ratios):
|
||||
counts = np.array([virtual_size * r for r in sample_ratios], dtype=np.int64)
|
||||
diff = virtual_size - counts.sum()
|
||||
assert diff >= 0
|
||||
# due to round-offs, the size might not match the desired sizes
|
||||
if diff > 0:
|
||||
dataset_indices = rng.choice(
|
||||
len(sample_ratios), size=diff, p=sample_ratios
|
||||
)
|
||||
for i in dataset_indices:
|
||||
counts[i] += 1
|
||||
return counts
|
||||
|
||||
def get_in_dataset_indices(datasets, sizes, sample_ratios):
|
||||
counts = get_counts(sample_ratios)
|
||||
# uniformally sample desired counts for each dataset
|
||||
# if the desired counts are large, sample with replacement:
|
||||
indices = [
|
||||
self.random_choice_in_dataset(rng, d, c)
|
||||
for c, d in zip(counts, datasets)
|
||||
]
|
||||
return indices
|
||||
|
||||
sizes = [len(d) for d in datasets]
|
||||
if sample_ratios is None:
|
||||
# default back to concating datasets
|
||||
in_dataset_indices = [list(range(s)) for s in sizes]
|
||||
virtual_sizes_per_dataset = sizes
|
||||
else:
|
||||
ratios = sample_ratios / sample_ratios.sum()
|
||||
in_dataset_indices = get_in_dataset_indices(datasets, sizes, ratios)
|
||||
virtual_sizes_per_dataset = [len(d) for d in in_dataset_indices]
|
||||
virtual_sizes_per_dataset = np.array(virtual_sizes_per_dataset, np.int64)
|
||||
cumulative_sizes = np.cumsum(virtual_sizes_per_dataset)
|
||||
assert sum(virtual_sizes_per_dataset) == virtual_size
|
||||
assert cumulative_sizes[-1] == virtual_size
|
||||
if virtual_size < sum(sizes):
|
||||
logger.warning(
|
||||
f"virtual data size ({virtual_size}) is less than real data size ({sum(sizes)})."
|
||||
" If virtual size << real data size, there could be data coverage issue."
|
||||
)
|
||||
in_dataset_indices = np.hstack(in_dataset_indices)
|
||||
return in_dataset_indices, cumulative_sizes, virtual_sizes_per_dataset
|
||||
|
||||
def _get_dataset_and_index(self, index):
|
||||
i = bisect_right(self.cumulated_sizes, index)
|
||||
return i, self._cur_indices[index]
|
||||
|
||||
def __getitem__(self, index):
|
||||
# self.__getitem__(index) returns self.datasets[k][self._cur_indices[index]]
|
||||
# where k satisfies self.cumulated_sizes[k - 1] <= k < self.cumulated_sizes[k]
|
||||
ds_idx, ds_sample_idx = self._get_dataset_and_index(index)
|
||||
ret = (ds_idx, self.datasets[ds_idx][ds_sample_idx])
|
||||
return ret
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.sizes[index].max()
|
||||
|
||||
def num_tokens_vec(self, indices):
|
||||
sizes_vec = self.sizes[np.array(indices)]
|
||||
# max across all dimensions but first one
|
||||
return np.amax(sizes_vec, axis=tuple(range(1, len(sizes_vec.shape))))
|
||||
|
||||
def size(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
def __len__(self):
|
||||
return self.virtual_size
|
||||
|
||||
def collater(self, samples, **extra_args):
|
||||
"""Merge a list of samples to form a mini-batch."""
|
||||
if len(samples) == 0:
|
||||
return None
|
||||
if self.collate_format == "ordered_dict":
|
||||
collect_samples = [[] for _ in range(len(self.datasets))]
|
||||
for (i, sample) in samples:
|
||||
collect_samples[i].append(sample)
|
||||
batch = OrderedDict(
|
||||
[
|
||||
(self.keys[i], dataset.collater(collect_samples[i]))
|
||||
for i, (key, dataset) in enumerate(zip(self.keys, self.datasets))
|
||||
if len(collect_samples[i]) > 0
|
||||
]
|
||||
)
|
||||
elif self.shared_collater:
|
||||
batch = self.datasets[0].collater([s for _, s in samples])
|
||||
else:
|
||||
samples_dict = defaultdict(list)
|
||||
pad_to_length = (
|
||||
defaultdict(int)
|
||||
if "pad_to_length" not in extra_args
|
||||
else extra_args["pad_to_length"]
|
||||
)
|
||||
for ds_idx, s in samples:
|
||||
pad_to_length["source"] = max(
|
||||
pad_to_length["source"], s["source"].size(0)
|
||||
)
|
||||
if s["target"] is not None:
|
||||
pad_to_length["target"] = max(
|
||||
pad_to_length["target"], s["target"].size(0)
|
||||
)
|
||||
samples_dict[ds_idx].append(s)
|
||||
batches = [
|
||||
self.datasets[i].collater(samples_dict[i], pad_to_length=pad_to_length)
|
||||
for i in range(len(self.datasets))
|
||||
if len(samples_dict[i]) > 0
|
||||
]
|
||||
|
||||
def straight_data(tensors):
|
||||
batch = torch.cat(tensors, dim=0)
|
||||
return batch
|
||||
|
||||
src_lengths = straight_data(
|
||||
[b["net_input"]["src_lengths"] for b in batches]
|
||||
)
|
||||
src_lengths, sort_order = src_lengths.sort(descending=True)
|
||||
|
||||
def straight_order(tensors):
|
||||
batch = straight_data(tensors)
|
||||
return batch.index_select(0, sort_order)
|
||||
|
||||
batch = {
|
||||
"id": straight_order([b["id"] for b in batches]),
|
||||
"nsentences": sum(b["nsentences"] for b in batches),
|
||||
"ntokens": sum(b["ntokens"] for b in batches),
|
||||
"net_input": {
|
||||
"src_tokens": straight_order(
|
||||
[b["net_input"]["src_tokens"] for b in batches]
|
||||
),
|
||||
"src_lengths": src_lengths,
|
||||
},
|
||||
"target": straight_order([b["target"] for b in batches])
|
||||
if batches[0]["target"] is not None
|
||||
else None,
|
||||
}
|
||||
if "prev_output_tokens" in batches[0]["net_input"]:
|
||||
batch["net_input"]["prev_output_tokens"] = straight_order(
|
||||
[b["net_input"]["prev_output_tokens"] for b in batches]
|
||||
)
|
||||
if "src_lang_id" in batches[0]["net_input"]:
|
||||
batch["net_input"]["src_lang_id"] = straight_order(
|
||||
[b["net_input"]["src_lang_id"] for b in batches]
|
||||
)
|
||||
if "tgt_lang_id" in batches[0]:
|
||||
batch["tgt_lang_id"] = straight_order(
|
||||
[b["tgt_lang_id"] for b in batches]
|
||||
)
|
||||
return batch
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
if self._sizes is not None:
|
||||
return self._sizes
|
||||
start_time = time.time()
|
||||
in_sub_dataset_indices = [
|
||||
self._cur_indices[
|
||||
0 if i == 0 else self.cumulated_sizes[i - 1] : self.cumulated_sizes[i]
|
||||
]
|
||||
for i in range(len(self.datasets))
|
||||
]
|
||||
sub_dataset_sizes = [
|
||||
d.sizes[indices]
|
||||
for d, indices in zip(self.datasets, in_sub_dataset_indices)
|
||||
]
|
||||
self._sizes = np.vstack(sub_dataset_sizes)
|
||||
logger.info(f"sizes() calling time: {get_time_gap(start_time, time.time())}")
|
||||
return self._sizes
|
||||
|
||||
def ordered_indices(self):
|
||||
if self.shuffle:
|
||||
indices = np.random.permutation(len(self))
|
||||
else:
|
||||
indices = np.arange(len(self))
|
||||
|
||||
sizes = self.sizes
|
||||
tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
|
||||
src_sizes = (
|
||||
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
|
||||
)
|
||||
|
||||
# sort by target length, then source length
|
||||
if tgt_sizes is not None:
|
||||
indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
|
||||
sort_indices = indices[np.argsort(src_sizes[indices], kind="mergesort")]
|
||||
return sort_indices
|
||||
|
||||
def prefetch(self, indices):
|
||||
prefetch_indices = [[] for _ in range(len(self.datasets))]
|
||||
for i in indices:
|
||||
ds_idx, ds_sample_idx = self._get_dataset_and_index(i)
|
||||
prefetch_indices[ds_idx].append(ds_sample_idx)
|
||||
for i in range(len(prefetch_indices)):
|
||||
self.datasets[i].prefetch(prefetch_indices[i])
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return False
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
super().set_epoch(epoch)
|
||||
if epoch == self._cur_epoch:
|
||||
# re-enter so return
|
||||
return
|
||||
for d in self.datasets:
|
||||
if hasattr(d, "set_epoch"):
|
||||
d.set_epoch(epoch)
|
||||
self._cur_epoch = epoch
|
||||
self._establish_virtual_datasets()
|
||||
|
||||
def _establish_virtual_datasets(self):
|
||||
if self.sample_ratios is None and self._cur_indices is not None:
|
||||
# not a samping dataset, no need to resample if indices are already established
|
||||
return
|
||||
self._reset_cached_properties()
|
||||
|
||||
start_time = time.time()
|
||||
# Generate a weighted sample of indices as a function of the
|
||||
# random seed and the current epoch.
|
||||
rng = np.random.RandomState(
|
||||
[
|
||||
int(
|
||||
hashlib.sha1(
|
||||
str(self.__class__.__name__).encode("utf-8")
|
||||
).hexdigest(),
|
||||
16,
|
||||
)
|
||||
% (2**32),
|
||||
self.seed % (2**32), # global seed
|
||||
self._cur_epoch, # epoch index,
|
||||
]
|
||||
)
|
||||
self._clean_if_not_none(
|
||||
[self.cumulated_sizes, self.virtual_size_per_dataset, self._sizes]
|
||||
)
|
||||
self._sizes = None
|
||||
|
||||
indices, cumulated_sizes, virtual_size_per_dataset = self.get_virtual_indices(
|
||||
rng, self.datasets, self.sample_ratios, self.virtual_size
|
||||
)
|
||||
self._cur_indices = indices
|
||||
self.cumulated_sizes = cumulated_sizes
|
||||
self.virtual_size_per_dataset = virtual_size_per_dataset
|
||||
|
||||
raw_sizes = [len(d) for d in self.datasets]
|
||||
sampled_sizes = self.virtual_size_per_dataset
|
||||
logger.info(
|
||||
f"[{self.split}] Raw sizes: {str(dict(zip(self.keys, raw_sizes)))}; "
|
||||
f"raw total size: {sum(raw_sizes)}"
|
||||
)
|
||||
logger.info(
|
||||
f"[{self.split}] Resampled sizes: {str(dict(zip(self.keys, sampled_sizes)))}; "
|
||||
f"resampled total size: {sum(sampled_sizes)}"
|
||||
)
|
||||
if self.sample_ratios is not None:
|
||||
logger.info(
|
||||
f"[{self.split}] Upsampling ratios: {str(dict(zip(self.keys, self.sample_ratios)))}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"[{self.split}] A concat dataset")
|
||||
logger.info(
|
||||
f"[{self.split}] virtual dataset established time: {get_time_gap(start_time, time.time())}"
|
||||
)
|
||||
|
||||
def filter_indices_by_size(self, indices, max_sizes):
|
||||
"""Filter a list of sample indices. Remove those that are longer
|
||||
than specified in max_sizes.
|
||||
|
||||
Args:
|
||||
indices (np.array): original array of sample indices
|
||||
max_sizes (int or list[int] or tuple[int]): max sample size,
|
||||
can be defined separately for src and tgt (then list or tuple)
|
||||
|
||||
Returns:
|
||||
np.array: filtered sample array
|
||||
list: list of removed indices
|
||||
"""
|
||||
sizes = self.sizes
|
||||
tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
|
||||
src_sizes = (
|
||||
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
|
||||
)
|
||||
|
||||
return data_utils.filter_paired_dataset_indices_by_size(
|
||||
src_sizes, tgt_sizes, indices, max_sizes
|
||||
)
|
||||
@@ -0,0 +1,199 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fairseq.data import SampledMultiDataset
|
||||
|
||||
from .sampled_multi_dataset import CollateFormat, default_virtual_size_func
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SampledMultiEpochDataset(SampledMultiDataset):
|
||||
"""Samples from multiple sub-datasets according to sampling ratios
|
||||
using virtual epoch sizes to speed up dataloading.
|
||||
Args:
|
||||
datasets (
|
||||
List[~torch.utils.data.Dataset]
|
||||
or OrderedDict[str, ~torch.utils.data.Dataset]
|
||||
): datasets
|
||||
sampling_ratios (List[float]): list of probability of each dataset to be sampled
|
||||
(default: None, which corresponds to concating all dataset together).
|
||||
seed (int): RNG seed to use (default: 2).
|
||||
epoch (int): starting epoch number (default: 1).
|
||||
eval_key (str, optional): a key used at evaluation time that causes
|
||||
this instance to pass-through batches from *datasets[eval_key]*.
|
||||
collate_format (CollateFormat): collater output format, either CollateFormat.ordered_dict or
|
||||
CollateFormat.single (default: CollateFormat.single) where CollateFormat.single configures
|
||||
the collater to output batches of data mixed from all sub-datasets,
|
||||
and CollateFormat.ordered_dict configures the collater to output a dictionary of batches indexed by keys
|
||||
of sub-datasets.
|
||||
Note that not all sub-datasets will present in a single batch in both formats.
|
||||
virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func).
|
||||
split (str): the split of the data, e.g. 'train', 'valid' or 'test'.
|
||||
virtual_epoch_size (int): virtual epoch size, the dataset will go through the data by
|
||||
this virtual epoch size one by one to speed up data loading, e.g. indicing and filtering
|
||||
can be performed whenever a virtual epoch is loaded without waiting for the whole dataset to be loaded.
|
||||
shared_collater (bool): whether or not to all sub-datasets have the same collater.
|
||||
shard_epoch (int): the real epoch number for shard selection.
|
||||
shuffle (bool): whether or not to shuffle data (default: True).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datasets,
|
||||
sampling_ratios=None,
|
||||
seed=2,
|
||||
epoch=1,
|
||||
eval_key=None,
|
||||
collate_format=CollateFormat.single,
|
||||
virtual_size=default_virtual_size_func,
|
||||
split="",
|
||||
virtual_epoch_size=None,
|
||||
shared_collater=False,
|
||||
shard_epoch=1,
|
||||
shuffle=True,
|
||||
):
|
||||
self.virtual_epoch_size = virtual_epoch_size
|
||||
self._current_epoch_start_index = None
|
||||
self._random_global_indices = None
|
||||
self.shard_epoch = shard_epoch if shard_epoch is not None else 1
|
||||
self.load_next_shard = None
|
||||
self._epoch_sizes = None
|
||||
super().__init__(
|
||||
datasets=datasets,
|
||||
sampling_ratios=sampling_ratios,
|
||||
seed=seed,
|
||||
epoch=epoch,
|
||||
eval_key=eval_key,
|
||||
collate_format=collate_format,
|
||||
virtual_size=virtual_size,
|
||||
split=split,
|
||||
shared_collater=shared_collater,
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
def _setup(self, epoch):
|
||||
self.virtual_epoch_size = (
|
||||
self.virtual_epoch_size
|
||||
if self.virtual_epoch_size is not None
|
||||
else self.virtual_size
|
||||
)
|
||||
if self.virtual_epoch_size > self.virtual_size:
|
||||
logger.warning(
|
||||
f"virtual epoch size {self.virtual_epoch_size} "
|
||||
f"is greater than virtual dataset size {self.virtual_size}"
|
||||
)
|
||||
self.virtual_epoch_size = self.virtual_size
|
||||
self.num_virtual_epochs = math.ceil(self.virtual_size / self.virtual_epoch_size)
|
||||
self._current_epoch_start_index = self._get_epoch_start_index(epoch)
|
||||
logger.info(
|
||||
f"virtual epoch size {self.virtual_epoch_size}; virtual dataset size {self.virtual_size}"
|
||||
)
|
||||
|
||||
def _map_epoch_index_to_global(self, index):
|
||||
index = self._current_epoch_start_index + index
|
||||
# add randomness
|
||||
return self._random_global_indices[index]
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
if self._epoch_sizes is not None:
|
||||
return self._epoch_sizes
|
||||
_sizes = super().sizes
|
||||
indices = self._random_global_indices[
|
||||
self._current_epoch_start_index : self._current_epoch_start_index
|
||||
+ len(self)
|
||||
]
|
||||
self._epoch_sizes = _sizes[indices]
|
||||
# del super()._sizes to save memory
|
||||
del self._sizes
|
||||
self._sizes = None
|
||||
return self._epoch_sizes
|
||||
|
||||
def _get_dataset_and_index(self, index):
|
||||
i = self._map_epoch_index_to_global(index)
|
||||
return super()._get_dataset_and_index(i)
|
||||
|
||||
def __len__(self):
|
||||
return (
|
||||
self.virtual_epoch_size
|
||||
if self._current_epoch_start_index + self.virtual_epoch_size
|
||||
< self.virtual_size
|
||||
else self.virtual_size - self._current_epoch_start_index
|
||||
)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
if self._current_epoch_start_index is None:
|
||||
# initializing epoch idnices of a virtual dataset
|
||||
self._setup(epoch)
|
||||
self._next_virtual_epoch(epoch)
|
||||
else:
|
||||
# working on already intialized epoch indices
|
||||
if epoch == self._cur_epoch:
|
||||
# re-enter so return
|
||||
return
|
||||
self._next_virtual_epoch(epoch)
|
||||
|
||||
def _get_epoch_start_index(self, epoch):
|
||||
assert epoch >= 1 # fairseq is using 1-based epoch everywhere
|
||||
return ((epoch - 1) % self.num_virtual_epochs) * self.virtual_epoch_size
|
||||
|
||||
def _next_global_indices(self, epoch):
|
||||
rng = np.random.RandomState(
|
||||
[
|
||||
int(
|
||||
hashlib.sha1(
|
||||
str(self.__class__.__name__).encode("utf-8")
|
||||
).hexdigest(),
|
||||
16,
|
||||
)
|
||||
% (2**32),
|
||||
self.seed % (2**32), # global seed
|
||||
epoch, # epoch index,
|
||||
]
|
||||
)
|
||||
del self._random_global_indices
|
||||
self._random_global_indices = rng.choice(
|
||||
self.virtual_size, self.virtual_size, replace=False
|
||||
)
|
||||
if self.load_next_shard is None:
|
||||
self.load_next_shard = False
|
||||
else:
|
||||
# increase shard epoch for next loading
|
||||
self.shard_epoch += 1
|
||||
self.load_next_shard = True
|
||||
logger.info(
|
||||
"to load next epoch/shard in next load_dataset: "
|
||||
f"epoch={epoch}/shard_epoch={self.shard_epoch}"
|
||||
)
|
||||
|
||||
def _next_virtual_epoch(self, epoch):
|
||||
index = self._get_epoch_start_index(epoch)
|
||||
if index == 0 or self._random_global_indices is None:
|
||||
# need to start from the beginning,
|
||||
# so call super().set_epoch(epoch) to establish the global virtual indices
|
||||
logger.info(
|
||||
"establishing a new set of global virtual indices for "
|
||||
f"epoch={epoch}/shard_epoch={self.shard_epoch}"
|
||||
)
|
||||
super().set_epoch(epoch)
|
||||
self._next_global_indices(epoch)
|
||||
else:
|
||||
self._cur_epoch = epoch
|
||||
|
||||
# reset cache sizes and ordered_indices for the epoch after moving to a new epoch
|
||||
self._clean_if_not_none(
|
||||
[
|
||||
self._epoch_sizes,
|
||||
]
|
||||
)
|
||||
self._epoch_sizes = None
|
||||
self._current_epoch_start_index = index
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def uniform(dataset_sizes: List[int]):
|
||||
return [1.0] * len(dataset_sizes)
|
||||
|
||||
|
||||
def temperature_sampling(dataset_sizes, temp):
|
||||
total_size = sum(dataset_sizes)
|
||||
return [(size / total_size) ** (1.0 / temp) for size in dataset_sizes]
|
||||
|
||||
|
||||
def make_temperature_sampling(temp=1.0):
|
||||
def sampling_func(dataset_sizes):
|
||||
return temperature_sampling(dataset_sizes, temp)
|
||||
|
||||
return sampling_func
|
||||
|
||||
|
||||
def make_ratio_sampling(ratios):
|
||||
def sampling_func(dataset_sizes):
|
||||
return ratios
|
||||
|
||||
return sampling_func
|
||||
|
||||
|
||||
class SamplingMethod:
|
||||
@staticmethod
|
||||
def add_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--sampling-method",
|
||||
choices=[
|
||||
"uniform",
|
||||
"temperature",
|
||||
"concat",
|
||||
"RoundRobin",
|
||||
],
|
||||
type=str,
|
||||
default="concat",
|
||||
help="The method to sample data per language pairs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampling-temperature",
|
||||
default=1.5,
|
||||
type=float,
|
||||
help="only work with --sampling-method temperature",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_sampler(args, task):
|
||||
return SamplingMethod(args, task)
|
||||
|
||||
def __init__(self, args, task):
|
||||
self.args = args
|
||||
self.task = task
|
||||
|
||||
def is_adaptive(self):
|
||||
return False
|
||||
|
||||
def sampling_method_selector(self):
|
||||
args = self.args
|
||||
logger.info(f"selected sampler: {args.sampling_method}")
|
||||
if args.sampling_method == "uniform":
|
||||
return uniform
|
||||
elif args.sampling_method == "temperature" or self.is_adaptive():
|
||||
return make_temperature_sampling(float(args.sampling_temperature))
|
||||
else:
|
||||
# default to concating all data set together
|
||||
return None
|
||||
@@ -0,0 +1,125 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
from . import FairseqDataset
|
||||
|
||||
|
||||
def _flatten(dico, prefix=None):
|
||||
"""Flatten a nested dictionary."""
|
||||
new_dico = OrderedDict()
|
||||
if isinstance(dico, dict):
|
||||
prefix = prefix + "." if prefix is not None else ""
|
||||
for k, v in dico.items():
|
||||
if v is None:
|
||||
continue
|
||||
new_dico.update(_flatten(v, prefix + k))
|
||||
elif isinstance(dico, list):
|
||||
for i, v in enumerate(dico):
|
||||
new_dico.update(_flatten(v, prefix + ".[" + str(i) + "]"))
|
||||
else:
|
||||
new_dico = OrderedDict({prefix: dico})
|
||||
return new_dico
|
||||
|
||||
|
||||
def _unflatten(dico):
|
||||
"""Unflatten a flattened dictionary into a nested dictionary."""
|
||||
new_dico = OrderedDict()
|
||||
for full_k, v in dico.items():
|
||||
full_k = full_k.split(".")
|
||||
node = new_dico
|
||||
for k in full_k[:-1]:
|
||||
if k.startswith("[") and k.endswith("]"):
|
||||
k = int(k[1:-1])
|
||||
if k not in node:
|
||||
node[k] = OrderedDict()
|
||||
node = node[k]
|
||||
node[full_k[-1]] = v
|
||||
return new_dico
|
||||
|
||||
|
||||
class NestedDictionaryDataset(FairseqDataset):
|
||||
def __init__(self, defn, sizes=None):
|
||||
super().__init__()
|
||||
self.defn = _flatten(defn)
|
||||
self.sizes = [sizes] if not isinstance(sizes, (list, tuple)) else sizes
|
||||
|
||||
first = None
|
||||
for v in self.defn.values():
|
||||
if not isinstance(
|
||||
v,
|
||||
(
|
||||
FairseqDataset,
|
||||
torch.utils.data.Dataset,
|
||||
),
|
||||
):
|
||||
raise ValueError("Expected Dataset but found: {}".format(v.__class__))
|
||||
first = first or v
|
||||
if len(v) > 0:
|
||||
assert len(v) == len(first), "dataset lengths must match"
|
||||
|
||||
self._len = len(first)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return OrderedDict((k, ds[index]) for k, ds in self.defn.items())
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def collater(self, samples):
|
||||
"""Merge a list of samples to form a mini-batch.
|
||||
|
||||
Args:
|
||||
samples (List[dict]): samples to collate
|
||||
|
||||
Returns:
|
||||
dict: a mini-batch suitable for forwarding with a Model
|
||||
"""
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
sample = OrderedDict()
|
||||
for k, ds in self.defn.items():
|
||||
try:
|
||||
sample[k] = ds.collater([s[k] for s in samples])
|
||||
except NotImplementedError:
|
||||
sample[k] = default_collate([s[k] for s in samples])
|
||||
return _unflatten(sample)
|
||||
|
||||
def num_tokens(self, index):
|
||||
"""Return the number of tokens in a sample. This value is used to
|
||||
enforce ``--max-tokens`` during batching."""
|
||||
return max(s[index] for s in self.sizes)
|
||||
|
||||
def size(self, index):
|
||||
"""Return an example's size as a float or tuple. This value is used when
|
||||
filtering a dataset with ``--max-positions``."""
|
||||
if len(self.sizes) == 1:
|
||||
return self.sizes[0][index]
|
||||
else:
|
||||
return (s[index] for s in self.sizes)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
"""Whether this dataset supports prefetching."""
|
||||
return any(ds.supports_prefetch for ds in self.defn.values())
|
||||
|
||||
def prefetch(self, indices):
|
||||
"""Prefetch the data required for this epoch."""
|
||||
for ds in self.defn.values():
|
||||
if getattr(ds, "supports_prefetch", False):
|
||||
ds.prefetch(indices)
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return all(ds.can_reuse_epoch_itr_across_epochs for ds in self.defn.values())
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
super().set_epoch(epoch)
|
||||
for ds in self.defn.values():
|
||||
ds.set_epoch(epoch)
|
||||
334
modules/voice_conversion/fairseq/data/noising.py
Normal file
334
modules/voice_conversion/fairseq/data/noising.py
Normal file
@@ -0,0 +1,334 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import data_utils
|
||||
|
||||
|
||||
class WordNoising(object):
|
||||
"""Generate a noisy version of a sentence, without changing words themselves."""
|
||||
|
||||
def __init__(self, dictionary, bpe_cont_marker="@@", bpe_end_marker=None):
|
||||
self.dictionary = dictionary
|
||||
self.bpe_end = None
|
||||
if bpe_cont_marker:
|
||||
self.bpe_end = np.array(
|
||||
[
|
||||
not self.dictionary[i].endswith(bpe_cont_marker)
|
||||
for i in range(len(self.dictionary))
|
||||
]
|
||||
)
|
||||
elif bpe_end_marker:
|
||||
self.bpe_end = np.array(
|
||||
[
|
||||
self.dictionary[i].endswith(bpe_end_marker)
|
||||
for i in range(len(self.dictionary))
|
||||
]
|
||||
)
|
||||
|
||||
self.get_word_idx = (
|
||||
self._get_bpe_word_idx if self.bpe_end is not None else self._get_token_idx
|
||||
)
|
||||
|
||||
def noising(self, x, lengths, noising_prob=0.0):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_bpe_word_idx(self, x):
|
||||
"""
|
||||
Given a list of BPE tokens, for every index in the tokens list,
|
||||
return the index of the word grouping that it belongs to.
|
||||
For example, for input x corresponding to ["how", "are", "y@@", "ou"],
|
||||
return [[0], [1], [2], [2]].
|
||||
"""
|
||||
# x: (T x B)
|
||||
bpe_end = self.bpe_end[x]
|
||||
|
||||
if x.size(0) == 1 and x.size(1) == 1:
|
||||
# Special case when we only have one word in x. If x = [[N]],
|
||||
# bpe_end is a scalar (bool) instead of a 2-dim array of bools,
|
||||
# which makes the sum operation below fail.
|
||||
return np.array([[0]])
|
||||
|
||||
# do a reduce front sum to generate word ids
|
||||
word_idx = bpe_end[::-1].cumsum(0)[::-1]
|
||||
word_idx = word_idx.max(0)[None, :] - word_idx
|
||||
return word_idx
|
||||
|
||||
def _get_token_idx(self, x):
|
||||
"""
|
||||
This is to extend noising functions to be able to apply to non-bpe
|
||||
tokens, e.g. word or characters.
|
||||
"""
|
||||
x = torch.t(x)
|
||||
word_idx = np.array([range(len(x_i)) for x_i in x])
|
||||
return np.transpose(word_idx)
|
||||
|
||||
|
||||
class WordDropout(WordNoising):
|
||||
"""Randomly drop input words. If not passing blank_idx (default is None),
|
||||
then dropped words will be removed. Otherwise, it will be replaced by the
|
||||
blank_idx."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dictionary,
|
||||
default_dropout_prob=0.1,
|
||||
bpe_cont_marker="@@",
|
||||
bpe_end_marker=None,
|
||||
):
|
||||
super().__init__(dictionary, bpe_cont_marker, bpe_end_marker)
|
||||
self.default_dropout_prob = default_dropout_prob
|
||||
|
||||
def noising(self, x, lengths, dropout_prob=None, blank_idx=None):
|
||||
if dropout_prob is None:
|
||||
dropout_prob = self.default_dropout_prob
|
||||
# x: (T x B), lengths: B
|
||||
if dropout_prob == 0:
|
||||
return x, lengths
|
||||
|
||||
assert 0 < dropout_prob < 1
|
||||
|
||||
# be sure to drop entire words
|
||||
word_idx = self.get_word_idx(x)
|
||||
sentences = []
|
||||
modified_lengths = []
|
||||
for i in range(lengths.size(0)):
|
||||
# Since dropout probabilities need to apply over non-pad tokens,
|
||||
# it is not trivial to generate the keep mask without consider
|
||||
# input lengths; otherwise, this could be done outside the loop
|
||||
|
||||
# We want to drop whole words based on word_idx grouping
|
||||
num_words = max(word_idx[:, i]) + 1
|
||||
|
||||
# ith example: [x0, x1, ..., eos, pad, ..., pad]
|
||||
# We should only generate keep probs for non-EOS tokens. Thus if the
|
||||
# input sentence ends in EOS, the last word idx is not included in
|
||||
# the dropout mask generation and we append True to always keep EOS.
|
||||
# Otherwise, just generate the dropout mask for all word idx
|
||||
# positions.
|
||||
has_eos = x[lengths[i] - 1, i] == self.dictionary.eos()
|
||||
if has_eos: # has eos?
|
||||
keep = np.random.rand(num_words - 1) >= dropout_prob
|
||||
keep = np.append(keep, [True]) # keep EOS symbol
|
||||
else:
|
||||
keep = np.random.rand(num_words) >= dropout_prob
|
||||
|
||||
words = x[: lengths[i], i].tolist()
|
||||
|
||||
# TODO: speed up the following loop
|
||||
# drop words from the input according to keep
|
||||
new_s = [
|
||||
w if keep[word_idx[j, i]] else blank_idx for j, w in enumerate(words)
|
||||
]
|
||||
new_s = [w for w in new_s if w is not None]
|
||||
# we need to have at least one word in the sentence (more than the
|
||||
# start / end sentence symbols)
|
||||
if len(new_s) <= 1:
|
||||
# insert at beginning in case the only token left is EOS
|
||||
# EOS should be at end of list.
|
||||
new_s.insert(0, words[np.random.randint(0, len(words))])
|
||||
assert len(new_s) >= 1 and (
|
||||
not has_eos # Either don't have EOS at end or last token is EOS
|
||||
or (len(new_s) >= 2 and new_s[-1] == self.dictionary.eos())
|
||||
), "New sentence is invalid."
|
||||
sentences.append(new_s)
|
||||
modified_lengths.append(len(new_s))
|
||||
# re-construct input
|
||||
modified_lengths = torch.LongTensor(modified_lengths)
|
||||
modified_x = torch.LongTensor(
|
||||
modified_lengths.max(), modified_lengths.size(0)
|
||||
).fill_(self.dictionary.pad())
|
||||
for i in range(modified_lengths.size(0)):
|
||||
modified_x[: modified_lengths[i], i].copy_(torch.LongTensor(sentences[i]))
|
||||
|
||||
return modified_x, modified_lengths
|
||||
|
||||
|
||||
class WordShuffle(WordNoising):
|
||||
"""Shuffle words by no more than k positions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dictionary,
|
||||
default_max_shuffle_distance=3,
|
||||
bpe_cont_marker="@@",
|
||||
bpe_end_marker=None,
|
||||
):
|
||||
super().__init__(dictionary, bpe_cont_marker, bpe_end_marker)
|
||||
self.default_max_shuffle_distance = 3
|
||||
|
||||
def noising(self, x, lengths, max_shuffle_distance=None):
|
||||
if max_shuffle_distance is None:
|
||||
max_shuffle_distance = self.default_max_shuffle_distance
|
||||
# x: (T x B), lengths: B
|
||||
if max_shuffle_distance == 0:
|
||||
return x, lengths
|
||||
|
||||
# max_shuffle_distance < 1 will return the same sequence
|
||||
assert max_shuffle_distance > 1
|
||||
|
||||
# define noise word scores
|
||||
noise = np.random.uniform(
|
||||
0,
|
||||
max_shuffle_distance,
|
||||
size=(x.size(0), x.size(1)),
|
||||
)
|
||||
noise[0] = -1 # do not move start sentence symbol
|
||||
# be sure to shuffle entire words
|
||||
word_idx = self.get_word_idx(x)
|
||||
x2 = x.clone()
|
||||
for i in range(lengths.size(0)):
|
||||
length_no_eos = lengths[i]
|
||||
if x[lengths[i] - 1, i] == self.dictionary.eos():
|
||||
length_no_eos = lengths[i] - 1
|
||||
# generate a random permutation
|
||||
scores = word_idx[:length_no_eos, i] + noise[word_idx[:length_no_eos, i], i]
|
||||
# ensure no reordering inside a word
|
||||
scores += 1e-6 * np.arange(length_no_eos.item())
|
||||
permutation = scores.argsort()
|
||||
# shuffle words
|
||||
x2[:length_no_eos, i].copy_(
|
||||
x2[:length_no_eos, i][torch.from_numpy(permutation)]
|
||||
)
|
||||
return x2, lengths
|
||||
|
||||
|
||||
class UnsupervisedMTNoising(WordNoising):
|
||||
"""
|
||||
Implements the default configuration for noising in UnsupervisedMT
|
||||
(github.com/facebookresearch/UnsupervisedMT)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dictionary,
|
||||
max_word_shuffle_distance,
|
||||
word_dropout_prob,
|
||||
word_blanking_prob,
|
||||
bpe_cont_marker="@@",
|
||||
bpe_end_marker=None,
|
||||
):
|
||||
super().__init__(dictionary)
|
||||
self.max_word_shuffle_distance = max_word_shuffle_distance
|
||||
self.word_dropout_prob = word_dropout_prob
|
||||
self.word_blanking_prob = word_blanking_prob
|
||||
|
||||
self.word_dropout = WordDropout(
|
||||
dictionary=dictionary,
|
||||
bpe_cont_marker=bpe_cont_marker,
|
||||
bpe_end_marker=bpe_end_marker,
|
||||
)
|
||||
self.word_shuffle = WordShuffle(
|
||||
dictionary=dictionary,
|
||||
bpe_cont_marker=bpe_cont_marker,
|
||||
bpe_end_marker=bpe_end_marker,
|
||||
)
|
||||
|
||||
def noising(self, x, lengths):
|
||||
# 1. Word Shuffle
|
||||
noisy_src_tokens, noisy_src_lengths = self.word_shuffle.noising(
|
||||
x=x,
|
||||
lengths=lengths,
|
||||
max_shuffle_distance=self.max_word_shuffle_distance,
|
||||
)
|
||||
# 2. Word Dropout
|
||||
noisy_src_tokens, noisy_src_lengths = self.word_dropout.noising(
|
||||
x=noisy_src_tokens,
|
||||
lengths=noisy_src_lengths,
|
||||
dropout_prob=self.word_dropout_prob,
|
||||
)
|
||||
# 3. Word Blanking
|
||||
noisy_src_tokens, noisy_src_lengths = self.word_dropout.noising(
|
||||
x=noisy_src_tokens,
|
||||
lengths=noisy_src_lengths,
|
||||
dropout_prob=self.word_blanking_prob,
|
||||
blank_idx=self.dictionary.unk(),
|
||||
)
|
||||
|
||||
return noisy_src_tokens
|
||||
|
||||
|
||||
class NoisingDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
src_dataset,
|
||||
src_dict,
|
||||
seed,
|
||||
noiser=None,
|
||||
noising_class=UnsupervisedMTNoising,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Wrap a :class:`~torch.utils.data.Dataset` and apply noise to the
|
||||
samples based on the supplied noising configuration.
|
||||
|
||||
Args:
|
||||
src_dataset (~torch.utils.data.Dataset): dataset to wrap.
|
||||
to build self.src_dataset --
|
||||
a LanguagePairDataset with src dataset as the source dataset and
|
||||
None as the target dataset. Should NOT have padding so that
|
||||
src_lengths are accurately calculated by language_pair_dataset
|
||||
collate function.
|
||||
We use language_pair_dataset here to encapsulate the tgt_dataset
|
||||
so we can re-use the LanguagePairDataset collater to format the
|
||||
batches in the structure that SequenceGenerator expects.
|
||||
src_dict (~fairseq.data.Dictionary): source dictionary
|
||||
seed (int): seed to use when generating random noise
|
||||
noiser (WordNoising): a pre-initialized :class:`WordNoising`
|
||||
instance. If this is None, a new instance will be created using
|
||||
*noising_class* and *kwargs*.
|
||||
noising_class (class, optional): class to use to initialize a
|
||||
default :class:`WordNoising` instance.
|
||||
kwargs (dict, optional): arguments to initialize the default
|
||||
:class:`WordNoising` instance given by *noiser*.
|
||||
"""
|
||||
self.src_dataset = src_dataset
|
||||
self.src_dict = src_dict
|
||||
self.seed = seed
|
||||
self.noiser = (
|
||||
noiser
|
||||
if noiser is not None
|
||||
else noising_class(
|
||||
dictionary=src_dict,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
self.sizes = src_dataset.sizes
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Returns a single noisy sample. Multiple samples are fed to the collater
|
||||
create a noising dataset batch.
|
||||
"""
|
||||
src_tokens = self.src_dataset[index]
|
||||
src_lengths = torch.LongTensor([len(src_tokens)])
|
||||
src_tokens = src_tokens.unsqueeze(0)
|
||||
|
||||
# Transpose src tokens to fit expected shape of x in noising function
|
||||
# (batch size, sequence length) -> (sequence length, batch size)
|
||||
src_tokens_t = torch.t(src_tokens)
|
||||
|
||||
with data_utils.numpy_seed(self.seed + index):
|
||||
noisy_src_tokens = self.noiser.noising(src_tokens_t, src_lengths)
|
||||
|
||||
# Transpose back to expected src_tokens format
|
||||
# (sequence length, 1) -> (1, sequence length)
|
||||
noisy_src_tokens = torch.t(noisy_src_tokens)
|
||||
return noisy_src_tokens[0]
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
The length of the noising dataset is the length of src.
|
||||
"""
|
||||
return len(self.src_dataset)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return self.src_dataset.supports_prefetch
|
||||
|
||||
def prefetch(self, indices):
|
||||
if self.src_dataset.supports_prefetch:
|
||||
self.src_dataset.prefetch(indices)
|
||||
17
modules/voice_conversion/fairseq/data/num_samples_dataset.py
Normal file
17
modules/voice_conversion/fairseq/data/num_samples_dataset.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from . import FairseqDataset
|
||||
|
||||
|
||||
class NumSamplesDataset(FairseqDataset):
|
||||
def __getitem__(self, index):
|
||||
return 1
|
||||
|
||||
def __len__(self):
|
||||
return 0
|
||||
|
||||
def collater(self, samples):
|
||||
return sum(samples)
|
||||
31
modules/voice_conversion/fairseq/data/numel_dataset.py
Normal file
31
modules/voice_conversion/fairseq/data/numel_dataset.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import BaseWrapperDataset
|
||||
|
||||
|
||||
class NumelDataset(BaseWrapperDataset):
|
||||
def __init__(self, dataset, reduce=False):
|
||||
super().__init__(dataset)
|
||||
self.reduce = reduce
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.dataset[index]
|
||||
if torch.is_tensor(item):
|
||||
return torch.numel(item)
|
||||
else:
|
||||
return np.size(item)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def collater(self, samples):
|
||||
if self.reduce:
|
||||
return sum(samples)
|
||||
else:
|
||||
return torch.tensor(samples)
|
||||
@@ -0,0 +1,15 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from . import BaseWrapperDataset
|
||||
|
||||
|
||||
class OffsetTokensDataset(BaseWrapperDataset):
|
||||
def __init__(self, dataset, offset):
|
||||
super().__init__(dataset)
|
||||
self.offset = offset
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.dataset[idx] + self.offset
|
||||
31
modules/voice_conversion/fairseq/data/pad_dataset.py
Normal file
31
modules/voice_conversion/fairseq/data/pad_dataset.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from fairseq.data import data_utils
|
||||
|
||||
from . import BaseWrapperDataset
|
||||
|
||||
|
||||
class PadDataset(BaseWrapperDataset):
|
||||
def __init__(self, dataset, pad_idx, left_pad, pad_length=None):
|
||||
super().__init__(dataset)
|
||||
self.pad_idx = pad_idx
|
||||
self.left_pad = left_pad
|
||||
self.pad_length = pad_length
|
||||
|
||||
def collater(self, samples):
|
||||
return data_utils.collate_tokens(
|
||||
samples, self.pad_idx, left_pad=self.left_pad, pad_to_length=self.pad_length
|
||||
)
|
||||
|
||||
|
||||
class LeftPadDataset(PadDataset):
|
||||
def __init__(self, dataset, pad_idx):
|
||||
super().__init__(dataset, pad_idx, left_pad=True)
|
||||
|
||||
|
||||
class RightPadDataset(PadDataset):
|
||||
def __init__(self, dataset, pad_idx):
|
||||
super().__init__(dataset, pad_idx, left_pad=False)
|
||||
197
modules/voice_conversion/fairseq/data/plasma_utils.py
Normal file
197
modules/voice_conversion/fairseq/data/plasma_utils.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Hashable
|
||||
|
||||
try:
|
||||
import pyarrow.plasma as plasma
|
||||
|
||||
PYARROW_AVAILABLE = True
|
||||
except ImportError:
|
||||
plasma = None
|
||||
PYARROW_AVAILABLE = False
|
||||
|
||||
|
||||
class PlasmaArray:
|
||||
"""
|
||||
Wrapper around numpy arrays that automatically moves the data to shared
|
||||
memory upon serialization. This is particularly helpful when passing numpy
|
||||
arrays through multiprocessing, so that data is not unnecessarily
|
||||
duplicated or pickled.
|
||||
"""
|
||||
|
||||
def __init__(self, array):
|
||||
super().__init__()
|
||||
self.array = array
|
||||
self.disable = array.nbytes < 134217728 # disable for arrays <128MB
|
||||
self.object_id = None
|
||||
self.path = None
|
||||
|
||||
# variables with underscores shouldn't be pickled
|
||||
self._client = None
|
||||
self._server = None
|
||||
self._server_tmp = None
|
||||
self._plasma = None
|
||||
|
||||
@property
|
||||
def plasma(self):
|
||||
if self._plasma is None and not self.disable:
|
||||
self._plasma = plasma
|
||||
return self._plasma
|
||||
|
||||
def start_server(self):
|
||||
if self.plasma is None or self._server is not None:
|
||||
return
|
||||
assert self.object_id is None
|
||||
assert self.path is None
|
||||
self._server_tmp = tempfile.NamedTemporaryFile()
|
||||
self.path = self._server_tmp.name
|
||||
self._server = subprocess.Popen(
|
||||
["plasma_store", "-m", str(int(1.05 * self.array.nbytes)), "-s", self.path]
|
||||
)
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
if self._client is None:
|
||||
assert self.path is not None
|
||||
self._client = self.plasma.connect(self.path, num_retries=200)
|
||||
return self._client
|
||||
|
||||
def __getstate__(self):
|
||||
"""Called on pickle load"""
|
||||
if self.plasma is None:
|
||||
return self.__dict__
|
||||
if self.object_id is None:
|
||||
self.start_server()
|
||||
self.object_id = self.client.put(self.array)
|
||||
state = self.__dict__.copy()
|
||||
del state["array"]
|
||||
state["_client"] = None
|
||||
state["_server"] = None
|
||||
state["_server_tmp"] = None
|
||||
state["_plasma"] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Called on pickle save"""
|
||||
self.__dict__.update(state)
|
||||
if self.plasma is None:
|
||||
return
|
||||
self.array = self.client.get(self.object_id)
|
||||
|
||||
def __del__(self):
|
||||
if self._server is not None:
|
||||
self._server.kill()
|
||||
self._server = None
|
||||
self._server_tmp.close()
|
||||
self._server_tmp = None
|
||||
|
||||
|
||||
DEFAULT_PLASMA_PATH = "/tmp/plasma"
|
||||
|
||||
|
||||
class PlasmaView:
|
||||
"""Interface to write and read from shared memory. Whereas PlasmaArray writes to plasma on serialization,
|
||||
PlasmaView writes to shared memory on instantiation."""
|
||||
|
||||
def __init__(self, array, split_path: str, hash_data: Hashable, plasma_path=None):
|
||||
"""
|
||||
Args:
|
||||
array: numpy array to store. This can be read with ``PlasmaView().array``
|
||||
split_path: the path whence the data was read, used for hashing
|
||||
hash_data: other metadata about the array that can be used to create a unique key.
|
||||
as of writing, the 3 callers in ``TokenBlockDataset`` use::
|
||||
|
||||
hash_data = ((block_size, document_sep_len, str(break_mode), len(dataset)), 0|1|2)
|
||||
|
||||
|
||||
"""
|
||||
assert PYARROW_AVAILABLE
|
||||
assert split_path is not None
|
||||
if plasma_path is None:
|
||||
plasma_path = DEFAULT_PLASMA_PATH
|
||||
|
||||
self.path = plasma_path
|
||||
self.split_path = split_path
|
||||
self._client = None # Initialize lazily for pickle. plasma clients should not be deep copied or serialized.
|
||||
self._n = None
|
||||
|
||||
self.object_id = self.get_object_id(self.split_path, hash_data)
|
||||
try:
|
||||
self.client.put(array, object_id=self.object_id)
|
||||
except plasma.PlasmaObjectExists:
|
||||
pass
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
if self._client is None:
|
||||
self._client = plasma.connect(self.path, num_retries=200)
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def array(self):
|
||||
"""Fetch a read only view of an np.array, stored in plasma."""
|
||||
ret = self.client.get(self.object_id)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def get_object_id(split_path: str, hash_data: Hashable):
|
||||
"""Returns plasma.ObjectID from hashing split_path and object_num."""
|
||||
hash = hashlib.blake2b(bytes(split_path, "utf-8"), digest_size=20)
|
||||
harg = json.dumps(hash_data).encode("utf-8")
|
||||
hash.update(harg)
|
||||
return plasma.ObjectID(hash.digest())
|
||||
|
||||
def __getstate__(self):
|
||||
"""Called on pickle save"""
|
||||
self.disconnect()
|
||||
state = self.__dict__.copy()
|
||||
assert state["_client"] is None
|
||||
assert "object_id" in state
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Called on pickle load"""
|
||||
self.__dict__.update(state)
|
||||
|
||||
def __del__(self):
|
||||
self.disconnect()
|
||||
|
||||
def disconnect(self):
|
||||
if self._client is not None:
|
||||
self._client.disconnect()
|
||||
self._client = None
|
||||
|
||||
def __len__(self):
|
||||
"""Save reads by caching len"""
|
||||
if self._n is None:
|
||||
self._n = len(self.array)
|
||||
return self._n
|
||||
|
||||
|
||||
GB100 = (1024**3) * 100
|
||||
|
||||
|
||||
class PlasmaStore:
|
||||
def __init__(self, path=DEFAULT_PLASMA_PATH, nbytes: int = GB100):
|
||||
|
||||
self.server = self.start(path, nbytes)
|
||||
|
||||
def __del__(self):
|
||||
self.server.kill()
|
||||
|
||||
@staticmethod
|
||||
def start(path=DEFAULT_PLASMA_PATH, nbytes: int = GB100) -> subprocess.Popen:
|
||||
if not PYARROW_AVAILABLE:
|
||||
raise ImportError("please run pip install pyarrow to use --use_plasma_view")
|
||||
# best practice is to allocate more space than we need. The limitation seems to be the size of /dev/shm
|
||||
_server = subprocess.Popen(["plasma_store", "-m", str(nbytes), "-s", path])
|
||||
plasma.connect(path, num_retries=200) # If we can't connect we fail immediately
|
||||
return _server
|
||||
28
modules/voice_conversion/fairseq/data/prepend_dataset.py
Normal file
28
modules/voice_conversion/fairseq/data/prepend_dataset.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import BaseWrapperDataset
|
||||
|
||||
|
||||
class PrependDataset(BaseWrapperDataset):
|
||||
def __init__(self, dataset, prepend_getter, ensure_first_token_is=None):
|
||||
super().__init__(dataset)
|
||||
self.prepend_getter = prepend_getter
|
||||
self.ensure_first_token = ensure_first_token_is
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.dataset[idx]
|
||||
is_tuple = isinstance(item, tuple)
|
||||
src = item[0] if is_tuple else item
|
||||
|
||||
assert self.ensure_first_token is None or src[0] == self.ensure_first_token
|
||||
prepend_idx = self.prepend_getter(self.dataset, idx)
|
||||
assert isinstance(prepend_idx, int)
|
||||
src[0] = prepend_idx
|
||||
item = tuple((src,) + item[1:]) if is_tuple else src
|
||||
return item
|
||||
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import BaseWrapperDataset
|
||||
|
||||
|
||||
class PrependTokenDataset(BaseWrapperDataset):
|
||||
def __init__(self, dataset, token=None):
|
||||
super().__init__(dataset)
|
||||
self.token = token
|
||||
if token is not None:
|
||||
self._sizes = np.array(dataset.sizes) + 1
|
||||
else:
|
||||
self._sizes = dataset.sizes
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.dataset[idx]
|
||||
if self.token is not None:
|
||||
item = torch.cat([item.new([self.token]), item])
|
||||
return item
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
def num_tokens(self, index):
|
||||
n = self.dataset.num_tokens(index)
|
||||
if self.token is not None:
|
||||
n += 1
|
||||
return n
|
||||
|
||||
def size(self, index):
|
||||
n = self.dataset.size(index)
|
||||
if self.token is not None:
|
||||
n += 1
|
||||
return n
|
||||
23
modules/voice_conversion/fairseq/data/raw_label_dataset.py
Normal file
23
modules/voice_conversion/fairseq/data/raw_label_dataset.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from . import FairseqDataset
|
||||
|
||||
|
||||
class RawLabelDataset(FairseqDataset):
|
||||
def __init__(self, labels):
|
||||
super().__init__()
|
||||
self.labels = labels
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.labels[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
|
||||
def collater(self, samples):
|
||||
return torch.tensor(samples)
|
||||
36
modules/voice_conversion/fairseq/data/replace_dataset.py
Normal file
36
modules/voice_conversion/fairseq/data/replace_dataset.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from . import BaseWrapperDataset
|
||||
|
||||
|
||||
class ReplaceDataset(BaseWrapperDataset):
|
||||
"""Replaces tokens found in the dataset by a specified replacement token
|
||||
|
||||
Args:
|
||||
dataset (~torch.utils.data.Dataset): dataset to replace tokens in
|
||||
replace_map(Dictionary[int,int]): map of token to replace -> replacement token
|
||||
offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be
|
||||
as many as the number of objects returned by the underlying dataset __getitem__ method.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, replace_map, offsets):
|
||||
super().__init__(dataset)
|
||||
assert len(replace_map) > 0
|
||||
self.replace_map = replace_map
|
||||
self.offsets = offsets
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.dataset[index]
|
||||
is_tuple = isinstance(item, tuple)
|
||||
srcs = item if is_tuple else [item]
|
||||
|
||||
for offset, src in zip(self.offsets, srcs):
|
||||
for k, v in self.replace_map.items():
|
||||
src_off = src[offset:] if offset >= 0 else src[:offset]
|
||||
src_off.masked_fill_(src_off == k, v)
|
||||
|
||||
item = srcs if is_tuple else srcs[0]
|
||||
return item
|
||||
139
modules/voice_conversion/fairseq/data/resampling_dataset.py
Normal file
139
modules/voice_conversion/fairseq/data/resampling_dataset.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fairseq.data import BaseWrapperDataset, plasma_utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResamplingDataset(BaseWrapperDataset):
|
||||
"""Randomly samples from a given dataset at each epoch.
|
||||
|
||||
Sampling is done with or without replacement, depending on the "replace"
|
||||
parameter.
|
||||
|
||||
Optionally, the epoch size can be rescaled. This is potentially desirable
|
||||
to increase per-epoch coverage of the base dataset (since sampling with
|
||||
replacement means that many items in the dataset will be left out). In the
|
||||
case of sampling without replacement, size_ratio should be strictly less
|
||||
than 1.
|
||||
|
||||
Args:
|
||||
dataset (~torch.utils.data.Dataset): dataset on which to sample.
|
||||
weights (List[float]): list of probability weights
|
||||
(default: None, which corresponds to uniform sampling).
|
||||
replace (bool): sampling mode; True for "with replacement", or False
|
||||
for "without replacement" (default: True)
|
||||
size_ratio (float): the ratio to subsample to; must be positive
|
||||
(default: 1.0).
|
||||
batch_by_size (bool): whether or not to batch by sequence length
|
||||
(default: True).
|
||||
seed (int): RNG seed to use (default: 0).
|
||||
epoch (int): starting epoch number (default: 1).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
weights=None,
|
||||
replace=True,
|
||||
size_ratio=1.0,
|
||||
batch_by_size=True,
|
||||
seed=0,
|
||||
epoch=1,
|
||||
):
|
||||
super().__init__(dataset)
|
||||
|
||||
if weights is None:
|
||||
self.weights = None
|
||||
|
||||
else:
|
||||
assert len(weights) == len(dataset)
|
||||
weights_arr = np.array(weights, dtype=np.float64)
|
||||
weights_arr /= weights_arr.sum()
|
||||
self.weights = plasma_utils.PlasmaArray(weights_arr)
|
||||
|
||||
self.replace = replace
|
||||
|
||||
assert size_ratio > 0.0
|
||||
if not self.replace:
|
||||
assert size_ratio < 1.0
|
||||
self.size_ratio = float(size_ratio)
|
||||
self.actual_size = np.ceil(len(dataset) * self.size_ratio).astype(int)
|
||||
|
||||
self.batch_by_size = batch_by_size
|
||||
self.seed = seed
|
||||
|
||||
self._cur_epoch = None
|
||||
self._cur_indices = None
|
||||
|
||||
self.set_epoch(epoch)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.dataset[self._cur_indices.array[index]]
|
||||
|
||||
def __len__(self):
|
||||
return self.actual_size
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
if isinstance(self.dataset.sizes, list):
|
||||
return [s[self._cur_indices.array] for s in self.dataset.sizes]
|
||||
return self.dataset.sizes[self._cur_indices.array]
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.dataset.num_tokens(self._cur_indices.array[index])
|
||||
|
||||
def size(self, index):
|
||||
return self.dataset.size(self._cur_indices.array[index])
|
||||
|
||||
def ordered_indices(self):
|
||||
if self.batch_by_size:
|
||||
order = [
|
||||
np.arange(len(self)),
|
||||
self.sizes,
|
||||
] # No need to handle `self.shuffle == True`
|
||||
return np.lexsort(order)
|
||||
else:
|
||||
return np.arange(len(self))
|
||||
|
||||
def prefetch(self, indices):
|
||||
self.dataset.prefetch(self._cur_indices.array[indices])
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return False
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
logger.debug("ResamplingDataset.set_epoch: {}".format(epoch))
|
||||
super().set_epoch(epoch)
|
||||
|
||||
if epoch == self._cur_epoch:
|
||||
return
|
||||
|
||||
self._cur_epoch = epoch
|
||||
|
||||
# Generate a weighted sample of indices as a function of the
|
||||
# random seed and the current epoch.
|
||||
|
||||
rng = np.random.RandomState(
|
||||
[
|
||||
42, # magic number
|
||||
self.seed % (2**32), # global seed
|
||||
self._cur_epoch, # epoch index
|
||||
]
|
||||
)
|
||||
self._cur_indices = plasma_utils.PlasmaArray(
|
||||
rng.choice(
|
||||
len(self.dataset),
|
||||
self.actual_size,
|
||||
replace=self.replace,
|
||||
p=(None if self.weights is None else self.weights.array),
|
||||
)
|
||||
)
|
||||
18
modules/voice_conversion/fairseq/data/roll_dataset.py
Normal file
18
modules/voice_conversion/fairseq/data/roll_dataset.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from . import BaseWrapperDataset
|
||||
|
||||
|
||||
class RollDataset(BaseWrapperDataset):
|
||||
def __init__(self, dataset, shifts):
|
||||
super().__init__(dataset)
|
||||
self.shifts = shifts
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.dataset[index]
|
||||
return torch.roll(item, self.shifts)
|
||||
@@ -0,0 +1,160 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import FairseqDataset, LanguagePairDataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RoundRobinZipDatasets(FairseqDataset):
|
||||
"""Zip multiple :class:`~fairseq.data.FairseqDataset` instances together.
|
||||
|
||||
Shorter datasets are repeated in a round-robin fashion to match the length
|
||||
of the longest one.
|
||||
|
||||
Args:
|
||||
datasets (Dict[~fairseq.data.FairseqDataset]): a dictionary of
|
||||
:class:`~fairseq.data.FairseqDataset` instances.
|
||||
eval_key (str, optional): a key used at evaluation time that causes
|
||||
this instance to pass-through batches from *datasets[eval_key]*.
|
||||
"""
|
||||
|
||||
def __init__(self, datasets, eval_key=None):
|
||||
super().__init__()
|
||||
if isinstance(datasets, dict):
|
||||
datasets = OrderedDict(datasets)
|
||||
assert isinstance(datasets, OrderedDict)
|
||||
assert datasets, "Can't make a RoundRobinZipDatasets out of nothing"
|
||||
for dataset in datasets.values():
|
||||
assert isinstance(dataset, FairseqDataset)
|
||||
|
||||
self.datasets = datasets
|
||||
self.eval_key = eval_key
|
||||
|
||||
self.longest_dataset_key = max(datasets, key=lambda k: len(datasets[k]))
|
||||
self.longest_dataset = datasets[self.longest_dataset_key]
|
||||
self._ordered_indices: Dict[str, Sequence[int]] = None
|
||||
|
||||
def _map_index(self, key, index):
|
||||
assert (
|
||||
self._ordered_indices is not None
|
||||
), "Must call RoundRobinZipDatasets.ordered_indices() first"
|
||||
o = self._ordered_indices[key]
|
||||
return o[index % len(o)]
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.eval_key is None:
|
||||
return OrderedDict(
|
||||
[
|
||||
(key, dataset[self._map_index(key, index)])
|
||||
for key, dataset in self.datasets.items()
|
||||
]
|
||||
)
|
||||
else:
|
||||
# at evaluation time it's useful to pass-through batches from a single key
|
||||
return self.datasets[self.eval_key][self._map_index(self.eval_key, index)]
|
||||
|
||||
def __len__(self):
|
||||
if self._ordered_indices is not None:
|
||||
return len(self._ordered_indices[self.longest_dataset_key])
|
||||
return len(self.longest_dataset)
|
||||
|
||||
def collater(self, samples):
|
||||
"""Merge a list of samples to form a mini-batch."""
|
||||
if len(samples) == 0:
|
||||
return None
|
||||
if self.eval_key is None:
|
||||
return OrderedDict(
|
||||
[
|
||||
(key, dataset.collater([sample[key] for sample in samples]))
|
||||
for key, dataset in self.datasets.items()
|
||||
]
|
||||
)
|
||||
else:
|
||||
# at evaluation time it's useful to pass-through batches from a single key
|
||||
return self.datasets[self.eval_key].collater(samples)
|
||||
|
||||
def num_tokens(self, index):
|
||||
"""Return an example's length (number of tokens), used for batching."""
|
||||
# TODO make it configurable whether to use max() or sum() here
|
||||
return max(
|
||||
dataset.num_tokens(self._map_index(key, index))
|
||||
for key, dataset in self.datasets.items()
|
||||
)
|
||||
|
||||
def size(self, index):
|
||||
"""Return an example's size as a float or tuple. This value is used when
|
||||
filtering a dataset with ``--max-positions``."""
|
||||
return {
|
||||
key: dataset.size(self._map_index(key, index))
|
||||
for key, dataset in self.datasets.items()
|
||||
}
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Ordered indices for batching."""
|
||||
if self._ordered_indices is None:
|
||||
# Call the underlying dataset's ordered_indices() here, so that we
|
||||
# get the same random ordering as we would have from using the
|
||||
# underlying sub-datasets directly.
|
||||
self._ordered_indices = OrderedDict(
|
||||
[
|
||||
(key, dataset.ordered_indices())
|
||||
for key, dataset in self.datasets.items()
|
||||
]
|
||||
)
|
||||
return np.arange(len(self))
|
||||
|
||||
def filter_indices_by_size(self, indices, max_positions=None):
|
||||
"""
|
||||
Filter each sub-dataset independently, then update the round robin to work
|
||||
on the filtered sub-datasets.
|
||||
"""
|
||||
|
||||
def _deep_until_language_pair(dataset):
|
||||
if isinstance(dataset, LanguagePairDataset):
|
||||
return dataset
|
||||
if hasattr(dataset, "tgt_dataset"):
|
||||
return _deep_until_language_pair(dataset.tgt_dataset)
|
||||
if hasattr(dataset, "dataset"):
|
||||
return _deep_until_language_pair(dataset.dataset)
|
||||
raise Exception(f"Don't know how to unwrap this dataset: {dataset}")
|
||||
|
||||
if not isinstance(max_positions, dict):
|
||||
max_positions = {k: max_positions for k in self.datasets.keys()}
|
||||
ignored_some = False
|
||||
for key, dataset in self.datasets.items():
|
||||
dataset = _deep_until_language_pair(dataset)
|
||||
self._ordered_indices[key], ignored = dataset.filter_indices_by_size(
|
||||
self._ordered_indices[key], max_positions[key]
|
||||
)
|
||||
if len(ignored) > 0:
|
||||
ignored_some = True
|
||||
logger.warning(
|
||||
f"{len(ignored)} samples from {key} have invalid sizes and will be skipped, "
|
||||
f"max_positions={max_positions[key]}, first few sample ids={ignored[:10]}"
|
||||
)
|
||||
# Since we are modifying in place the _ordered_indices,
|
||||
# it's not possible anymore to return valid ignored indices.
|
||||
# Hopefully the extra debug information print above should be enough to debug.
|
||||
# Ideally we would receive ignore_invalid_inputs so that we could have
|
||||
# a proper error message.
|
||||
return (np.arange(len(self)), [0] if ignored_some else [])
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return all(
|
||||
getattr(dataset, "supports_prefetch", False)
|
||||
for dataset in self.datasets.values()
|
||||
)
|
||||
|
||||
def prefetch(self, indices):
|
||||
for key, dataset in self.datasets.items():
|
||||
dataset.prefetch([self._map_index(key, index) for index in indices])
|
||||
78
modules/voice_conversion/fairseq/data/shorten_dataset.py
Normal file
78
modules/voice_conversion/fairseq/data/shorten_dataset.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
from fairseq.data import data_utils
|
||||
|
||||
from . import BaseWrapperDataset
|
||||
|
||||
|
||||
class TruncateDataset(BaseWrapperDataset):
|
||||
"""Truncate a sequence by returning the first truncation_length tokens"""
|
||||
|
||||
def __init__(self, dataset, truncation_length):
|
||||
super().__init__(dataset)
|
||||
assert truncation_length is not None
|
||||
self.truncation_length = truncation_length
|
||||
self.dataset = dataset
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.dataset[index]
|
||||
item_len = item.size(0)
|
||||
if item_len > self.truncation_length:
|
||||
item = item[: self.truncation_length]
|
||||
return item
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return np.minimum(self.dataset.sizes, self.truncation_length)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
|
||||
class RandomCropDataset(TruncateDataset):
|
||||
"""Truncate a sequence by returning a random crop of truncation_length tokens"""
|
||||
|
||||
def __init__(self, dataset, truncation_length, seed=1):
|
||||
super().__init__(dataset, truncation_length)
|
||||
self.seed = seed
|
||||
self.epoch = 0
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return True # only the crop changes, not item sizes
|
||||
|
||||
def set_epoch(self, epoch, **unused):
|
||||
super().set_epoch(epoch)
|
||||
self.epoch = epoch
|
||||
|
||||
def __getitem__(self, index):
|
||||
with data_utils.numpy_seed(self.seed, self.epoch, index):
|
||||
item = self.dataset[index]
|
||||
item_len = item.size(0)
|
||||
excess = item_len - self.truncation_length
|
||||
if excess > 0:
|
||||
start_idx = np.random.randint(0, excess)
|
||||
item = item[start_idx : start_idx + self.truncation_length]
|
||||
return item
|
||||
|
||||
|
||||
def maybe_shorten_dataset(
|
||||
dataset,
|
||||
split,
|
||||
shorten_data_split_list,
|
||||
shorten_method,
|
||||
tokens_per_sample,
|
||||
seed,
|
||||
):
|
||||
truncate_split = (
|
||||
split in shorten_data_split_list.split(",") or len(shorten_data_split_list) == 0
|
||||
)
|
||||
if shorten_method == "truncate" and truncate_split:
|
||||
dataset = TruncateDataset(dataset, tokens_per_sample)
|
||||
elif shorten_method == "random_crop" and truncate_split:
|
||||
dataset = RandomCropDataset(dataset, tokens_per_sample, seed)
|
||||
return dataset
|
||||
21
modules/voice_conversion/fairseq/data/sort_dataset.py
Normal file
21
modules/voice_conversion/fairseq/data/sort_dataset.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import BaseWrapperDataset
|
||||
|
||||
|
||||
class SortDataset(BaseWrapperDataset):
|
||||
def __init__(self, dataset, sort_order):
|
||||
super().__init__(dataset)
|
||||
if not isinstance(sort_order, (list, tuple)):
|
||||
sort_order = [sort_order]
|
||||
self.sort_order = sort_order
|
||||
|
||||
assert all(len(so) == len(dataset) for so in sort_order)
|
||||
|
||||
def ordered_indices(self):
|
||||
return np.lexsort(self.sort_order)
|
||||
@@ -0,0 +1,293 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import Dictionary, FairseqDataset, data_utils
|
||||
|
||||
|
||||
def collate(
|
||||
samples,
|
||||
pad_idx,
|
||||
eos_idx,
|
||||
vocab,
|
||||
left_pad_source=False,
|
||||
left_pad_target=False,
|
||||
input_feeding=True,
|
||||
pad_to_length=None,
|
||||
):
|
||||
assert input_feeding
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
|
||||
return data_utils.collate_tokens(
|
||||
[s[key] for s in samples],
|
||||
pad_idx,
|
||||
eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
|
||||
left_pad=left_pad,
|
||||
move_eos_to_beginning=move_eos_to_beginning,
|
||||
pad_to_length=pad_to_length,
|
||||
)
|
||||
|
||||
id = torch.LongTensor([s["id"] for s in samples])
|
||||
src_tokens = merge(
|
||||
"source",
|
||||
left_pad=left_pad_source,
|
||||
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
|
||||
)
|
||||
# sort by descending source length
|
||||
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
|
||||
src_lengths, sort_order = src_lengths.sort(descending=True)
|
||||
id = id.index_select(0, sort_order)
|
||||
src_tokens = src_tokens.index_select(0, sort_order)
|
||||
|
||||
prev_output_tokens = None
|
||||
target = None
|
||||
if samples[0].get("target", None) is not None:
|
||||
target = merge(
|
||||
"target",
|
||||
left_pad=left_pad_target,
|
||||
pad_to_length=pad_to_length["target"]
|
||||
if pad_to_length is not None
|
||||
else None,
|
||||
)
|
||||
target = target.index_select(0, sort_order)
|
||||
ntokens = sum(len(s["target"]) for s in samples)
|
||||
|
||||
if input_feeding:
|
||||
# we create a shifted version of targets for feeding the
|
||||
# previous output token(s) into the next decoder step
|
||||
prev_output_tokens = merge(
|
||||
"target",
|
||||
left_pad=left_pad_target,
|
||||
move_eos_to_beginning=True,
|
||||
pad_to_length=pad_to_length["target"]
|
||||
if pad_to_length is not None
|
||||
else None,
|
||||
)
|
||||
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
|
||||
else:
|
||||
ntokens = sum(len(s["source"]) for s in samples)
|
||||
|
||||
batch = {
|
||||
"id": id,
|
||||
"ntokens": ntokens,
|
||||
"net_input": {
|
||||
"src_tokens": src_tokens,
|
||||
"src_lengths": src_lengths,
|
||||
},
|
||||
"target": target,
|
||||
"target_lengths": torch.LongTensor([len(t) for t in target]),
|
||||
"nsentences": samples[0]["source"].size(0),
|
||||
"sort_order": sort_order,
|
||||
}
|
||||
if prev_output_tokens is not None:
|
||||
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class SpanMaskedTokensDataset(FairseqDataset):
|
||||
"""
|
||||
A wrapper around TokenBlockDataset for T5 dataset.
|
||||
|
||||
Args:
|
||||
dataset (~torch.utils.data.Dataset): dataset to wrap
|
||||
vocab (~fairseq.data.Dictionary): vocabulary
|
||||
noise_density (float): fraction of the tokens to select as noise.
|
||||
mean_noise_span_length (float): mean noise span length.
|
||||
shuffle (bool, optional): shuffle the elements before batching.
|
||||
Default: ``True``
|
||||
seed: Seed for random number generator for reproducibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: torch.utils.data.Dataset,
|
||||
vocab: Dictionary,
|
||||
noise_density: float,
|
||||
mean_noise_span_length: float,
|
||||
shuffle: bool,
|
||||
seed: int = 1,
|
||||
):
|
||||
self.dataset = dataset
|
||||
self.vocab = vocab
|
||||
self.seed = seed
|
||||
self.noise_density = noise_density
|
||||
self.mean_noise_span_length = mean_noise_span_length
|
||||
self.shuffle = shuffle
|
||||
self.epoch = 0
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return True # only the noise changes, not item sizes
|
||||
|
||||
def set_epoch(self, epoch, **unused):
|
||||
self.epoch = epoch
|
||||
|
||||
def __getitem__(self, index):
|
||||
with data_utils.numpy_seed(self.seed, self.epoch, index):
|
||||
item = self.dataset[index]
|
||||
assert item[-1] == self.vocab.eos()
|
||||
|
||||
noise_mask = self.random_spans_noise_mask(len(item))
|
||||
|
||||
source_sentinel_ids = self.create_sentinel_ids(noise_mask.astype(np.int8))
|
||||
source = self.filter_input_ids(item, source_sentinel_ids)
|
||||
|
||||
target_sentinel_ids = self.create_sentinel_ids(
|
||||
(~noise_mask).astype(np.int8)
|
||||
)
|
||||
target = self.filter_input_ids(item, target_sentinel_ids)
|
||||
|
||||
return {
|
||||
"id": index,
|
||||
"source": torch.from_numpy(source),
|
||||
"target": torch.from_numpy(target),
|
||||
}
|
||||
|
||||
def random_spans_noise_mask(self, length):
|
||||
|
||||
"""
|
||||
This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
|
||||
Noise mask consisting of random spans of noise tokens.
|
||||
The number of noise tokens and the number of noise spans and non-noise spans
|
||||
are determined deterministically as follows:
|
||||
num_noise_tokens = round(length * noise_density)
|
||||
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
|
||||
Spans alternate between non-noise and noise, beginning with non-noise.
|
||||
Subject to the above restrictions, all masks are equally likely.
|
||||
Args:
|
||||
length: an int32 scalar (length of the incoming token sequence)
|
||||
Returns:
|
||||
a boolean tensor with shape [length]
|
||||
"""
|
||||
|
||||
orig_length = length
|
||||
|
||||
num_noise_tokens = int(np.round(length * self.noise_density))
|
||||
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
|
||||
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
|
||||
num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
|
||||
|
||||
# avoid degeneracy by ensuring positive number of noise spans
|
||||
num_noise_spans = max(num_noise_spans, 1)
|
||||
num_nonnoise_tokens = length - num_noise_tokens
|
||||
|
||||
# pick the lengths of the noise spans and the non-noise spans
|
||||
def _random_segmentation(num_items, num_segments):
|
||||
"""
|
||||
Partition a sequence of items randomly into non-empty segments.
|
||||
Args:
|
||||
num_items: an integer scalar > 0
|
||||
num_segments: an integer scalar in [1, num_items]
|
||||
Returns:
|
||||
a Tensor with shape [num_segments] containing positive integers that add up to num_items
|
||||
"""
|
||||
mask_indices = np.arange(num_items - 1) < (num_segments - 1)
|
||||
np.random.shuffle(mask_indices)
|
||||
first_in_segment = np.pad(mask_indices, [[1, 0]])
|
||||
segment_id = np.cumsum(first_in_segment)
|
||||
# count length of subsegments assuming that list is sorted
|
||||
_, segment_length = np.unique(segment_id, return_counts=True)
|
||||
return segment_length
|
||||
|
||||
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
|
||||
nonnoise_span_lengths = _random_segmentation(
|
||||
num_nonnoise_tokens, num_noise_spans
|
||||
)
|
||||
|
||||
interleaved_span_lengths = np.reshape(
|
||||
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
|
||||
[num_noise_spans * 2],
|
||||
)
|
||||
span_starts = np.cumsum(interleaved_span_lengths)[:-1]
|
||||
span_start_indicator = np.zeros((length,), dtype=np.int8)
|
||||
span_start_indicator[span_starts] = True
|
||||
span_num = np.cumsum(span_start_indicator)
|
||||
is_noise = np.equal(span_num % 2, 1)
|
||||
|
||||
return is_noise[:orig_length]
|
||||
|
||||
def create_sentinel_ids(self, mask_indices):
|
||||
"""
|
||||
Sentinel ids creation given the indices that should be masked.
|
||||
The start indices of each mask are replaced by the sentinel ids in increasing
|
||||
order. Consecutive mask indices to be deleted are replaced with `-1`.
|
||||
"""
|
||||
start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
|
||||
|
||||
sentinel_ids = np.where(
|
||||
start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices
|
||||
)
|
||||
# making sure all sentinel tokens are unique over the example
|
||||
sentinel_ids = np.where(sentinel_ids != 0, len(self.vocab) - sentinel_ids, 0)
|
||||
sentinel_ids -= mask_indices - start_indices
|
||||
return sentinel_ids
|
||||
|
||||
@staticmethod
|
||||
def filter_input_ids(input_ids, sentinel_ids):
|
||||
"""
|
||||
Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
|
||||
This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
|
||||
"""
|
||||
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
|
||||
|
||||
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
|
||||
# masked tokens coming after sentinel tokens and should be removed
|
||||
return input_ids_full[input_ids_full >= 0]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def collater(self, samples, pad_to_length=None):
|
||||
"""
|
||||
Merge a list of samples to form a mini-batch.
|
||||
Args:
|
||||
samples (List[dict]): samples to collate
|
||||
Returns:
|
||||
dict: a mini-batch of data
|
||||
"""
|
||||
return collate(
|
||||
samples,
|
||||
self.vocab.pad(),
|
||||
self.vocab.eos(),
|
||||
self.vocab,
|
||||
pad_to_length=pad_to_length,
|
||||
)
|
||||
|
||||
def num_tokens(self, index):
|
||||
"""Return the number of tokens in a sample. This value is used to
|
||||
enforce ``--max-tokens`` during batching."""
|
||||
return self.dataset.sizes[index]
|
||||
|
||||
def size(self, index):
|
||||
"""Return an example's size as a float or tuple. This value is used when
|
||||
filtering a dataset with ``--max-positions``."""
|
||||
return self.dataset.sizes[index]
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
if self.shuffle:
|
||||
indices = np.random.permutation(len(self))
|
||||
else:
|
||||
indices = np.arange(len(self))
|
||||
return indices[np.argsort(self.dataset.sizes[indices], kind="mergesort")]
|
||||
|
||||
def prefetch(self, indices):
|
||||
self.src.prefetch(indices)
|
||||
self.tgt.prefetch(indices)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return (
|
||||
hasattr(self.src, "supports_prefetch")
|
||||
and self.src.supports_prefetch
|
||||
and hasattr(self.tgt, "supports_prefetch")
|
||||
and self.tgt.supports_prefetch
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user