mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-29 02:41:21 +00:00
Coqui TTS Cpu/Gpu Flag Fix
This commit is contained in:
@@ -279,6 +279,8 @@ if "coqui-tts" in modules:
|
||||
print("Initializing Coqui TTS client in " + mode + " mode")
|
||||
import tts_coqui as coqui
|
||||
from tts_coqui import *
|
||||
if mode == "GPU":
|
||||
coqui.setGPU(True)
|
||||
if args.coqui_model is not None:
|
||||
coqui.coqui_modeldownload(args.coqui_model)
|
||||
|
||||
|
||||
19
tts_coqui.py
19
tts_coqui.py
@@ -24,6 +24,12 @@ multspeak = "None"
|
||||
loadedModel = "None"
|
||||
spkdirectory = ""
|
||||
multspeakjson = ""
|
||||
_gpu = False
|
||||
|
||||
def setGPU(flag):
|
||||
global _gpu
|
||||
_gpu = flag
|
||||
return
|
||||
|
||||
def model_type(_config_path):
|
||||
try:
|
||||
@@ -51,9 +57,11 @@ def load_model(_model, _gpu, _progress):
|
||||
global loadedModel
|
||||
global multlang
|
||||
global multspeak
|
||||
|
||||
|
||||
status = None
|
||||
|
||||
print("GPU is set to: ", _gpu)
|
||||
|
||||
_model_directory, _file = os.path.split(_model)
|
||||
|
||||
if _model_directory == "": #make it assign vars correctly if no filename provioded
|
||||
@@ -112,6 +120,7 @@ def load_model(_model, _gpu, _progress):
|
||||
if model_type(_config_path) not in _loadertypes:
|
||||
try:
|
||||
print("Loading ", model_type(_config_path))
|
||||
print("Load Line:", _model_path, _progress, _gpu)
|
||||
tts = TTS(model_path=_model_path, config_path=_config_path, progress_bar=_progress, gpu=_gpu)
|
||||
status = "Loaded"
|
||||
loadedModel = _model
|
||||
@@ -285,9 +294,10 @@ def get_coqui_download_models(): #Avail voices list
|
||||
return json_data
|
||||
|
||||
def coqui_modeldownload(_modeldownload): #Avail voices function
|
||||
global _gpu
|
||||
print(_modeldownload)
|
||||
try:
|
||||
tts = TTS(model_name=_modeldownload, progress_bar=True, gpu=False)
|
||||
tts = TTS(model_name=_modeldownload, progress_bar=True, gpu=_gpu)
|
||||
status = "True"
|
||||
except:
|
||||
status = "False"
|
||||
@@ -300,6 +310,7 @@ def coqui_tts(text, speaker_id, mspker_id, style_wav, language_id):
|
||||
global loadedModel
|
||||
global spkdirectory
|
||||
global multspeakjson
|
||||
global _gpu
|
||||
|
||||
try:
|
||||
# Splitting the string to get speaker_id and the rest
|
||||
@@ -343,7 +354,9 @@ def coqui_tts(text, speaker_id, mspker_id, style_wav, language_id):
|
||||
|
||||
if loadedModel != speaker_id:
|
||||
print("MODEL NOT LOADED!!! Loading... ", loadedModel, speaker_id)
|
||||
load_model(speaker_id, True, True) #use GPU and progress bar?
|
||||
print("Loading :", speaker_id, "GPU is: ", _gpu)
|
||||
|
||||
load_model(speaker_id, _gpu, True)
|
||||
|
||||
|
||||
audio_buffer = io.BytesIO()
|
||||
|
||||
Reference in New Issue
Block a user