mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-29 19:01:20 +00:00
Add Silero TTS server
This commit is contained in:
46
server.py
46
server.py
@@ -1,5 +1,5 @@
|
||||
from functools import wraps
|
||||
from flask import Flask, jsonify, request, render_template_string, abort
|
||||
from flask import Flask, jsonify, request, render_template_string, abort, send_from_directory, send_file
|
||||
from flask_cors import CORS
|
||||
import markdown
|
||||
import argparse
|
||||
@@ -9,6 +9,7 @@ from transformers import BlipForConditionalGeneration, GPT2Tokenizer
|
||||
import unicodedata
|
||||
import torch
|
||||
import time
|
||||
import os
|
||||
import gc
|
||||
from PIL import Image
|
||||
import base64
|
||||
@@ -32,6 +33,8 @@ DEFAULT_PROMPT_MODEL = 'FredZhang7/anime-anything-promptgen-v2'
|
||||
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
|
||||
DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
|
||||
DEFAULT_REMOTE_SD_PORT = 7860
|
||||
SILERO_SAMPLES_PATH = 'tts_samples'
|
||||
SILERO_SAMPLE_TEXT = 'The quick brown fox jumps over the lazy dog'
|
||||
#ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd']
|
||||
DEFAULT_SUMMARIZE_PARAMS = {
|
||||
'temperature': 1.0,
|
||||
@@ -189,6 +192,17 @@ elif 'sd' in modules and sd_use_remote:
|
||||
print(f"{Fore.RED}{Style.BRIGHT}Could not connect to remote SD backend at http{'s' if sd_remote_ssl else ''}://{sd_remote_host}:{sd_remote_port}! Disabling SD module...{Style.RESET_ALL}")
|
||||
modules.remove('sd')
|
||||
|
||||
if 'tts' in modules:
|
||||
if not os.path.exists(SILERO_SAMPLES_PATH):
|
||||
os.makedirs(SILERO_SAMPLES_PATH)
|
||||
print('Initializing Silero TTS server')
|
||||
from silero_api_server import tts
|
||||
tts_service = tts.SileroTtsService(SILERO_SAMPLES_PATH)
|
||||
if len(os.listdir(SILERO_SAMPLES_PATH)) == 0:
|
||||
print('Generating Silero TTS samples...')
|
||||
tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
|
||||
tts_service.generate_samples()
|
||||
|
||||
PROMPT_PREFIX = "best quality, absurdres, "
|
||||
NEGATIVE_PROMPT = """lowres, bad anatomy, error body, error hair, error arm,
|
||||
error hands, bad hands, error fingers, bad fingers, missing fingers
|
||||
@@ -551,6 +565,36 @@ def api_image_samplers():
|
||||
def get_modules():
|
||||
return jsonify({'modules': modules})
|
||||
|
||||
@app.route("/api/tts/speakers", methods=['GET'])
|
||||
def tts_speakers():
|
||||
voices = [
|
||||
{
|
||||
"name":speaker,
|
||||
"voice_id":speaker,
|
||||
"preview_url": f"{str(request.url_root)}api/tts/sample/{speaker}"
|
||||
} for speaker in tts_service.get_speakers()
|
||||
]
|
||||
return jsonify(voices)
|
||||
|
||||
@app.route("/api/tts/generate", methods=['POST'])
|
||||
def tts_generate():
|
||||
voice = request.get_json()
|
||||
if 'text' not in voice or not isinstance(voice['text'], str):
|
||||
abort(400, '"text" is required')
|
||||
if 'speaker' not in voice or not isinstance(voice['speaker'], str):
|
||||
abort(400, '"speaker" is required')
|
||||
# Remove asterisks
|
||||
voice['text'] = voice['text'].replace("*", "")
|
||||
try:
|
||||
audio = tts_service.generate(voice['speaker'], voice['text'])
|
||||
return send_file(audio, mimetype='audio/x-wav')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
abort(500, voice['speaker'])
|
||||
|
||||
@app.route("/api/tts/sample/<speaker>", methods=['GET'])
|
||||
def tts_play_sample(speaker: str):
|
||||
return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav")
|
||||
|
||||
if args.share:
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
|
||||
Reference in New Issue
Block a user