mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-24 08:28:59 +00:00
Add TTS module
This commit is contained in:
@@ -3,9 +3,12 @@ flask-cloudflared
|
||||
flask-cors
|
||||
markdown
|
||||
Pillow
|
||||
--extra-index-url https://download.pytorch.org/whl/cu116
|
||||
--extra-index-url https://download.pytorch.org/whl/cu117
|
||||
torch >= 1.9, < 1.13
|
||||
torchvision >= 0.9, < 0.13
|
||||
torchaudio >= 0.9, < 0.13
|
||||
numpy
|
||||
accelerate
|
||||
git+https://github.com/huggingface/transformers
|
||||
git+https://github.com/huggingface/diffusers
|
||||
git+https://github.com/huggingface/diffusers
|
||||
git+https://github.com/coqui-ai/TTS
|
||||
68
server.py
68
server.py
@@ -1,3 +1,4 @@
|
||||
from functools import wraps
|
||||
from flask import Flask, jsonify, request, render_template_string, abort, send_from_directory
|
||||
from flask_cors import CORS
|
||||
import markdown
|
||||
@@ -8,6 +9,8 @@ from transformers import BlipForConditionalGeneration, GPT2Tokenizer
|
||||
import unicodedata
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
from glob import glob
|
||||
import json
|
||||
import os
|
||||
@@ -28,7 +31,7 @@ DEFAULT_CAPTIONING_MODEL = 'Salesforce/blip-image-captioning-base'
|
||||
DEFAULT_KEYPHRASE_MODEL = 'ml6team/keyphrase-extraction-distilbert-inspec'
|
||||
DEFAULT_PROMPT_MODEL = 'FredZhang7/anime-anything-promptgen-v2'
|
||||
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
|
||||
ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd']
|
||||
ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd', 'tts']
|
||||
DEFAULT_SUMMARIZE_PARAMS = {
|
||||
'temperature': 1.0,
|
||||
'repetition_penalty': 1.0,
|
||||
@@ -127,6 +130,11 @@ if 'sd' in modules:
|
||||
# pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config)
|
||||
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
|
||||
|
||||
if 'tts' in modules:
|
||||
from TTS.api import TTS
|
||||
tts_model = TTS('tts_models/multilingual/multi-dataset/your_tts')
|
||||
sample_rate = 16000
|
||||
|
||||
prompt_prefix = "best quality, absurdres, "
|
||||
neg_prompt = """lowres, bad anatomy, error body, error hair, error arm,
|
||||
error hands, bad hands, error fingers, bad fingers, missing fingers
|
||||
@@ -152,6 +160,17 @@ app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024
|
||||
extensions = []
|
||||
|
||||
|
||||
def require_module(name):
|
||||
def wrapper(fn):
|
||||
@wraps(fn)
|
||||
def decorated_view(*args, **kwargs):
|
||||
if name not in modules:
|
||||
abort(403, 'Module is disabled by config')
|
||||
return fn(*args, **kwargs)
|
||||
return decorated_view
|
||||
return wrapper
|
||||
|
||||
|
||||
def load_extensions():
|
||||
for match in glob("./extensions/*/"):
|
||||
manifest_path = os.path.join(match, 'manifest.json')
|
||||
@@ -238,6 +257,13 @@ def generate_image(input: str, steps: int = 30, scale: int = 6) -> Image:
|
||||
return image
|
||||
|
||||
|
||||
def generate_audio(text: str, voice: str):
|
||||
audio = tts_model.tts(text=text, speaker_wav=f'tts_voices/{voice}.wav', language='en')
|
||||
filename = f'tts_output/{voice}_{time.time_ns()}.wav';
|
||||
wavfile.write(filename, sample_rate, np.array(audio))
|
||||
return filename
|
||||
|
||||
|
||||
def image_to_base64(image: Image):
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="JPEG")
|
||||
@@ -297,10 +323,8 @@ def get_asset(name: str, asset: str):
|
||||
|
||||
|
||||
@app.route('/api/caption', methods=['POST'])
|
||||
@require_module('caption')
|
||||
def api_caption():
|
||||
if 'caption' not in modules:
|
||||
abort(403, 'Module is disabled by config')
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if 'image' not in data or not isinstance(data['image'], str):
|
||||
@@ -312,10 +336,8 @@ def api_caption():
|
||||
|
||||
|
||||
@app.route('/api/summarize', methods=['POST'])
|
||||
@require_module('summarize')
|
||||
def api_summarize():
|
||||
if 'summarize' not in modules:
|
||||
abort(403, 'Module is disabled by config')
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if 'text' not in data or not isinstance(data['text'], str):
|
||||
@@ -331,10 +353,8 @@ def api_summarize():
|
||||
|
||||
|
||||
@app.route('/api/classify', methods=['POST'])
|
||||
@require_module('classify')
|
||||
def api_classify():
|
||||
if 'classify' not in modules:
|
||||
abort(403, 'Module is disabled by config')
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if 'text' not in data or not isinstance(data['text'], str):
|
||||
@@ -345,10 +365,8 @@ def api_classify():
|
||||
|
||||
|
||||
@app.route('/api/keywords', methods=['POST'])
|
||||
@require_module('keywords')
|
||||
def api_keywords():
|
||||
if 'keywords' not in modules:
|
||||
abort(403, 'Module is disabled by config')
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if 'text' not in data or not isinstance(data['text'], str):
|
||||
@@ -359,10 +377,8 @@ def api_keywords():
|
||||
|
||||
|
||||
@app.route('/api/prompt', methods=['POST'])
|
||||
@require_module('prompt')
|
||||
def api_prompt():
|
||||
if 'prompt' not in modules:
|
||||
abort(403, 'Module is disabled by config')
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if 'text' not in data or not isinstance(data['text'], str):
|
||||
@@ -378,10 +394,8 @@ def api_prompt():
|
||||
|
||||
|
||||
@app.route('/api/image', methods=['POST'])
|
||||
@require_module('sd')
|
||||
def api_image():
|
||||
if 'sd' not in modules:
|
||||
abort(403, 'Module is disabled by config')
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if 'prompt' not in data or not isinstance(data['prompt'], str):
|
||||
@@ -392,6 +406,22 @@ def api_image():
|
||||
return jsonify({'image': base64image})
|
||||
|
||||
|
||||
@app.route('/api/tts', methods=['POST'])
|
||||
@require_module('tts')
|
||||
def api_tts():
|
||||
data = request.get_json()
|
||||
|
||||
if 'text' not in data or not isinstance(data['text'], str):
|
||||
abort(400, '"text" is required')
|
||||
|
||||
if 'voice' not in data or not isinstance(data['voice'], str):
|
||||
abort(400, '"voice" is required')
|
||||
|
||||
filename = generate_audio(data['text'], data['voice'])
|
||||
base64audio = base64.b64encode(open(filename, "rb").read())
|
||||
return jsonify({'audio': base64audio})
|
||||
|
||||
|
||||
if args.share:
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
metrics_port = randint(8100, 9000)
|
||||
|
||||
0
tts_output/tts_outputs_will_be_here.txt
Normal file
0
tts_output/tts_outputs_will_be_here.txt
Normal file
0
tts_voices/put_wav_files_here.txt
Normal file
0
tts_voices/put_wav_files_here.txt
Normal file
Reference in New Issue
Block a user