mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-30 19:31:20 +00:00
Initial commit
This commit is contained in:
11
README.md
Normal file
11
README.md
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# TavernAI - Extras
|
||||||
|
## What is this
|
||||||
|
A set of unofficial APIs for various [TavernAI](https://github.com/TavernAI/TavernAI) extensions
|
||||||
|
## How to run
|
||||||
|
* Install Python 3.10
|
||||||
|
* Run `pip install -r requirements.txt`
|
||||||
|
* Run `python server.py`
|
||||||
|
## Included functionality
|
||||||
|
* BART model for text summarization
|
||||||
|
* BERT model for text classification
|
||||||
|
* BLIP model for image captioning
|
||||||
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
flask
|
||||||
|
flask_cloudflared
|
||||||
|
Pillow
|
||||||
|
torch >= 1.9, < 1.13
|
||||||
|
git+https://github.com/huggingface/transformers
|
||||||
98
server.py
Normal file
98
server.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
from flask import Flask, jsonify, request, flash, abort
|
||||||
|
import argparse
|
||||||
|
from transformers import AutoTokenizer, BartForConditionalGeneration
|
||||||
|
from transformers import BlipForConditionalGeneration, AutoProcessor
|
||||||
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||||
|
import unicodedata
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
from PIL import Image
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
# Also try: 'Qiliang/bart-large-cnn-samsum-ElectrifAi_v10'
|
||||||
|
DEFAULT_BART = 'Qiliang/bart-large-cnn-samsum-ChatGPT_v3'
|
||||||
|
DEFAULT_BERT = 'bhadresh-savani/distilbert-base-uncased-emotion'
|
||||||
|
DEFAULT_BLIP = 'Salesforce/blip-image-captioning-base'
|
||||||
|
UPLOAD_FOLDER = './uploads'
|
||||||
|
|
||||||
|
# Script arguments
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog = 'TavernAI Extras',
|
||||||
|
description = 'Web API for transformers models')
|
||||||
|
parser.add_argument('--port', type=int, help="Specify the port on which the application is hosted")
|
||||||
|
parser.add_argument('--listen', action='store_true', help="Hosts the app on the local network")
|
||||||
|
parser.add_argument('--share', action='store_true', help="Shares the app on cloudflare tunnel")
|
||||||
|
parser.add_argument('--bart-model', help="Customize a BART model to be used by the app")
|
||||||
|
parser.add_argument('--bert-model', help="Customize a BERT model to be used by the app")
|
||||||
|
parser.add_argument('--blip-model', help="Customize a BLIP model to be used by the app")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.port:
|
||||||
|
port = args.port
|
||||||
|
else:
|
||||||
|
port = 5100
|
||||||
|
|
||||||
|
if args.share:
|
||||||
|
from flask_cloudflared import _run_cloudflared
|
||||||
|
cloudflare = _run_cloudflared(port)
|
||||||
|
|
||||||
|
if args.listen:
|
||||||
|
host = '0.0.0.0'
|
||||||
|
else:
|
||||||
|
host = 'localhost'
|
||||||
|
|
||||||
|
if args.bart_model:
|
||||||
|
bart_model = args.bart_model
|
||||||
|
else:
|
||||||
|
bart_model = DEFAULT_BART
|
||||||
|
|
||||||
|
if args.bert_model:
|
||||||
|
bert_model = args.bert_model
|
||||||
|
else:
|
||||||
|
bert_model = DEFAULT_BERT
|
||||||
|
|
||||||
|
if args.blip_model:
|
||||||
|
blip_model = args.blip_model
|
||||||
|
else:
|
||||||
|
blip_model = DEFAULT_BLIP
|
||||||
|
|
||||||
|
# Models init
|
||||||
|
print('Initializing BLIP...')
|
||||||
|
blip_processor = AutoProcessor.from_pretrained(blip_model)
|
||||||
|
blip = BlipForConditionalGeneration.from_pretrained(blip_model, torch_dtype=torch.float32).to("cpu")
|
||||||
|
|
||||||
|
print('Initializing BART...')
|
||||||
|
bart_tokenizer = AutoTokenizer.from_pretrained(bart_model)
|
||||||
|
bart = BartForConditionalGeneration.from_pretrained(bart_model, torch_dtype=torch.float32).to("cpu")
|
||||||
|
|
||||||
|
print('Initializing BERT...')
|
||||||
|
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model)
|
||||||
|
bert = AutoModelForSequenceClassification.from_pretrained(bert_model)
|
||||||
|
|
||||||
|
# Flask init
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
|
||||||
|
app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024
|
||||||
|
|
||||||
|
# AI stuff
|
||||||
|
def caption_image(raw_image, max_new_tokens=20):
|
||||||
|
inputs = blip_processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
|
||||||
|
outputs = blip.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||||
|
caption = blip_processor.decode(outputs[0], skip_special_tokens=True)
|
||||||
|
return caption
|
||||||
|
|
||||||
|
@app.route('/', methods=['GET'])
|
||||||
|
def index():
|
||||||
|
return 'I work OK'
|
||||||
|
|
||||||
|
@app.route('/api/caption', methods=['POST'])
|
||||||
|
def api_caption():
|
||||||
|
data = request.get_json()
|
||||||
|
image = Image.open(BytesIO(base64.b64decode(data['image'])))
|
||||||
|
caption = caption_image(image)
|
||||||
|
return jsonify({ 'caption': caption })
|
||||||
|
|
||||||
|
app.run(host=host, port=port)
|
||||||
Reference in New Issue
Block a user