mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-15 00:07:38 +00:00
192 lines
6.1 KiB
Python
192 lines
6.1 KiB
Python
# Copyright (c) 2017-present, Facebook, Inc.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the LICENSE file in
|
|
# the root directory of this source tree. An additional grant of patent rights
|
|
# can be found in the PATENTS file in the same directory.
|
|
|
|
import logging
|
|
import os
|
|
import sys
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from dataclasses import dataclass, field
|
|
from fairseq.data import Dictionary, HubertDataset
|
|
from fairseq.dataclass.configs import FairseqDataclass
|
|
from fairseq.tasks import register_task
|
|
from fairseq.tasks.fairseq_task import FairseqTask
|
|
from omegaconf import MISSING
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LabelEncoder(object):
|
|
def __init__(self, dictionary: Dictionary) -> None:
|
|
self.dictionary = dictionary
|
|
|
|
def __call__(self, label: str) -> List[str]:
|
|
return self.dictionary.encode_line(
|
|
label,
|
|
append_eos=False,
|
|
add_if_not_exist=False,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class HubertPretrainingConfig(FairseqDataclass):
|
|
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
|
|
fine_tuning: bool = field(
|
|
default=False, metadata={"help": "set to true if fine-tuning Hubert"}
|
|
)
|
|
labels: List[str] = field(
|
|
default_factory=lambda: ["ltr"],
|
|
metadata={
|
|
"help": (
|
|
"extension of the label files to load, frame-level labels for"
|
|
" pre-training, and sequence-level label for fine-tuning"
|
|
)
|
|
},
|
|
)
|
|
label_dir: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "if set, looks for labels in this directory instead",
|
|
},
|
|
)
|
|
label_rate: float = field(
|
|
default=-1.0,
|
|
metadata={"help": "label frame rate. -1.0 for sequence label"},
|
|
)
|
|
sample_rate: int = field(
|
|
default=16_000,
|
|
metadata={
|
|
"help": "target sample rate. audio files will be up/down "
|
|
"sampled to this rate"
|
|
},
|
|
)
|
|
normalize: bool = field(
|
|
default=False,
|
|
metadata={"help": "if set, normalizes input to have 0 mean and unit variance"},
|
|
)
|
|
enable_padding: bool = field(
|
|
default=False,
|
|
metadata={"help": "pad shorter samples instead of cropping"},
|
|
)
|
|
max_keep_size: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "exclude sample longer than this"},
|
|
)
|
|
max_sample_size: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "max sample size to crop to for batching"},
|
|
)
|
|
min_sample_size: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "min sample size to crop to for batching"},
|
|
)
|
|
single_target: Optional[bool] = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "if set, AddTargetDatasets outputs same keys " "as AddTargetDataset"
|
|
},
|
|
)
|
|
random_crop: Optional[bool] = field(
|
|
default=True,
|
|
metadata={"help": "always crop from the beginning if false"},
|
|
)
|
|
pad_audio: Optional[bool] = field(
|
|
default=False,
|
|
metadata={"help": "pad audio to the longest one in the batch if true"},
|
|
)
|
|
|
|
|
|
@register_task("hubert_pretraining", dataclass=HubertPretrainingConfig)
|
|
class HubertPretrainingTask(FairseqTask):
|
|
|
|
cfg: HubertPretrainingConfig
|
|
|
|
def __init__(
|
|
self,
|
|
cfg: HubertPretrainingConfig,
|
|
) -> None:
|
|
super().__init__(cfg)
|
|
|
|
logger.info(f"current directory is {os.getcwd()}")
|
|
logger.info(f"HubertPretrainingTask Config {cfg}")
|
|
|
|
self.cfg = cfg
|
|
self.fine_tuning = cfg.fine_tuning
|
|
|
|
if cfg.fine_tuning:
|
|
self.state.add_factory("target_dictionary", self.load_dictionaries)
|
|
else:
|
|
self.state.add_factory("dictionaries", self.load_dictionaries)
|
|
|
|
self.blank_symbol = "<s>"
|
|
|
|
@property
|
|
def source_dictionary(self) -> Optional[Dictionary]:
|
|
return None
|
|
|
|
@property
|
|
def target_dictionary(self) -> Optional[Dictionary]:
|
|
return self.state.target_dictionary
|
|
|
|
@property
|
|
def dictionaries(self) -> List[Dictionary]:
|
|
return self.state.dictionaries
|
|
|
|
@classmethod
|
|
def setup_task(
|
|
cls, cfg: HubertPretrainingConfig, **kwargs
|
|
) -> "HubertPretrainingTask":
|
|
return cls(cfg)
|
|
|
|
def load_dictionaries(self):
|
|
label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir
|
|
dictionaries = [
|
|
Dictionary.load(f"{label_dir}/dict.{label}.txt")
|
|
for label in self.cfg.labels
|
|
]
|
|
return dictionaries[0] if self.cfg.fine_tuning else dictionaries
|
|
|
|
def get_label_dir(self) -> str:
|
|
if self.cfg.label_dir is None:
|
|
return self.cfg.data
|
|
return self.cfg.label_dir
|
|
|
|
def load_dataset(self, split: str, **kwargs) -> None:
|
|
manifest = f"{self.cfg.data}/{split}.tsv"
|
|
dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries
|
|
pad_list = [dict.pad() for dict in dicts]
|
|
eos_list = [dict.eos() for dict in dicts]
|
|
procs = [LabelEncoder(dict) for dict in dicts]
|
|
paths = [f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels]
|
|
|
|
# hubert v1: pad_audio=True, random_crop=False;
|
|
self.datasets[split] = HubertDataset(
|
|
manifest,
|
|
sample_rate=self.cfg.sample_rate,
|
|
label_paths=paths,
|
|
label_rates=self.cfg.label_rate,
|
|
pad_list=pad_list,
|
|
eos_list=eos_list,
|
|
label_processors=procs,
|
|
max_keep_sample_size=self.cfg.max_keep_size,
|
|
min_keep_sample_size=self.cfg.min_sample_size,
|
|
max_sample_size=self.cfg.max_sample_size,
|
|
pad_audio=self.cfg.pad_audio,
|
|
normalize=self.cfg.normalize,
|
|
store_labels=False,
|
|
random_crop=self.cfg.random_crop,
|
|
single_target=self.cfg.single_target,
|
|
)
|
|
|
|
def max_positions(self) -> Tuple[int, int]:
|
|
return (sys.maxsize, sys.maxsize)
|
|
|
|
def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array:
|
|
return indices
|