diff --git a/README.md b/README.md
index 14a4d63..daf24b4 100644
--- a/README.md
+++ b/README.md
@@ -81,14 +81,56 @@ A set of unofficial APIs for various [TavernAI](https://github.com/TavernAI/Tave
> 2. Six fixed categories
> 3. Value range from 0.0 to 1.0
+### Key phrase extraction
+`POST /api/keywords`
+#### **Input**
+```
+{ "text": "text to be scanned for key phrases" }
+```
+#### **Output**
+```
+{
+ "keywords": [
+ "array of",
+ "extracted",
+ "keywords",
+ ]
+}
+```
+
+### GPT-2 for Stable Diffusion prompt generation
+`POST /api/prompt`
+#### **Input**
+```
+{ "name": "character name (optional)", "text": "textual summary of a character" }
+```
+#### **Output**
+```
+{ "prompts": [ "array of generated prompts" ] }
+```
+
+### Stable Diffusion for image generation
+`POST /api/image`
+#### **Input**
+```
+{ "prompt": "prompt to be generated" }
+```
+#### **Output**
+```
+{ "image": "base64 encoded image" }
+```
## Additional options
-| Flag | Description |
-| -------------- | -------------------------------------------------------------------- |
-| `--port` | Specify the port on which the application is hosted. Default: *5100* |
-| `--listen` | Hosts the app on the local network |
-| `--share` | Shares the app on CloudFlare tunnel |
-| `--cpu` | Run the models on the CPU instead of CUDA |
-| `--bart-model` | Load a custom BART model.
Expects a HuggingFace model ID.
Default: [Qiliang/bart-large-cnn-samsum-ChatGPT_v3](https://huggingface.co/Qiliang/bart-large-cnn-samsum-ChatGPT_v3) |
-| `--bert-model` | Load a custom BERT model.
Expects a HuggingFace model ID.
Default: [bhadresh-savani/distilbert-base-uncased-emotion](https://huggingface.co/bhadresh-savani/distilbert-base-uncased-emotion) |
-| `--blip-model` | Load a custom BLIP model.
Expects a HuggingFace model Id.
Default: [Salesforce/blip-image-captioning-base](https://huggingface.co/Salesforce/blip-image-captioning-base) |
+| Flag | Description |
+| ------------------------ | ---------------------------------------------------------------------- |
+| `--port` | Specify the port on which the application is hosted. Default: **5100** |
+| `--listen` | Hosts the app on the local network |
+| `--share` | Shares the app on CloudFlare tunnel |
+| `--cpu` | Run the models on the CPU instead of CUDA |
+| `--summarization-model` | Load a custom BART summarization model.
Expects a HuggingFace model ID.
Default: [Qiliang/bart-large-cnn-samsum-ChatGPT_v3](https://huggingface.co/Qiliang/bart-large-cnn-samsum-ChatGPT_v3) |
+| `--classification-model` | Load a custom BERT sentiment classification model.
Expects a HuggingFace model ID.
Default: [bhadresh-savani/distilbert-base-uncased-emotion](https://huggingface.co/bhadresh-savani/distilbert-base-uncased-emotion) |
+| `--captioning-model` | Load a custom BLIP captioning model.
Expects a HuggingFace model ID.
Default: [Salesforce/blip-image-captioning-base](https://huggingface.co/Salesforce/blip-image-captioning-base) |
+| `--keyphrase-model` | Load a custom key phrase extraction model.
Expects a HuggingFace model ID.
Default: [ml6team/keyphrase-extraction-distilbert-inspec](https://huggingface.co/ml6team/keyphrase-extraction-distilbert-inspec) |
+| `--prompt-model` | Load a custom GPT-2 prompt generation model.
Expects a HuggingFace model ID.
Default: [FredZhang7/anime-anything-promptgen-v2](https://huggingface.co/FredZhang7/anime-anything-promptgen-v2) |
+| `--sd-model` | Load a custom Stable Diffusion image generation model.
Expects a HuggingFace model ID.
Default: [ckpt/anything-v4.5-vae-swapped](https://huggingface.co/ckpt/anything-v4.5-vae-swapped)
*Must have VAE pre-baked in PyTorch format or the output will look drab!* |
+| `--sd-cpu` | Forces the Stable Diffusion generation pipeline to run on the CPU.
**SLOW!** |
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 11ff04f..c34a064 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,5 +2,9 @@ flask
flask-cloudflared
markdown
Pillow
+--extra-index-url https://download.pytorch.org/whl/cu116
torch >= 1.9, < 1.13
+numpy
+accelerate
git+https://github.com/huggingface/transformers
+git+https://github.com/huggingface/diffusers
\ No newline at end of file
diff --git a/server.py b/server.py
index db24dfd..1e1b81c 100644
--- a/server.py
+++ b/server.py
@@ -3,18 +3,28 @@ import markdown
import argparse
from transformers import AutoTokenizer, AutoProcessor, pipeline
from transformers import BlipForConditionalGeneration, BartForConditionalGeneration
+from transformers import AutoModelForTokenClassification, TokenClassificationPipeline
+from transformers.pipelines import AggregationStrategy
+from transformers import GPT2Tokenizer, GPT2LMHeadModel
import unicodedata
import torch
import time
from PIL import Image
import base64
from io import BytesIO
+import numpy as np
+from diffusers import StableDiffusionPipeline
+from diffusers import EulerAncestralDiscreteScheduler
+
# Constants
# Also try: 'Qiliang/bart-large-cnn-samsum-ElectrifAi_v10'
-DEFAULT_BART = 'Qiliang/bart-large-cnn-samsum-ChatGPT_v3'
-DEFAULT_BERT = 'bhadresh-savani/distilbert-base-uncased-emotion'
-DEFAULT_BLIP = 'Salesforce/blip-image-captioning-base'
+DEFAULT_SUMMARIZATION_MODEL = 'Qiliang/bart-large-cnn-samsum-ChatGPT_v3'
+DEFAULT_CLASSIFICATION_MODEL = 'bhadresh-savani/distilbert-base-uncased-emotion'
+DEFAULT_CAPTIONING_MODEL = 'Salesforce/blip-image-captioning-base'
+DEFAULT_KEYPHRASE_MODEL = 'ml6team/keyphrase-extraction-distilbert-inspec'
+DEFAULT_PROMPT_MODEL = 'FredZhang7/anime-anything-promptgen-v2'
+DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
DEFAULT_SUMMARIZE_PARAMS = {
'temperature': 1.0,
'repetition_penalty': 1.0,
@@ -35,55 +45,115 @@ parser.add_argument('--share', action='store_true',
help="Shares the app on CloudFlare tunnel")
parser.add_argument('--cpu', action='store_true',
help="Runs the models on the CPU")
-parser.add_argument('--bart-model', help="Load a custom BART model")
-parser.add_argument('--bert-model', help="Load a custom BERT model")
-parser.add_argument('--blip-model', help="Load a custom BLIP model")
+parser.add_argument('--summarization-model',
+ help="Load a custom BART summarization model")
+parser.add_argument('--classification-model',
+ help="Load a custom BERT text classification model")
+parser.add_argument('--captioning-model',
+ help="Load a custom BLIP captioning model")
+parser.add_argument('--keyphrase-model',
+ help="Load a custom keyphrase extraction model")
+parser.add_argument('--prompt-model',
+ help="Load a custom GPT-2 prompt generation model")
+parser.add_argument('--sd-model',
+ help="Load a custom SD image generation model")
+parser.add_argument('--sd-cpu',
+ help="Force the SD pipeline to run on the CPU")
args = parser.parse_args()
-if args.port:
- port = args.port
-else:
- port = 5100
-
-if args.listen:
- host = '0.0.0.0'
-else:
- host = 'localhost'
-
-if args.bart_model:
- bart_model = args.bart_model
-else:
- bart_model = DEFAULT_BART
-
-if args.bert_model:
- bert_model = args.bert_model
-else:
- bert_model = DEFAULT_BERT
-
-if args.blip_model:
- blip_model = args.blip_model
-else:
- blip_model = DEFAULT_BLIP
+port = args.port if args.port else 5100
+host = '0.0.0.0' if args.listen else 'localhost'
+summarization_model = args.summarization_model if args.summarization_model else DEFAULT_SUMMARIZATION_MODEL
+classification_model = args.classification_model if args.classification_model else DEFAULT_CLASSIFICATION_MODEL
+captioning_model = args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
+keyphrase_model = args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL
+prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL
+sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
# Models init
-device_string = "cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu"
+device_string = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"
device = torch.device(device_string)
torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
-print('Initializing BLIP...')
-blip_processor = AutoProcessor.from_pretrained(blip_model)
+print('Initializing BLIP image captioning model...')
+blip_processor = AutoProcessor.from_pretrained(captioning_model)
blip = BlipForConditionalGeneration.from_pretrained(
- blip_model, torch_dtype=torch_dtype).to(device)
+ captioning_model, torch_dtype=torch_dtype).to(device)
-print('Initializing BART...')
-bart_tokenizer = AutoTokenizer.from_pretrained(bart_model)
+print('Initializing BART text summarization model...')
+bart_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
bart = BartForConditionalGeneration.from_pretrained(
- bart_model, torch_dtype=torch_dtype).to(device)
+ summarization_model, torch_dtype=torch_dtype).to(device)
-print('Initializing BERT...')
-bert_classifier = pipeline("text-classification", model=bert_model,
- return_all_scores=True, device=device, torch_dtype=torch_dtype)
+print('Initializing BERT sentiment classification model...')
+bert_classifier = pipeline("text-classification", model=classification_model,
+ top_k=None, device=device, torch_dtype=torch_dtype)
+
+print('Initializing keyword extractor...')
+
+
+class KeyphraseExtractionPipeline(TokenClassificationPipeline):
+ def __init__(self, model, *args, **kwargs):
+ super().__init__(
+ model=AutoModelForTokenClassification.from_pretrained(model),
+ tokenizer=AutoTokenizer.from_pretrained(model),
+ *args,
+ **kwargs
+ )
+
+ def postprocess(self, model_outputs):
+ results = super().postprocess(
+ model_outputs=model_outputs,
+ aggregation_strategy=AggregationStrategy.SIMPLE
+ if self.model.config.model_type == "roberta"
+ else AggregationStrategy.FIRST,
+ )
+ return np.unique([result.get("word").strip() for result in results])
+
+
+keyphrase_pipe = KeyphraseExtractionPipeline(keyphrase_model)
+
+print('Initializing GPT prompt generator')
+gpt_tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
+gpt_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
+gpt_model = GPT2LMHeadModel.from_pretrained(
+ 'FredZhang7/anime-anything-promptgen-v2')
+prompt_generator = pipeline(
+ 'text-generation', model=gpt_model, tokenizer=gpt_tokenizer)
+
+
+print('Initializing Stable Diffusion pipeline')
+sd_device_string = "cuda" if torch.cuda.is_available() and not args.sd_cpu else "cpu"
+sd_device = torch.device(sd_device_string)
+sd_torch_dtype = torch.float32 if sd_device_string == "cpu" else torch.float16
+sd_pipe = StableDiffusionPipeline.from_pretrained(
+ sd_model,
+ custom_pipeline="lpw_stable_diffusion",
+ torch_dtype=sd_torch_dtype,
+).to(sd_device)
+sd_pipe.safety_checker = lambda images, clip_input: (images, False)
+sd_pipe.enable_attention_slicing()
+# pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config)
+sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
+ sd_pipe.scheduler.config)
+
+prompt_prefix = "best quality, absurdres, "
+neg_prompt = """lowres, bad anatomy, error body, error hair, error arm,
+error hands, bad hands, error fingers, bad fingers, missing fingers
+error legs, bad legs, multiple legs, missing legs, error lighting,
+error shadow, error reflection, text, error, extra digit, fewer digits,
+cropped, worst quality, low quality, normal quality, jpeg artifacts,
+signature, watermark, username, blurry"""
+
+
+# list of key phrases to be looking for in text (unused for now)
+indicator_list = ['female', 'girl', 'male', 'boy', 'woman', 'man', 'hair', 'eyes', 'skin', 'wears',
+ 'appearance', 'costume', 'clothes', 'body', 'tall', 'short', 'chubby', 'thin',
+ 'expression', 'angry', 'sad', 'blush', 'smile', 'happy', 'depressed', 'long',
+ 'cold', 'breasts', 'chest', 'tail', 'ears', 'fur', 'race', 'species', 'wearing',
+ 'shoes', 'boots', 'shirt', 'panties', 'bra', 'skirt', 'dress', 'kimono', 'wings', 'horns',
+ 'pants', 'shorts', 'leggins', 'sandals', 'hat', 'glasses', 'sweater', 'hoodie', 'sweatshirt']
# Flask init
app = Flask(__name__)
@@ -126,11 +196,52 @@ def summarize(text: str, params: dict) -> str:
summary = bart_tokenizer.batch_decode(
summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)[0]
- # Normalize string
- summary = " ".join(unicodedata.normalize("NFKC", summary).strip().split())
+ summary = normalize_string(summary)
return summary
+def normalize_string(input: str) -> str:
+ output = " ".join(unicodedata.normalize("NFKC", input).strip().split())
+ return output
+
+
+def extract_keywords(text: str) -> list[str]:
+ punctuation = '(){}[]\n\r<>'
+ trans = str.maketrans(punctuation, ' '*len(punctuation))
+ text = text.translate(trans)
+ text = normalize_string(text)
+ return list(keyphrase_pipe(text))
+
+
+def generate_prompt(keywords: list[str], length: int = 100, num: int = 4) -> str:
+ prompt = ', '.join(keywords)
+ outs = prompt_generator(prompt, max_length=length, num_return_sequences=num, do_sample=True,
+ repetition_penalty=1.2, temperature=0.7, top_k=4, early_stopping=True)
+ return [out['generated_text'] for out in outs]
+
+
+def generate_image(input: str, steps: int = 30, scale: int = 6) -> Image:
+ prompt = normalize_string(f'{prompt_prefix}{input}')
+ print(prompt)
+
+ image = sd_pipe(
+ prompt=prompt,
+ negative_prompt=neg_prompt,
+ num_inference_steps=steps,
+ guidance_scale=scale,
+ ).images[0]
+
+ image.save("./debug.png")
+ return image
+
+
+def image_to_base64(image: Image):
+ buffered = BytesIO()
+ image.save(buffered, format="JPEG")
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
+ return img_str
+
+
@app.before_request
# Request time measuring
def before_request():
@@ -190,8 +301,46 @@ def api_classify():
return jsonify({'classification': classification})
+@app.route('/api/keywords', methods=['POST'])
+def api_keywords():
+ data = request.get_json()
+
+ if not 'text' in data or not isinstance(data['text'], str):
+ abort(400, '"text" is required')
+
+ keywords = extract_keywords(data['text'])
+ return jsonify({'keywords': keywords})
+
+
+@app.route('/api/prompt', methods=['POST'])
+def api_prompt():
+ data = request.get_json()
+
+ if not 'text' in data or not isinstance(data['text'], str):
+ abort(400, '"text" is required')
+
+ keywords = extract_keywords(data['text'])
+
+ if 'name' in data or isinstance(data['name'], str):
+ keywords.insert(0, data['name'])
+
+ prompts = generate_prompt(keywords)
+ return jsonify({'prompts': prompts})
+
+
+@app.route('/api/image', methods=['POST'])
+def api_image():
+ data = request.get_json()
+
+ if not 'prompt' in data or not isinstance(data['prompt'], str):
+ abort(400, '"prompt" is required')
+
+ image = generate_image(data['prompt'])
+ base64image = image_to_base64(image)
+ return jsonify({'image': base64image})
+
+
if args.share:
- # Doesn't work currently
from flask_cloudflared import run_with_cloudflared
run_with_cloudflared(app)