mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-30 19:31:20 +00:00
Remove unused modules
This commit is contained in:
30
README.md
30
README.md
@@ -120,8 +120,6 @@ cd SillyTavern-extras
|
|||||||
| `caption` | Image captioning | ✔️ Yes |
|
| `caption` | Image captioning | ✔️ Yes |
|
||||||
| `summarize` | Text summarization | ✔️ Yes |
|
| `summarize` | Text summarization | ✔️ Yes |
|
||||||
| `classify` | Text sentiment classification | ✔️ Yes |
|
| `classify` | Text sentiment classification | ✔️ Yes |
|
||||||
| `keywords` | Text key phrases extraction | ✔️ Yes |
|
|
||||||
| `prompt` | SD prompt generation from text | ✔️ Yes |
|
|
||||||
| `sd` | Stable Diffusion image generation | :x: No (✔️ remote) |
|
| `sd` | Stable Diffusion image generation | :x: No (✔️ remote) |
|
||||||
| `tts` | [Silero TTS server](https://github.com/ouoertheo/silero-api-server) | :x: No |
|
| `tts` | [Silero TTS server](https://github.com/ouoertheo/silero-api-server) | :x: No |
|
||||||
| `chromadb` | Infinity context server | :x: No |
|
| `chromadb` | Infinity context server | :x: No |
|
||||||
@@ -264,34 +262,6 @@ None
|
|||||||
> 2. List of categories defined by the summarization model
|
> 2. List of categories defined by the summarization model
|
||||||
> 3. Value range from 0.0 to 1.0
|
> 3. Value range from 0.0 to 1.0
|
||||||
|
|
||||||
### Key phrase extraction
|
|
||||||
`POST /api/keywords`
|
|
||||||
#### **Input**
|
|
||||||
```
|
|
||||||
{ "text": "text to be scanned for key phrases" }
|
|
||||||
```
|
|
||||||
#### **Output**
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"keywords": [
|
|
||||||
"array of",
|
|
||||||
"extracted",
|
|
||||||
"keywords",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Stable Diffusion prompt generation
|
|
||||||
`POST /api/prompt`
|
|
||||||
#### **Input**
|
|
||||||
```
|
|
||||||
{ "name": "character name (optional)", "text": "textual summary of a character" }
|
|
||||||
```
|
|
||||||
#### **Output**
|
|
||||||
```
|
|
||||||
{ "prompts": [ "array of generated prompts" ] }
|
|
||||||
```
|
|
||||||
|
|
||||||
### Stable Diffusion image generation
|
### Stable Diffusion image generation
|
||||||
`POST /api/image`
|
`POST /api/image`
|
||||||
#### **Input**
|
#### **Input**
|
||||||
|
|||||||
61
constants.py
61
constants.py
@@ -5,8 +5,6 @@ DEFAULT_SUMMARIZATION_MODEL = "Qiliang/bart-large-cnn-samsum-ChatGPT_v3"
|
|||||||
DEFAULT_CLASSIFICATION_MODEL = "nateraw/bert-base-uncased-emotion"
|
DEFAULT_CLASSIFICATION_MODEL = "nateraw/bert-base-uncased-emotion"
|
||||||
# Also try: 'Salesforce/blip-image-captioning-base'
|
# Also try: 'Salesforce/blip-image-captioning-base'
|
||||||
DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large"
|
DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large"
|
||||||
DEFAULT_KEYPHRASE_MODEL = "ml6team/keyphrase-extraction-distilbert-inspec"
|
|
||||||
DEFAULT_PROMPT_MODEL = "FredZhang7/anime-anything-promptgen-v2"
|
|
||||||
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
|
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
|
||||||
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
|
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
|
||||||
DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
|
DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
|
||||||
@@ -49,62 +47,3 @@ error legs, bad legs, multiple legs, missing legs, error lighting,
|
|||||||
error shadow, error reflection, text, error, extra digit, fewer digits,
|
error shadow, error reflection, text, error, extra digit, fewer digits,
|
||||||
cropped, worst quality, low quality, normal quality, jpeg artifacts,
|
cropped, worst quality, low quality, normal quality, jpeg artifacts,
|
||||||
signature, watermark, username, blurry"""
|
signature, watermark, username, blurry"""
|
||||||
|
|
||||||
|
|
||||||
# list of key phrases to be looking for in text (unused for now)
|
|
||||||
INDICATOR_LIST = [
|
|
||||||
"female",
|
|
||||||
"girl",
|
|
||||||
"male",
|
|
||||||
"boy",
|
|
||||||
"woman",
|
|
||||||
"man",
|
|
||||||
"hair",
|
|
||||||
"eyes",
|
|
||||||
"skin",
|
|
||||||
"wears",
|
|
||||||
"appearance",
|
|
||||||
"costume",
|
|
||||||
"clothes",
|
|
||||||
"body",
|
|
||||||
"tall",
|
|
||||||
"short",
|
|
||||||
"chubby",
|
|
||||||
"thin",
|
|
||||||
"expression",
|
|
||||||
"angry",
|
|
||||||
"sad",
|
|
||||||
"blush",
|
|
||||||
"smile",
|
|
||||||
"happy",
|
|
||||||
"depressed",
|
|
||||||
"long",
|
|
||||||
"cold",
|
|
||||||
"breasts",
|
|
||||||
"chest",
|
|
||||||
"tail",
|
|
||||||
"ears",
|
|
||||||
"fur",
|
|
||||||
"race",
|
|
||||||
"species",
|
|
||||||
"wearing",
|
|
||||||
"shoes",
|
|
||||||
"boots",
|
|
||||||
"shirt",
|
|
||||||
"panties",
|
|
||||||
"bra",
|
|
||||||
"skirt",
|
|
||||||
"dress",
|
|
||||||
"kimono",
|
|
||||||
"wings",
|
|
||||||
"horns",
|
|
||||||
"pants",
|
|
||||||
"shorts",
|
|
||||||
"leggins",
|
|
||||||
"sandals",
|
|
||||||
"hat",
|
|
||||||
"glasses",
|
|
||||||
"sweater",
|
|
||||||
"hoodie",
|
|
||||||
"sweatshirt",
|
|
||||||
]
|
|
||||||
|
|||||||
26
pipelines.py
26
pipelines.py
@@ -1,26 +0,0 @@
|
|||||||
from transformers import (
|
|
||||||
AutoModelForTokenClassification,
|
|
||||||
AutoTokenizer,
|
|
||||||
TokenClassificationPipeline,
|
|
||||||
)
|
|
||||||
from transformers.pipelines import AggregationStrategy
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
|
|
||||||
def __init__(self, model, *args, **kwargs):
|
|
||||||
super().__init__(
|
|
||||||
model=AutoModelForTokenClassification.from_pretrained(model),
|
|
||||||
tokenizer=AutoTokenizer.from_pretrained(model),
|
|
||||||
*args,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def postprocess(self, model_outputs):
|
|
||||||
results = super().postprocess(
|
|
||||||
model_outputs=model_outputs,
|
|
||||||
aggregation_strategy=AggregationStrategy.SIMPLE
|
|
||||||
if self.model.config.model_type == "roberta"
|
|
||||||
else AggregationStrategy.FIRST,
|
|
||||||
)
|
|
||||||
return np.unique([result.get("word").strip() for result in results])
|
|
||||||
83
server.py
83
server.py
@@ -13,7 +13,7 @@ import markdown
|
|||||||
import argparse
|
import argparse
|
||||||
from transformers import AutoTokenizer, AutoProcessor, pipeline
|
from transformers import AutoTokenizer, AutoProcessor, pipeline
|
||||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
||||||
from transformers import BlipForConditionalGeneration, GPT2Tokenizer
|
from transformers import BlipForConditionalGeneration
|
||||||
import unicodedata
|
import unicodedata
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
@@ -41,7 +41,7 @@ class SplitArgs(argparse.Action):
|
|||||||
|
|
||||||
# Script arguments
|
# Script arguments
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
prog="TavernAI Extras", description="Web API for transformers models"
|
prog="SillyTavern Extras", description="Web API for transformers models"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port", type=int, help="Specify the port on which the application is hosted"
|
"--port", type=int, help="Specify the port on which the application is hosted"
|
||||||
@@ -58,10 +58,6 @@ parser.add_argument(
|
|||||||
"--classification-model", help="Load a custom text classification model"
|
"--classification-model", help="Load a custom text classification model"
|
||||||
)
|
)
|
||||||
parser.add_argument("--captioning-model", help="Load a custom captioning model")
|
parser.add_argument("--captioning-model", help="Load a custom captioning model")
|
||||||
parser.add_argument(
|
|
||||||
"--keyphrase-model", help="Load a custom keyphrase extraction model"
|
|
||||||
)
|
|
||||||
parser.add_argument("--prompt-model", help="Load a custom prompt generation model")
|
|
||||||
parser.add_argument("--embedding-model", help="Load a custom text embedding model")
|
parser.add_argument("--embedding-model", help="Load a custom text embedding model")
|
||||||
parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance")
|
parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance")
|
||||||
parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)")
|
parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)")
|
||||||
@@ -119,10 +115,6 @@ classification_model = (
|
|||||||
captioning_model = (
|
captioning_model = (
|
||||||
args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
|
args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
|
||||||
)
|
)
|
||||||
keyphrase_model = (
|
|
||||||
args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL
|
|
||||||
)
|
|
||||||
prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL
|
|
||||||
embedding_model = (
|
embedding_model = (
|
||||||
args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL
|
args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL
|
||||||
)
|
)
|
||||||
@@ -178,21 +170,6 @@ if "classify" in modules:
|
|||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "keywords" in modules:
|
|
||||||
print("Initializing a keyword extraction pipeline...")
|
|
||||||
import pipelines as pipelines
|
|
||||||
|
|
||||||
keyphrase_pipe = pipelines.KeyphraseExtractionPipeline(keyphrase_model)
|
|
||||||
|
|
||||||
if "prompt" in modules:
|
|
||||||
print("Initializing a prompt generator")
|
|
||||||
gpt_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
|
||||||
gpt_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
|
||||||
gpt_model = AutoModelForCausalLM.from_pretrained(prompt_model)
|
|
||||||
prompt_generator = pipeline(
|
|
||||||
"text-generation", model=gpt_model, tokenizer=gpt_tokenizer
|
|
||||||
)
|
|
||||||
|
|
||||||
if "sd" in modules and not sd_use_remote:
|
if "sd" in modules and not sd_use_remote:
|
||||||
from diffusers import StableDiffusionPipeline
|
from diffusers import StableDiffusionPipeline
|
||||||
from diffusers import EulerAncestralDiscreteScheduler
|
from diffusers import EulerAncestralDiscreteScheduler
|
||||||
@@ -362,29 +339,6 @@ def normalize_string(input: str) -> str:
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def extract_keywords(text: str) -> list:
|
|
||||||
punctuation = "(){}[]\n\r<>"
|
|
||||||
trans = str.maketrans(punctuation, " " * len(punctuation))
|
|
||||||
text = text.translate(trans)
|
|
||||||
text = normalize_string(text)
|
|
||||||
return list(keyphrase_pipe(text))
|
|
||||||
|
|
||||||
|
|
||||||
def generate_prompt(keywords: list, length: int = 100, num: int = 4) -> str:
|
|
||||||
prompt = ", ".join(keywords)
|
|
||||||
outs = prompt_generator(
|
|
||||||
prompt,
|
|
||||||
max_length=length,
|
|
||||||
num_return_sequences=num,
|
|
||||||
do_sample=True,
|
|
||||||
repetition_penalty=1.2,
|
|
||||||
temperature=0.7,
|
|
||||||
top_k=4,
|
|
||||||
early_stopping=True,
|
|
||||||
)
|
|
||||||
return [out["generated_text"] for out in outs]
|
|
||||||
|
|
||||||
|
|
||||||
def generate_image(data: dict) -> Image:
|
def generate_image(data: dict) -> Image:
|
||||||
prompt = normalize_string(f'{data["prompt_prefix"]} {data["prompt"]}')
|
prompt = normalize_string(f'{data["prompt_prefix"]} {data["prompt"]}')
|
||||||
|
|
||||||
@@ -552,39 +506,6 @@ def api_classify_labels():
|
|||||||
return jsonify({"labels": labels})
|
return jsonify({"labels": labels})
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/keywords", methods=["POST"])
|
|
||||||
@require_module("keywords")
|
|
||||||
def api_keywords():
|
|
||||||
data = request.get_json()
|
|
||||||
|
|
||||||
if "text" not in data or not isinstance(data["text"], str):
|
|
||||||
abort(400, '"text" is required')
|
|
||||||
|
|
||||||
print("Keywords input:", data["text"], sep="\n")
|
|
||||||
keywords = extract_keywords(data["text"])
|
|
||||||
print("Keywords output:", keywords, sep="\n")
|
|
||||||
return jsonify({"keywords": keywords})
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/prompt", methods=["POST"])
|
|
||||||
@require_module("prompt")
|
|
||||||
def api_prompt():
|
|
||||||
data = request.get_json()
|
|
||||||
|
|
||||||
if "text" not in data or not isinstance(data["text"], str):
|
|
||||||
abort(400, '"text" is required')
|
|
||||||
|
|
||||||
keywords = extract_keywords(data["text"])
|
|
||||||
|
|
||||||
if "name" in data and isinstance(data["name"], str):
|
|
||||||
keywords.insert(0, data["name"])
|
|
||||||
|
|
||||||
print("Prompt input:", data["text"], sep="\n")
|
|
||||||
prompts = generate_prompt(keywords)
|
|
||||||
print("Prompt output:", prompts, sep="\n")
|
|
||||||
return jsonify({"prompts": prompts})
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/image", methods=["POST"])
|
@app.route("/api/image", methods=["POST"])
|
||||||
@require_module("sd")
|
@require_module("sd")
|
||||||
def api_image():
|
def api_image():
|
||||||
|
|||||||
Reference in New Issue
Block a user