diff --git a/README.md b/README.md index 14a4d63..daf24b4 100644 --- a/README.md +++ b/README.md @@ -81,14 +81,56 @@ A set of unofficial APIs for various [TavernAI](https://github.com/TavernAI/Tave > 2. Six fixed categories > 3. Value range from 0.0 to 1.0 +### Key phrase extraction +`POST /api/keywords` +#### **Input** +``` +{ "text": "text to be scanned for key phrases" } +``` +#### **Output** +``` +{ + "keywords": [ + "array of", + "extracted", + "keywords", + ] +} +``` + +### GPT-2 for Stable Diffusion prompt generation +`POST /api/prompt` +#### **Input** +``` +{ "name": "character name (optional)", "text": "textual summary of a character" } +``` +#### **Output** +``` +{ "prompts": [ "array of generated prompts" ] } +``` + +### Stable Diffusion for image generation +`POST /api/image` +#### **Input** +``` +{ "prompt": "prompt to be generated" } +``` +#### **Output** +``` +{ "image": "base64 encoded image" } +``` ## Additional options -| Flag | Description | -| -------------- | -------------------------------------------------------------------- | -| `--port` | Specify the port on which the application is hosted. Default: *5100* | -| `--listen` | Hosts the app on the local network | -| `--share` | Shares the app on CloudFlare tunnel | -| `--cpu` | Run the models on the CPU instead of CUDA | -| `--bart-model` | Load a custom BART model.
Expects a HuggingFace model ID.
Default: [Qiliang/bart-large-cnn-samsum-ChatGPT_v3](https://huggingface.co/Qiliang/bart-large-cnn-samsum-ChatGPT_v3) | -| `--bert-model` | Load a custom BERT model.
Expects a HuggingFace model ID.
Default: [bhadresh-savani/distilbert-base-uncased-emotion](https://huggingface.co/bhadresh-savani/distilbert-base-uncased-emotion) | -| `--blip-model` | Load a custom BLIP model.
Expects a HuggingFace model Id.
Default: [Salesforce/blip-image-captioning-base](https://huggingface.co/Salesforce/blip-image-captioning-base) | +| Flag | Description | +| ------------------------ | ---------------------------------------------------------------------- | +| `--port` | Specify the port on which the application is hosted. Default: **5100** | +| `--listen` | Hosts the app on the local network | +| `--share` | Shares the app on CloudFlare tunnel | +| `--cpu` | Run the models on the CPU instead of CUDA | +| `--summarization-model` | Load a custom BART summarization model.
Expects a HuggingFace model ID.
Default: [Qiliang/bart-large-cnn-samsum-ChatGPT_v3](https://huggingface.co/Qiliang/bart-large-cnn-samsum-ChatGPT_v3) | +| `--classification-model` | Load a custom BERT sentiment classification model.
Expects a HuggingFace model ID.
Default: [bhadresh-savani/distilbert-base-uncased-emotion](https://huggingface.co/bhadresh-savani/distilbert-base-uncased-emotion) | +| `--captioning-model` | Load a custom BLIP captioning model.
Expects a HuggingFace model ID.
Default: [Salesforce/blip-image-captioning-base](https://huggingface.co/Salesforce/blip-image-captioning-base) | +| `--keyphrase-model` | Load a custom key phrase extraction model.
Expects a HuggingFace model ID.
Default: [ml6team/keyphrase-extraction-distilbert-inspec](https://huggingface.co/ml6team/keyphrase-extraction-distilbert-inspec) | +| `--prompt-model` | Load a custom GPT-2 prompt generation model.
Expects a HuggingFace model ID.
Default: [FredZhang7/anime-anything-promptgen-v2](https://huggingface.co/FredZhang7/anime-anything-promptgen-v2) | +| `--sd-model` | Load a custom Stable Diffusion image generation model.
Expects a HuggingFace model ID.
Default: [ckpt/anything-v4.5-vae-swapped](https://huggingface.co/ckpt/anything-v4.5-vae-swapped)
*Must have VAE pre-baked in PyTorch format or the output will look drab!* | +| `--sd-cpu` | Forces the Stable Diffusion generation pipeline to run on the CPU.
**SLOW!** | \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 11ff04f..c34a064 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,9 @@ flask flask-cloudflared markdown Pillow +--extra-index-url https://download.pytorch.org/whl/cu116 torch >= 1.9, < 1.13 +numpy +accelerate git+https://github.com/huggingface/transformers +git+https://github.com/huggingface/diffusers \ No newline at end of file diff --git a/server.py b/server.py index db24dfd..1e1b81c 100644 --- a/server.py +++ b/server.py @@ -3,18 +3,28 @@ import markdown import argparse from transformers import AutoTokenizer, AutoProcessor, pipeline from transformers import BlipForConditionalGeneration, BartForConditionalGeneration +from transformers import AutoModelForTokenClassification, TokenClassificationPipeline +from transformers.pipelines import AggregationStrategy +from transformers import GPT2Tokenizer, GPT2LMHeadModel import unicodedata import torch import time from PIL import Image import base64 from io import BytesIO +import numpy as np +from diffusers import StableDiffusionPipeline +from diffusers import EulerAncestralDiscreteScheduler + # Constants # Also try: 'Qiliang/bart-large-cnn-samsum-ElectrifAi_v10' -DEFAULT_BART = 'Qiliang/bart-large-cnn-samsum-ChatGPT_v3' -DEFAULT_BERT = 'bhadresh-savani/distilbert-base-uncased-emotion' -DEFAULT_BLIP = 'Salesforce/blip-image-captioning-base' +DEFAULT_SUMMARIZATION_MODEL = 'Qiliang/bart-large-cnn-samsum-ChatGPT_v3' +DEFAULT_CLASSIFICATION_MODEL = 'bhadresh-savani/distilbert-base-uncased-emotion' +DEFAULT_CAPTIONING_MODEL = 'Salesforce/blip-image-captioning-base' +DEFAULT_KEYPHRASE_MODEL = 'ml6team/keyphrase-extraction-distilbert-inspec' +DEFAULT_PROMPT_MODEL = 'FredZhang7/anime-anything-promptgen-v2' +DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped" DEFAULT_SUMMARIZE_PARAMS = { 'temperature': 1.0, 'repetition_penalty': 1.0, @@ -35,55 +45,115 @@ parser.add_argument('--share', action='store_true', help="Shares the app on CloudFlare tunnel") parser.add_argument('--cpu', action='store_true', help="Runs the models on the CPU") -parser.add_argument('--bart-model', help="Load a custom BART model") -parser.add_argument('--bert-model', help="Load a custom BERT model") -parser.add_argument('--blip-model', help="Load a custom BLIP model") +parser.add_argument('--summarization-model', + help="Load a custom BART summarization model") +parser.add_argument('--classification-model', + help="Load a custom BERT text classification model") +parser.add_argument('--captioning-model', + help="Load a custom BLIP captioning model") +parser.add_argument('--keyphrase-model', + help="Load a custom keyphrase extraction model") +parser.add_argument('--prompt-model', + help="Load a custom GPT-2 prompt generation model") +parser.add_argument('--sd-model', + help="Load a custom SD image generation model") +parser.add_argument('--sd-cpu', + help="Force the SD pipeline to run on the CPU") args = parser.parse_args() -if args.port: - port = args.port -else: - port = 5100 - -if args.listen: - host = '0.0.0.0' -else: - host = 'localhost' - -if args.bart_model: - bart_model = args.bart_model -else: - bart_model = DEFAULT_BART - -if args.bert_model: - bert_model = args.bert_model -else: - bert_model = DEFAULT_BERT - -if args.blip_model: - blip_model = args.blip_model -else: - blip_model = DEFAULT_BLIP +port = args.port if args.port else 5100 +host = '0.0.0.0' if args.listen else 'localhost' +summarization_model = args.summarization_model if args.summarization_model else DEFAULT_SUMMARIZATION_MODEL +classification_model = args.classification_model if args.classification_model else DEFAULT_CLASSIFICATION_MODEL +captioning_model = args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL +keyphrase_model = args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL +prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL +sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL # Models init -device_string = "cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu" +device_string = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu" device = torch.device(device_string) torch_dtype = torch.float32 if device_string == "cpu" else torch.float16 -print('Initializing BLIP...') -blip_processor = AutoProcessor.from_pretrained(blip_model) +print('Initializing BLIP image captioning model...') +blip_processor = AutoProcessor.from_pretrained(captioning_model) blip = BlipForConditionalGeneration.from_pretrained( - blip_model, torch_dtype=torch_dtype).to(device) + captioning_model, torch_dtype=torch_dtype).to(device) -print('Initializing BART...') -bart_tokenizer = AutoTokenizer.from_pretrained(bart_model) +print('Initializing BART text summarization model...') +bart_tokenizer = AutoTokenizer.from_pretrained(summarization_model) bart = BartForConditionalGeneration.from_pretrained( - bart_model, torch_dtype=torch_dtype).to(device) + summarization_model, torch_dtype=torch_dtype).to(device) -print('Initializing BERT...') -bert_classifier = pipeline("text-classification", model=bert_model, - return_all_scores=True, device=device, torch_dtype=torch_dtype) +print('Initializing BERT sentiment classification model...') +bert_classifier = pipeline("text-classification", model=classification_model, + top_k=None, device=device, torch_dtype=torch_dtype) + +print('Initializing keyword extractor...') + + +class KeyphraseExtractionPipeline(TokenClassificationPipeline): + def __init__(self, model, *args, **kwargs): + super().__init__( + model=AutoModelForTokenClassification.from_pretrained(model), + tokenizer=AutoTokenizer.from_pretrained(model), + *args, + **kwargs + ) + + def postprocess(self, model_outputs): + results = super().postprocess( + model_outputs=model_outputs, + aggregation_strategy=AggregationStrategy.SIMPLE + if self.model.config.model_type == "roberta" + else AggregationStrategy.FIRST, + ) + return np.unique([result.get("word").strip() for result in results]) + + +keyphrase_pipe = KeyphraseExtractionPipeline(keyphrase_model) + +print('Initializing GPT prompt generator') +gpt_tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2') +gpt_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) +gpt_model = GPT2LMHeadModel.from_pretrained( + 'FredZhang7/anime-anything-promptgen-v2') +prompt_generator = pipeline( + 'text-generation', model=gpt_model, tokenizer=gpt_tokenizer) + + +print('Initializing Stable Diffusion pipeline') +sd_device_string = "cuda" if torch.cuda.is_available() and not args.sd_cpu else "cpu" +sd_device = torch.device(sd_device_string) +sd_torch_dtype = torch.float32 if sd_device_string == "cpu" else torch.float16 +sd_pipe = StableDiffusionPipeline.from_pretrained( + sd_model, + custom_pipeline="lpw_stable_diffusion", + torch_dtype=sd_torch_dtype, +).to(sd_device) +sd_pipe.safety_checker = lambda images, clip_input: (images, False) +sd_pipe.enable_attention_slicing() +# pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config) +sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( + sd_pipe.scheduler.config) + +prompt_prefix = "best quality, absurdres, " +neg_prompt = """lowres, bad anatomy, error body, error hair, error arm, +error hands, bad hands, error fingers, bad fingers, missing fingers +error legs, bad legs, multiple legs, missing legs, error lighting, +error shadow, error reflection, text, error, extra digit, fewer digits, +cropped, worst quality, low quality, normal quality, jpeg artifacts, +signature, watermark, username, blurry""" + + +# list of key phrases to be looking for in text (unused for now) +indicator_list = ['female', 'girl', 'male', 'boy', 'woman', 'man', 'hair', 'eyes', 'skin', 'wears', + 'appearance', 'costume', 'clothes', 'body', 'tall', 'short', 'chubby', 'thin', + 'expression', 'angry', 'sad', 'blush', 'smile', 'happy', 'depressed', 'long', + 'cold', 'breasts', 'chest', 'tail', 'ears', 'fur', 'race', 'species', 'wearing', + 'shoes', 'boots', 'shirt', 'panties', 'bra', 'skirt', 'dress', 'kimono', 'wings', 'horns', + 'pants', 'shorts', 'leggins', 'sandals', 'hat', 'glasses', 'sweater', 'hoodie', 'sweatshirt'] # Flask init app = Flask(__name__) @@ -126,11 +196,52 @@ def summarize(text: str, params: dict) -> str: summary = bart_tokenizer.batch_decode( summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True )[0] - # Normalize string - summary = " ".join(unicodedata.normalize("NFKC", summary).strip().split()) + summary = normalize_string(summary) return summary +def normalize_string(input: str) -> str: + output = " ".join(unicodedata.normalize("NFKC", input).strip().split()) + return output + + +def extract_keywords(text: str) -> list[str]: + punctuation = '(){}[]\n\r<>' + trans = str.maketrans(punctuation, ' '*len(punctuation)) + text = text.translate(trans) + text = normalize_string(text) + return list(keyphrase_pipe(text)) + + +def generate_prompt(keywords: list[str], length: int = 100, num: int = 4) -> str: + prompt = ', '.join(keywords) + outs = prompt_generator(prompt, max_length=length, num_return_sequences=num, do_sample=True, + repetition_penalty=1.2, temperature=0.7, top_k=4, early_stopping=True) + return [out['generated_text'] for out in outs] + + +def generate_image(input: str, steps: int = 30, scale: int = 6) -> Image: + prompt = normalize_string(f'{prompt_prefix}{input}') + print(prompt) + + image = sd_pipe( + prompt=prompt, + negative_prompt=neg_prompt, + num_inference_steps=steps, + guidance_scale=scale, + ).images[0] + + image.save("./debug.png") + return image + + +def image_to_base64(image: Image): + buffered = BytesIO() + image.save(buffered, format="JPEG") + img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") + return img_str + + @app.before_request # Request time measuring def before_request(): @@ -190,8 +301,46 @@ def api_classify(): return jsonify({'classification': classification}) +@app.route('/api/keywords', methods=['POST']) +def api_keywords(): + data = request.get_json() + + if not 'text' in data or not isinstance(data['text'], str): + abort(400, '"text" is required') + + keywords = extract_keywords(data['text']) + return jsonify({'keywords': keywords}) + + +@app.route('/api/prompt', methods=['POST']) +def api_prompt(): + data = request.get_json() + + if not 'text' in data or not isinstance(data['text'], str): + abort(400, '"text" is required') + + keywords = extract_keywords(data['text']) + + if 'name' in data or isinstance(data['name'], str): + keywords.insert(0, data['name']) + + prompts = generate_prompt(keywords) + return jsonify({'prompts': prompts}) + + +@app.route('/api/image', methods=['POST']) +def api_image(): + data = request.get_json() + + if not 'prompt' in data or not isinstance(data['prompt'], str): + abort(400, '"prompt" is required') + + image = generate_image(data['prompt']) + base64image = image_to_base64(image) + return jsonify({'image': base64image}) + + if args.share: - # Doesn't work currently from flask_cloudflared import run_with_cloudflared run_with_cloudflared(app)