mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-24 08:28:59 +00:00
Use generic model names where possible
This commit is contained in:
81
server.py
81
server.py
@@ -3,8 +3,8 @@ from flask_cors import CORS
|
||||
import markdown
|
||||
import argparse
|
||||
from transformers import AutoTokenizer, AutoProcessor, pipeline
|
||||
from transformers import BlipForConditionalGeneration, BartForConditionalGeneration
|
||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
||||
from transformers import BlipForConditionalGeneration, GPT2Tokenizer
|
||||
import unicodedata
|
||||
import torch
|
||||
import time
|
||||
@@ -22,6 +22,7 @@ from diffusers import EulerAncestralDiscreteScheduler
|
||||
# Also try: 'Qiliang/bart-large-cnn-samsum-ElectrifAi_v10'
|
||||
DEFAULT_SUMMARIZATION_MODEL = 'Qiliang/bart-large-cnn-samsum-ChatGPT_v3'
|
||||
DEFAULT_CLASSIFICATION_MODEL = 'bhadresh-savani/distilbert-base-uncased-emotion'
|
||||
# Also try: 'Salesforce/blip-image-captioning-large' or 'microsoft/git-large-r-textcaps'
|
||||
DEFAULT_CAPTIONING_MODEL = 'Salesforce/blip-image-captioning-base'
|
||||
DEFAULT_KEYPHRASE_MODEL = 'ml6team/keyphrase-extraction-distilbert-inspec'
|
||||
DEFAULT_PROMPT_MODEL = 'FredZhang7/anime-anything-promptgen-v2'
|
||||
@@ -42,21 +43,21 @@ parser = argparse.ArgumentParser(
|
||||
parser.add_argument('--port', type=int,
|
||||
help="Specify the port on which the application is hosted")
|
||||
parser.add_argument('--listen', action='store_true',
|
||||
help="Hosts the app on the local network")
|
||||
help="Host the app on the local network")
|
||||
parser.add_argument('--share', action='store_true',
|
||||
help="Shares the app on CloudFlare tunnel")
|
||||
help="Share the app on CloudFlare tunnel")
|
||||
parser.add_argument('--cpu', action='store_true',
|
||||
help="Runs the models on the CPU")
|
||||
help="Run the models on the CPU")
|
||||
parser.add_argument('--summarization-model',
|
||||
help="Load a custom BART summarization model")
|
||||
help="Load a custom summarization model")
|
||||
parser.add_argument('--classification-model',
|
||||
help="Load a custom BERT text classification model")
|
||||
help="Load a custom text classification model")
|
||||
parser.add_argument('--captioning-model',
|
||||
help="Load a custom BLIP captioning model")
|
||||
help="Load a custom captioning model")
|
||||
parser.add_argument('--keyphrase-model',
|
||||
help="Load a custom keyphrase extraction model")
|
||||
parser.add_argument('--prompt-model',
|
||||
help="Load a custom GPT-2 prompt generation model")
|
||||
help="Load a custom prompt generation model")
|
||||
parser.add_argument('--sd-model',
|
||||
help="Load a custom SD image generation model")
|
||||
parser.add_argument('--sd-cpu',
|
||||
@@ -77,34 +78,37 @@ sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
|
||||
modules = args.enable_modules if args.enable_modules and len(args.enable_modules) > 0 else ALL_MODULES
|
||||
|
||||
# Models init
|
||||
device_string = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"
|
||||
device_string = "cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu"
|
||||
device = torch.device(device_string)
|
||||
torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
|
||||
|
||||
if 'caption' in modules:
|
||||
print('Initializing BLIP image captioning model...')
|
||||
blip_processor = AutoProcessor.from_pretrained(captioning_model)
|
||||
blip = BlipForConditionalGeneration.from_pretrained(captioning_model, torch_dtype=torch_dtype).to(device)
|
||||
print('Initializing an image captioning model...')
|
||||
captioning_processor = AutoProcessor.from_pretrained(captioning_model)
|
||||
if 'blip' in captioning_model:
|
||||
captioning_transformer = BlipForConditionalGeneration.from_pretrained(captioning_model, torch_dtype=torch_dtype).to(device)
|
||||
else:
|
||||
captioning_transformer = AutoModelForCausalLM.from_pretrained(captioning_model, torch_dtype=torch_dtype).to(device)
|
||||
|
||||
if 'summarize' in modules:
|
||||
print('Initializing BART text summarization model...')
|
||||
bart_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
|
||||
bart = BartForConditionalGeneration.from_pretrained(summarization_model, torch_dtype=torch_dtype).to(device)
|
||||
print('Initializing a text summarization model...')
|
||||
summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
|
||||
summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained(summarization_model, torch_dtype=torch_dtype).to(device)
|
||||
|
||||
if 'classify' in modules:
|
||||
print('Initializing BERT sentiment classification model...')
|
||||
bert_classifier = pipeline("text-classification", model=classification_model, top_k=None, device=device, torch_dtype=torch_dtype)
|
||||
print('Initializing a sentiment classification pipeline...')
|
||||
classification_pipe = pipeline("text-classification", model=classification_model, top_k=None, device=device, torch_dtype=torch_dtype)
|
||||
|
||||
if 'keywords' in modules:
|
||||
print('Initializing keyword extractor...')
|
||||
print('Initializing a keyword extraction pipeline...')
|
||||
import pipelines as pipelines
|
||||
keyphrase_pipe = pipelines.KeyphraseExtractionPipeline(keyphrase_model)
|
||||
|
||||
if 'prompt' in modules:
|
||||
print('Initializing GPT prompt generator')
|
||||
print('Initializing a prompt generator')
|
||||
gpt_tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
|
||||
gpt_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
gpt_model = GPT2LMHeadModel.from_pretrained(prompt_model)
|
||||
gpt_model = AutoModelForCausalLM.from_pretrained(prompt_model)
|
||||
prompt_generator = pipeline('text-generation', model=gpt_model, tokenizer=gpt_tokenizer)
|
||||
|
||||
if 'sd' in modules:
|
||||
@@ -157,28 +161,27 @@ def load_extensions():
|
||||
|
||||
# AI stuff
|
||||
def classify_text(text: str) -> list:
|
||||
output = bert_classifier(text)[0]
|
||||
output = classification_pipe(text)[0]
|
||||
return sorted(output, key=lambda x: x['score'], reverse=True)
|
||||
|
||||
|
||||
def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
|
||||
inputs = blip_processor(raw_image.convert(
|
||||
'RGB'), return_tensors="pt").to(device, torch_dtype)
|
||||
outputs = blip.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
caption = blip_processor.decode(outputs[0], skip_special_tokens=True)
|
||||
inputs = captioning_processor(raw_image.convert('RGB'), return_tensors="pt").to(device, torch_dtype)
|
||||
outputs = captioning_transformer.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
caption = captioning_processor.decode(outputs[0], skip_special_tokens=True)
|
||||
return caption
|
||||
|
||||
|
||||
def summarize(text: str, params: dict) -> str:
|
||||
# Tokenize input
|
||||
inputs = bart_tokenizer(text, return_tensors="pt")
|
||||
inputs = summarization_tokenizer(text, return_tensors="pt").to(device)
|
||||
token_count = len(inputs[0])
|
||||
|
||||
bad_words_ids = [
|
||||
bart_tokenizer(bad_word, add_special_tokens=True).input_ids
|
||||
summarization_tokenizer(bad_word, add_special_tokens=True).input_ids
|
||||
for bad_word in params['bad_words']
|
||||
]
|
||||
summary_ids = bart.generate(
|
||||
summary_ids = summarization_transformer.generate(
|
||||
inputs["input_ids"],
|
||||
num_beams=2,
|
||||
min_length=min(token_count, int(params['min_length'])),
|
||||
@@ -188,7 +191,7 @@ def summarize(text: str, params: dict) -> str:
|
||||
length_penalty=float(params['length_penalty']),
|
||||
bad_words_ids=bad_words_ids,
|
||||
)
|
||||
summary = bart_tokenizer.batch_decode(
|
||||
summary = summarization_tokenizer.batch_decode(
|
||||
summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)[0]
|
||||
summary = normalize_string(summary)
|
||||
@@ -283,7 +286,7 @@ def get_asset(name: str, asset: str):
|
||||
extension = [element for element in extensions if element['name'] == name]
|
||||
if len(extension) == 0:
|
||||
abort(404)
|
||||
if not asset in extension[0]['metadata']['assets']:
|
||||
if asset not in extension[0]['metadata']['assets']:
|
||||
abort(404)
|
||||
return send_from_directory(os.path.join('./extensions', extension[0]['name'], 'assets'), asset)
|
||||
|
||||
@@ -295,7 +298,7 @@ def api_caption():
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if not 'image' in data or not isinstance(data['image'], str):
|
||||
if 'image' not in data or not isinstance(data['image'], str):
|
||||
abort(400, '"image" is required')
|
||||
|
||||
image = Image.open(BytesIO(base64.b64decode(data['image'])))
|
||||
@@ -310,7 +313,7 @@ def api_summarize():
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if not 'text' in data or not isinstance(data['text'], str):
|
||||
if 'text' not in data or not isinstance(data['text'], str):
|
||||
abort(400, '"text" is required')
|
||||
|
||||
params = DEFAULT_SUMMARIZE_PARAMS.copy()
|
||||
@@ -318,7 +321,7 @@ def api_summarize():
|
||||
if 'params' in data and isinstance(data['params'], dict):
|
||||
params.update(data['params'])
|
||||
|
||||
summary = summarize(data['text'], params)[0]
|
||||
summary = summarize(data['text'], params)
|
||||
return jsonify({'summary': summary})
|
||||
|
||||
|
||||
@@ -329,7 +332,7 @@ def api_classify():
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if not 'text' in data or not isinstance(data['text'], str):
|
||||
if 'text' not in data or not isinstance(data['text'], str):
|
||||
abort(400, '"text" is required')
|
||||
|
||||
classification = classify_text(data['text'])
|
||||
@@ -343,7 +346,7 @@ def api_keywords():
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if not 'text' in data or not isinstance(data['text'], str):
|
||||
if 'text' not in data or not isinstance(data['text'], str):
|
||||
abort(400, '"text" is required')
|
||||
|
||||
keywords = extract_keywords(data['text'])
|
||||
@@ -357,12 +360,12 @@ def api_prompt():
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if not 'text' in data or not isinstance(data['text'], str):
|
||||
if 'text' not in data or not isinstance(data['text'], str):
|
||||
abort(400, '"text" is required')
|
||||
|
||||
keywords = extract_keywords(data['text'])
|
||||
|
||||
if 'name' in data or isinstance(data['name'], str):
|
||||
if 'name' in data and isinstance(data['name'], str):
|
||||
keywords.insert(0, data['name'])
|
||||
|
||||
prompts = generate_prompt(keywords)
|
||||
@@ -376,7 +379,7 @@ def api_image():
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
if not 'prompt' in data or not isinstance(data['prompt'], str):
|
||||
if 'prompt' not in data or not isinstance(data['prompt'], str):
|
||||
abort(400, '"prompt" is required')
|
||||
|
||||
image = generate_image(data['prompt'])
|
||||
|
||||
Reference in New Issue
Block a user