mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-05 11:30:13 +00:00
Add keyword extraction and image generation APIs
This commit is contained in:
60
README.md
60
README.md
@@ -81,14 +81,56 @@ A set of unofficial APIs for various [TavernAI](https://github.com/TavernAI/Tave
|
||||
> 2. Six fixed categories
|
||||
> 3. Value range from 0.0 to 1.0
|
||||
|
||||
### Key phrase extraction
|
||||
`POST /api/keywords`
|
||||
#### **Input**
|
||||
```
|
||||
{ "text": "text to be scanned for key phrases" }
|
||||
```
|
||||
#### **Output**
|
||||
```
|
||||
{
|
||||
"keywords": [
|
||||
"array of",
|
||||
"extracted",
|
||||
"keywords",
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### GPT-2 for Stable Diffusion prompt generation
|
||||
`POST /api/prompt`
|
||||
#### **Input**
|
||||
```
|
||||
{ "name": "character name (optional)", "text": "textual summary of a character" }
|
||||
```
|
||||
#### **Output**
|
||||
```
|
||||
{ "prompts": [ "array of generated prompts" ] }
|
||||
```
|
||||
|
||||
### Stable Diffusion for image generation
|
||||
`POST /api/image`
|
||||
#### **Input**
|
||||
```
|
||||
{ "prompt": "prompt to be generated" }
|
||||
```
|
||||
#### **Output**
|
||||
```
|
||||
{ "image": "base64 encoded image" }
|
||||
```
|
||||
|
||||
## Additional options
|
||||
| Flag | Description |
|
||||
| -------------- | -------------------------------------------------------------------- |
|
||||
| `--port` | Specify the port on which the application is hosted. Default: *5100* |
|
||||
| `--listen` | Hosts the app on the local network |
|
||||
| `--share` | Shares the app on CloudFlare tunnel |
|
||||
| `--cpu` | Run the models on the CPU instead of CUDA |
|
||||
| `--bart-model` | Load a custom BART model.<br>Expects a HuggingFace model ID.<br>Default: [Qiliang/bart-large-cnn-samsum-ChatGPT_v3](https://huggingface.co/Qiliang/bart-large-cnn-samsum-ChatGPT_v3) |
|
||||
| `--bert-model` | Load a custom BERT model.<br>Expects a HuggingFace model ID.<br>Default: [bhadresh-savani/distilbert-base-uncased-emotion](https://huggingface.co/bhadresh-savani/distilbert-base-uncased-emotion) |
|
||||
| `--blip-model` | Load a custom BLIP model.<br>Expects a HuggingFace model Id.<br>Default: [Salesforce/blip-image-captioning-base](https://huggingface.co/Salesforce/blip-image-captioning-base) |
|
||||
| Flag | Description |
|
||||
| ------------------------ | ---------------------------------------------------------------------- |
|
||||
| `--port` | Specify the port on which the application is hosted. Default: **5100** |
|
||||
| `--listen` | Hosts the app on the local network |
|
||||
| `--share` | Shares the app on CloudFlare tunnel |
|
||||
| `--cpu` | Run the models on the CPU instead of CUDA |
|
||||
| `--summarization-model` | Load a custom BART summarization model.<br>Expects a HuggingFace model ID.<br>Default: [Qiliang/bart-large-cnn-samsum-ChatGPT_v3](https://huggingface.co/Qiliang/bart-large-cnn-samsum-ChatGPT_v3) |
|
||||
| `--classification-model` | Load a custom BERT sentiment classification model.<br>Expects a HuggingFace model ID.<br>Default: [bhadresh-savani/distilbert-base-uncased-emotion](https://huggingface.co/bhadresh-savani/distilbert-base-uncased-emotion) |
|
||||
| `--captioning-model` | Load a custom BLIP captioning model.<br>Expects a HuggingFace model ID.<br>Default: [Salesforce/blip-image-captioning-base](https://huggingface.co/Salesforce/blip-image-captioning-base) |
|
||||
| `--keyphrase-model` | Load a custom key phrase extraction model.<br>Expects a HuggingFace model ID.<br>Default: [ml6team/keyphrase-extraction-distilbert-inspec](https://huggingface.co/ml6team/keyphrase-extraction-distilbert-inspec) |
|
||||
| `--prompt-model` | Load a custom GPT-2 prompt generation model.<br>Expects a HuggingFace model ID.<br>Default: [FredZhang7/anime-anything-promptgen-v2](https://huggingface.co/FredZhang7/anime-anything-promptgen-v2) |
|
||||
| `--sd-model` | Load a custom Stable Diffusion image generation model.<br>Expects a HuggingFace model ID.<br>Default: [ckpt/anything-v4.5-vae-swapped](https://huggingface.co/ckpt/anything-v4.5-vae-swapped)<br>*Must have VAE pre-baked in PyTorch format or the output will look drab!* |
|
||||
| `--sd-cpu` | Forces the Stable Diffusion generation pipeline to run on the CPU.<br>**SLOW!** |
|
||||
@@ -2,5 +2,9 @@ flask
|
||||
flask-cloudflared
|
||||
markdown
|
||||
Pillow
|
||||
--extra-index-url https://download.pytorch.org/whl/cu116
|
||||
torch >= 1.9, < 1.13
|
||||
numpy
|
||||
accelerate
|
||||
git+https://github.com/huggingface/transformers
|
||||
git+https://github.com/huggingface/diffusers
|
||||
235
server.py
235
server.py
@@ -3,18 +3,28 @@ import markdown
|
||||
import argparse
|
||||
from transformers import AutoTokenizer, AutoProcessor, pipeline
|
||||
from transformers import BlipForConditionalGeneration, BartForConditionalGeneration
|
||||
from transformers import AutoModelForTokenClassification, TokenClassificationPipeline
|
||||
from transformers.pipelines import AggregationStrategy
|
||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
import unicodedata
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
|
||||
|
||||
# Constants
|
||||
# Also try: 'Qiliang/bart-large-cnn-samsum-ElectrifAi_v10'
|
||||
DEFAULT_BART = 'Qiliang/bart-large-cnn-samsum-ChatGPT_v3'
|
||||
DEFAULT_BERT = 'bhadresh-savani/distilbert-base-uncased-emotion'
|
||||
DEFAULT_BLIP = 'Salesforce/blip-image-captioning-base'
|
||||
DEFAULT_SUMMARIZATION_MODEL = 'Qiliang/bart-large-cnn-samsum-ChatGPT_v3'
|
||||
DEFAULT_CLASSIFICATION_MODEL = 'bhadresh-savani/distilbert-base-uncased-emotion'
|
||||
DEFAULT_CAPTIONING_MODEL = 'Salesforce/blip-image-captioning-base'
|
||||
DEFAULT_KEYPHRASE_MODEL = 'ml6team/keyphrase-extraction-distilbert-inspec'
|
||||
DEFAULT_PROMPT_MODEL = 'FredZhang7/anime-anything-promptgen-v2'
|
||||
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
|
||||
DEFAULT_SUMMARIZE_PARAMS = {
|
||||
'temperature': 1.0,
|
||||
'repetition_penalty': 1.0,
|
||||
@@ -35,55 +45,115 @@ parser.add_argument('--share', action='store_true',
|
||||
help="Shares the app on CloudFlare tunnel")
|
||||
parser.add_argument('--cpu', action='store_true',
|
||||
help="Runs the models on the CPU")
|
||||
parser.add_argument('--bart-model', help="Load a custom BART model")
|
||||
parser.add_argument('--bert-model', help="Load a custom BERT model")
|
||||
parser.add_argument('--blip-model', help="Load a custom BLIP model")
|
||||
parser.add_argument('--summarization-model',
|
||||
help="Load a custom BART summarization model")
|
||||
parser.add_argument('--classification-model',
|
||||
help="Load a custom BERT text classification model")
|
||||
parser.add_argument('--captioning-model',
|
||||
help="Load a custom BLIP captioning model")
|
||||
parser.add_argument('--keyphrase-model',
|
||||
help="Load a custom keyphrase extraction model")
|
||||
parser.add_argument('--prompt-model',
|
||||
help="Load a custom GPT-2 prompt generation model")
|
||||
parser.add_argument('--sd-model',
|
||||
help="Load a custom SD image generation model")
|
||||
parser.add_argument('--sd-cpu',
|
||||
help="Force the SD pipeline to run on the CPU")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.port:
|
||||
port = args.port
|
||||
else:
|
||||
port = 5100
|
||||
|
||||
if args.listen:
|
||||
host = '0.0.0.0'
|
||||
else:
|
||||
host = 'localhost'
|
||||
|
||||
if args.bart_model:
|
||||
bart_model = args.bart_model
|
||||
else:
|
||||
bart_model = DEFAULT_BART
|
||||
|
||||
if args.bert_model:
|
||||
bert_model = args.bert_model
|
||||
else:
|
||||
bert_model = DEFAULT_BERT
|
||||
|
||||
if args.blip_model:
|
||||
blip_model = args.blip_model
|
||||
else:
|
||||
blip_model = DEFAULT_BLIP
|
||||
port = args.port if args.port else 5100
|
||||
host = '0.0.0.0' if args.listen else 'localhost'
|
||||
summarization_model = args.summarization_model if args.summarization_model else DEFAULT_SUMMARIZATION_MODEL
|
||||
classification_model = args.classification_model if args.classification_model else DEFAULT_CLASSIFICATION_MODEL
|
||||
captioning_model = args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
|
||||
keyphrase_model = args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL
|
||||
prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL
|
||||
sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
|
||||
|
||||
# Models init
|
||||
device_string = "cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu"
|
||||
device_string = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"
|
||||
device = torch.device(device_string)
|
||||
torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
|
||||
|
||||
print('Initializing BLIP...')
|
||||
blip_processor = AutoProcessor.from_pretrained(blip_model)
|
||||
print('Initializing BLIP image captioning model...')
|
||||
blip_processor = AutoProcessor.from_pretrained(captioning_model)
|
||||
blip = BlipForConditionalGeneration.from_pretrained(
|
||||
blip_model, torch_dtype=torch_dtype).to(device)
|
||||
captioning_model, torch_dtype=torch_dtype).to(device)
|
||||
|
||||
print('Initializing BART...')
|
||||
bart_tokenizer = AutoTokenizer.from_pretrained(bart_model)
|
||||
print('Initializing BART text summarization model...')
|
||||
bart_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
|
||||
bart = BartForConditionalGeneration.from_pretrained(
|
||||
bart_model, torch_dtype=torch_dtype).to(device)
|
||||
summarization_model, torch_dtype=torch_dtype).to(device)
|
||||
|
||||
print('Initializing BERT...')
|
||||
bert_classifier = pipeline("text-classification", model=bert_model,
|
||||
return_all_scores=True, device=device, torch_dtype=torch_dtype)
|
||||
print('Initializing BERT sentiment classification model...')
|
||||
bert_classifier = pipeline("text-classification", model=classification_model,
|
||||
top_k=None, device=device, torch_dtype=torch_dtype)
|
||||
|
||||
print('Initializing keyword extractor...')
|
||||
|
||||
|
||||
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(
|
||||
model=AutoModelForTokenClassification.from_pretrained(model),
|
||||
tokenizer=AutoTokenizer.from_pretrained(model),
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
results = super().postprocess(
|
||||
model_outputs=model_outputs,
|
||||
aggregation_strategy=AggregationStrategy.SIMPLE
|
||||
if self.model.config.model_type == "roberta"
|
||||
else AggregationStrategy.FIRST,
|
||||
)
|
||||
return np.unique([result.get("word").strip() for result in results])
|
||||
|
||||
|
||||
keyphrase_pipe = KeyphraseExtractionPipeline(keyphrase_model)
|
||||
|
||||
print('Initializing GPT prompt generator')
|
||||
gpt_tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
|
||||
gpt_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
gpt_model = GPT2LMHeadModel.from_pretrained(
|
||||
'FredZhang7/anime-anything-promptgen-v2')
|
||||
prompt_generator = pipeline(
|
||||
'text-generation', model=gpt_model, tokenizer=gpt_tokenizer)
|
||||
|
||||
|
||||
print('Initializing Stable Diffusion pipeline')
|
||||
sd_device_string = "cuda" if torch.cuda.is_available() and not args.sd_cpu else "cpu"
|
||||
sd_device = torch.device(sd_device_string)
|
||||
sd_torch_dtype = torch.float32 if sd_device_string == "cpu" else torch.float16
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained(
|
||||
sd_model,
|
||||
custom_pipeline="lpw_stable_diffusion",
|
||||
torch_dtype=sd_torch_dtype,
|
||||
).to(sd_device)
|
||||
sd_pipe.safety_checker = lambda images, clip_input: (images, False)
|
||||
sd_pipe.enable_attention_slicing()
|
||||
# pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config)
|
||||
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||
sd_pipe.scheduler.config)
|
||||
|
||||
prompt_prefix = "best quality, absurdres, "
|
||||
neg_prompt = """lowres, bad anatomy, error body, error hair, error arm,
|
||||
error hands, bad hands, error fingers, bad fingers, missing fingers
|
||||
error legs, bad legs, multiple legs, missing legs, error lighting,
|
||||
error shadow, error reflection, text, error, extra digit, fewer digits,
|
||||
cropped, worst quality, low quality, normal quality, jpeg artifacts,
|
||||
signature, watermark, username, blurry"""
|
||||
|
||||
|
||||
# list of key phrases to be looking for in text (unused for now)
|
||||
indicator_list = ['female', 'girl', 'male', 'boy', 'woman', 'man', 'hair', 'eyes', 'skin', 'wears',
|
||||
'appearance', 'costume', 'clothes', 'body', 'tall', 'short', 'chubby', 'thin',
|
||||
'expression', 'angry', 'sad', 'blush', 'smile', 'happy', 'depressed', 'long',
|
||||
'cold', 'breasts', 'chest', 'tail', 'ears', 'fur', 'race', 'species', 'wearing',
|
||||
'shoes', 'boots', 'shirt', 'panties', 'bra', 'skirt', 'dress', 'kimono', 'wings', 'horns',
|
||||
'pants', 'shorts', 'leggins', 'sandals', 'hat', 'glasses', 'sweater', 'hoodie', 'sweatshirt']
|
||||
|
||||
# Flask init
|
||||
app = Flask(__name__)
|
||||
@@ -126,11 +196,52 @@ def summarize(text: str, params: dict) -> str:
|
||||
summary = bart_tokenizer.batch_decode(
|
||||
summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)[0]
|
||||
# Normalize string
|
||||
summary = " ".join(unicodedata.normalize("NFKC", summary).strip().split())
|
||||
summary = normalize_string(summary)
|
||||
return summary
|
||||
|
||||
|
||||
def normalize_string(input: str) -> str:
|
||||
output = " ".join(unicodedata.normalize("NFKC", input).strip().split())
|
||||
return output
|
||||
|
||||
|
||||
def extract_keywords(text: str) -> list[str]:
|
||||
punctuation = '(){}[]\n\r<>'
|
||||
trans = str.maketrans(punctuation, ' '*len(punctuation))
|
||||
text = text.translate(trans)
|
||||
text = normalize_string(text)
|
||||
return list(keyphrase_pipe(text))
|
||||
|
||||
|
||||
def generate_prompt(keywords: list[str], length: int = 100, num: int = 4) -> str:
|
||||
prompt = ', '.join(keywords)
|
||||
outs = prompt_generator(prompt, max_length=length, num_return_sequences=num, do_sample=True,
|
||||
repetition_penalty=1.2, temperature=0.7, top_k=4, early_stopping=True)
|
||||
return [out['generated_text'] for out in outs]
|
||||
|
||||
|
||||
def generate_image(input: str, steps: int = 30, scale: int = 6) -> Image:
|
||||
prompt = normalize_string(f'{prompt_prefix}{input}')
|
||||
print(prompt)
|
||||
|
||||
image = sd_pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=neg_prompt,
|
||||
num_inference_steps=steps,
|
||||
guidance_scale=scale,
|
||||
).images[0]
|
||||
|
||||
image.save("./debug.png")
|
||||
return image
|
||||
|
||||
|
||||
def image_to_base64(image: Image):
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
return img_str
|
||||
|
||||
|
||||
@app.before_request
|
||||
# Request time measuring
|
||||
def before_request():
|
||||
@@ -190,8 +301,46 @@ def api_classify():
|
||||
return jsonify({'classification': classification})
|
||||
|
||||
|
||||
@app.route('/api/keywords', methods=['POST'])
|
||||
def api_keywords():
|
||||
data = request.get_json()
|
||||
|
||||
if not 'text' in data or not isinstance(data['text'], str):
|
||||
abort(400, '"text" is required')
|
||||
|
||||
keywords = extract_keywords(data['text'])
|
||||
return jsonify({'keywords': keywords})
|
||||
|
||||
|
||||
@app.route('/api/prompt', methods=['POST'])
|
||||
def api_prompt():
|
||||
data = request.get_json()
|
||||
|
||||
if not 'text' in data or not isinstance(data['text'], str):
|
||||
abort(400, '"text" is required')
|
||||
|
||||
keywords = extract_keywords(data['text'])
|
||||
|
||||
if 'name' in data or isinstance(data['name'], str):
|
||||
keywords.insert(0, data['name'])
|
||||
|
||||
prompts = generate_prompt(keywords)
|
||||
return jsonify({'prompts': prompts})
|
||||
|
||||
|
||||
@app.route('/api/image', methods=['POST'])
|
||||
def api_image():
|
||||
data = request.get_json()
|
||||
|
||||
if not 'prompt' in data or not isinstance(data['prompt'], str):
|
||||
abort(400, '"prompt" is required')
|
||||
|
||||
image = generate_image(data['prompt'])
|
||||
base64image = image_to_base64(image)
|
||||
return jsonify({'image': base64image})
|
||||
|
||||
|
||||
if args.share:
|
||||
# Doesn't work currently
|
||||
from flask_cloudflared import run_with_cloudflared
|
||||
run_with_cloudflared(app)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user