Initial commit

This commit is contained in:
SillyLossy
2023-02-28 22:51:34 +02:00
commit 01ad5366ff
3 changed files with 114 additions and 0 deletions

11
README.md Normal file
View File

@@ -0,0 +1,11 @@
# TavernAI - Extras
## What is this
A set of unofficial APIs for various [TavernAI](https://github.com/TavernAI/TavernAI) extensions
## How to run
* Install Python 3.10
* Run `pip install -r requirements.txt`
* Run `python server.py`
## Included functionality
* BART model for text summarization
* BERT model for text classification
* BLIP model for image captioning

5
requirements.txt Normal file
View File

@@ -0,0 +1,5 @@
flask
flask_cloudflared
Pillow
torch >= 1.9, < 1.13
git+https://github.com/huggingface/transformers

98
server.py Normal file
View File

@@ -0,0 +1,98 @@
from flask import Flask, jsonify, request, flash, abort
import argparse
from transformers import AutoTokenizer, BartForConditionalGeneration
from transformers import BlipForConditionalGeneration, AutoProcessor
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import unicodedata
import torch
import json
from PIL import Image
import base64
from io import BytesIO
# Constants
# Also try: 'Qiliang/bart-large-cnn-samsum-ElectrifAi_v10'
DEFAULT_BART = 'Qiliang/bart-large-cnn-samsum-ChatGPT_v3'
DEFAULT_BERT = 'bhadresh-savani/distilbert-base-uncased-emotion'
DEFAULT_BLIP = 'Salesforce/blip-image-captioning-base'
UPLOAD_FOLDER = './uploads'
# Script arguments
parser = argparse.ArgumentParser(
prog = 'TavernAI Extras',
description = 'Web API for transformers models')
parser.add_argument('--port', type=int, help="Specify the port on which the application is hosted")
parser.add_argument('--listen', action='store_true', help="Hosts the app on the local network")
parser.add_argument('--share', action='store_true', help="Shares the app on cloudflare tunnel")
parser.add_argument('--bart-model', help="Customize a BART model to be used by the app")
parser.add_argument('--bert-model', help="Customize a BERT model to be used by the app")
parser.add_argument('--blip-model', help="Customize a BLIP model to be used by the app")
args = parser.parse_args()
if args.port:
port = args.port
else:
port = 5100
if args.share:
from flask_cloudflared import _run_cloudflared
cloudflare = _run_cloudflared(port)
if args.listen:
host = '0.0.0.0'
else:
host = 'localhost'
if args.bart_model:
bart_model = args.bart_model
else:
bart_model = DEFAULT_BART
if args.bert_model:
bert_model = args.bert_model
else:
bert_model = DEFAULT_BERT
if args.blip_model:
blip_model = args.blip_model
else:
blip_model = DEFAULT_BLIP
# Models init
print('Initializing BLIP...')
blip_processor = AutoProcessor.from_pretrained(blip_model)
blip = BlipForConditionalGeneration.from_pretrained(blip_model, torch_dtype=torch.float32).to("cpu")
print('Initializing BART...')
bart_tokenizer = AutoTokenizer.from_pretrained(bart_model)
bart = BartForConditionalGeneration.from_pretrained(bart_model, torch_dtype=torch.float32).to("cpu")
print('Initializing BERT...')
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model)
bert = AutoModelForSequenceClassification.from_pretrained(bert_model)
# Flask init
app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024
# AI stuff
def caption_image(raw_image, max_new_tokens=20):
inputs = blip_processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
outputs = blip.generate(**inputs, max_new_tokens=max_new_tokens)
caption = blip_processor.decode(outputs[0], skip_special_tokens=True)
return caption
@app.route('/', methods=['GET'])
def index():
return 'I work OK'
@app.route('/api/caption', methods=['POST'])
def api_caption():
data = request.get_json()
image = Image.open(BytesIO(base64.b64decode(data['image'])))
caption = caption_image(image)
return jsonify({ 'caption': caption })
app.run(host=host, port=port)