mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-30 11:21:28 +00:00
Add monkey patched fairseq package to run on python 3.11 (what is needed for our use of RVC at least)
This commit is contained in:
93
modules/voice_conversion/fairseq/data/audio/__init__.py
Normal file
93
modules/voice_conversion/fairseq/data/audio/__init__.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional
|
||||
import importlib
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AudioTransform(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config_dict(cls, config: Optional[Dict] = None):
|
||||
pass
|
||||
|
||||
|
||||
class CompositeAudioTransform(AudioTransform):
|
||||
def _from_config_dict(
|
||||
cls,
|
||||
transform_type,
|
||||
get_audio_transform,
|
||||
composite_cls,
|
||||
config=None,
|
||||
return_empty=False,
|
||||
):
|
||||
_config = {} if config is None else config
|
||||
_transforms = _config.get(f"{transform_type}_transforms")
|
||||
|
||||
if _transforms is None:
|
||||
if return_empty:
|
||||
_transforms = []
|
||||
else:
|
||||
return None
|
||||
|
||||
transforms = [
|
||||
get_audio_transform(_t).from_config_dict(_config.get(_t))
|
||||
for _t in _transforms
|
||||
]
|
||||
return composite_cls(transforms)
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = [t for t in transforms if t is not None]
|
||||
|
||||
def __call__(self, x):
|
||||
for t in self.transforms:
|
||||
x = t(x)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
format_string = (
|
||||
[self.__class__.__name__ + "("]
|
||||
+ [f" {t.__repr__()}" for t in self.transforms]
|
||||
+ [")"]
|
||||
)
|
||||
return "\n".join(format_string)
|
||||
|
||||
|
||||
def register_audio_transform(name, cls_type, registry, class_names):
|
||||
def register_audio_transform_cls(cls):
|
||||
if name in registry:
|
||||
raise ValueError(f"Cannot register duplicate transform ({name})")
|
||||
if not issubclass(cls, cls_type):
|
||||
raise ValueError(
|
||||
f"Transform ({name}: {cls.__name__}) must extend "
|
||||
f"{cls_type.__name__}"
|
||||
)
|
||||
if cls.__name__ in class_names:
|
||||
raise ValueError(
|
||||
f"Cannot register audio transform with duplicate "
|
||||
f"class name ({cls.__name__})"
|
||||
)
|
||||
registry[name] = cls
|
||||
class_names.add(cls.__name__)
|
||||
return cls
|
||||
|
||||
return register_audio_transform_cls
|
||||
|
||||
|
||||
def import_transforms(transforms_dir, transform_type):
|
||||
for file in os.listdir(transforms_dir):
|
||||
path = os.path.join(transforms_dir, file)
|
||||
if (
|
||||
not file.startswith("_")
|
||||
and not file.startswith(".")
|
||||
and (file.endswith(".py") or os.path.isdir(path))
|
||||
):
|
||||
name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||
importlib.import_module(
|
||||
f"fairseq.data.audio.{transform_type}_transforms." + name
|
||||
)
|
||||
|
||||
|
||||
# Utility fn for uniform numbers in transforms
|
||||
def rand_uniform(a, b):
|
||||
return np.random.uniform() * (b - a) + a
|
||||
Reference in New Issue
Block a user