Add image search example

This commit is contained in:
turboderp
2025-06-06 16:32:25 +02:00
parent 7989c93b6a
commit 7ea9559837
2 changed files with 221 additions and 0 deletions

130
examples/imgsearch.py Normal file
View File

@@ -0,0 +1,130 @@
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import argparse
from exllamav3 import Model, model_init
from exllamav3.util.progress import ProgressBar
from PIL import Image
import glob
from pathlib import (Path)
from imgsearch_gallery import gallery
"""
Small example of using a VLM to query multiple images with the same yes/no question and collect all the positive
results. Note that only Gemma3 currently benefits from batch sizes higher than 1. The model is only sampled for one
token per image and only the relative probabilities of 'Yes' and 'No' are considered. A more advanced pipeline is
probably required to take full advantage of most models.
Example:
python examples/imgsearch.py /path/to/images -m /path/to/vlm -v -p "Are there any cats in this image?"
"""
def resolve_files(input_path):
input_path = Path(input_path)
if input_path.is_dir():
return [str(p) for p in input_path.rglob("*") if p.is_file()]
elif input_path.is_file():
return [str(input_path)]
else:
return [str(p) for p in glob.glob(str(input_path), recursive = True) if Path(p).is_file()]
def get_token_mask(tokenizer, substr):
vocab = tokenizer.get_id_to_piece_list()
substr1 = substr.upper()
substr2 = substr.upper()
mask = torch.tensor([(token.upper().startswith(substr1) or token.upper().startswith(substr2)) for token in vocab])
return mask
@torch.inference_mode()
def main(args):
# Resolve filenames
input_files = []
for arg in args.input:
input_files += resolve_files(arg)
# Prepare model etc.
model, config, _, tokenizer = model_init.init(args)
# Load the image component model
vision_model = Model.from_config(config, component = "vision")
vision_model.load(progressbar = True)
batchsize = args.batchsize
if batchsize > 1 and not vision_model.caps.get("fixed_size_image_embeddings"):
print(" !! Cannot do batched image ingestion with this model, falling back to batch size 1")
batchsize = 1
# Output masks
yes_mask = get_token_mask(tokenizer, "yes")
no_mask = get_token_mask(tokenizer, "no")
vocab_size = tokenizer.actual_vocab_size # To account for padded logits
# Results
skipped_files = 0
total_files = 0
all_matches = []
# Process images
with ProgressBar(" -- Inference", count = len(input_files)) as pb:
idx = 0
batch_files = []
batch_images = []
while idx < len(input_files):
try:
img = Image.open(input_files[idx])
batch_files.append(input_files[idx])
batch_images.append(img)
total_files += 1
except (IOError, SyntaxError):
# Skip non-image files and ignore other errors
skipped_files += 1
idx += 1
if len(batch_images) == batchsize or idx == len(input_files):
batch_embed = vision_model.get_image_embeddings(tokenizer, batch_images)
batch_prompt = [
model.default_chat_prompt(f"{be.text_alias}\n{args.prompt.strip()}")
for be in batch_embed
]
input_ids = tokenizer.encode(batch_prompt, embeddings = batch_embed)
params = {
"last_tokens_only": 1,
"indexed_embeddings": batch_embed
}
logits = model.forward(input_ids, params)[:, :, :vocab_size]
probs = logits.softmax(dim = -1)
probs = probs.cpu()
yes = torch.sum(probs * yes_mask, dim = -1)
no = torch.sum(probs * no_mask, dim = -1)
for m, filename in enumerate(batch_files):
if (not args.no and yes[m] > no[m]) or (args.no and yes[m] < no[m]):
print(f" -- Match: {filename}")
all_matches.append(filename)
batch_files = []
batch_images = []
pb.update(idx)
# Results
print(f" -- Total files checked: {total_files:,}")
print(f" -- Skipped files: {skipped_files:,}")
if args.view:
gallery(all_matches, args.prompt + (" (No)" if args.no else " (Yes)"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
model_init.add_args(parser, cache = False)
parser.add_argument("-p", "--prompt", type = str, help = "Per-image prompt (yes/no question)", required = True)
parser.add_argument("-n", "--no", action = "store_true", help = "Match images on 'No' instead of 'Yes'")
parser.add_argument("-bsz", "--batchsize", type = int, help = "Batch size", default = 1)
parser.add_argument("-v", "--view", action = "store_true", help = "View results after search")
parser.add_argument("input", nargs = "+", type = str, help = "Input files")
_args = parser.parse_args()
main(_args)

View File

@@ -0,0 +1,91 @@
"""
Quick & dirty image gallery using TKInter and Pillow
"""
THUMBNAIL_SIZE = 256
THUMBNAIL_PADDING = 10
BACKGROUND = "#202020"
HOVER_BG = "#aaaaaa"
THUMB_TOTAL_WIDTH = THUMBNAIL_SIZE + THUMBNAIL_PADDING
SCREEN_PADDING = 50
DEFAULT_SIZE = "1094x720"
KWSTYLES = {"borderwidth": 0, "highlightthickness": 0, "bg": BACKGROUND}
def gallery(image_paths: list[str], title: str):
import tkinter as tk
from tkinter import Toplevel
from PIL import Image, ImageTk, ImageOps
def make_thumbnail(path):
img = Image.open(path)
thumb_img = ImageOps.pad(img, (THUMBNAIL_SIZE, THUMBNAIL_SIZE), color = BACKGROUND)
return ImageTk.PhotoImage(thumb_img)
def show_full_image(path):
top = Toplevel()
top.bind("<Escape>", lambda e: top.destroy())
top.bind("<Button-1>", lambda e: top.destroy())
top.title(path)
screen_w, screen_h = top.winfo_screenwidth(), top.winfo_screenheight()
max_w = screen_w - SCREEN_PADDING * 2
max_h = screen_h - SCREEN_PADDING * 2
img = Image.open(path)
img_w, img_h = img.size
scale = min(max_w / img_w, max_h / img_h, 1.0)
if scale < 1.0:
img = img.resize((int(img_w * scale), int(img_h * scale)), Image.LANCZOS)
tk_img = ImageTk.PhotoImage(img)
lbl = tk.Label(top, image = tk_img, **KWSTYLES)
lbl.image = tk_img
lbl.pack()
def draw_grid(columns):
for widget in scrollable_frame.winfo_children():
widget.destroy()
for i, (thumb, path) in enumerate(thumbnails):
label = tk.Label(scrollable_frame, image = thumb, bg = BACKGROUND, cursor = "hand2")
label.image = thumb
label.bind("<Button-1>", lambda e, p = path: show_full_image(p))
label.grid(row = i // columns, column = i % columns, padx = 5, pady = 5)
label.bind("<Enter>", lambda e, lbl=label: lbl.configure(bg = HOVER_BG))
label.bind("<Leave>", lambda e, lbl=label: lbl.configure(bg = BACKGROUND))
def on_resize(event):
columns = max(1, (event.width - THUMBNAIL_PADDING) // THUMB_TOTAL_WIDTH)
draw_grid(columns)
def on_mousewheel(event):
if event.num == 4 or event.delta > 0:
canvas.yview_scroll(-1, "units")
elif event.num == 5 or event.delta < 0:
canvas.yview_scroll(1, "units")
# Layout
root = tk.Tk()
root.bind("<Escape>", lambda e: root.destroy())
root.title(title)
root.geometry(DEFAULT_SIZE)
root.configure(**KWSTYLES)
scrollbar = tk.Scrollbar(root,orient = "vertical", troughcolor = "#222", activebackground = "#666", bd = 0, **KWSTYLES)
canvas = tk.Canvas(root, **KWSTYLES, yscrollcommand = scrollbar.set)
scrollable_frame = tk.Frame(canvas, **KWSTYLES)
scrollable_frame.bind("<Configure>", lambda e: canvas.configure(scrollregion = canvas.bbox("all")))
canvas.create_window((0, 0), window = scrollable_frame, anchor = "nw")
canvas.pack(side = "left", fill = "both", expand = True)
scrollbar.pack(side = "right", fill = "y")
canvas.bind_all("<MouseWheel>", on_mousewheel) # Windows/macOS
canvas.bind_all("<Button-4>", on_mousewheel) # Linux up
canvas.bind_all("<Button-5>", on_mousewheel) # Linux down
canvas.bind("<Configure>", on_resize)
# Preload thumbnails
thumbnails = [(make_thumbnail(path), path) for path in image_paths]
# Initial grid
root.update_idletasks()
initial_columns = max(1, (root.winfo_width() - THUMBNAIL_PADDING) // THUMB_TOTAL_WIDTH)
draw_grid(initial_columns)
# Go
root.mainloop()