Qwen2-VL: Basic video support

This commit is contained in:
turboderp
2024-12-15 23:32:41 +01:00
parent c78d9027aa
commit 4061c24373
28 changed files with 295 additions and 18 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 72 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 73 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 77 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 78 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 79 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 79 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 81 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 84 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 86 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 86 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 87 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 87 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 87 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 87 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 87 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 90 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 90 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

View File

@@ -0,0 +1,148 @@
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Tokenizer,
ExLlamaV2VisionTower,
)
from exllamav2.generator import (
ExLlamaV2DynamicGenerator,
ExLlamaV2DynamicJob,
ExLlamaV2Sampler,
)
from PIL import Image
import requests, glob
import torch
torch.set_printoptions(precision = 5, sci_mode = False, linewidth=200)
# Model used:
#
# Qwen2-VL:
# https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct
# https://huggingface.co/turboderp/Qwen2-VL-7B-Instruct-exl2
streaming = True
greedy = True
model_directory = "/mnt/str/models/qwen2-vl-7b-instruct-exl2/6.0bpw"
frames = [
{"file": f}
for f in sorted(glob.glob("media/test_video_*.png"))
]
instruction = "Describe this video."
# Initialize model
config = ExLlamaV2Config(model_directory)
config.max_seq_len = 16384 # Pixtral default is 1M
# Load vision model and multimodal projector and initialize preprocessor
vision_model = ExLlamaV2VisionTower(config)
vision_model.load(progress = True)
# Load EXL2 model
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, lazy = True, max_seq_len = 16384)
model.load_autosplit(cache, progress = True)
tokenizer = ExLlamaV2Tokenizer(config)
# Create generator
generator = ExLlamaV2DynamicGenerator(
model = model,
cache = cache,
tokenizer = tokenizer,
)
# Util function to get a PIL image from a URL or from a file in the script's directory
def get_image(file = None, url = None):
assert (file or url) and not (file and url)
if file:
script_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(script_dir, file)
return Image.open(file_path)
elif url:
return Image.open(requests.get(url, stream = True).raw)
# Convert video to embeddings. Aliases can be given explicitly with the text_alias argument, but here we
# use automatically assigned unique identifiers, then concatenate them into a string
video_embedding = vision_model.get_video_embeddings(
model = model,
tokenizer = tokenizer,
video = [get_image(**img_args) for img_args in frames],
)
video_embeddings = [video_embedding]
# Define prompt
prompt = (
"<|im_start|>system\n" +
"You are a helpful assistant.<|im_end|>\n" +
"<|im_start|>user\n" +
video_embedding.text_alias +
# "\n" +
instruction +
"<|im_end|>\n" +
"<|im_start|>assistant\n"
)
# Generate
if streaming:
input_ids = tokenizer.encode(
prompt,
# add_bos = True,
encode_special_tokens = True,
embeddings = video_embeddings,
)
job = ExLlamaV2DynamicJob(
input_ids = input_ids,
max_new_tokens = 500,
decode_special_tokens = True,
stop_conditions = [tokenizer.eos_token_id],
gen_settings = ExLlamaV2Sampler.Settings.greedy() if greedy else None,
embeddings = video_embeddings,
)
generator.enqueue(job)
print()
print(prompt, end = ""); sys.stdout.flush()
eos = False
while generator.num_remaining_jobs():
results = generator.iterate()
for result in results:
text = result.get("text", "")
print(text, end = ""); sys.stdout.flush()
print()
else:
output = generator.generate(
prompt = prompt,
max_new_tokens = 500,
add_bos = True,
encode_special_tokens = True,
decode_special_tokens = True,
stop_conditions = [tokenizer.eos_token_id],
gen_settings = ExLlamaV2Sampler.Settings.greedy() if greedy else None,
embeddings = video_embeddings,
)
print(output)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 84 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 40 KiB

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import torch
import numpy as np
from PIL import Image
@@ -80,11 +82,15 @@ def position_embeddings(
max_width: int,
rope_sin: torch.Tensor,
rope_cos: torch.Tensor,
thw_grid: tuple | None = None,
):
"""
Create flat position IDs tensor for grid of patches: id(row, col) = row * max_width + col
"""
assert thw_grid is None, \
"Video not supported for Pixtral"
row_indices = torch.arange(height).unsqueeze(1) * max_width
col_indices = torch.arange(width).unsqueeze(0)
ids = row_indices + col_indices
@@ -93,4 +99,3 @@ def position_embeddings(
cos = rope_cos[ids]
sin = rope_sin[ids]
return sin, cos

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import torch
import numpy as np
from PIL import Image
@@ -11,7 +13,7 @@ from exllamav2.vlm.util import (
def preprocess(
config: ExLlamaV2Config,
image: Image
images: Image | list[Image]
) -> (torch.Tensor, tuple):
resample = Image.Resampling(config.vision_resample)
@@ -19,30 +21,48 @@ def preprocess(
image_std = tuple(config.vision_image_std)
rescale_factor = config.vision_rescale_factor
# Make list and truncate to whole number of spatial patches
if not isinstance(images, list):
mode = "image"
images = [images]
else:
mode = "video"
g = config.vision_temporal_patch_size
frames = len(images)
if frames > 1:
frames = frames // g * g
images = images[:frames]
# Convert to RGB and resize as necessary
image = convert_to_rgb(image)
old_size = image.size
images = [convert_to_rgb(image) for image in images]
old_size = images[0].size
assert all(old_size == frame.size for frame in images), \
"All frames in video must have same dimensions"
new_size = smart_resize(
image.size,
old_size,
config.vision_spatial_patch_size * config.vision_spatial_merge_size,
config.vision_min_pixels,
config.vision_max_pixels,
)
if old_size != new_size:
image = image.resize(new_size, resample = resample)
images = [image.resize(new_size, resample = resample) for image in images]
# Convert to numpy array and normalize
image = np.array(image).astype(np.float32)
image = image * rescale_factor
image = normalize_image(image, image_mean, image_std)
images = [np.array(image).astype(np.float32) for image in images]
images = [image * rescale_factor for image in images]
images = [normalize_image(image, image_mean, image_std) for image in images]
# Reshape and convert to tensor
image = image.transpose(2, 0, 1)
patches = np.array([image])
patches = np.tile(patches, (config.vision_temporal_patch_size, 1, 1, 1))
patches = np.array(images)
patches = patches.transpose(0, 3, 1, 2)
if patches.shape[0] == 1:
patches = np.tile(patches, (config.vision_temporal_patch_size, 1, 1, 1))
channels = patches.shape[1]
grid_t = patches.shape[0] // config.vision_temporal_patch_size
grid_h = new_size[1] // config.vision_spatial_patch_size
@@ -64,8 +84,12 @@ def preprocess(
channels * config.vision_temporal_patch_size * config.vision_spatial_patch_size ** 2
)
image = torch.from_numpy(flatten_patches).half()
return image, new_size
if mode == "image":
image = torch.from_numpy(flatten_patches).half()
return image, new_size
else:
video = torch.from_numpy(flatten_patches).half()
return video, new_size, (grid_t, grid_h, grid_w), config.vision_spatial_patch_size ** 2
def postprocess(
model: ExLlamaV2,
@@ -94,13 +118,17 @@ def position_embeddings(
max_width: int,
rope_sin: torch.Tensor,
rope_cos: torch.Tensor,
thw_grid: tuple | None = None,
):
"""
Create position IDs for Qwen2 grid
"""
t = 1 # TODO: t dimension
h, w = height, width
if thw_grid is not None:
t, h, w = thw_grid
else:
h, w = height, width
spm = config.vision_spatial_merge_size
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
@@ -112,7 +140,9 @@ def position_embeddings(
wpos_ids = wpos_ids.reshape(h // spm, spm, w // spm, spm)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
ids = torch.stack([hpos_ids, wpos_ids], dim = -1).repeat(t, 1)
# ids = torch.stack([hpos_ids, wpos_ids], dim = -1).repeat(t, 1)
ids = torch.stack([hpos_ids, wpos_ids], dim = -1)
cos = rope_cos[ids].flatten(1)
sin = rope_sin[ids].flatten(1)

View File

@@ -15,6 +15,7 @@ from exllamav2.compat import safe_move_tensor
from exllamav2.generator import ExLlamaV2MMEmbedding
from exllamav2.vlm.processor import pixtral, qwen2
from exllamav2.vlm.util import convert_to_rgb
from PIL.Image import Image
@@ -41,9 +42,13 @@ class ExLlamaV2VisionTower(ExLlamaV2):
if cfg.vision_model_type == "pixtral":
self.preprocess_func = pixtral.preprocess
self.postprocess_func = pixtral.postprocess
self.video_preprocess_func = None
self.video_postprocess_func = None
elif cfg.vision_model_type == "qwen2":
self.preprocess_func = qwen2.preprocess
self.postprocess_func = qwen2.postprocess
self.video_preprocess_func = qwen2.preprocess
self.video_postprocess_func = qwen2.postprocess
else:
raise ValueError(f"Unknown vision model type: {cfg.vision_model_type}")
@@ -165,6 +170,7 @@ class ExLlamaV2VisionTower(ExLlamaV2):
hidden_states: torch.Tensor,
patches_size = None,
abort_event: threading.Event | None = None,
thw_grid: tuple | None = None,
**kwargs
):
cfg = self.config
@@ -188,7 +194,8 @@ class ExLlamaV2VisionTower(ExLlamaV2):
p_width,
self.p_maxedge,
self.rope_sin,
self.rope_cos
self.rope_cos,
thw_grid
)
attn_params = ExLlamaV2Attention.Params(non_causal_attn = True)
@@ -211,6 +218,14 @@ class ExLlamaV2VisionTower(ExLlamaV2):
cos = safe_move_tensor(cos, hidden_states.device)
sin = safe_move_tensor(sin, hidden_states.device)
if thw_grid is not None and isinstance(module, ExLlamaV2Attention):
pa_shape = hidden_states.shape
hidden_states = hidden_states.view(
thw_grid[0],
hidden_states.shape[1] // thw_grid[0],
hidden_states.shape[2]
)
hidden_states = module.forward(
hidden_states,
attn_params = attn_params,
@@ -219,6 +234,9 @@ class ExLlamaV2VisionTower(ExLlamaV2):
}
)
if thw_grid is not None and isinstance(module, ExLlamaV2Attention):
hidden_states = hidden_states.view(pa_shape)
return hidden_states
@@ -295,4 +313,80 @@ class ExLlamaV2VisionTower(ExLlamaV2):
"patches_size": (features_y, features_x),
})
return mme
def get_video_embeddings(
self,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
video: list[Image],
text_alias: str | None = None,
embeddings_cpu: bool = True
) -> ExLlamaV2MMEmbedding:
"""
:param model:
Text model for which to produce embeddings
:param tokenizer:
Tokenizer
:param video:
Video as list of PIL images, one per frame
:param text_alias:
Text string to represent this embedding for tokenizing
:param embeddings_cpu:
Move embeddings to CPU. This can be skipped for simple jobs, but ideally embeddings should be cached
when used with the dynamic generator, and it is not ideal to keep some large cache of data in VRAM. The
overhead of copying them back to VRAM is relatively low. If this argument is False, embeddings will
reside on whatever device the vision tower is loaded on.
:return:
ExLlamaV2MMEmbedding
"""
width, height = video[0].size
assert all((width, height) == frame.size for frame in video), \
"All video frames must have same dimensions"
original_size = (height, width)
video_tensor, prep_image_size, video_grid_thw, merge = self.preprocess_func(self.config, video)
features_x = prep_image_size[0] // self.config.vision_patch_size["width"]
features_y = prep_image_size[1] // self.config.vision_patch_size["height"]
embedding_tensor = self.process(
video_tensor,
(features_y, features_x),
thw_grid = video_grid_thw,
)
if embeddings_cpu:
embedding_tensor = embedding_tensor.cpu()
embedding_tensor, pre_tokens, post_tokens = self.postprocess_func(
model,
tokenizer,
embedding_tensor[0],
features_y,
features_x,
)
mme = ExLlamaV2MMEmbedding(
model = model,
embeddings = embedding_tensor,
text_alias = text_alias,
thw_grid = video_grid_thw,
pre_tokens = pre_tokens,
post_tokens = post_tokens
)
mme.metadata.update({
"original_size": original_size,
"preprocessed_size": prep_image_size,
"patches_size": (features_y, features_x),
})
return mme