From a0000347e9ef0c2eab19e386be6cd8e7e293f2de Mon Sep 17 00:00:00 2001 From: SillyLossy Date: Sun, 4 Jun 2023 01:58:49 +0300 Subject: [PATCH] Remove unused modules --- README.md | 30 ------------------- constants.py | 61 -------------------------------------- pipelines.py | 26 ---------------- server.py | 83 ++-------------------------------------------------- 4 files changed, 2 insertions(+), 198 deletions(-) delete mode 100644 pipelines.py diff --git a/README.md b/README.md index cdf1791..fe69201 100644 --- a/README.md +++ b/README.md @@ -120,8 +120,6 @@ cd SillyTavern-extras | `caption` | Image captioning | ✔️ Yes | | `summarize` | Text summarization | ✔️ Yes | | `classify` | Text sentiment classification | ✔️ Yes | -| `keywords` | Text key phrases extraction | ✔️ Yes | -| `prompt` | SD prompt generation from text | ✔️ Yes | | `sd` | Stable Diffusion image generation | :x: No (✔️ remote) | | `tts` | [Silero TTS server](https://github.com/ouoertheo/silero-api-server) | :x: No | | `chromadb` | Infinity context server | :x: No | @@ -264,34 +262,6 @@ None > 2. List of categories defined by the summarization model > 3. Value range from 0.0 to 1.0 -### Key phrase extraction -`POST /api/keywords` -#### **Input** -``` -{ "text": "text to be scanned for key phrases" } -``` -#### **Output** -``` -{ - "keywords": [ - "array of", - "extracted", - "keywords", - ] -} -``` - -### Stable Diffusion prompt generation -`POST /api/prompt` -#### **Input** -``` -{ "name": "character name (optional)", "text": "textual summary of a character" } -``` -#### **Output** -``` -{ "prompts": [ "array of generated prompts" ] } -``` - ### Stable Diffusion image generation `POST /api/image` #### **Input** diff --git a/constants.py b/constants.py index 154321f..66c62a6 100644 --- a/constants.py +++ b/constants.py @@ -5,8 +5,6 @@ DEFAULT_SUMMARIZATION_MODEL = "Qiliang/bart-large-cnn-samsum-ChatGPT_v3" DEFAULT_CLASSIFICATION_MODEL = "nateraw/bert-base-uncased-emotion" # Also try: 'Salesforce/blip-image-captioning-base' DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large" -DEFAULT_KEYPHRASE_MODEL = "ml6team/keyphrase-extraction-distilbert-inspec" -DEFAULT_PROMPT_MODEL = "FredZhang7/anime-anything-promptgen-v2" DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped" DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2" DEFAULT_REMOTE_SD_HOST = "127.0.0.1" @@ -49,62 +47,3 @@ error legs, bad legs, multiple legs, missing legs, error lighting, error shadow, error reflection, text, error, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry""" - - -# list of key phrases to be looking for in text (unused for now) -INDICATOR_LIST = [ - "female", - "girl", - "male", - "boy", - "woman", - "man", - "hair", - "eyes", - "skin", - "wears", - "appearance", - "costume", - "clothes", - "body", - "tall", - "short", - "chubby", - "thin", - "expression", - "angry", - "sad", - "blush", - "smile", - "happy", - "depressed", - "long", - "cold", - "breasts", - "chest", - "tail", - "ears", - "fur", - "race", - "species", - "wearing", - "shoes", - "boots", - "shirt", - "panties", - "bra", - "skirt", - "dress", - "kimono", - "wings", - "horns", - "pants", - "shorts", - "leggins", - "sandals", - "hat", - "glasses", - "sweater", - "hoodie", - "sweatshirt", -] diff --git a/pipelines.py b/pipelines.py deleted file mode 100644 index 92733dc..0000000 --- a/pipelines.py +++ /dev/null @@ -1,26 +0,0 @@ -from transformers import ( - AutoModelForTokenClassification, - AutoTokenizer, - TokenClassificationPipeline, -) -from transformers.pipelines import AggregationStrategy -import numpy as np - - -class KeyphraseExtractionPipeline(TokenClassificationPipeline): - def __init__(self, model, *args, **kwargs): - super().__init__( - model=AutoModelForTokenClassification.from_pretrained(model), - tokenizer=AutoTokenizer.from_pretrained(model), - *args, - **kwargs - ) - - def postprocess(self, model_outputs): - results = super().postprocess( - model_outputs=model_outputs, - aggregation_strategy=AggregationStrategy.SIMPLE - if self.model.config.model_type == "roberta" - else AggregationStrategy.FIRST, - ) - return np.unique([result.get("word").strip() for result in results]) diff --git a/server.py b/server.py index 15633e3..36085c4 100644 --- a/server.py +++ b/server.py @@ -13,7 +13,7 @@ import markdown import argparse from transformers import AutoTokenizer, AutoProcessor, pipeline from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM -from transformers import BlipForConditionalGeneration, GPT2Tokenizer +from transformers import BlipForConditionalGeneration import unicodedata import torch import time @@ -41,7 +41,7 @@ class SplitArgs(argparse.Action): # Script arguments parser = argparse.ArgumentParser( - prog="TavernAI Extras", description="Web API for transformers models" + prog="SillyTavern Extras", description="Web API for transformers models" ) parser.add_argument( "--port", type=int, help="Specify the port on which the application is hosted" @@ -58,10 +58,6 @@ parser.add_argument( "--classification-model", help="Load a custom text classification model" ) parser.add_argument("--captioning-model", help="Load a custom captioning model") -parser.add_argument( - "--keyphrase-model", help="Load a custom keyphrase extraction model" -) -parser.add_argument("--prompt-model", help="Load a custom prompt generation model") parser.add_argument("--embedding-model", help="Load a custom text embedding model") parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance") parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)") @@ -119,10 +115,6 @@ classification_model = ( captioning_model = ( args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL ) -keyphrase_model = ( - args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL -) -prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL embedding_model = ( args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL ) @@ -178,21 +170,6 @@ if "classify" in modules: torch_dtype=torch_dtype, ) -if "keywords" in modules: - print("Initializing a keyword extraction pipeline...") - import pipelines as pipelines - - keyphrase_pipe = pipelines.KeyphraseExtractionPipeline(keyphrase_model) - -if "prompt" in modules: - print("Initializing a prompt generator") - gpt_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2") - gpt_tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - gpt_model = AutoModelForCausalLM.from_pretrained(prompt_model) - prompt_generator = pipeline( - "text-generation", model=gpt_model, tokenizer=gpt_tokenizer - ) - if "sd" in modules and not sd_use_remote: from diffusers import StableDiffusionPipeline from diffusers import EulerAncestralDiscreteScheduler @@ -362,29 +339,6 @@ def normalize_string(input: str) -> str: return output -def extract_keywords(text: str) -> list: - punctuation = "(){}[]\n\r<>" - trans = str.maketrans(punctuation, " " * len(punctuation)) - text = text.translate(trans) - text = normalize_string(text) - return list(keyphrase_pipe(text)) - - -def generate_prompt(keywords: list, length: int = 100, num: int = 4) -> str: - prompt = ", ".join(keywords) - outs = prompt_generator( - prompt, - max_length=length, - num_return_sequences=num, - do_sample=True, - repetition_penalty=1.2, - temperature=0.7, - top_k=4, - early_stopping=True, - ) - return [out["generated_text"] for out in outs] - - def generate_image(data: dict) -> Image: prompt = normalize_string(f'{data["prompt_prefix"]} {data["prompt"]}') @@ -552,39 +506,6 @@ def api_classify_labels(): return jsonify({"labels": labels}) -@app.route("/api/keywords", methods=["POST"]) -@require_module("keywords") -def api_keywords(): - data = request.get_json() - - if "text" not in data or not isinstance(data["text"], str): - abort(400, '"text" is required') - - print("Keywords input:", data["text"], sep="\n") - keywords = extract_keywords(data["text"]) - print("Keywords output:", keywords, sep="\n") - return jsonify({"keywords": keywords}) - - -@app.route("/api/prompt", methods=["POST"]) -@require_module("prompt") -def api_prompt(): - data = request.get_json() - - if "text" not in data or not isinstance(data["text"], str): - abort(400, '"text" is required') - - keywords = extract_keywords(data["text"]) - - if "name" in data and isinstance(data["name"], str): - keywords.insert(0, data["name"]) - - print("Prompt input:", data["text"], sep="\n") - prompts = generate_prompt(keywords) - print("Prompt output:", prompts, sep="\n") - return jsonify({"prompts": prompts}) - - @app.route("/api/image", methods=["POST"]) @require_module("sd") def api_image():