commit 01ad5366ffbb90b674ede6efd6f4c2cfdc3f6111 Author: SillyLossy Date: Tue Feb 28 22:51:34 2023 +0200 Initial commit diff --git a/README.md b/README.md new file mode 100644 index 0000000..e6afdea --- /dev/null +++ b/README.md @@ -0,0 +1,11 @@ +# TavernAI - Extras +## What is this +A set of unofficial APIs for various [TavernAI](https://github.com/TavernAI/TavernAI) extensions +## How to run +* Install Python 3.10 +* Run `pip install -r requirements.txt` +* Run `python server.py` +## Included functionality +* BART model for text summarization +* BERT model for text classification +* BLIP model for image captioning \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3c13ffb --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +flask +flask_cloudflared +Pillow +torch >= 1.9, < 1.13 +git+https://github.com/huggingface/transformers diff --git a/server.py b/server.py new file mode 100644 index 0000000..5f168f7 --- /dev/null +++ b/server.py @@ -0,0 +1,98 @@ +from flask import Flask, jsonify, request, flash, abort +import argparse +from transformers import AutoTokenizer, BartForConditionalGeneration +from transformers import BlipForConditionalGeneration, AutoProcessor +from transformers import AutoTokenizer, AutoModelForSequenceClassification +import unicodedata +import torch +import json +from PIL import Image +import base64 +from io import BytesIO + +# Constants +# Also try: 'Qiliang/bart-large-cnn-samsum-ElectrifAi_v10' +DEFAULT_BART = 'Qiliang/bart-large-cnn-samsum-ChatGPT_v3' +DEFAULT_BERT = 'bhadresh-savani/distilbert-base-uncased-emotion' +DEFAULT_BLIP = 'Salesforce/blip-image-captioning-base' +UPLOAD_FOLDER = './uploads' + +# Script arguments +parser = argparse.ArgumentParser( + prog = 'TavernAI Extras', + description = 'Web API for transformers models') +parser.add_argument('--port', type=int, help="Specify the port on which the application is hosted") +parser.add_argument('--listen', action='store_true', help="Hosts the app on the local network") +parser.add_argument('--share', action='store_true', help="Shares the app on cloudflare tunnel") +parser.add_argument('--bart-model', help="Customize a BART model to be used by the app") +parser.add_argument('--bert-model', help="Customize a BERT model to be used by the app") +parser.add_argument('--blip-model', help="Customize a BLIP model to be used by the app") + +args = parser.parse_args() + +if args.port: + port = args.port +else: + port = 5100 + +if args.share: + from flask_cloudflared import _run_cloudflared + cloudflare = _run_cloudflared(port) + +if args.listen: + host = '0.0.0.0' +else: + host = 'localhost' + +if args.bart_model: + bart_model = args.bart_model +else: + bart_model = DEFAULT_BART + +if args.bert_model: + bert_model = args.bert_model +else: + bert_model = DEFAULT_BERT + +if args.blip_model: + blip_model = args.blip_model +else: + blip_model = DEFAULT_BLIP + +# Models init +print('Initializing BLIP...') +blip_processor = AutoProcessor.from_pretrained(blip_model) +blip = BlipForConditionalGeneration.from_pretrained(blip_model, torch_dtype=torch.float32).to("cpu") + +print('Initializing BART...') +bart_tokenizer = AutoTokenizer.from_pretrained(bart_model) +bart = BartForConditionalGeneration.from_pretrained(bart_model, torch_dtype=torch.float32).to("cpu") + +print('Initializing BERT...') +bert_tokenizer = AutoTokenizer.from_pretrained(bert_model) +bert = AutoModelForSequenceClassification.from_pretrained(bert_model) + +# Flask init +app = Flask(__name__) +app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER +app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024 + +# AI stuff +def caption_image(raw_image, max_new_tokens=20): + inputs = blip_processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32) + outputs = blip.generate(**inputs, max_new_tokens=max_new_tokens) + caption = blip_processor.decode(outputs[0], skip_special_tokens=True) + return caption + +@app.route('/', methods=['GET']) +def index(): + return 'I work OK' + +@app.route('/api/caption', methods=['POST']) +def api_caption(): + data = request.get_json() + image = Image.open(BytesIO(base64.b64decode(data['image']))) + caption = caption_image(image) + return jsonify({ 'caption': caption }) + +app.run(host=host, port=port)