Use pipeline for image captioning

This commit is contained in:
Cohee
2023-12-20 01:28:41 +02:00
parent 7ca92eaeac
commit 423f51e3f8

View File

@@ -13,9 +13,8 @@ from flask_cors import CORS
from flask_compress import Compress
import markdown
import argparse
from transformers import AutoTokenizer, AutoProcessor, pipeline
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
from transformers import BlipForConditionalGeneration
from transformers import AutoTokenizer, pipeline
from transformers import AutoModelForSeq2SeqLM
import unicodedata
import torch
import time
@@ -204,15 +203,7 @@ if "talkinghead" in modules:
if "caption" in modules:
print("Initializing an image captioning model...")
captioning_processor = AutoProcessor.from_pretrained(captioning_model)
if "blip" in captioning_model:
captioning_transformer = BlipForConditionalGeneration.from_pretrained(
captioning_model, torch_dtype=torch_dtype
).to(device)
else:
captioning_transformer = AutoModelForCausalLM.from_pretrained(
captioning_model, torch_dtype=torch_dtype
).to(device)
captioning_pipeline = pipeline('image-to-text', model=captioning_model, device=device_string, torch_dtype=torch_dtype)
if "summarize" in modules:
print("Initializing a text summarization model...")
@@ -441,12 +432,8 @@ def classify_text(text: str) -> list:
return classify_module.classify_text_emotion(text)
def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
inputs = captioning_processor(raw_image.convert("RGB"), return_tensors="pt").to(
device, torch_dtype
)
outputs = captioning_transformer.generate(**inputs, max_new_tokens=max_new_tokens)
caption = captioning_processor.decode(outputs[0], skip_special_tokens=True)
def caption_image(raw_image: Image) -> str:
caption = captioning_pipeline(raw_image.convert("RGB"))[0]['generated_text']
return caption