Remove unused modules

This commit is contained in:
SillyLossy
2023-06-04 01:58:49 +03:00
parent befdc746c5
commit a0000347e9
4 changed files with 2 additions and 198 deletions

View File

@@ -120,8 +120,6 @@ cd SillyTavern-extras
| `caption` | Image captioning | ✔️ Yes |
| `summarize` | Text summarization | ✔️ Yes |
| `classify` | Text sentiment classification | ✔️ Yes |
| `keywords` | Text key phrases extraction | ✔️ Yes |
| `prompt` | SD prompt generation from text | ✔️ Yes |
| `sd` | Stable Diffusion image generation | :x: No (✔️ remote) |
| `tts` | [Silero TTS server](https://github.com/ouoertheo/silero-api-server) | :x: No |
| `chromadb` | Infinity context server | :x: No |
@@ -264,34 +262,6 @@ None
> 2. List of categories defined by the summarization model
> 3. Value range from 0.0 to 1.0
### Key phrase extraction
`POST /api/keywords`
#### **Input**
```
{ "text": "text to be scanned for key phrases" }
```
#### **Output**
```
{
"keywords": [
"array of",
"extracted",
"keywords",
]
}
```
### Stable Diffusion prompt generation
`POST /api/prompt`
#### **Input**
```
{ "name": "character name (optional)", "text": "textual summary of a character" }
```
#### **Output**
```
{ "prompts": [ "array of generated prompts" ] }
```
### Stable Diffusion image generation
`POST /api/image`
#### **Input**

View File

@@ -5,8 +5,6 @@ DEFAULT_SUMMARIZATION_MODEL = "Qiliang/bart-large-cnn-samsum-ChatGPT_v3"
DEFAULT_CLASSIFICATION_MODEL = "nateraw/bert-base-uncased-emotion"
# Also try: 'Salesforce/blip-image-captioning-base'
DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large"
DEFAULT_KEYPHRASE_MODEL = "ml6team/keyphrase-extraction-distilbert-inspec"
DEFAULT_PROMPT_MODEL = "FredZhang7/anime-anything-promptgen-v2"
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
@@ -49,62 +47,3 @@ error legs, bad legs, multiple legs, missing legs, error lighting,
error shadow, error reflection, text, error, extra digit, fewer digits,
cropped, worst quality, low quality, normal quality, jpeg artifacts,
signature, watermark, username, blurry"""
# list of key phrases to be looking for in text (unused for now)
INDICATOR_LIST = [
"female",
"girl",
"male",
"boy",
"woman",
"man",
"hair",
"eyes",
"skin",
"wears",
"appearance",
"costume",
"clothes",
"body",
"tall",
"short",
"chubby",
"thin",
"expression",
"angry",
"sad",
"blush",
"smile",
"happy",
"depressed",
"long",
"cold",
"breasts",
"chest",
"tail",
"ears",
"fur",
"race",
"species",
"wearing",
"shoes",
"boots",
"shirt",
"panties",
"bra",
"skirt",
"dress",
"kimono",
"wings",
"horns",
"pants",
"shorts",
"leggins",
"sandals",
"hat",
"glasses",
"sweater",
"hoodie",
"sweatshirt",
]

View File

@@ -1,26 +0,0 @@
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
TokenClassificationPipeline,
)
from transformers.pipelines import AggregationStrategy
import numpy as np
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
def __init__(self, model, *args, **kwargs):
super().__init__(
model=AutoModelForTokenClassification.from_pretrained(model),
tokenizer=AutoTokenizer.from_pretrained(model),
*args,
**kwargs
)
def postprocess(self, model_outputs):
results = super().postprocess(
model_outputs=model_outputs,
aggregation_strategy=AggregationStrategy.SIMPLE
if self.model.config.model_type == "roberta"
else AggregationStrategy.FIRST,
)
return np.unique([result.get("word").strip() for result in results])

View File

@@ -13,7 +13,7 @@ import markdown
import argparse
from transformers import AutoTokenizer, AutoProcessor, pipeline
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
from transformers import BlipForConditionalGeneration, GPT2Tokenizer
from transformers import BlipForConditionalGeneration
import unicodedata
import torch
import time
@@ -41,7 +41,7 @@ class SplitArgs(argparse.Action):
# Script arguments
parser = argparse.ArgumentParser(
prog="TavernAI Extras", description="Web API for transformers models"
prog="SillyTavern Extras", description="Web API for transformers models"
)
parser.add_argument(
"--port", type=int, help="Specify the port on which the application is hosted"
@@ -58,10 +58,6 @@ parser.add_argument(
"--classification-model", help="Load a custom text classification model"
)
parser.add_argument("--captioning-model", help="Load a custom captioning model")
parser.add_argument(
"--keyphrase-model", help="Load a custom keyphrase extraction model"
)
parser.add_argument("--prompt-model", help="Load a custom prompt generation model")
parser.add_argument("--embedding-model", help="Load a custom text embedding model")
parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance")
parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)")
@@ -119,10 +115,6 @@ classification_model = (
captioning_model = (
args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
)
keyphrase_model = (
args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL
)
prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL
embedding_model = (
args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL
)
@@ -178,21 +170,6 @@ if "classify" in modules:
torch_dtype=torch_dtype,
)
if "keywords" in modules:
print("Initializing a keyword extraction pipeline...")
import pipelines as pipelines
keyphrase_pipe = pipelines.KeyphraseExtractionPipeline(keyphrase_model)
if "prompt" in modules:
print("Initializing a prompt generator")
gpt_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
gpt_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
gpt_model = AutoModelForCausalLM.from_pretrained(prompt_model)
prompt_generator = pipeline(
"text-generation", model=gpt_model, tokenizer=gpt_tokenizer
)
if "sd" in modules and not sd_use_remote:
from diffusers import StableDiffusionPipeline
from diffusers import EulerAncestralDiscreteScheduler
@@ -362,29 +339,6 @@ def normalize_string(input: str) -> str:
return output
def extract_keywords(text: str) -> list:
punctuation = "(){}[]\n\r<>"
trans = str.maketrans(punctuation, " " * len(punctuation))
text = text.translate(trans)
text = normalize_string(text)
return list(keyphrase_pipe(text))
def generate_prompt(keywords: list, length: int = 100, num: int = 4) -> str:
prompt = ", ".join(keywords)
outs = prompt_generator(
prompt,
max_length=length,
num_return_sequences=num,
do_sample=True,
repetition_penalty=1.2,
temperature=0.7,
top_k=4,
early_stopping=True,
)
return [out["generated_text"] for out in outs]
def generate_image(data: dict) -> Image:
prompt = normalize_string(f'{data["prompt_prefix"]} {data["prompt"]}')
@@ -552,39 +506,6 @@ def api_classify_labels():
return jsonify({"labels": labels})
@app.route("/api/keywords", methods=["POST"])
@require_module("keywords")
def api_keywords():
data = request.get_json()
if "text" not in data or not isinstance(data["text"], str):
abort(400, '"text" is required')
print("Keywords input:", data["text"], sep="\n")
keywords = extract_keywords(data["text"])
print("Keywords output:", keywords, sep="\n")
return jsonify({"keywords": keywords})
@app.route("/api/prompt", methods=["POST"])
@require_module("prompt")
def api_prompt():
data = request.get_json()
if "text" not in data or not isinstance(data["text"], str):
abort(400, '"text" is required')
keywords = extract_keywords(data["text"])
if "name" in data and isinstance(data["name"], str):
keywords.insert(0, data["name"])
print("Prompt input:", data["text"], sep="\n")
prompts = generate_prompt(keywords)
print("Prompt output:", prompts, sep="\n")
return jsonify({"prompts": prompts})
@app.route("/api/image", methods=["POST"])
@require_module("sd")
def api_image():