mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-13 23:40:12 +00:00
Remove unused modules
This commit is contained in:
30
README.md
30
README.md
@@ -120,8 +120,6 @@ cd SillyTavern-extras
|
||||
| `caption` | Image captioning | ✔️ Yes |
|
||||
| `summarize` | Text summarization | ✔️ Yes |
|
||||
| `classify` | Text sentiment classification | ✔️ Yes |
|
||||
| `keywords` | Text key phrases extraction | ✔️ Yes |
|
||||
| `prompt` | SD prompt generation from text | ✔️ Yes |
|
||||
| `sd` | Stable Diffusion image generation | :x: No (✔️ remote) |
|
||||
| `tts` | [Silero TTS server](https://github.com/ouoertheo/silero-api-server) | :x: No |
|
||||
| `chromadb` | Infinity context server | :x: No |
|
||||
@@ -264,34 +262,6 @@ None
|
||||
> 2. List of categories defined by the summarization model
|
||||
> 3. Value range from 0.0 to 1.0
|
||||
|
||||
### Key phrase extraction
|
||||
`POST /api/keywords`
|
||||
#### **Input**
|
||||
```
|
||||
{ "text": "text to be scanned for key phrases" }
|
||||
```
|
||||
#### **Output**
|
||||
```
|
||||
{
|
||||
"keywords": [
|
||||
"array of",
|
||||
"extracted",
|
||||
"keywords",
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Stable Diffusion prompt generation
|
||||
`POST /api/prompt`
|
||||
#### **Input**
|
||||
```
|
||||
{ "name": "character name (optional)", "text": "textual summary of a character" }
|
||||
```
|
||||
#### **Output**
|
||||
```
|
||||
{ "prompts": [ "array of generated prompts" ] }
|
||||
```
|
||||
|
||||
### Stable Diffusion image generation
|
||||
`POST /api/image`
|
||||
#### **Input**
|
||||
|
||||
61
constants.py
61
constants.py
@@ -5,8 +5,6 @@ DEFAULT_SUMMARIZATION_MODEL = "Qiliang/bart-large-cnn-samsum-ChatGPT_v3"
|
||||
DEFAULT_CLASSIFICATION_MODEL = "nateraw/bert-base-uncased-emotion"
|
||||
# Also try: 'Salesforce/blip-image-captioning-base'
|
||||
DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large"
|
||||
DEFAULT_KEYPHRASE_MODEL = "ml6team/keyphrase-extraction-distilbert-inspec"
|
||||
DEFAULT_PROMPT_MODEL = "FredZhang7/anime-anything-promptgen-v2"
|
||||
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
|
||||
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
|
||||
DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
|
||||
@@ -49,62 +47,3 @@ error legs, bad legs, multiple legs, missing legs, error lighting,
|
||||
error shadow, error reflection, text, error, extra digit, fewer digits,
|
||||
cropped, worst quality, low quality, normal quality, jpeg artifacts,
|
||||
signature, watermark, username, blurry"""
|
||||
|
||||
|
||||
# list of key phrases to be looking for in text (unused for now)
|
||||
INDICATOR_LIST = [
|
||||
"female",
|
||||
"girl",
|
||||
"male",
|
||||
"boy",
|
||||
"woman",
|
||||
"man",
|
||||
"hair",
|
||||
"eyes",
|
||||
"skin",
|
||||
"wears",
|
||||
"appearance",
|
||||
"costume",
|
||||
"clothes",
|
||||
"body",
|
||||
"tall",
|
||||
"short",
|
||||
"chubby",
|
||||
"thin",
|
||||
"expression",
|
||||
"angry",
|
||||
"sad",
|
||||
"blush",
|
||||
"smile",
|
||||
"happy",
|
||||
"depressed",
|
||||
"long",
|
||||
"cold",
|
||||
"breasts",
|
||||
"chest",
|
||||
"tail",
|
||||
"ears",
|
||||
"fur",
|
||||
"race",
|
||||
"species",
|
||||
"wearing",
|
||||
"shoes",
|
||||
"boots",
|
||||
"shirt",
|
||||
"panties",
|
||||
"bra",
|
||||
"skirt",
|
||||
"dress",
|
||||
"kimono",
|
||||
"wings",
|
||||
"horns",
|
||||
"pants",
|
||||
"shorts",
|
||||
"leggins",
|
||||
"sandals",
|
||||
"hat",
|
||||
"glasses",
|
||||
"sweater",
|
||||
"hoodie",
|
||||
"sweatshirt",
|
||||
]
|
||||
|
||||
26
pipelines.py
26
pipelines.py
@@ -1,26 +0,0 @@
|
||||
from transformers import (
|
||||
AutoModelForTokenClassification,
|
||||
AutoTokenizer,
|
||||
TokenClassificationPipeline,
|
||||
)
|
||||
from transformers.pipelines import AggregationStrategy
|
||||
import numpy as np
|
||||
|
||||
|
||||
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(
|
||||
model=AutoModelForTokenClassification.from_pretrained(model),
|
||||
tokenizer=AutoTokenizer.from_pretrained(model),
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
results = super().postprocess(
|
||||
model_outputs=model_outputs,
|
||||
aggregation_strategy=AggregationStrategy.SIMPLE
|
||||
if self.model.config.model_type == "roberta"
|
||||
else AggregationStrategy.FIRST,
|
||||
)
|
||||
return np.unique([result.get("word").strip() for result in results])
|
||||
83
server.py
83
server.py
@@ -13,7 +13,7 @@ import markdown
|
||||
import argparse
|
||||
from transformers import AutoTokenizer, AutoProcessor, pipeline
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
||||
from transformers import BlipForConditionalGeneration, GPT2Tokenizer
|
||||
from transformers import BlipForConditionalGeneration
|
||||
import unicodedata
|
||||
import torch
|
||||
import time
|
||||
@@ -41,7 +41,7 @@ class SplitArgs(argparse.Action):
|
||||
|
||||
# Script arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="TavernAI Extras", description="Web API for transformers models"
|
||||
prog="SillyTavern Extras", description="Web API for transformers models"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, help="Specify the port on which the application is hosted"
|
||||
@@ -58,10 +58,6 @@ parser.add_argument(
|
||||
"--classification-model", help="Load a custom text classification model"
|
||||
)
|
||||
parser.add_argument("--captioning-model", help="Load a custom captioning model")
|
||||
parser.add_argument(
|
||||
"--keyphrase-model", help="Load a custom keyphrase extraction model"
|
||||
)
|
||||
parser.add_argument("--prompt-model", help="Load a custom prompt generation model")
|
||||
parser.add_argument("--embedding-model", help="Load a custom text embedding model")
|
||||
parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance")
|
||||
parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)")
|
||||
@@ -119,10 +115,6 @@ classification_model = (
|
||||
captioning_model = (
|
||||
args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
|
||||
)
|
||||
keyphrase_model = (
|
||||
args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL
|
||||
)
|
||||
prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL
|
||||
embedding_model = (
|
||||
args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL
|
||||
)
|
||||
@@ -178,21 +170,6 @@ if "classify" in modules:
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
if "keywords" in modules:
|
||||
print("Initializing a keyword extraction pipeline...")
|
||||
import pipelines as pipelines
|
||||
|
||||
keyphrase_pipe = pipelines.KeyphraseExtractionPipeline(keyphrase_model)
|
||||
|
||||
if "prompt" in modules:
|
||||
print("Initializing a prompt generator")
|
||||
gpt_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
||||
gpt_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
gpt_model = AutoModelForCausalLM.from_pretrained(prompt_model)
|
||||
prompt_generator = pipeline(
|
||||
"text-generation", model=gpt_model, tokenizer=gpt_tokenizer
|
||||
)
|
||||
|
||||
if "sd" in modules and not sd_use_remote:
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
@@ -362,29 +339,6 @@ def normalize_string(input: str) -> str:
|
||||
return output
|
||||
|
||||
|
||||
def extract_keywords(text: str) -> list:
|
||||
punctuation = "(){}[]\n\r<>"
|
||||
trans = str.maketrans(punctuation, " " * len(punctuation))
|
||||
text = text.translate(trans)
|
||||
text = normalize_string(text)
|
||||
return list(keyphrase_pipe(text))
|
||||
|
||||
|
||||
def generate_prompt(keywords: list, length: int = 100, num: int = 4) -> str:
|
||||
prompt = ", ".join(keywords)
|
||||
outs = prompt_generator(
|
||||
prompt,
|
||||
max_length=length,
|
||||
num_return_sequences=num,
|
||||
do_sample=True,
|
||||
repetition_penalty=1.2,
|
||||
temperature=0.7,
|
||||
top_k=4,
|
||||
early_stopping=True,
|
||||
)
|
||||
return [out["generated_text"] for out in outs]
|
||||
|
||||
|
||||
def generate_image(data: dict) -> Image:
|
||||
prompt = normalize_string(f'{data["prompt_prefix"]} {data["prompt"]}')
|
||||
|
||||
@@ -552,39 +506,6 @@ def api_classify_labels():
|
||||
return jsonify({"labels": labels})
|
||||
|
||||
|
||||
@app.route("/api/keywords", methods=["POST"])
|
||||
@require_module("keywords")
|
||||
def api_keywords():
|
||||
data = request.get_json()
|
||||
|
||||
if "text" not in data or not isinstance(data["text"], str):
|
||||
abort(400, '"text" is required')
|
||||
|
||||
print("Keywords input:", data["text"], sep="\n")
|
||||
keywords = extract_keywords(data["text"])
|
||||
print("Keywords output:", keywords, sep="\n")
|
||||
return jsonify({"keywords": keywords})
|
||||
|
||||
|
||||
@app.route("/api/prompt", methods=["POST"])
|
||||
@require_module("prompt")
|
||||
def api_prompt():
|
||||
data = request.get_json()
|
||||
|
||||
if "text" not in data or not isinstance(data["text"], str):
|
||||
abort(400, '"text" is required')
|
||||
|
||||
keywords = extract_keywords(data["text"])
|
||||
|
||||
if "name" in data and isinstance(data["name"], str):
|
||||
keywords.insert(0, data["name"])
|
||||
|
||||
print("Prompt input:", data["text"], sep="\n")
|
||||
prompts = generate_prompt(keywords)
|
||||
print("Prompt output:", prompts, sep="\n")
|
||||
return jsonify({"prompts": prompts})
|
||||
|
||||
|
||||
@app.route("/api/image", methods=["POST"])
|
||||
@require_module("sd")
|
||||
def api_image():
|
||||
|
||||
Reference in New Issue
Block a user