mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-27 09:52:03 +00:00
Use pipeline for image captioning
This commit is contained in:
23
server.py
23
server.py
@@ -13,9 +13,8 @@ from flask_cors import CORS
|
||||
from flask_compress import Compress
|
||||
import markdown
|
||||
import argparse
|
||||
from transformers import AutoTokenizer, AutoProcessor, pipeline
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
||||
from transformers import BlipForConditionalGeneration
|
||||
from transformers import AutoTokenizer, pipeline
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
import unicodedata
|
||||
import torch
|
||||
import time
|
||||
@@ -204,15 +203,7 @@ if "talkinghead" in modules:
|
||||
|
||||
if "caption" in modules:
|
||||
print("Initializing an image captioning model...")
|
||||
captioning_processor = AutoProcessor.from_pretrained(captioning_model)
|
||||
if "blip" in captioning_model:
|
||||
captioning_transformer = BlipForConditionalGeneration.from_pretrained(
|
||||
captioning_model, torch_dtype=torch_dtype
|
||||
).to(device)
|
||||
else:
|
||||
captioning_transformer = AutoModelForCausalLM.from_pretrained(
|
||||
captioning_model, torch_dtype=torch_dtype
|
||||
).to(device)
|
||||
captioning_pipeline = pipeline('image-to-text', model=captioning_model, device=device_string, torch_dtype=torch_dtype)
|
||||
|
||||
if "summarize" in modules:
|
||||
print("Initializing a text summarization model...")
|
||||
@@ -441,12 +432,8 @@ def classify_text(text: str) -> list:
|
||||
return classify_module.classify_text_emotion(text)
|
||||
|
||||
|
||||
def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
|
||||
inputs = captioning_processor(raw_image.convert("RGB"), return_tensors="pt").to(
|
||||
device, torch_dtype
|
||||
)
|
||||
outputs = captioning_transformer.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
caption = captioning_processor.decode(outputs[0], skip_special_tokens=True)
|
||||
def caption_image(raw_image: Image) -> str:
|
||||
caption = captioning_pipeline(raw_image.convert("RGB"))[0]['generated_text']
|
||||
return caption
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user