mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-21 15:09:01 +00:00
Add module system
This commit is contained in:
16
README.md
16
README.md
@@ -7,7 +7,18 @@ A set of unofficial APIs for various [TavernAI](https://github.com/TavernAI/Tave
|
||||
* Run `pip install -r requirements.txt`
|
||||
* Run `python server.py`
|
||||
|
||||
## Included functionality
|
||||
## Modules
|
||||
|
||||
| Name | Description |
|
||||
| ----------- | --------------------------------- |
|
||||
| `caption` | Image captioning |
|
||||
| `summarize` | Text summarization |
|
||||
| `classify` | Text sentiment classification |
|
||||
| `keywords` | Text key phrases extraction |
|
||||
| `prompt` | SD prompt generation from text |
|
||||
| `sd` | Stable Diffusion image generation |
|
||||
|
||||
## API Endpoints
|
||||
### BLIP model for image captioning
|
||||
`POST /api/caption`
|
||||
#### **Input**
|
||||
@@ -133,4 +144,5 @@ A set of unofficial APIs for various [TavernAI](https://github.com/TavernAI/Tave
|
||||
| `--keyphrase-model` | Load a custom key phrase extraction model.<br>Expects a HuggingFace model ID.<br>Default: [ml6team/keyphrase-extraction-distilbert-inspec](https://huggingface.co/ml6team/keyphrase-extraction-distilbert-inspec) |
|
||||
| `--prompt-model` | Load a custom GPT-2 prompt generation model.<br>Expects a HuggingFace model ID.<br>Default: [FredZhang7/anime-anything-promptgen-v2](https://huggingface.co/FredZhang7/anime-anything-promptgen-v2) |
|
||||
| `--sd-model` | Load a custom Stable Diffusion image generation model.<br>Expects a HuggingFace model ID.<br>Default: [ckpt/anything-v4.5-vae-swapped](https://huggingface.co/ckpt/anything-v4.5-vae-swapped)<br>*Must have VAE pre-baked in PyTorch format or the output will look drab!* |
|
||||
| `--sd-cpu` | Forces the Stable Diffusion generation pipeline to run on the CPU.<br>**SLOW!** |
|
||||
| `--sd-cpu` | Forces the Stable Diffusion generation pipeline to run on the CPU.<br>**SLOW!** |
|
||||
| `--enable-modules` | Override a list of enabled modules. Runs with everything enabled by default.<br>Expects a comma-separated list of module names. See [Modules](#modules) |
|
||||
22
pipelines.py
Normal file
22
pipelines.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from transformers import AutoModelForTokenClassification, AutoTokenizer, TokenClassificationPipeline
|
||||
from transformers.pipelines import AggregationStrategy
|
||||
import numpy as np
|
||||
|
||||
|
||||
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(
|
||||
model=AutoModelForTokenClassification.from_pretrained(model),
|
||||
tokenizer=AutoTokenizer.from_pretrained(model),
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
results = super().postprocess(
|
||||
model_outputs=model_outputs,
|
||||
aggregation_strategy=AggregationStrategy.SIMPLE
|
||||
if self.model.config.model_type == "roberta"
|
||||
else AggregationStrategy.FIRST,
|
||||
)
|
||||
return np.unique([result.get("word").strip() for result in results])
|
||||
112
server.py
112
server.py
@@ -3,8 +3,6 @@ import markdown
|
||||
import argparse
|
||||
from transformers import AutoTokenizer, AutoProcessor, pipeline
|
||||
from transformers import BlipForConditionalGeneration, BartForConditionalGeneration
|
||||
from transformers import AutoModelForTokenClassification, TokenClassificationPipeline
|
||||
from transformers.pipelines import AggregationStrategy
|
||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
import unicodedata
|
||||
import torch
|
||||
@@ -12,7 +10,6 @@ import time
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
|
||||
@@ -25,6 +22,7 @@ DEFAULT_CAPTIONING_MODEL = 'Salesforce/blip-image-captioning-base'
|
||||
DEFAULT_KEYPHRASE_MODEL = 'ml6team/keyphrase-extraction-distilbert-inspec'
|
||||
DEFAULT_PROMPT_MODEL = 'FredZhang7/anime-anything-promptgen-v2'
|
||||
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
|
||||
ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd']
|
||||
DEFAULT_SUMMARIZE_PARAMS = {
|
||||
'temperature': 1.0,
|
||||
'repetition_penalty': 1.0,
|
||||
@@ -59,6 +57,8 @@ parser.add_argument('--sd-model',
|
||||
help="Load a custom SD image generation model")
|
||||
parser.add_argument('--sd-cpu',
|
||||
help="Force the SD pipeline to run on the CPU")
|
||||
parser.add_argument('--enable-modules', nargs='*', default=[],
|
||||
help="Override a list of enabled modules")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -70,73 +70,49 @@ captioning_model = args.captioning_model if args.captioning_model else DEFAULT_C
|
||||
keyphrase_model = args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL
|
||||
prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL
|
||||
sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
|
||||
modules = args.enable_modules if args.enable_modules and len(args.enable_modules) > 0 else ALL_MODULES
|
||||
|
||||
# Models init
|
||||
device_string = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"
|
||||
device = torch.device(device_string)
|
||||
torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
|
||||
|
||||
print('Initializing BLIP image captioning model...')
|
||||
blip_processor = AutoProcessor.from_pretrained(captioning_model)
|
||||
blip = BlipForConditionalGeneration.from_pretrained(
|
||||
captioning_model, torch_dtype=torch_dtype).to(device)
|
||||
if 'caption' in modules:
|
||||
print('Initializing BLIP image captioning model...')
|
||||
blip_processor = AutoProcessor.from_pretrained(captioning_model)
|
||||
blip = BlipForConditionalGeneration.from_pretrained(captioning_model, torch_dtype=torch_dtype).to(device)
|
||||
|
||||
print('Initializing BART text summarization model...')
|
||||
bart_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
|
||||
bart = BartForConditionalGeneration.from_pretrained(
|
||||
summarization_model, torch_dtype=torch_dtype).to(device)
|
||||
if 'summarize' in modules:
|
||||
print('Initializing BART text summarization model...')
|
||||
bart_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
|
||||
bart = BartForConditionalGeneration.from_pretrained(summarization_model, torch_dtype=torch_dtype).to(device)
|
||||
|
||||
print('Initializing BERT sentiment classification model...')
|
||||
bert_classifier = pipeline("text-classification", model=classification_model,
|
||||
top_k=None, device=device, torch_dtype=torch_dtype)
|
||||
if 'classify' in modules:
|
||||
print('Initializing BERT sentiment classification model...')
|
||||
bert_classifier = pipeline("text-classification", model=classification_model, top_k=None, device=device, torch_dtype=torch_dtype)
|
||||
|
||||
print('Initializing keyword extractor...')
|
||||
if 'keywords' in modules:
|
||||
print('Initializing keyword extractor...')
|
||||
import pipelines as pipelines
|
||||
keyphrase_pipe = pipelines.KeyphraseExtractionPipeline(keyphrase_model)
|
||||
|
||||
if 'prompt' in modules:
|
||||
print('Initializing GPT prompt generator')
|
||||
gpt_tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
|
||||
gpt_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
gpt_model = GPT2LMHeadModel.from_pretrained(prompt_model)
|
||||
prompt_generator = pipeline('text-generation', model=gpt_model, tokenizer=gpt_tokenizer)
|
||||
|
||||
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(
|
||||
model=AutoModelForTokenClassification.from_pretrained(model),
|
||||
tokenizer=AutoTokenizer.from_pretrained(model),
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
results = super().postprocess(
|
||||
model_outputs=model_outputs,
|
||||
aggregation_strategy=AggregationStrategy.SIMPLE
|
||||
if self.model.config.model_type == "roberta"
|
||||
else AggregationStrategy.FIRST,
|
||||
)
|
||||
return np.unique([result.get("word").strip() for result in results])
|
||||
|
||||
|
||||
keyphrase_pipe = KeyphraseExtractionPipeline(keyphrase_model)
|
||||
|
||||
print('Initializing GPT prompt generator')
|
||||
gpt_tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
|
||||
gpt_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
gpt_model = GPT2LMHeadModel.from_pretrained(
|
||||
'FredZhang7/anime-anything-promptgen-v2')
|
||||
prompt_generator = pipeline(
|
||||
'text-generation', model=gpt_model, tokenizer=gpt_tokenizer)
|
||||
|
||||
|
||||
print('Initializing Stable Diffusion pipeline')
|
||||
sd_device_string = "cuda" if torch.cuda.is_available() and not args.sd_cpu else "cpu"
|
||||
sd_device = torch.device(sd_device_string)
|
||||
sd_torch_dtype = torch.float32 if sd_device_string == "cpu" else torch.float16
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained(
|
||||
sd_model,
|
||||
custom_pipeline="lpw_stable_diffusion",
|
||||
torch_dtype=sd_torch_dtype,
|
||||
).to(sd_device)
|
||||
sd_pipe.safety_checker = lambda images, clip_input: (images, False)
|
||||
sd_pipe.enable_attention_slicing()
|
||||
# pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config)
|
||||
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||
sd_pipe.scheduler.config)
|
||||
if 'sd' in modules:
|
||||
print('Initializing Stable Diffusion pipeline')
|
||||
sd_device_string = "cuda" if torch.cuda.is_available() and not args.sd_cpu else "cpu"
|
||||
sd_device = torch.device(sd_device_string)
|
||||
sd_torch_dtype = torch.float32 if sd_device_string == "cpu" else torch.float16
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model, custom_pipeline="lpw_stable_diffusion", torch_dtype=sd_torch_dtype).to(sd_device)
|
||||
sd_pipe.safety_checker = lambda images, clip_input: (images, False)
|
||||
sd_pipe.enable_attention_slicing()
|
||||
# pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config)
|
||||
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
|
||||
|
||||
prompt_prefix = "best quality, absurdres, "
|
||||
neg_prompt = """lowres, bad anatomy, error body, error hair, error arm,
|
||||
@@ -264,6 +240,9 @@ def index():
|
||||
|
||||
@app.route('/api/caption', methods=['POST'])
|
||||
def api_caption():
|
||||
if 'caption' not in modules:
|
||||
abort(403, 'Module is disabled by config')
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if not 'image' in data or not isinstance(data['image'], str):
|
||||
@@ -276,6 +255,9 @@ def api_caption():
|
||||
|
||||
@app.route('/api/summarize', methods=['POST'])
|
||||
def api_summarize():
|
||||
if 'summarize' not in modules:
|
||||
abort(403, 'Module is disabled by config')
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if not 'text' in data or not isinstance(data['text'], str):
|
||||
@@ -292,6 +274,9 @@ def api_summarize():
|
||||
|
||||
@app.route('/api/classify', methods=['POST'])
|
||||
def api_classify():
|
||||
if 'classify' not in modules:
|
||||
abort(403, 'Module is disabled by config')
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if not 'text' in data or not isinstance(data['text'], str):
|
||||
@@ -303,6 +288,9 @@ def api_classify():
|
||||
|
||||
@app.route('/api/keywords', methods=['POST'])
|
||||
def api_keywords():
|
||||
if 'keywords' not in modules:
|
||||
abort(403, 'Module is disabled by config')
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if not 'text' in data or not isinstance(data['text'], str):
|
||||
@@ -314,6 +302,9 @@ def api_keywords():
|
||||
|
||||
@app.route('/api/prompt', methods=['POST'])
|
||||
def api_prompt():
|
||||
if 'prompt' not in modules:
|
||||
abort(403, 'Module is disabled by config')
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if not 'text' in data or not isinstance(data['text'], str):
|
||||
@@ -330,6 +321,9 @@ def api_prompt():
|
||||
|
||||
@app.route('/api/image', methods=['POST'])
|
||||
def api_image():
|
||||
if 'sd' not in modules:
|
||||
abort(403, 'Module is disabled by config')
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if not 'prompt' in data or not isinstance(data['prompt'], str):
|
||||
|
||||
Reference in New Issue
Block a user