mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-26 09:28:57 +00:00
Merge pull request #133 from Tony-sama/neo
RVC expression-based dynamic voice
This commit is contained in:
44
modules/classify/classify_module.py
Normal file
44
modules/classify/classify_module.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
Classify module for SillyTavern Extras
|
||||
|
||||
Authors:
|
||||
- Tony Ribeiro (https://github.com/Tony-sama)
|
||||
|
||||
Provides classification features for text
|
||||
|
||||
References:
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
DEBUG_PREFIX = "<Classify module>"
|
||||
|
||||
# Models init
|
||||
cuda_device = "cuda:0"# if not args.cuda_device else args.cuda_device
|
||||
device_string = cuda_device if torch.cuda.is_available() else 'cpu'
|
||||
device = torch.device(device_string)
|
||||
torch_dtype = torch.float32 if device_string != cuda_device else torch.float16
|
||||
|
||||
text_emotion_pipe = None
|
||||
|
||||
def init_text_emotion_classifier(model_name: str) -> None:
|
||||
global text_emotion_pipe
|
||||
|
||||
print(DEBUG_PREFIX,"Initializing text classification pipeline with model",model_name)
|
||||
text_emotion_pipe = pipeline(
|
||||
"text-classification",
|
||||
model=model_name,
|
||||
top_k=None,
|
||||
device=device,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
|
||||
def classify_text_emotion(text: str) -> list:
|
||||
output = text_emotion_pipe(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=text_emotion_pipe.model.config.max_position_embeddings,
|
||||
)[0]
|
||||
return sorted(output, key=lambda x: x["score"], reverse=True)
|
||||
@@ -22,6 +22,7 @@ from flask import abort, request, send_file, jsonify
|
||||
from TTS.api import TTS
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
from modules.utils import silence_log
|
||||
|
||||
DEBUG_PREFIX = "<Coqui-TTS module>"
|
||||
COQUI_MODELS_PATH = "data/models/coqui/"
|
||||
@@ -71,29 +72,33 @@ def coqui_check_model_state():
|
||||
print(DEBUG_PREFIX,"Search for model", model_id)
|
||||
|
||||
coqui_models_folder = ModelManager().output_prefix # models location
|
||||
installed_models = os.listdir(coqui_models_folder)
|
||||
|
||||
model_folder_exists = False
|
||||
model_folder = None
|
||||
# Check if tts folder exist
|
||||
if os.path.isdir(coqui_models_folder):
|
||||
|
||||
for i in installed_models:
|
||||
if model_id == i.replace("--","/"):
|
||||
model_folder_exists = True
|
||||
model_folder = i
|
||||
print(DEBUG_PREFIX,"Folder found:",model_folder)
|
||||
installed_models = os.listdir(coqui_models_folder)
|
||||
|
||||
# Check failed download
|
||||
if model_folder_exists:
|
||||
content = os.listdir(os.path.join(coqui_models_folder,model_folder))
|
||||
print(DEBUG_PREFIX,"Checking content:",content)
|
||||
for i in content:
|
||||
if i == model_folder+".zip":
|
||||
print("Corrupt installed found, model download must have failed previously")
|
||||
model_state = "corrupted"
|
||||
break
|
||||
model_folder_exists = False
|
||||
model_folder = None
|
||||
|
||||
if model_state != "corrupted":
|
||||
model_state = "installed"
|
||||
for i in installed_models:
|
||||
if model_id == i.replace("--","/",3): # Error with model wrong name
|
||||
model_folder_exists = True
|
||||
model_folder = i
|
||||
print(DEBUG_PREFIX,"Folder found:",model_folder)
|
||||
|
||||
# Check failed download
|
||||
if model_folder_exists:
|
||||
content = os.listdir(os.path.join(coqui_models_folder,model_folder))
|
||||
print(DEBUG_PREFIX,"Checking content:",content)
|
||||
for i in content:
|
||||
if i == model_folder+".zip":
|
||||
print("Corrupt installed found, model download must have failed previously")
|
||||
model_state = "corrupted"
|
||||
break
|
||||
|
||||
if model_state != "corrupted":
|
||||
model_state = "installed"
|
||||
|
||||
response = json.dumps({"model_state":model_state})
|
||||
return response
|
||||
@@ -122,40 +127,43 @@ def coqui_install_model():
|
||||
return json.dumps({"status":"downloading"})
|
||||
|
||||
coqui_models_folder = ModelManager().output_prefix # models location
|
||||
installed_models = os.listdir(coqui_models_folder)
|
||||
model_path = None
|
||||
|
||||
print(DEBUG_PREFIX,"Found",len(installed_models),"models in",coqui_models_folder)
|
||||
# Check if tts folder exist
|
||||
if os.path.isdir(coqui_models_folder):
|
||||
installed_models = os.listdir(coqui_models_folder)
|
||||
model_path = None
|
||||
|
||||
for i in installed_models:
|
||||
if model_id == i.replace("--","/"):
|
||||
model_installed = True
|
||||
model_path = os.path.join(coqui_models_folder,i)
|
||||
print(DEBUG_PREFIX,"Found",len(installed_models),"models in",coqui_models_folder)
|
||||
|
||||
if model_installed:
|
||||
print(DEBUG_PREFIX,"model found:", model_id)
|
||||
else:
|
||||
print(DEBUG_PREFIX,"model not found")
|
||||
for i in installed_models:
|
||||
if model_id == i.replace("--","/"):
|
||||
model_installed = True
|
||||
model_path = os.path.join(coqui_models_folder,i)
|
||||
|
||||
if action == "download":
|
||||
if model_installed:
|
||||
abort(500, DEBUG_PREFIX + "Bad request, model already installed.")
|
||||
print(DEBUG_PREFIX,"model found:", model_id)
|
||||
else:
|
||||
print(DEBUG_PREFIX,"model not found")
|
||||
|
||||
is_downloading = True
|
||||
TTS(model_name=model_id, progress_bar=True, gpu=gpu_mode)
|
||||
is_downloading = False
|
||||
if action == "download":
|
||||
if model_installed:
|
||||
abort(500, DEBUG_PREFIX + "Bad request, model already installed.")
|
||||
|
||||
if action == "repare":
|
||||
if not model_installed:
|
||||
abort(500, DEBUG_PREFIX + " bad request: requesting repare of model not installed")
|
||||
is_downloading = True
|
||||
TTS(model_name=model_id, progress_bar=True, gpu=gpu_mode)
|
||||
is_downloading = False
|
||||
|
||||
if action == "repare":
|
||||
if not model_installed:
|
||||
abort(500, DEBUG_PREFIX + " bad request: requesting repare of model not installed")
|
||||
|
||||
|
||||
print(DEBUG_PREFIX,"Deleting corrupted model folder:",model_path)
|
||||
shutil.rmtree(model_path, ignore_errors=True)
|
||||
print(DEBUG_PREFIX,"Deleting corrupted model folder:",model_path)
|
||||
shutil.rmtree(model_path, ignore_errors=True)
|
||||
|
||||
is_downloading = True
|
||||
TTS(model_name=model_id, progress_bar=True, gpu=gpu_mode)
|
||||
is_downloading = False
|
||||
is_downloading = True
|
||||
TTS(model_name=model_id, progress_bar=True, gpu=gpu_mode)
|
||||
is_downloading = False
|
||||
|
||||
response = json.dumps({"status":"done"})
|
||||
return response
|
||||
|
||||
15
modules/utils.py
Normal file
15
modules/utils.py
Normal file
@@ -0,0 +1,15 @@
|
||||
|
||||
from contextlib import contextmanager
|
||||
import sys
|
||||
|
||||
@contextmanager
|
||||
def silence_log():
|
||||
old_stdout = sys.stdout
|
||||
old_stderr = sys.stderr
|
||||
try:
|
||||
with open(os.devnull, "w") as new_target:
|
||||
sys.stdout = new_target
|
||||
yield new_target
|
||||
finally:
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
@@ -14,13 +14,14 @@ References:
|
||||
"""
|
||||
from flask import abort, request, send_file, jsonify
|
||||
import json
|
||||
import modules.voice_conversion.rvc.rvc as rvc
|
||||
from scipy.io import wavfile
|
||||
import os
|
||||
import io
|
||||
|
||||
from py7zr import pack_7zarchive, unpack_7zarchive
|
||||
import shutil
|
||||
from py7zr import pack_7zarchive, unpack_7zarchive
|
||||
|
||||
import modules.voice_conversion.rvc.rvc as rvc
|
||||
import modules.classify.classify_module as classify_module
|
||||
|
||||
DEBUG_PREFIX = "<RVC module>"
|
||||
RVC_MODELS_PATH = "data/models/rvc/"
|
||||
@@ -30,13 +31,15 @@ TEMP_FOLDER_PATH = "data/tmp/"
|
||||
|
||||
RVC_INPUT_PATH = "data/tmp/rvc_input.wav"
|
||||
RVC_OUTPUT_PATH ="data/tmp/rvc_output.wav"
|
||||
save_file = False
|
||||
|
||||
save_file = False
|
||||
classification_mode = False
|
||||
|
||||
# register file format at first.
|
||||
shutil.register_archive_format('7zip', pack_7zarchive, description='7zip archive')
|
||||
shutil.register_unpack_format('7zip', ['.7z'], unpack_7zarchive)
|
||||
|
||||
|
||||
def rvc_get_models_list():
|
||||
"""
|
||||
Return the list of RVC model in the expected folder
|
||||
@@ -145,8 +148,10 @@ def rvc_process_audio():
|
||||
filterRadius: int [0,7],
|
||||
rmsMixRate: rmsMixRate,
|
||||
protect: float [0,1]
|
||||
text: string
|
||||
"""
|
||||
global save_file
|
||||
global classification_mode
|
||||
|
||||
try:
|
||||
file = request.files.get('AudioFile')
|
||||
@@ -172,31 +177,66 @@ def rvc_process_audio():
|
||||
folder_path = RVC_MODELS_PATH+parameters["modelName"]+"/"
|
||||
model_path = None
|
||||
index_path = None
|
||||
|
||||
print(DEBUG_PREFIX, "Check for pth file in ", folder_path)
|
||||
for file_name in os.listdir(folder_path):
|
||||
if file_name.endswith(".pth"):
|
||||
print(" > set pth as ",file_name)
|
||||
model_path = folder_path+file_name
|
||||
break
|
||||
|
||||
# HACK: emotion mode EXPERIMENTAL
|
||||
if classification_mode:
|
||||
print(DEBUG_PREFIX,"EXPERIMENT MODE: emotions")
|
||||
|
||||
print("> Searching overide code ($emotion$)")
|
||||
emotion = None
|
||||
for code in ["anger","fear", "joy","love","sadness","surprise"]:
|
||||
if "$"+code+"$" in parameters["text"]:
|
||||
print(" > Overide detected:",code)
|
||||
emotion = code
|
||||
parameters["text"] = parameters["text"].replace("$"+code+"$","")
|
||||
print(parameters["text"])
|
||||
break
|
||||
|
||||
if emotion is None:
|
||||
print("> calling text classification pipeline")
|
||||
emotions_score = classify_module.classify_text_emotion(parameters["text"])
|
||||
|
||||
print(" > ",emotions_score)
|
||||
emotion = emotions_score[0]["label"]
|
||||
print(" > Selected:", emotion)
|
||||
|
||||
model_path = folder_path+emotion+".pth"
|
||||
index_path = folder_path+emotion+".index"
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(" > WARNING emotion model pth not found:",model_path," will try loading default")
|
||||
model_path = None
|
||||
|
||||
if not os.path.exists(index_path):
|
||||
print(" > WARNING emotion model index not found:",index_path)
|
||||
index_path = None
|
||||
|
||||
if model_path is None:
|
||||
abort(500, DEBUG_PREFIX + " No pth file found.")
|
||||
print(DEBUG_PREFIX, "Check for pth file in ", folder_path)
|
||||
for file_name in os.listdir(folder_path):
|
||||
if file_name.endswith(".pth"):
|
||||
print(" > set pth as ",file_name)
|
||||
model_path = folder_path+file_name
|
||||
break
|
||||
|
||||
if model_path is None:
|
||||
abort(500, DEBUG_PREFIX + " No pth file found.")
|
||||
|
||||
print(DEBUG_PREFIX, "Check for index file", folder_path)
|
||||
for file_name in os.listdir(folder_path):
|
||||
if file_name.endswith(".index"):
|
||||
print(" > set index as ",file_name)
|
||||
index_path = folder_path+file_name
|
||||
break
|
||||
|
||||
if index_path is None:
|
||||
index_path = ""
|
||||
print(DEBUG_PREFIX, "no index file found, proceeding without index")
|
||||
|
||||
|
||||
print(DEBUG_PREFIX, "loading", model_path)
|
||||
rvc.load_rvc(model_path)
|
||||
|
||||
print(DEBUG_PREFIX, "Check for index file", folder_path)
|
||||
for file_name in os.listdir(folder_path):
|
||||
if file_name.endswith(".index"):
|
||||
print(" > set index as ",file_name)
|
||||
index_path = folder_path+file_name
|
||||
break
|
||||
|
||||
if index_path is None:
|
||||
index_path = ""
|
||||
print(DEBUG_PREFIX, "no index file found, proceeding without index")
|
||||
|
||||
info, (tgt_sr, wav_opt) = rvc.vc_single(
|
||||
sid=0,
|
||||
input_audio_path=input_audio_path,
|
||||
|
||||
28
server.py
28
server.py
@@ -221,16 +221,6 @@ if "summarize" in modules:
|
||||
summarization_model, torch_dtype=torch_dtype
|
||||
).to(device)
|
||||
|
||||
if "classify" in modules:
|
||||
print("Initializing a sentiment classification pipeline...")
|
||||
classification_pipe = pipeline(
|
||||
"text-classification",
|
||||
model=classification_model,
|
||||
top_k=None,
|
||||
device=device,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
if "sd" in modules and not sd_use_remote:
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
@@ -337,6 +327,20 @@ if max_content_length is not None:
|
||||
print("Setting MAX_CONTENT_LENGTH to",max_content_length,"Mb")
|
||||
app.config["MAX_CONTENT_LENGTH"] = int(max_content_length) * 1024 * 1024
|
||||
|
||||
# TODO: Keij, unify main classify and module one
|
||||
if "classify" in modules:
|
||||
print("Initializing a sentiment classification pipeline...")
|
||||
classification_pipe = pipeline(
|
||||
"text-classification",
|
||||
model=classification_model,
|
||||
top_k=None,
|
||||
device=device,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
import modules.classify.classify_module as classify_module
|
||||
classify_module.init_text_emotion_classifier(classification_model)
|
||||
|
||||
if "vosk-stt" in modules:
|
||||
print("Initializing Vosk speech-recognition (from ST request file)")
|
||||
vosk_model_path = (
|
||||
@@ -389,6 +393,10 @@ if "rvc" in modules:
|
||||
|
||||
import modules.voice_conversion.rvc_module as rvc_module
|
||||
rvc_module.save_file = rvc_save_file
|
||||
|
||||
if "classify" in modules:
|
||||
rvc_module.classification_mode = True
|
||||
|
||||
rvc_module.fix_model_install()
|
||||
app.add_url_rule("/api/voice-conversion/rvc/get-models-list", view_func=rvc_module.rvc_get_models_list, methods=["POST"])
|
||||
app.add_url_rule("/api/voice-conversion/rvc/upload-models", view_func=rvc_module.rvc_upload_models, methods=["POST"])
|
||||
|
||||
Reference in New Issue
Block a user