Add Silero TTS server

This commit is contained in:
SillyLossy
2023-05-12 00:31:50 +03:00
parent 782694afe2
commit 42767bd6e8
4 changed files with 77 additions and 2 deletions

View File

@@ -1,5 +1,5 @@
from functools import wraps
from flask import Flask, jsonify, request, render_template_string, abort
from flask import Flask, jsonify, request, render_template_string, abort, send_from_directory, send_file
from flask_cors import CORS
import markdown
import argparse
@@ -9,6 +9,7 @@ from transformers import BlipForConditionalGeneration, GPT2Tokenizer
import unicodedata
import torch
import time
import os
import gc
from PIL import Image
import base64
@@ -32,6 +33,8 @@ DEFAULT_PROMPT_MODEL = 'FredZhang7/anime-anything-promptgen-v2'
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
DEFAULT_REMOTE_SD_PORT = 7860
SILERO_SAMPLES_PATH = 'tts_samples'
SILERO_SAMPLE_TEXT = 'The quick brown fox jumps over the lazy dog'
#ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd']
DEFAULT_SUMMARIZE_PARAMS = {
'temperature': 1.0,
@@ -189,6 +192,17 @@ elif 'sd' in modules and sd_use_remote:
print(f"{Fore.RED}{Style.BRIGHT}Could not connect to remote SD backend at http{'s' if sd_remote_ssl else ''}://{sd_remote_host}:{sd_remote_port}! Disabling SD module...{Style.RESET_ALL}")
modules.remove('sd')
if 'tts' in modules:
if not os.path.exists(SILERO_SAMPLES_PATH):
os.makedirs(SILERO_SAMPLES_PATH)
print('Initializing Silero TTS server')
from silero_api_server import tts
tts_service = tts.SileroTtsService(SILERO_SAMPLES_PATH)
if len(os.listdir(SILERO_SAMPLES_PATH)) == 0:
print('Generating Silero TTS samples...')
tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
tts_service.generate_samples()
PROMPT_PREFIX = "best quality, absurdres, "
NEGATIVE_PROMPT = """lowres, bad anatomy, error body, error hair, error arm,
error hands, bad hands, error fingers, bad fingers, missing fingers
@@ -551,6 +565,36 @@ def api_image_samplers():
def get_modules():
return jsonify({'modules': modules})
@app.route("/api/tts/speakers", methods=['GET'])
def tts_speakers():
voices = [
{
"name":speaker,
"voice_id":speaker,
"preview_url": f"{str(request.url_root)}api/tts/sample/{speaker}"
} for speaker in tts_service.get_speakers()
]
return jsonify(voices)
@app.route("/api/tts/generate", methods=['POST'])
def tts_generate():
voice = request.get_json()
if 'text' not in voice or not isinstance(voice['text'], str):
abort(400, '"text" is required')
if 'speaker' not in voice or not isinstance(voice['speaker'], str):
abort(400, '"speaker" is required')
# Remove asterisks
voice['text'] = voice['text'].replace("*", "")
try:
audio = tts_service.generate(voice['speaker'], voice['text'])
return send_file(audio, mimetype='audio/x-wav')
except Exception as e:
print(e)
abort(500, voice['speaker'])
@app.route("/api/tts/sample/<speaker>", methods=['GET'])
def tts_play_sample(speaker: str):
return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav")
if args.share:
from flask_cloudflared import _run_cloudflared