mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-05-03 12:51:20 +00:00
Add monkey patched fairseq package to run on python 3.11 (what is needed for our use of RVC at least)
This commit is contained in:
387
modules/voice_conversion/fairseq/data/audio/data_cfg.py
Normal file
387
modules/voice_conversion/fairseq/data/audio/data_cfg.py
Normal file
@@ -0,0 +1,387 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from argparse import Namespace
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fairseq.data import Dictionary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_config_from_yaml(yaml_path: Path):
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
print("Please install PyYAML: pip install PyYAML")
|
||||
config = {}
|
||||
if yaml_path.is_file():
|
||||
try:
|
||||
with open(yaml_path) as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to load config from {yaml_path.as_posix()}: {e}")
|
||||
else:
|
||||
raise FileNotFoundError(f"{yaml_path.as_posix()} not found")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class S2TDataConfig(object):
|
||||
"""Wrapper class for data config YAML"""
|
||||
|
||||
def __init__(self, yaml_path: Path):
|
||||
self.config = get_config_from_yaml(yaml_path)
|
||||
self.root = yaml_path.parent
|
||||
|
||||
def _auto_convert_to_abs_path(self, x):
|
||||
if isinstance(x, str):
|
||||
if not Path(x).exists() and (self.root / x).exists():
|
||||
return (self.root / x).as_posix()
|
||||
elif isinstance(x, dict):
|
||||
return {k: self._auto_convert_to_abs_path(v) for k, v in x.items()}
|
||||
return x
|
||||
|
||||
@property
|
||||
def vocab_filename(self):
|
||||
"""fairseq vocabulary file under data root"""
|
||||
return self.config.get("vocab_filename", "dict.txt")
|
||||
|
||||
@property
|
||||
def speaker_set_filename(self):
|
||||
"""speaker set file under data root"""
|
||||
return self.config.get("speaker_set_filename", None)
|
||||
|
||||
@property
|
||||
def shuffle(self) -> bool:
|
||||
"""Shuffle dataset samples before batching"""
|
||||
return self.config.get("shuffle", False)
|
||||
|
||||
@property
|
||||
def pre_tokenizer(self) -> Dict:
|
||||
"""Pre-tokenizer to apply before subword tokenization. Returning
|
||||
a dictionary with `tokenizer` providing the tokenizer name and
|
||||
the other items providing the tokenizer-specific arguments.
|
||||
Tokenizers are defined in `fairseq.data.encoders.*`"""
|
||||
tokenizer = self.config.get("pre_tokenizer", {"tokenizer": None})
|
||||
return self._auto_convert_to_abs_path(tokenizer)
|
||||
|
||||
@property
|
||||
def bpe_tokenizer(self) -> Dict:
|
||||
"""Subword tokenizer to apply after pre-tokenization. Returning
|
||||
a dictionary with `bpe` providing the tokenizer name and
|
||||
the other items providing the tokenizer-specific arguments.
|
||||
Tokenizers are defined in `fairseq.data.encoders.*`"""
|
||||
tokenizer = self.config.get("bpe_tokenizer", {"bpe": None})
|
||||
return self._auto_convert_to_abs_path(tokenizer)
|
||||
|
||||
@property
|
||||
def prepend_tgt_lang_tag(self) -> bool:
|
||||
"""Prepend target lang ID token as the target BOS (e.g. for to-many
|
||||
multilingual setting). During inference, this requires `--prefix-size 1`
|
||||
to force BOS to be lang ID token."""
|
||||
return self.config.get("prepend_tgt_lang_tag", False)
|
||||
|
||||
@property
|
||||
def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
|
||||
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
|
||||
return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
|
||||
|
||||
@property
|
||||
def input_feat_per_channel(self):
|
||||
"""The dimension of input features (per audio channel)"""
|
||||
return self.config.get("input_feat_per_channel", 80)
|
||||
|
||||
@property
|
||||
def input_channels(self):
|
||||
"""The number of channels in the input audio"""
|
||||
return self.config.get("input_channels", 1)
|
||||
|
||||
@property
|
||||
def sample_rate(self):
|
||||
return self.config.get("sample_rate", 16_000)
|
||||
|
||||
@property
|
||||
def sampling_alpha(self):
|
||||
"""Hyper-parameter alpha = 1/T for temperature-based resampling.
|
||||
(alpha = 1 for no resampling)"""
|
||||
return self.config.get("sampling_alpha", 1.0)
|
||||
|
||||
@property
|
||||
def use_audio_input(self):
|
||||
"""Needed by the dataset loader to see if the model requires
|
||||
raw audio as inputs."""
|
||||
return self.config.get("use_audio_input", False)
|
||||
|
||||
def standardize_audio(self) -> bool:
|
||||
return self.use_audio_input and self.config.get("standardize_audio", False)
|
||||
|
||||
@property
|
||||
def use_sample_rate(self):
|
||||
"""Needed by the dataset loader to see if the model requires
|
||||
raw audio with specific sample rate as inputs."""
|
||||
return self.config.get("use_sample_rate", 16000)
|
||||
|
||||
@property
|
||||
def audio_root(self):
|
||||
"""Audio paths in the manifest TSV can be relative and this provides
|
||||
the root path. Set this to empty string when using absolute paths."""
|
||||
return self.config.get("audio_root", "")
|
||||
|
||||
def get_transforms(self, transform_type, split, is_train):
|
||||
"""Split-specific feature transforms. Allowing train set
|
||||
wildcard `_train`, evaluation set wildcard `_eval` and general
|
||||
wildcard `*` for matching."""
|
||||
from copy import deepcopy
|
||||
|
||||
cfg = deepcopy(self.config)
|
||||
_cur = cfg.get(f"{transform_type}transforms", {})
|
||||
cur = _cur.get(split)
|
||||
cur = _cur.get("_train") if cur is None and is_train else cur
|
||||
cur = _cur.get("_eval") if cur is None and not is_train else cur
|
||||
cur = _cur.get("*") if cur is None else cur
|
||||
return cur
|
||||
|
||||
def get_feature_transforms(self, split, is_train):
|
||||
cfg = deepcopy(self.config)
|
||||
# TODO: deprecate transforms
|
||||
cur = self.get_transforms("", split, is_train)
|
||||
if cur is not None:
|
||||
logger.warning(
|
||||
"Auto converting transforms into feature_transforms, "
|
||||
"but transforms will be deprecated in the future. Please "
|
||||
"update this in the config."
|
||||
)
|
||||
ft_transforms = self.get_transforms("feature_", split, is_train)
|
||||
if ft_transforms:
|
||||
cur.extend(ft_transforms)
|
||||
else:
|
||||
cur = self.get_transforms("feature_", split, is_train)
|
||||
cfg["feature_transforms"] = cur
|
||||
return cfg
|
||||
|
||||
def get_waveform_transforms(self, split, is_train):
|
||||
cfg = deepcopy(self.config)
|
||||
cfg["waveform_transforms"] = self.get_transforms("waveform_", split, is_train)
|
||||
return cfg
|
||||
|
||||
def get_dataset_transforms(self, split, is_train):
|
||||
cfg = deepcopy(self.config)
|
||||
cfg["dataset_transforms"] = self.get_transforms("dataset_", split, is_train)
|
||||
return cfg
|
||||
|
||||
@property
|
||||
def global_cmvn_stats_npz(self) -> Optional[str]:
|
||||
path = self.config.get("global_cmvn", {}).get("stats_npz_path", None)
|
||||
return self._auto_convert_to_abs_path(path)
|
||||
|
||||
@property
|
||||
def vocoder(self) -> Dict[str, str]:
|
||||
vocoder = self.config.get("vocoder", {"type": "griffin_lim"})
|
||||
return self._auto_convert_to_abs_path(vocoder)
|
||||
|
||||
@property
|
||||
def hub(self) -> Dict[str, str]:
|
||||
return self.config.get("hub", {})
|
||||
|
||||
|
||||
class S2SDataConfig(S2TDataConfig):
|
||||
"""Wrapper class for data config YAML"""
|
||||
|
||||
@property
|
||||
def vocab_filename(self):
|
||||
"""fairseq vocabulary file under data root"""
|
||||
return self.config.get("vocab_filename", None)
|
||||
|
||||
@property
|
||||
def pre_tokenizer(self) -> Dict:
|
||||
return None
|
||||
|
||||
@property
|
||||
def bpe_tokenizer(self) -> Dict:
|
||||
return None
|
||||
|
||||
@property
|
||||
def input_transformed_channels(self):
|
||||
"""The number of channels in the audio after feature transforms"""
|
||||
# TODO: move this into individual transforms
|
||||
# TODO: deprecate transforms
|
||||
_cur = self.config.get("transforms", {})
|
||||
ft_transforms = self.config.get("feature_transforms", {})
|
||||
if _cur and ft_transforms:
|
||||
_cur.update(ft_transforms)
|
||||
else:
|
||||
_cur = self.config.get("feature_transforms", {})
|
||||
cur = _cur.get("_train", [])
|
||||
|
||||
_channels = self.input_channels
|
||||
if "delta_deltas" in cur:
|
||||
_channels *= 3
|
||||
|
||||
return _channels
|
||||
|
||||
@property
|
||||
def output_sample_rate(self):
|
||||
"""The audio sample rate of output target speech"""
|
||||
return self.config.get("output_sample_rate", 22050)
|
||||
|
||||
@property
|
||||
def target_speaker_embed(self):
|
||||
"""Target speaker embedding file (one line per target audio sample)"""
|
||||
return self.config.get("target_speaker_embed", None)
|
||||
|
||||
@property
|
||||
def prepend_tgt_lang_tag_as_bos(self) -> bool:
|
||||
"""Prepend target lang ID token as the target BOS."""
|
||||
return self.config.get("prepend_tgt_lang_tag_as_bos", False)
|
||||
|
||||
|
||||
class MultitaskConfig(object):
|
||||
"""Wrapper class for data config YAML"""
|
||||
|
||||
def __init__(self, yaml_path: Path):
|
||||
config = get_config_from_yaml(yaml_path)
|
||||
self.config = {}
|
||||
for k, v in config.items():
|
||||
self.config[k] = SingleTaskConfig(k, v)
|
||||
|
||||
def get_all_tasks(self):
|
||||
return self.config
|
||||
|
||||
def get_single_task(self, name):
|
||||
assert name in self.config, f"multitask '{name}' does not exist!"
|
||||
return self.config[name]
|
||||
|
||||
@property
|
||||
def first_pass_decoder_task_index(self):
|
||||
"""Return the task index of the first-pass text decoder.
|
||||
If there are multiple 'is_first_pass_decoder: True' in the config file,
|
||||
the last task is used for the first-pass decoder.
|
||||
If there is no 'is_first_pass_decoder: True' in the config file,
|
||||
the last task whose task_name includes 'target' and decoder_type is not ctc.
|
||||
"""
|
||||
idx = -1
|
||||
for i, (k, v) in enumerate(self.config.items()):
|
||||
if v.is_first_pass_decoder:
|
||||
idx = i
|
||||
if idx < 0:
|
||||
for i, (k, v) in enumerate(self.config.items()):
|
||||
if k.startswith("target") and v.decoder_type == "transformer":
|
||||
idx = i
|
||||
return idx
|
||||
|
||||
|
||||
class SingleTaskConfig(object):
|
||||
def __init__(self, name, config):
|
||||
self.task_name = name
|
||||
self.config = config
|
||||
dict_path = config.get("dict", "")
|
||||
self.tgt_dict = Dictionary.load(dict_path) if Path(dict_path).exists() else None
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.config.get("data", "")
|
||||
|
||||
@property
|
||||
def decoder_type(self):
|
||||
return self.config.get("decoder_type", "transformer")
|
||||
|
||||
@property
|
||||
def decoder_args(self):
|
||||
"""Decoder arch related args"""
|
||||
args = self.config.get("decoder_args", {})
|
||||
return Namespace(**args)
|
||||
|
||||
@property
|
||||
def criterion_cfg(self):
|
||||
"""cfg for the multitask criterion"""
|
||||
if self.decoder_type == "ctc":
|
||||
from fairseq.criterions.ctc import CtcCriterionConfig
|
||||
|
||||
cfg = CtcCriterionConfig
|
||||
cfg.zero_infinity = self.config.get("zero_infinity", True)
|
||||
else:
|
||||
from fairseq.criterions.label_smoothed_cross_entropy import (
|
||||
LabelSmoothedCrossEntropyCriterionConfig,
|
||||
)
|
||||
|
||||
cfg = LabelSmoothedCrossEntropyCriterionConfig
|
||||
cfg.label_smoothing = self.config.get("label_smoothing", 0.2)
|
||||
return cfg
|
||||
|
||||
@property
|
||||
def input_from(self):
|
||||
"""Condition on encoder/decoder of the main model"""
|
||||
return "decoder" if "decoder_layer" in self.config else "encoder"
|
||||
|
||||
@property
|
||||
def input_layer(self):
|
||||
if self.input_from == "decoder":
|
||||
return self.config["decoder_layer"] - 1
|
||||
else:
|
||||
# default using the output from the last encoder layer (-1)
|
||||
return self.config.get("encoder_layer", 0) - 1
|
||||
|
||||
@property
|
||||
def loss_weight_schedule(self):
|
||||
return (
|
||||
"decay"
|
||||
if "loss_weight_max" in self.config
|
||||
and "loss_weight_decay_steps" in self.config
|
||||
else "fixed"
|
||||
)
|
||||
|
||||
def get_loss_weight(self, num_updates):
|
||||
if self.loss_weight_schedule == "fixed":
|
||||
weight = self.config.get("loss_weight", 1.0)
|
||||
else: # "decay"
|
||||
assert (
|
||||
self.config.get("loss_weight_decay_steps", 0) > 0
|
||||
), "loss_weight_decay_steps must be greater than 0 for a decay schedule"
|
||||
loss_weight_min = self.config.get("loss_weight_min", 0.0001)
|
||||
loss_weight_decay_stepsize = (
|
||||
self.config["loss_weight_max"] - loss_weight_min
|
||||
) / self.config["loss_weight_decay_steps"]
|
||||
weight = max(
|
||||
self.config["loss_weight_max"]
|
||||
- loss_weight_decay_stepsize * num_updates,
|
||||
loss_weight_min,
|
||||
)
|
||||
return weight
|
||||
|
||||
@property
|
||||
def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
|
||||
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
|
||||
return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
"""EOS token during generation"""
|
||||
return self.config.get("eos_token", "<eos>")
|
||||
|
||||
@property
|
||||
def rdrop_alpha(self):
|
||||
return self.config.get("rdrop_alpha", 0.0)
|
||||
|
||||
@property
|
||||
def is_first_pass_decoder(self):
|
||||
flag = self.config.get("is_first_pass_decoder", False)
|
||||
if flag:
|
||||
if self.decoder_type == "ctc":
|
||||
raise ValueError(
|
||||
"First-pass decoder in the multi-decoder model must not be CTC."
|
||||
)
|
||||
if "target" not in self.task_name:
|
||||
raise Warning(
|
||||
'The name of the first-pass decoder does not include "target".'
|
||||
)
|
||||
return flag
|
||||
|
||||
@property
|
||||
def get_lang_tag_mapping(self):
|
||||
return self.config.get("lang_tag_mapping", {})
|
||||
Reference in New Issue
Block a user