Add keyword extraction and image generation APIs

This commit is contained in:
SillyLossy
2023-03-01 17:58:41 +02:00
parent a0d803b958
commit a8dc30be0c
3 changed files with 247 additions and 52 deletions

View File

@@ -81,14 +81,56 @@ A set of unofficial APIs for various [TavernAI](https://github.com/TavernAI/Tave
> 2. Six fixed categories
> 3. Value range from 0.0 to 1.0
### Key phrase extraction
`POST /api/keywords`
#### **Input**
```
{ "text": "text to be scanned for key phrases" }
```
#### **Output**
```
{
"keywords": [
"array of",
"extracted",
"keywords",
]
}
```
### GPT-2 for Stable Diffusion prompt generation
`POST /api/prompt`
#### **Input**
```
{ "name": "character name (optional)", "text": "textual summary of a character" }
```
#### **Output**
```
{ "prompts": [ "array of generated prompts" ] }
```
### Stable Diffusion for image generation
`POST /api/image`
#### **Input**
```
{ "prompt": "prompt to be generated" }
```
#### **Output**
```
{ "image": "base64 encoded image" }
```
## Additional options
| Flag | Description |
| -------------- | -------------------------------------------------------------------- |
| `--port` | Specify the port on which the application is hosted. Default: *5100* |
| `--listen` | Hosts the app on the local network |
| `--share` | Shares the app on CloudFlare tunnel |
| `--cpu` | Run the models on the CPU instead of CUDA |
| `--bart-model` | Load a custom BART model.<br>Expects a HuggingFace model ID.<br>Default: [Qiliang/bart-large-cnn-samsum-ChatGPT_v3](https://huggingface.co/Qiliang/bart-large-cnn-samsum-ChatGPT_v3) |
| `--bert-model` | Load a custom BERT model.<br>Expects a HuggingFace model ID.<br>Default: [bhadresh-savani/distilbert-base-uncased-emotion](https://huggingface.co/bhadresh-savani/distilbert-base-uncased-emotion) |
| `--blip-model` | Load a custom BLIP model.<br>Expects a HuggingFace model Id.<br>Default: [Salesforce/blip-image-captioning-base](https://huggingface.co/Salesforce/blip-image-captioning-base) |
| Flag | Description |
| ------------------------ | ---------------------------------------------------------------------- |
| `--port` | Specify the port on which the application is hosted. Default: **5100** |
| `--listen` | Hosts the app on the local network |
| `--share` | Shares the app on CloudFlare tunnel |
| `--cpu` | Run the models on the CPU instead of CUDA |
| `--summarization-model` | Load a custom BART summarization model.<br>Expects a HuggingFace model ID.<br>Default: [Qiliang/bart-large-cnn-samsum-ChatGPT_v3](https://huggingface.co/Qiliang/bart-large-cnn-samsum-ChatGPT_v3) |
| `--classification-model` | Load a custom BERT sentiment classification model.<br>Expects a HuggingFace model ID.<br>Default: [bhadresh-savani/distilbert-base-uncased-emotion](https://huggingface.co/bhadresh-savani/distilbert-base-uncased-emotion) |
| `--captioning-model` | Load a custom BLIP captioning model.<br>Expects a HuggingFace model ID.<br>Default: [Salesforce/blip-image-captioning-base](https://huggingface.co/Salesforce/blip-image-captioning-base) |
| `--keyphrase-model` | Load a custom key phrase extraction model.<br>Expects a HuggingFace model ID.<br>Default: [ml6team/keyphrase-extraction-distilbert-inspec](https://huggingface.co/ml6team/keyphrase-extraction-distilbert-inspec) |
| `--prompt-model` | Load a custom GPT-2 prompt generation model.<br>Expects a HuggingFace model ID.<br>Default: [FredZhang7/anime-anything-promptgen-v2](https://huggingface.co/FredZhang7/anime-anything-promptgen-v2) |
| `--sd-model` | Load a custom Stable Diffusion image generation model.<br>Expects a HuggingFace model ID.<br>Default: [ckpt/anything-v4.5-vae-swapped](https://huggingface.co/ckpt/anything-v4.5-vae-swapped)<br>*Must have VAE pre-baked in PyTorch format or the output will look drab!* |
| `--sd-cpu` | Forces the Stable Diffusion generation pipeline to run on the CPU.<br>**SLOW!** |

View File

@@ -2,5 +2,9 @@ flask
flask-cloudflared
markdown
Pillow
--extra-index-url https://download.pytorch.org/whl/cu116
torch >= 1.9, < 1.13
numpy
accelerate
git+https://github.com/huggingface/transformers
git+https://github.com/huggingface/diffusers

235
server.py
View File

@@ -3,18 +3,28 @@ import markdown
import argparse
from transformers import AutoTokenizer, AutoProcessor, pipeline
from transformers import BlipForConditionalGeneration, BartForConditionalGeneration
from transformers import AutoModelForTokenClassification, TokenClassificationPipeline
from transformers.pipelines import AggregationStrategy
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import unicodedata
import torch
import time
from PIL import Image
import base64
from io import BytesIO
import numpy as np
from diffusers import StableDiffusionPipeline
from diffusers import EulerAncestralDiscreteScheduler
# Constants
# Also try: 'Qiliang/bart-large-cnn-samsum-ElectrifAi_v10'
DEFAULT_BART = 'Qiliang/bart-large-cnn-samsum-ChatGPT_v3'
DEFAULT_BERT = 'bhadresh-savani/distilbert-base-uncased-emotion'
DEFAULT_BLIP = 'Salesforce/blip-image-captioning-base'
DEFAULT_SUMMARIZATION_MODEL = 'Qiliang/bart-large-cnn-samsum-ChatGPT_v3'
DEFAULT_CLASSIFICATION_MODEL = 'bhadresh-savani/distilbert-base-uncased-emotion'
DEFAULT_CAPTIONING_MODEL = 'Salesforce/blip-image-captioning-base'
DEFAULT_KEYPHRASE_MODEL = 'ml6team/keyphrase-extraction-distilbert-inspec'
DEFAULT_PROMPT_MODEL = 'FredZhang7/anime-anything-promptgen-v2'
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
DEFAULT_SUMMARIZE_PARAMS = {
'temperature': 1.0,
'repetition_penalty': 1.0,
@@ -35,55 +45,115 @@ parser.add_argument('--share', action='store_true',
help="Shares the app on CloudFlare tunnel")
parser.add_argument('--cpu', action='store_true',
help="Runs the models on the CPU")
parser.add_argument('--bart-model', help="Load a custom BART model")
parser.add_argument('--bert-model', help="Load a custom BERT model")
parser.add_argument('--blip-model', help="Load a custom BLIP model")
parser.add_argument('--summarization-model',
help="Load a custom BART summarization model")
parser.add_argument('--classification-model',
help="Load a custom BERT text classification model")
parser.add_argument('--captioning-model',
help="Load a custom BLIP captioning model")
parser.add_argument('--keyphrase-model',
help="Load a custom keyphrase extraction model")
parser.add_argument('--prompt-model',
help="Load a custom GPT-2 prompt generation model")
parser.add_argument('--sd-model',
help="Load a custom SD image generation model")
parser.add_argument('--sd-cpu',
help="Force the SD pipeline to run on the CPU")
args = parser.parse_args()
if args.port:
port = args.port
else:
port = 5100
if args.listen:
host = '0.0.0.0'
else:
host = 'localhost'
if args.bart_model:
bart_model = args.bart_model
else:
bart_model = DEFAULT_BART
if args.bert_model:
bert_model = args.bert_model
else:
bert_model = DEFAULT_BERT
if args.blip_model:
blip_model = args.blip_model
else:
blip_model = DEFAULT_BLIP
port = args.port if args.port else 5100
host = '0.0.0.0' if args.listen else 'localhost'
summarization_model = args.summarization_model if args.summarization_model else DEFAULT_SUMMARIZATION_MODEL
classification_model = args.classification_model if args.classification_model else DEFAULT_CLASSIFICATION_MODEL
captioning_model = args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
keyphrase_model = args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL
prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL
sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
# Models init
device_string = "cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu"
device_string = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"
device = torch.device(device_string)
torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
print('Initializing BLIP...')
blip_processor = AutoProcessor.from_pretrained(blip_model)
print('Initializing BLIP image captioning model...')
blip_processor = AutoProcessor.from_pretrained(captioning_model)
blip = BlipForConditionalGeneration.from_pretrained(
blip_model, torch_dtype=torch_dtype).to(device)
captioning_model, torch_dtype=torch_dtype).to(device)
print('Initializing BART...')
bart_tokenizer = AutoTokenizer.from_pretrained(bart_model)
print('Initializing BART text summarization model...')
bart_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
bart = BartForConditionalGeneration.from_pretrained(
bart_model, torch_dtype=torch_dtype).to(device)
summarization_model, torch_dtype=torch_dtype).to(device)
print('Initializing BERT...')
bert_classifier = pipeline("text-classification", model=bert_model,
return_all_scores=True, device=device, torch_dtype=torch_dtype)
print('Initializing BERT sentiment classification model...')
bert_classifier = pipeline("text-classification", model=classification_model,
top_k=None, device=device, torch_dtype=torch_dtype)
print('Initializing keyword extractor...')
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
def __init__(self, model, *args, **kwargs):
super().__init__(
model=AutoModelForTokenClassification.from_pretrained(model),
tokenizer=AutoTokenizer.from_pretrained(model),
*args,
**kwargs
)
def postprocess(self, model_outputs):
results = super().postprocess(
model_outputs=model_outputs,
aggregation_strategy=AggregationStrategy.SIMPLE
if self.model.config.model_type == "roberta"
else AggregationStrategy.FIRST,
)
return np.unique([result.get("word").strip() for result in results])
keyphrase_pipe = KeyphraseExtractionPipeline(keyphrase_model)
print('Initializing GPT prompt generator')
gpt_tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
gpt_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
gpt_model = GPT2LMHeadModel.from_pretrained(
'FredZhang7/anime-anything-promptgen-v2')
prompt_generator = pipeline(
'text-generation', model=gpt_model, tokenizer=gpt_tokenizer)
print('Initializing Stable Diffusion pipeline')
sd_device_string = "cuda" if torch.cuda.is_available() and not args.sd_cpu else "cpu"
sd_device = torch.device(sd_device_string)
sd_torch_dtype = torch.float32 if sd_device_string == "cpu" else torch.float16
sd_pipe = StableDiffusionPipeline.from_pretrained(
sd_model,
custom_pipeline="lpw_stable_diffusion",
torch_dtype=sd_torch_dtype,
).to(sd_device)
sd_pipe.safety_checker = lambda images, clip_input: (images, False)
sd_pipe.enable_attention_slicing()
# pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config)
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
sd_pipe.scheduler.config)
prompt_prefix = "best quality, absurdres, "
neg_prompt = """lowres, bad anatomy, error body, error hair, error arm,
error hands, bad hands, error fingers, bad fingers, missing fingers
error legs, bad legs, multiple legs, missing legs, error lighting,
error shadow, error reflection, text, error, extra digit, fewer digits,
cropped, worst quality, low quality, normal quality, jpeg artifacts,
signature, watermark, username, blurry"""
# list of key phrases to be looking for in text (unused for now)
indicator_list = ['female', 'girl', 'male', 'boy', 'woman', 'man', 'hair', 'eyes', 'skin', 'wears',
'appearance', 'costume', 'clothes', 'body', 'tall', 'short', 'chubby', 'thin',
'expression', 'angry', 'sad', 'blush', 'smile', 'happy', 'depressed', 'long',
'cold', 'breasts', 'chest', 'tail', 'ears', 'fur', 'race', 'species', 'wearing',
'shoes', 'boots', 'shirt', 'panties', 'bra', 'skirt', 'dress', 'kimono', 'wings', 'horns',
'pants', 'shorts', 'leggins', 'sandals', 'hat', 'glasses', 'sweater', 'hoodie', 'sweatshirt']
# Flask init
app = Flask(__name__)
@@ -126,11 +196,52 @@ def summarize(text: str, params: dict) -> str:
summary = bart_tokenizer.batch_decode(
summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)[0]
# Normalize string
summary = " ".join(unicodedata.normalize("NFKC", summary).strip().split())
summary = normalize_string(summary)
return summary
def normalize_string(input: str) -> str:
output = " ".join(unicodedata.normalize("NFKC", input).strip().split())
return output
def extract_keywords(text: str) -> list[str]:
punctuation = '(){}[]\n\r<>'
trans = str.maketrans(punctuation, ' '*len(punctuation))
text = text.translate(trans)
text = normalize_string(text)
return list(keyphrase_pipe(text))
def generate_prompt(keywords: list[str], length: int = 100, num: int = 4) -> str:
prompt = ', '.join(keywords)
outs = prompt_generator(prompt, max_length=length, num_return_sequences=num, do_sample=True,
repetition_penalty=1.2, temperature=0.7, top_k=4, early_stopping=True)
return [out['generated_text'] for out in outs]
def generate_image(input: str, steps: int = 30, scale: int = 6) -> Image:
prompt = normalize_string(f'{prompt_prefix}{input}')
print(prompt)
image = sd_pipe(
prompt=prompt,
negative_prompt=neg_prompt,
num_inference_steps=steps,
guidance_scale=scale,
).images[0]
image.save("./debug.png")
return image
def image_to_base64(image: Image):
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
@app.before_request
# Request time measuring
def before_request():
@@ -190,8 +301,46 @@ def api_classify():
return jsonify({'classification': classification})
@app.route('/api/keywords', methods=['POST'])
def api_keywords():
data = request.get_json()
if not 'text' in data or not isinstance(data['text'], str):
abort(400, '"text" is required')
keywords = extract_keywords(data['text'])
return jsonify({'keywords': keywords})
@app.route('/api/prompt', methods=['POST'])
def api_prompt():
data = request.get_json()
if not 'text' in data or not isinstance(data['text'], str):
abort(400, '"text" is required')
keywords = extract_keywords(data['text'])
if 'name' in data or isinstance(data['name'], str):
keywords.insert(0, data['name'])
prompts = generate_prompt(keywords)
return jsonify({'prompts': prompts})
@app.route('/api/image', methods=['POST'])
def api_image():
data = request.get_json()
if not 'prompt' in data or not isinstance(data['prompt'], str):
abort(400, '"prompt" is required')
image = generate_image(data['prompt'])
base64image = image_to_base64(image)
return jsonify({'image': base64image})
if args.share:
# Doesn't work currently
from flask_cloudflared import run_with_cloudflared
run_with_cloudflared(app)