mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-30 11:21:28 +00:00
Add monkey patched fairseq package to run on python 3.11 (what is needed for our use of RVC at least)
This commit is contained in:
@@ -0,0 +1,63 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from fairseq.data import Dictionary
|
||||
|
||||
|
||||
class EncoderLangtok(Enum):
|
||||
"""
|
||||
Prepend to the beginning of source sentence either the
|
||||
source or target language token. (src/tgt).
|
||||
"""
|
||||
|
||||
src = "src"
|
||||
tgt = "tgt"
|
||||
|
||||
|
||||
class LangTokSpec(Enum):
|
||||
main = "main"
|
||||
mono_dae = "mono_dae"
|
||||
|
||||
|
||||
class LangTokStyle(Enum):
|
||||
multilingual = "multilingual"
|
||||
mbart = "mbart"
|
||||
|
||||
|
||||
@torch.jit.export
|
||||
def get_lang_tok(
|
||||
lang: str, lang_tok_style: str, spec: str = LangTokSpec.main.value
|
||||
) -> str:
|
||||
# TOKEN_STYLES can't be defined outside this fn since it needs to be
|
||||
# TorchScriptable.
|
||||
TOKEN_STYLES: Dict[str, str] = {
|
||||
LangTokStyle.mbart.value: "[{}]",
|
||||
LangTokStyle.multilingual.value: "__{}__",
|
||||
}
|
||||
|
||||
if spec.endswith("dae"):
|
||||
lang = f"{lang}_dae"
|
||||
elif spec.endswith("mined"):
|
||||
lang = f"{lang}_mined"
|
||||
style = TOKEN_STYLES[lang_tok_style]
|
||||
return style.format(lang)
|
||||
|
||||
|
||||
def augment_dictionary(
|
||||
dictionary: Dictionary,
|
||||
language_list: List[str],
|
||||
lang_tok_style: str,
|
||||
langtoks_specs: Sequence[str] = (LangTokSpec.main.value,),
|
||||
extra_data: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
for spec in langtoks_specs:
|
||||
for language in language_list:
|
||||
dictionary.add_symbol(
|
||||
get_lang_tok(lang=language, lang_tok_style=lang_tok_style, spec=spec)
|
||||
)
|
||||
|
||||
if lang_tok_style == LangTokStyle.mbart.value or (
|
||||
extra_data is not None and LangTokSpec.mono_dae.value in extra_data
|
||||
):
|
||||
dictionary.add_symbol("<mask>")
|
||||
Reference in New Issue
Block a user