Files
SillyTavern-extras/server.py
2023-02-28 23:08:28 +02:00

95 lines
3.0 KiB
Python

from flask import Flask, jsonify, request, flash, abort
import argparse
from transformers import AutoTokenizer, BartForConditionalGeneration
from transformers import BlipForConditionalGeneration, AutoProcessor
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import unicodedata
import torch
import json
from PIL import Image
import base64
from io import BytesIO
# Constants
# Also try: 'Qiliang/bart-large-cnn-samsum-ElectrifAi_v10'
DEFAULT_BART = 'Qiliang/bart-large-cnn-samsum-ChatGPT_v3'
DEFAULT_BERT = 'bhadresh-savani/distilbert-base-uncased-emotion'
DEFAULT_BLIP = 'Salesforce/blip-image-captioning-base'
# Script arguments
parser = argparse.ArgumentParser(prog = 'TavernAI Extras', description = 'Web API for transformers models')
parser.add_argument('--port', type=int, help="Specify the port on which the application is hosted")
parser.add_argument('--listen', action='store_true', help="Hosts the app on the local network")
parser.add_argument('--share', action='store_true', help="Shares the app on CloudFlare tunnel")
parser.add_argument('--bart-model', help="Load a custom BART model")
parser.add_argument('--bert-model', help="Load a custom BERT model")
parser.add_argument('--blip-model', help="Load a custom BLIP model")
args = parser.parse_args()
if args.port:
port = args.port
else:
port = 5100
if args.share:
from flask_cloudflared import _run_cloudflared
cloudflare = _run_cloudflared(port)
if args.listen:
host = '0.0.0.0'
else:
host = 'localhost'
if args.bart_model:
bart_model = args.bart_model
else:
bart_model = DEFAULT_BART
if args.bert_model:
bert_model = args.bert_model
else:
bert_model = DEFAULT_BERT
if args.blip_model:
blip_model = args.blip_model
else:
blip_model = DEFAULT_BLIP
# Models init
print('Initializing BLIP...')
blip_processor = AutoProcessor.from_pretrained(blip_model)
blip = BlipForConditionalGeneration.from_pretrained(blip_model, torch_dtype=torch.float32).to("cpu")
print('Initializing BART...')
bart_tokenizer = AutoTokenizer.from_pretrained(bart_model)
bart = BartForConditionalGeneration.from_pretrained(bart_model, torch_dtype=torch.float32).to("cpu")
print('Initializing BERT...')
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model)
bert = AutoModelForSequenceClassification.from_pretrained(bert_model)
# Flask init
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024
# AI stuff
def caption_image(raw_image, max_new_tokens=20):
inputs = blip_processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
outputs = blip.generate(**inputs, max_new_tokens=max_new_tokens)
caption = blip_processor.decode(outputs[0], skip_special_tokens=True)
return caption
@app.route('/', methods=['GET'])
def index():
return 'I work OK'
@app.route('/api/caption', methods=['POST'])
def api_caption():
data = request.get_json()
image = Image.open(BytesIO(base64.b64decode(data['image'])))
caption = caption_image(image)
return jsonify({ 'caption': caption })
app.run(host=host, port=port)