Add module system

This commit is contained in:
SillyLossy
2023-03-01 18:37:33 +02:00
parent a8dc30be0c
commit 5fd7d96fe3
3 changed files with 89 additions and 61 deletions

View File

@@ -7,7 +7,18 @@ A set of unofficial APIs for various [TavernAI](https://github.com/TavernAI/Tave
* Run `pip install -r requirements.txt`
* Run `python server.py`
## Included functionality
## Modules
| Name | Description |
| ----------- | --------------------------------- |
| `caption` | Image captioning |
| `summarize` | Text summarization |
| `classify` | Text sentiment classification |
| `keywords` | Text key phrases extraction |
| `prompt` | SD prompt generation from text |
| `sd` | Stable Diffusion image generation |
## API Endpoints
### BLIP model for image captioning
`POST /api/caption`
#### **Input**
@@ -133,4 +144,5 @@ A set of unofficial APIs for various [TavernAI](https://github.com/TavernAI/Tave
| `--keyphrase-model` | Load a custom key phrase extraction model.<br>Expects a HuggingFace model ID.<br>Default: [ml6team/keyphrase-extraction-distilbert-inspec](https://huggingface.co/ml6team/keyphrase-extraction-distilbert-inspec) |
| `--prompt-model` | Load a custom GPT-2 prompt generation model.<br>Expects a HuggingFace model ID.<br>Default: [FredZhang7/anime-anything-promptgen-v2](https://huggingface.co/FredZhang7/anime-anything-promptgen-v2) |
| `--sd-model` | Load a custom Stable Diffusion image generation model.<br>Expects a HuggingFace model ID.<br>Default: [ckpt/anything-v4.5-vae-swapped](https://huggingface.co/ckpt/anything-v4.5-vae-swapped)<br>*Must have VAE pre-baked in PyTorch format or the output will look drab!* |
| `--sd-cpu` | Forces the Stable Diffusion generation pipeline to run on the CPU.<br>**SLOW!** |
| `--sd-cpu` | Forces the Stable Diffusion generation pipeline to run on the CPU.<br>**SLOW!** |
| `--enable-modules` | Override a list of enabled modules. Runs with everything enabled by default.<br>Expects a comma-separated list of module names. See [Modules](#modules) |

22
pipelines.py Normal file
View File

@@ -0,0 +1,22 @@
from transformers import AutoModelForTokenClassification, AutoTokenizer, TokenClassificationPipeline
from transformers.pipelines import AggregationStrategy
import numpy as np
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
def __init__(self, model, *args, **kwargs):
super().__init__(
model=AutoModelForTokenClassification.from_pretrained(model),
tokenizer=AutoTokenizer.from_pretrained(model),
*args,
**kwargs
)
def postprocess(self, model_outputs):
results = super().postprocess(
model_outputs=model_outputs,
aggregation_strategy=AggregationStrategy.SIMPLE
if self.model.config.model_type == "roberta"
else AggregationStrategy.FIRST,
)
return np.unique([result.get("word").strip() for result in results])

112
server.py
View File

@@ -3,8 +3,6 @@ import markdown
import argparse
from transformers import AutoTokenizer, AutoProcessor, pipeline
from transformers import BlipForConditionalGeneration, BartForConditionalGeneration
from transformers import AutoModelForTokenClassification, TokenClassificationPipeline
from transformers.pipelines import AggregationStrategy
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import unicodedata
import torch
@@ -12,7 +10,6 @@ import time
from PIL import Image
import base64
from io import BytesIO
import numpy as np
from diffusers import StableDiffusionPipeline
from diffusers import EulerAncestralDiscreteScheduler
@@ -25,6 +22,7 @@ DEFAULT_CAPTIONING_MODEL = 'Salesforce/blip-image-captioning-base'
DEFAULT_KEYPHRASE_MODEL = 'ml6team/keyphrase-extraction-distilbert-inspec'
DEFAULT_PROMPT_MODEL = 'FredZhang7/anime-anything-promptgen-v2'
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd']
DEFAULT_SUMMARIZE_PARAMS = {
'temperature': 1.0,
'repetition_penalty': 1.0,
@@ -59,6 +57,8 @@ parser.add_argument('--sd-model',
help="Load a custom SD image generation model")
parser.add_argument('--sd-cpu',
help="Force the SD pipeline to run on the CPU")
parser.add_argument('--enable-modules', nargs='*', default=[],
help="Override a list of enabled modules")
args = parser.parse_args()
@@ -70,73 +70,49 @@ captioning_model = args.captioning_model if args.captioning_model else DEFAULT_C
keyphrase_model = args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL
prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL
sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
modules = args.enable_modules if args.enable_modules and len(args.enable_modules) > 0 else ALL_MODULES
# Models init
device_string = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"
device = torch.device(device_string)
torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
print('Initializing BLIP image captioning model...')
blip_processor = AutoProcessor.from_pretrained(captioning_model)
blip = BlipForConditionalGeneration.from_pretrained(
captioning_model, torch_dtype=torch_dtype).to(device)
if 'caption' in modules:
print('Initializing BLIP image captioning model...')
blip_processor = AutoProcessor.from_pretrained(captioning_model)
blip = BlipForConditionalGeneration.from_pretrained(captioning_model, torch_dtype=torch_dtype).to(device)
print('Initializing BART text summarization model...')
bart_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
bart = BartForConditionalGeneration.from_pretrained(
summarization_model, torch_dtype=torch_dtype).to(device)
if 'summarize' in modules:
print('Initializing BART text summarization model...')
bart_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
bart = BartForConditionalGeneration.from_pretrained(summarization_model, torch_dtype=torch_dtype).to(device)
print('Initializing BERT sentiment classification model...')
bert_classifier = pipeline("text-classification", model=classification_model,
top_k=None, device=device, torch_dtype=torch_dtype)
if 'classify' in modules:
print('Initializing BERT sentiment classification model...')
bert_classifier = pipeline("text-classification", model=classification_model, top_k=None, device=device, torch_dtype=torch_dtype)
print('Initializing keyword extractor...')
if 'keywords' in modules:
print('Initializing keyword extractor...')
import pipelines as pipelines
keyphrase_pipe = pipelines.KeyphraseExtractionPipeline(keyphrase_model)
if 'prompt' in modules:
print('Initializing GPT prompt generator')
gpt_tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
gpt_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
gpt_model = GPT2LMHeadModel.from_pretrained(prompt_model)
prompt_generator = pipeline('text-generation', model=gpt_model, tokenizer=gpt_tokenizer)
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
def __init__(self, model, *args, **kwargs):
super().__init__(
model=AutoModelForTokenClassification.from_pretrained(model),
tokenizer=AutoTokenizer.from_pretrained(model),
*args,
**kwargs
)
def postprocess(self, model_outputs):
results = super().postprocess(
model_outputs=model_outputs,
aggregation_strategy=AggregationStrategy.SIMPLE
if self.model.config.model_type == "roberta"
else AggregationStrategy.FIRST,
)
return np.unique([result.get("word").strip() for result in results])
keyphrase_pipe = KeyphraseExtractionPipeline(keyphrase_model)
print('Initializing GPT prompt generator')
gpt_tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
gpt_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
gpt_model = GPT2LMHeadModel.from_pretrained(
'FredZhang7/anime-anything-promptgen-v2')
prompt_generator = pipeline(
'text-generation', model=gpt_model, tokenizer=gpt_tokenizer)
print('Initializing Stable Diffusion pipeline')
sd_device_string = "cuda" if torch.cuda.is_available() and not args.sd_cpu else "cpu"
sd_device = torch.device(sd_device_string)
sd_torch_dtype = torch.float32 if sd_device_string == "cpu" else torch.float16
sd_pipe = StableDiffusionPipeline.from_pretrained(
sd_model,
custom_pipeline="lpw_stable_diffusion",
torch_dtype=sd_torch_dtype,
).to(sd_device)
sd_pipe.safety_checker = lambda images, clip_input: (images, False)
sd_pipe.enable_attention_slicing()
# pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config)
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
sd_pipe.scheduler.config)
if 'sd' in modules:
print('Initializing Stable Diffusion pipeline')
sd_device_string = "cuda" if torch.cuda.is_available() and not args.sd_cpu else "cpu"
sd_device = torch.device(sd_device_string)
sd_torch_dtype = torch.float32 if sd_device_string == "cpu" else torch.float16
sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model, custom_pipeline="lpw_stable_diffusion", torch_dtype=sd_torch_dtype).to(sd_device)
sd_pipe.safety_checker = lambda images, clip_input: (images, False)
sd_pipe.enable_attention_slicing()
# pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config)
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
prompt_prefix = "best quality, absurdres, "
neg_prompt = """lowres, bad anatomy, error body, error hair, error arm,
@@ -264,6 +240,9 @@ def index():
@app.route('/api/caption', methods=['POST'])
def api_caption():
if 'caption' not in modules:
abort(403, 'Module is disabled by config')
data = request.get_json()
if not 'image' in data or not isinstance(data['image'], str):
@@ -276,6 +255,9 @@ def api_caption():
@app.route('/api/summarize', methods=['POST'])
def api_summarize():
if 'summarize' not in modules:
abort(403, 'Module is disabled by config')
data = request.get_json()
if not 'text' in data or not isinstance(data['text'], str):
@@ -292,6 +274,9 @@ def api_summarize():
@app.route('/api/classify', methods=['POST'])
def api_classify():
if 'classify' not in modules:
abort(403, 'Module is disabled by config')
data = request.get_json()
if not 'text' in data or not isinstance(data['text'], str):
@@ -303,6 +288,9 @@ def api_classify():
@app.route('/api/keywords', methods=['POST'])
def api_keywords():
if 'keywords' not in modules:
abort(403, 'Module is disabled by config')
data = request.get_json()
if not 'text' in data or not isinstance(data['text'], str):
@@ -314,6 +302,9 @@ def api_keywords():
@app.route('/api/prompt', methods=['POST'])
def api_prompt():
if 'prompt' not in modules:
abort(403, 'Module is disabled by config')
data = request.get_json()
if not 'text' in data or not isinstance(data['text'], str):
@@ -330,6 +321,9 @@ def api_prompt():
@app.route('/api/image', methods=['POST'])
def api_image():
if 'sd' not in modules:
abort(403, 'Module is disabled by config')
data = request.get_json()
if not 'prompt' in data or not isinstance(data['prompt'], str):