Add TTS module

This commit is contained in:
SillyLossy
2023-03-04 16:55:32 +02:00
parent 8d7ecbc40d
commit ada45f2c70
4 changed files with 54 additions and 21 deletions

View File

@@ -3,9 +3,12 @@ flask-cloudflared
flask-cors
markdown
Pillow
--extra-index-url https://download.pytorch.org/whl/cu116
--extra-index-url https://download.pytorch.org/whl/cu117
torch >= 1.9, < 1.13
torchvision >= 0.9, < 0.13
torchaudio >= 0.9, < 0.13
numpy
accelerate
git+https://github.com/huggingface/transformers
git+https://github.com/huggingface/diffusers
git+https://github.com/huggingface/diffusers
git+https://github.com/coqui-ai/TTS

View File

@@ -1,3 +1,4 @@
from functools import wraps
from flask import Flask, jsonify, request, render_template_string, abort, send_from_directory
from flask_cors import CORS
import markdown
@@ -8,6 +9,8 @@ from transformers import BlipForConditionalGeneration, GPT2Tokenizer
import unicodedata
import torch
import time
import numpy as np
from scipy.io import wavfile
from glob import glob
import json
import os
@@ -28,7 +31,7 @@ DEFAULT_CAPTIONING_MODEL = 'Salesforce/blip-image-captioning-base'
DEFAULT_KEYPHRASE_MODEL = 'ml6team/keyphrase-extraction-distilbert-inspec'
DEFAULT_PROMPT_MODEL = 'FredZhang7/anime-anything-promptgen-v2'
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd']
ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd', 'tts']
DEFAULT_SUMMARIZE_PARAMS = {
'temperature': 1.0,
'repetition_penalty': 1.0,
@@ -127,6 +130,11 @@ if 'sd' in modules:
# pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config)
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
if 'tts' in modules:
from TTS.api import TTS
tts_model = TTS('tts_models/multilingual/multi-dataset/your_tts')
sample_rate = 16000
prompt_prefix = "best quality, absurdres, "
neg_prompt = """lowres, bad anatomy, error body, error hair, error arm,
error hands, bad hands, error fingers, bad fingers, missing fingers
@@ -152,6 +160,17 @@ app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024
extensions = []
def require_module(name):
def wrapper(fn):
@wraps(fn)
def decorated_view(*args, **kwargs):
if name not in modules:
abort(403, 'Module is disabled by config')
return fn(*args, **kwargs)
return decorated_view
return wrapper
def load_extensions():
for match in glob("./extensions/*/"):
manifest_path = os.path.join(match, 'manifest.json')
@@ -238,6 +257,13 @@ def generate_image(input: str, steps: int = 30, scale: int = 6) -> Image:
return image
def generate_audio(text: str, voice: str):
audio = tts_model.tts(text=text, speaker_wav=f'tts_voices/{voice}.wav', language='en')
filename = f'tts_output/{voice}_{time.time_ns()}.wav';
wavfile.write(filename, sample_rate, np.array(audio))
return filename
def image_to_base64(image: Image):
buffered = BytesIO()
image.save(buffered, format="JPEG")
@@ -297,10 +323,8 @@ def get_asset(name: str, asset: str):
@app.route('/api/caption', methods=['POST'])
@require_module('caption')
def api_caption():
if 'caption' not in modules:
abort(403, 'Module is disabled by config')
data = request.get_json()
if 'image' not in data or not isinstance(data['image'], str):
@@ -312,10 +336,8 @@ def api_caption():
@app.route('/api/summarize', methods=['POST'])
@require_module('summarize')
def api_summarize():
if 'summarize' not in modules:
abort(403, 'Module is disabled by config')
data = request.get_json()
if 'text' not in data or not isinstance(data['text'], str):
@@ -331,10 +353,8 @@ def api_summarize():
@app.route('/api/classify', methods=['POST'])
@require_module('classify')
def api_classify():
if 'classify' not in modules:
abort(403, 'Module is disabled by config')
data = request.get_json()
if 'text' not in data or not isinstance(data['text'], str):
@@ -345,10 +365,8 @@ def api_classify():
@app.route('/api/keywords', methods=['POST'])
@require_module('keywords')
def api_keywords():
if 'keywords' not in modules:
abort(403, 'Module is disabled by config')
data = request.get_json()
if 'text' not in data or not isinstance(data['text'], str):
@@ -359,10 +377,8 @@ def api_keywords():
@app.route('/api/prompt', methods=['POST'])
@require_module('prompt')
def api_prompt():
if 'prompt' not in modules:
abort(403, 'Module is disabled by config')
data = request.get_json()
if 'text' not in data or not isinstance(data['text'], str):
@@ -378,10 +394,8 @@ def api_prompt():
@app.route('/api/image', methods=['POST'])
@require_module('sd')
def api_image():
if 'sd' not in modules:
abort(403, 'Module is disabled by config')
data = request.get_json()
if 'prompt' not in data or not isinstance(data['prompt'], str):
@@ -392,6 +406,22 @@ def api_image():
return jsonify({'image': base64image})
@app.route('/api/tts', methods=['POST'])
@require_module('tts')
def api_tts():
data = request.get_json()
if 'text' not in data or not isinstance(data['text'], str):
abort(400, '"text" is required')
if 'voice' not in data or not isinstance(data['voice'], str):
abort(400, '"voice" is required')
filename = generate_audio(data['text'], data['voice'])
base64audio = base64.b64encode(open(filename, "rb").read())
return jsonify({'audio': base64audio})
if args.share:
from flask_cloudflared import _run_cloudflared
metrics_port = randint(8100, 9000)

View File

View File