Files
exllamav2/examples/multimodal_grounding_qwen.py
2025-01-29 22:41:36 +01:00

560 lines
19 KiB
Python

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import sys
from PyQt5.QtWidgets import (
QApplication,
QMainWindow,
QVBoxLayout,
QHBoxLayout,
QLabel,
QPushButton,
QTextEdit,
QFileDialog,
QWidget,
QMessageBox,
)
from PyQt5.QtGui import (
QPixmap,
QImage,
QTextCursor,
QPainter,
QPen,
QFont,
QColor,
QRegion,
)
from PyQt5.QtCore import Qt, pyqtSignal, QRect, QBuffer
from PIL import Image
import requests
from io import BytesIO
import time
import pprint
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Tokenizer,
ExLlamaV2VisionTower,
)
from exllamav2.generator import (
ExLlamaV2DynamicGenerator,
ExLlamaV2DynamicJob,
ExLlamaV2Sampler,
)
class Model:
current_image: Image or None = None
current_description: str
def __init__(self, model_directory, bbox_mode: str):
self.model_directory = model_directory
self.config = None
self.vision_model = None
self.model = None
self.cache = None
self.tokenizer = None
self.current_image = None
self.current_emb = None
self.current_description = ""
bbox_funcs = {
"qwen2": self.get_grounding_bb_qwen2,
"qwen25": self.get_grounding_bb_qwen25,
}
self.bbox_func = bbox_funcs[bbox_mode]
def load(self):
"""Load and initialize the things"""
self.config = ExLlamaV2Config(self.model_directory)
self.config.max_seq_len = 8192
self.vision_model = ExLlamaV2VisionTower(self.config)
self.vision_model.load(progress = True)
self.model = ExLlamaV2(self.config)
self.cache = ExLlamaV2Cache(self.model, lazy = True, max_seq_len = 32768)
self.model.load_autosplit(self.cache, progress = True)
self.tokenizer = ExLlamaV2Tokenizer(self.config)
self.generator = ExLlamaV2DynamicGenerator(
model = self.model,
cache = self.cache,
tokenizer = self.tokenizer,
)
def set_image(self, image: Image):
w, h = image.size
print(f"New image: {w} x {h} pixels")
self.current_image = image
self.current_description = ""
def get_prompt(self):
prompt = (
"<|im_start|>system\n" +
"You are a helpful assistant.<|im_end|>\n" +
"<|im_start|>user\n" +
self.current_emb.text_alias +
"Describe the image in detail." +
"\n" +
"<|im_start|>assistant\n"
)
return prompt
def inference(self, settext_fn, update_fn):
"""Run inference on the current image, stream results"""
if self.current_image is None:
settext_fn("No image loaded.")
return
settext_fn("")
update_fn()
self.current_emb = self.vision_model.get_image_embeddings(
model = self.model,
tokenizer = self.tokenizer,
image = self.current_image,
)
prompt = self.get_prompt()
input_ids = self.tokenizer.encode(
prompt,
add_bos = True,
encode_special_tokens = True,
embeddings = [self.current_emb],
)
job = ExLlamaV2DynamicJob(
input_ids = input_ids,
max_new_tokens = 1000,
decode_special_tokens = True,
stop_conditions = [self.tokenizer.eos_token_id],
gen_settings = ExLlamaV2Sampler.Settings.greedy(),
embeddings = [self.current_emb],
)
self.generator.enqueue(job)
text = ""
lastupdate = time.time()
while self.generator.num_remaining_jobs():
results = self.generator.iterate()
for result in results:
text += result.get("text", "")
# Update at max 30 fps
if time.time() - lastupdate > (1/30):
lastupdate = time.time()
settext_fn(text)
update_fn()
#
# text = \
# """And you may find yourself living in a shotgun shack
# And you may find yourself in another part of the world
# And you may find yourself behind the wheel of a large automobile
# And you may find yourself in a beautiful house, with a beautiful wife
# And you may ask yourself, "Well, how did I get here?\""""
settext_fn(text)
update_fn()
self.current_description = text
print("Image description from model:")
print(text)
def get_grounding_bb_qwen2(self, start, end) -> tuple:
"""
Prompt the model again and try to extraxt the bounding box of the image details indicated by selected portion
of the description. We do this by repeating the exact same prompt up to and including the selected text, but
enclosed in the special tokens that Qwen would emit when prompted for grounding. Qwen is then strongly biased
towards completing the bounding box.
Since we're using the same description as the model originally generated, all keys/values for the system
prompt, image and generated description up to the selection will be reused from the cache.
"""
if start >= end:
return None, None
# Including leading space
if start > 0 and self.current_description[start - 1] == " ":
start -= 1
# Repeat the same prompt up to the selection, with grounding tokens added
prompt = self.get_prompt()
prompt += self.current_description[:start]
prompt += "<|object_ref_start|>"
prompt += self.current_description[start:end]
prompt += "<|object_ref_end|><|box_start|>("
bb_string, res = self.generator.generate(
prompt = prompt,
add_bos = True,
max_new_tokens = 25,
stop_conditions = [self.tokenizer.single_id("<|box_end|>")],
gen_settings = ExLlamaV2Sampler.Settings.greedy(),
embeddings = [self.current_emb],
completion_only = True,
return_last_results = True, # debug purposes
)
bb_string = "(" + bb_string
print(f"Generation: {bb_string}")
pprint.pprint(res, indent = 4)
# BB string is in the format "(x0,y0),(x1,y1)" with integer coordinates normalized to a range of 1000x1000
try:
parts = bb_string.strip("()").split("),(")
a = tuple(map(int, parts[0].split(",")))
b = tuple(map(int, parts[1].split(",")))
a = (a[0] / 1000.0, a[1] / 1000.0)
b = (b[0] / 1000.0, b[1] / 1000.0)
except:
print("No bounding box could be determined")
a, b = None, None
return a, b
def get_grounding_bb_qwen25(self, start, end) -> tuple:
"""
Qwen2.5 works the same way, except the coordinates are no longer normalized and the format is:
"(x0,y0,x1,y1)"
"""
if start >= end:
return None, None
# Including leading space
if start > 0 and self.current_description[start - 1] == " ":
start -= 1
# Repeat the same prompt up to the selection, with grounding tokens added
prompt = self.get_prompt()
prompt += self.current_description[:start]
prompt += "<|object_ref_start|>"
prompt += self.current_description[start:end]
prompt += "<|object_ref_end|><|box_start|>("
bb_string, res = self.generator.generate(
prompt = prompt,
add_bos = True,
max_new_tokens = 28,
stop_conditions = [self.tokenizer.single_id("<|box_end|>")],
gen_settings = ExLlamaV2Sampler.Settings.greedy(),
embeddings = [self.current_emb],
completion_only = True,
return_last_results = True, # debug purposes
)
bb_string = "(" + bb_string
print(f"Generation: {bb_string}")
pprint.pprint(res, indent = 4)
# BB string is in the format "(x0,y0,x1,y1)" with integer coordinates
s = self.current_image.size
try:
d = tuple(map(int, bb_string.strip("()").split(",")))
a = (d[0] / s[0], d[1] / s[1])
b = (d[2] / s[0], d[3] / s[1])
except:
print("No bounding box could be determined")
a, b = None, None
return a, b
class GroundingDemo(QMainWindow):
model: Model
class CustomTextEdit(QTextEdit):
"""Custom QTextEdit that emits a signal when a selection is completed."""
selection_complete = pyqtSignal(tuple)
def mouseReleaseEvent(self, event):
"""Handle mouse release and emit the selection complete signal."""
super().mouseReleaseEvent(event)
cursor = self.textCursor()
if cursor.hasSelection():
# Start with the selected range
start = cursor.selectionStart()
end = cursor.selectionEnd()
# Move to the start of the selection and expand to the start of the word
cursor.setPosition(start)
cursor.movePosition(QTextCursor.StartOfWord, QTextCursor.MoveAnchor)
expanded_start = cursor.position()
# Move to the end of the selection and expand to the end of the word
cursor.setPosition(end)
cursor.movePosition(QTextCursor.EndOfWord, QTextCursor.MoveAnchor)
expanded_end = cursor.position()
# Update the selection
cursor.setPosition(expanded_start, QTextCursor.MoveAnchor)
cursor.setPosition(expanded_end, QTextCursor.KeepAnchor)
self.setTextCursor(cursor) # Update the visible selection
# Emit the expanded selection range
self.selection_complete.emit((expanded_start, expanded_end))
class CustomQLabel(QLabel):
def __init__(self, parent, callback):
super().__init__(parent)
self.setAcceptDrops(True)
self.callback = callback
self.bounding_box = None
self.scale = (1, 1)
def setEnabled(self, enabled):
"""Override setEnabled to prevent grayscaling."""
super().setEnabled(True)
def set_bounding_box(self, a, b):
"""Set the bounding box to be drawn."""
if a is None:
self.clear_bounding_box()
return
inner_rect = self.contentsRect()
w, h = inner_rect.width(), inner_rect.height()
iw, ih = self.scale
x1, y1 = a
x2, y2 = b
x1 = int(x1 * iw + (w - iw) / 2)
y1 = int(y1 * ih + (h - ih) / 2)
x2 = int(x2 * iw + (w - iw) / 2)
y2 = int(y2 * ih + (h - ih) / 2)
self.bounding_box = QRect(x1, y1, x2 - x1, y2 - y1)
self.update()
def clear_bounding_box(self):
"""Clear the bounding box."""
self.bounding_box = None
self.update()
def paintEvent(self, event):
"""Override paintEvent to draw the bounding box."""
super().paintEvent(event)
if self.bounding_box:
painter = QPainter(self)
overlay_color = QColor(64, 64, 64, 150)
painter.setBrush(overlay_color)
painter.setPen(Qt.NoPen)
full_region = QRegion(self.rect())
exclude_region = QRegion(self.bounding_box)
clip_region = full_region.subtracted(exclude_region)
painter.setClipRegion(clip_region)
painter.drawRect(self.rect())
pen = QPen(Qt.white, 2)
painter.setPen(pen)
painter.setBrush(overlay_color)
painter.drawRect(self.bounding_box)
def dragEnterEvent(self, event):
"""Handle drag enter events."""
if event.mimeData().hasUrls():
event.accept()
else:
event.ignore()
def dropEvent(self, event):
"""Handle drop events."""
if event.mimeData().hasUrls():
# Get the first file path or URL
urls = event.mimeData().urls()
url = urls[0]
file_path_or_url = url.toString()
if not file_path_or_url.startswith(("http://", "https://")):
file_path_or_url = url.toLocalFile()
self.callback(file_path_or_url) # Pass the local file path to the callback
else:
event.ignore()
def __init__(self, model: Model, title: str):
super().__init__()
self.model = model
self.no_events_plz = False
self.setWindowTitle(f"Grounding Demo - {title}")
self.setGeometry(100, 100, 1000, 1000)
font = QFont()
font.setPointSize(11) # Set the font size to 16 points
# Main layout
self.central_widget = QWidget()
self.setCentralWidget(self.central_widget)
main_layout = QVBoxLayout(self.central_widget)
# Image display
self.image_label = self.CustomQLabel(self, self.load_dropped_image)
self.image_label.setText("Image goes here")
self.image_label.setAlignment(Qt.AlignCenter)
self.image_label.setStyleSheet("background-color: #404040; color: white;")
self.image_label.setFont(font)
main_layout.addWidget(self.image_label, stretch = 5)
# Button row
button_layout = QHBoxLayout()
main_layout.addLayout(button_layout)
self.paste_button = QPushButton("Paste Image")
self.paste_button.clicked.connect(self.paste_image)
self.paste_button.setFont(font)
button_layout.addWidget(self.paste_button)
self.load_button = QPushButton("Load Image")
self.load_button.clicked.connect(self.load_image)
self.load_button.setFont(font)
button_layout.addWidget(self.load_button)
self.inference_button = QPushButton("Inference")
self.inference_button.clicked.connect(self.run_inference)
self.inference_button.setFont(font)
button_layout.addWidget(self.inference_button)
# Model output
self.output_label = QLabel("Model Output:", self)
self.output_label.setStyleSheet("color: white;")
self.output_label.setFont(font)
main_layout.addWidget(self.output_label)
self.output_text = self.CustomTextEdit(self)
self.output_text.setReadOnly(True)
self.output_text.setStyleSheet("background-color: #3C3C3C; color: white;")
self.output_text.setFont(font)
main_layout.addWidget(self.output_text, stretch = 2)
self.output_text.selection_complete.connect(self.on_selection_made)
self.previous_selection = ""
# Set dark theme
self.setStyleSheet("background-color: #2E2E2E; color: white;")
def load_dropped_image(self, file_path_or_url):
"""Load an image when it is dropped on the image label."""
try:
if file_path_or_url.startswith(("http://", "https://")):
# Handle web URL
response = requests.get(file_path_or_url)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
else:
image = Image.open(file_path_or_url)
self.display_image(image, from_pil=True)
except Exception as e:
QMessageBox.critical(self, "Error", f"Failed to load image: {e}")
def paste_image(self):
"""Paste an image from the clipboard."""
clipboard = QApplication.clipboard()
mime_data = clipboard.mimeData()
if mime_data.hasImage():
qt_image = clipboard.image()
self.display_image(qt_image, from_pil = False)
else:
QMessageBox.warning(self, "Error", "No image found in clipboard.")
def load_image(self):
"""Open a file dialog to load an image."""
file_path, _ = QFileDialog.getOpenFileName(
self, "Open Image", "", "Image Files (*.png *.jpg *.jpeg *.bmp *.gif *.tiff)"
)
if file_path:
try:
image = Image.open(file_path)
self.display_image(image, from_pil = True)
except Exception as e:
QMessageBox.critical(self, "Error", f"Failed to load image: {e}")
def run_inference(self):
try:
self.no_events_plz = True
self.paste_button.setEnabled(False)
self.load_button.setEnabled(False)
self.inference_button.setEnabled(False)
self.output_text.setText("")
self.model.inference(self.output_text.setText, QApplication.processEvents)
finally:
self.no_events_plz = False
self.paste_button.setEnabled(True)
self.load_button.setEnabled(True)
self.inference_button.setEnabled(True)
def display_image(self, image, from_pil):
# Convert and display the image
if self.no_events_plz:
return
if from_pil:
self.model.set_image(image)
image = image.convert("RGBA")
data = image.tobytes("raw", "RGBA")
q_image = QImage(data, image.width, image.height, QImage.Format_RGBA8888)
else:
# If the image comes from a QImage (e.g., clipboard), convert to PIL
self.model.set_image(self.qimage_to_pil(image))
q_image = image
pixmap = QPixmap.fromImage(q_image)
scaled_pixmap = pixmap.scaled(self.image_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation)
scaled_width = scaled_pixmap.width()
scaled_height = scaled_pixmap.height()
self.image_label.setPixmap(scaled_pixmap)
self.image_label.clear_bounding_box()
self.image_label.scale = (scaled_width, scaled_height)
self.output_text.setText("")
def qimage_to_pil(self, q_image):
"""Convert a QImage to a PIL Image."""
buffer = QBuffer()
buffer.open(QBuffer.ReadWrite)
q_image.save(buffer, "PNG")
pil_image = Image.open(BytesIO(buffer.data()))
return pil_image
def on_selection_made(self, pos):
"""Callback for when a selection is made."""
if self.no_events_plz:
return
start, end = pos
# start, end = model.expand_selection(start, end)
print(f"Selected span: {start}, {end}")
print(f"Selected text: {repr(self.model.current_description[start:end])}")
a, b = self.model.bbox_func(start, end)
self.image_label.set_bounding_box(a, b)
# Qwen2-VL 7B:
# https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct
# https://huggingface.co/turboderp/Qwen2-VL-7B-Instruct-exl2
def main():
# model_dir = "/mnt/str/models/qwen2-vl-7b-instruct-exl2/6.0bpw"
# bbox_mode = "qwen25"
model_dir = "/mnt/str/models/qwen2.5-vl-7b-instruct-exl2/6.0bpw"
bbox_mode = "qwen25"
app = QApplication(sys.argv)
model = Model(model_dir, bbox_mode)
model.load()
window = GroundingDemo(model, model_dir)
window.show()
sys.exit(app.exec_())
if __name__ == "__main__":
main()