Initial support for Qwen2.5-VL

This commit is contained in:
turboderp
2025-01-29 02:55:44 +01:00
parent d0413b06f8
commit cce6f95cd3
8 changed files with 222 additions and 48 deletions

View File

@@ -356,7 +356,7 @@ class ExLlamaV2ArchParams:
# Qwen2-VL (2, 2.5)
if arch_string == "Qwen2VLForConditionalGeneration":
if arch_string in ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]:
arch_recognized = True
self.lm.layer_keys += \
layer_keys_llama_norms + \
@@ -368,27 +368,44 @@ class ExLlamaV2ArchParams:
self.lm.mrope = True
self.lm.rope_freq_half = True
read_config["vision_config"].update({"model_type": "qwen2"})
self.vt_prefix = "visual."
self.vt.keys.update({
"fused_qkv": ".attn.qkv",
"attn_o": ".attn.proj",
"mlp_gate": None,
"mlp_up": ".mlp.fc1",
"mlp_down": ".mlp.fc2",
"norm_1": ".norm1",
"norm_2": ".norm2",
"layers": "blocks",
"patch_conv": "patch_embed.proj",
})
self.vt.mlp_gate = False
if arch_string == "Qwen2VLForConditionalGeneration":
read_config["vision_config"].update({"model_type": "qwen2"})
self.vt.keys.update({
"fused_qkv": ".attn.qkv",
"attn_o": ".attn.proj",
"mlp_gate": None,
"mlp_up": ".mlp.fc1",
"mlp_down": ".mlp.fc2",
"norm_1": ".norm1",
"norm_2": ".norm2",
"layers": "blocks",
"patch_conv": "patch_embed.proj",
})
self.vt.mlp_gate = False
self.vt.mlp_act_func = "quickgelu"
self.vt.norm = "layernorm"
elif arch_string == "Qwen2_5_VLForConditionalGeneration":
read_config["vision_config"].update({"model_type": "qwen2.5"})
self.vt.keys.update({
"fused_qkv": ".attn.qkv",
"attn_o": ".attn.proj",
"mlp_gate": ".mlp.gate_proj",
"mlp_up": ".mlp.up_proj",
"mlp_down": ".mlp.down_proj",
"norm_1": ".norm1",
"norm_2": ".norm2",
"layers": "blocks",
"patch_conv": "patch_embed.proj",
})
self.vt.mlp_gate = True
self.vt.mlp_act_func = "silu"
self.vt.norm = "rmsnorm"
self.vt.mlp_bias = True
self.vt.attention_bias_qkv = True
self.vt.attention_bias_o = True
self.vt.vision_input_norm = False
self.vt.vision_conv3d = True
self.vt.mlp_act_func = "quickgelu"
self.vt.norm = "layernorm"
self.mmp_prefix = "visual.merger."
self.mmp.keys.update({

View File

@@ -41,11 +41,11 @@ if 'EXLLAMA_NO_FLASH_ATTN' not in os.environ:
print(" ## Warning: Flash Attention is installed but unsupported GPUs were detected.")
if [2, 2, 1] <= flash_attn_ver < [2, 5, 7]:
from flash_attn import flash_attn_func
from flash_attn import flash_attn_func, flash_attn_varlen_func
has_flash_attn = True
if [2, 5, 7] <= flash_attn_ver:
from flash_attn import flash_attn_func, flash_attn_with_kvcache
from flash_attn import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache
# import flash_attn_2_cuda as flash_attn_cuda
signature = list(inspect.signature(flash_attn_func).parameters)
@@ -882,7 +882,9 @@ class ExLlamaV2Attention(ExLlamaV2Module):
k_states = k_states[:, :, -self.sliding_window:, :]
v_states = v_states[:, :, -self.sliding_window:, :]
if attn_params.is_causal():
if self.layer_idx in attn_params.block_diag_layers:
attn_mask_lr = attn_params.get_block_diag_mask(q_states.device)
elif attn_params.is_causal():
attn_mask_lr = causal_lower_right(q_len, k_states.shape[2])
else:
attn_mask_lr = attn_params.get_attn_mask(q_states.device)
@@ -904,7 +906,9 @@ class ExLlamaV2Attention(ExLlamaV2Module):
attn_weights = torch.matmul(q_states, k_states)
attn_weights *= self.scaling
if causal:
if self.layer_idx in attn_params.block_diag_layers:
attn_mask = attn_params.get_block_diag_mask(attn_weights.device)
elif causal:
attn_mask = attn_params.get_attn_mask(attn_weights.device)
if cfg.attn_logit_softcapping:
@@ -939,14 +943,30 @@ class ExLlamaV2Attention(ExLlamaV2Module):
if has_flash_attn_with_softcap:
flash_kwargs["softcap"] = cfg.attn_logit_softcapping
attn_output = flash_attn_func(
q_states,
k_states,
v_states,
causal = causal,
softmax_scale = self.scaling,
**flash_kwargs
)
if self.layer_idx in attn_params.block_diag_layers:
q_states = q_states.flatten(start_dim = 0, end_dim = 1)
k_states = k_states.flatten(start_dim = 0, end_dim = 1)
v_states = v_states.flatten(start_dim = 0, end_dim = 1)
max_seqlen = attn_params.get_cu_seqlens_max()
cu_seqlens = attn_params.get_cu_seqlens(self.device_idx)
attn_output = flash_attn_varlen_func(
q_states,
k_states,
v_states,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen
)
else:
attn_output = flash_attn_func(
q_states,
k_states,
v_states,
causal = causal,
softmax_scale = self.scaling,
**flash_kwargs
)
attn_output = attn_output.reshape((batch_size, q_len, self.num_attention_heads * self.head_dim))
return attn_output

View File

@@ -21,6 +21,10 @@ class Params:
alt_rope_embed_dict: dict | None
rope_offsets: torch.Tensor | None
non_causal_attn: bool
block_diag_layers: set
block_diag_mask: torch.Tensor | None
cu_seqlens: torch.Tensor | None
cu_seqlens_max: int | None
def __init__(
self,
@@ -66,6 +70,11 @@ class Params:
self.past_len_tp = None
self.paged = paged
self.block_diag_layers = set()
self.block_diag_mask = None
self.cu_seqlens = None
self.cu_seqlens_max = None
def is_causal(self) -> bool:
return self.input_mask is None
@@ -164,6 +173,31 @@ class Params:
self.rope_offsets = safe_move_tensor(self.rope_offsets, device_idx, non_blocking = True)
return self.rope_offsets
def get_cu_seqlens(self, device: int) -> torch.Tensor | None:
if self.cu_seqlens is None:
return None
if self.cu_seqlens.device.index != device:
self.cu_seqlens = safe_move_tensor(self.cu_seqlens, device, non_blocking = True)
return self.cu_seqlens
def get_cu_seqlens_max(self) -> torch.Tensor | None:
assert self.cu_seqlens is not None
if self.cu_seqlens_max is not None:
return self.cu_seqlens_max
self.cu_seqlens_max = (self.cu_seqlens[1:] - self.cu_seqlens[:-1]).max().item()
return self.cu_seqlens_max
def get_block_diag_mask(self, device: int) -> torch.Tensor | None:
if self.block_diag_mask is None:
csl = self.get_cu_seqlens(device)
if csl is None:
return None
positions = torch.arange(csl[-1], device = csl.device)
labels = torch.searchsorted(csl[1:], positions, right = True)
self.block_diag_mask = labels.unsqueeze(0) == labels.unsqueeze(1).repeat(self.batch_size)
if self.block_diag_mask.device.index != device:
self.block_diag_mask = safe_move_tensor(self.block_diag_mask, device, non_blocking = True)
return self.block_diag_mask
class PagedParams(Params):

View File

@@ -135,6 +135,7 @@ class ExLlamaV2Config:
vision_num_key_value_groups: int | None
vision_hidden_size: int | None
vision_intermediate_size: int | None
vision_merger_intermediate_size: int | None
vision_hidden_act: str | None
vision_rope_theta: float | None
vision_feature_layer: int | None
@@ -152,6 +153,8 @@ class ExLlamaV2Config:
vision_max_pixels: int | None
vision_temporal_patch_size: int | None
vision_max_size: int | None
vision_fullatt_block_indexes: list | None
vision_window_size: int | None
# Deprecated fields, kept for compatibiltiy
@@ -478,6 +481,8 @@ class ExLlamaV2Config:
# TODO: Cleanup & refactor
self.vision_fullatt_block_indexes = None
if self.vision_model_type is None:
pass
@@ -495,6 +500,7 @@ class ExLlamaV2Config:
self.vision_feature_layer = read(read_config, int, ["vision_feature_layer"], no_default)
self.vision_num_layers = read(read_config, int, ["vision_config->num_hidden_layers"], 24)
self.vision_intermediate_size = read(read_config, int, ["vision_config->intermediate_size"], self.hidden_size)
self.vision_merger_intermediate_size = self.vision_intermediate_size
image_processor_type = read(read_prep_config, str, ["image_processor_type"], no_default)
assert image_processor_type == "PixtralImageProcessor", \
@@ -511,10 +517,27 @@ class ExLlamaV2Config:
self.vision_spatial_merge_size = 1
self.vision_max_size = 16384
elif self.vision_model_type == "qwen2":
elif self.vision_model_type in ["qwen2", "qwen2.5"]:
image_processor_type = read(read_prep_config, str, ["image_processor_type"], no_default)
if self.vision_model_type == "qwen2":
self.vision_hidden_size = read(read_config, int, ["vision_config->embed_dim"], no_default)
mlp_ratio = read(read_config, int, ["vision_config->mlp_ratio"], None)
self.vision_intermediate_size = self.vision_hidden_size * mlp_ratio
self.vision_merger_intermediate_size = self.vision_intermediate_size
assert image_processor_type == "Qwen2VLImageProcessor", \
f"Wrong image processor type: {image_processor_type}"
self.vision_window_size = None
elif self.vision_model_type == "qwen2.5":
self.vision_hidden_size = read(read_config, int, ["vision_config->hidden_size"], no_default)
self.vision_intermediate_size = read(read_config, int, ["vision_config->intermediate_size"], no_default)
self.vision_fullatt_block_indexes = read(read_config, list, ["vision_config->fullatt_block_indexes", None])
self.vision_window_size = read(read_config, int, ["vision_config->window_size", None])
assert image_processor_type == "Qwen2_5_VLImageProcessor", \
f"Wrong image processor type: {image_processor_type}"
self.vision_merger_intermediate_size = 5120 # TODO: This doesn't seem to appear in the config anywhere?
self.vision_num_attention_heads = read(read_config, int, ["vision_config->num_heads"], no_default)
self.vision_num_key_value_heads = self.vision_num_attention_heads
self.vision_hidden_size = read(read_config, int, ["vision_config->embed_dim"], no_default)
self.vision_head_dim = self.vision_hidden_size // self.vision_num_attention_heads
self.vision_num_key_value_groups = 1
self.vision_hidden_act = "quickgelu"
@@ -523,12 +546,7 @@ class ExLlamaV2Config:
patch_size = read(read_config, int, ["vision_config->patch_size"], no_default)
self.vision_rope_theta = read(read_config, int, ["vision_config->rope_theta"], 10000.0)
self.vision_num_layers = read(read_config, int, ["vision_config->depth"], no_default)
mlp_ratio = read(read_config, int, ["vision_config->mlp_ratio"], no_default)
self.vision_intermediate_size = self.vision_hidden_size * mlp_ratio
image_processor_type = read(read_prep_config, str, ["image_processor_type"], no_default)
assert image_processor_type == "Qwen2VLImageProcessor", \
f"Wrong image processor type: {image_processor_type}"
self.vision_image_mean = read(read_prep_config, list, ["image_mean"], no_default)
self.vision_image_std = read(read_prep_config, list, ["image_std"], no_default)
assert read(read_prep_config, int, ["patch_size"], no_default) == patch_size, \

View File

@@ -51,6 +51,7 @@ class ExLlamaV2MLP(ExLlamaV2Module):
out_features: int | None = None,
interm_features: int | None = None,
merge: int | None = None,
pad32: bool = True,
):
super().__init__(model, key, archparams)
cfg = self.model.config
@@ -98,8 +99,8 @@ class ExLlamaV2MLP(ExLlamaV2Module):
self.pre_layernorm = None
self.post_layernorm = None
self.up_proj = ExLlamaV2Linear(model, key + km["mlp_up"], in_features, interm_features, ap.mlp_bias, f_key = f_key, f_beg = f_b, f_end = f_c)
self.down_proj = ExLlamaV2Linear(model, key + km["mlp_down"], interm_features, out_features, ap.mlp_bias, prescale = cfg.scale_depth)
self.up_proj = ExLlamaV2Linear(model, key + km["mlp_up"], in_features, interm_features, ap.mlp_bias, f_key = f_key, f_beg = f_b, f_end = f_c, pad32 = pad32)
self.down_proj = ExLlamaV2Linear(model, key + km["mlp_down"], interm_features, out_features, ap.mlp_bias, prescale = cfg.scale_depth, pad32 = pad32)
self.submodules = [self.up_proj,
self.down_proj]
@@ -109,7 +110,7 @@ class ExLlamaV2MLP(ExLlamaV2Module):
self.submodules += [self.post_layernorm]
if ap.mlp_gate:
self.gate_proj = ExLlamaV2Linear(model, key + km["mlp_gate"], in_features, interm_features, ap.mlp_bias, f_key = f_key, f_beg = f_a, f_end = f_b)
self.gate_proj = ExLlamaV2Linear(model, key + km["mlp_gate"], in_features, interm_features, ap.mlp_bias, f_key = f_key, f_beg = f_a, f_end = f_b, pad32 = pad32)
self.submodules += [self.gate_proj]
else:
self.gate_proj = None

View File

@@ -44,7 +44,7 @@ def preprocess(
image = image.transpose(2, 0, 1)
image = torch.from_numpy(image).half()
return image, new_size
return image, new_size, None, None
def postprocess(
model: ExLlamaV2,

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
@@ -86,7 +87,7 @@ def preprocess(
if mode == "image":
image = torch.from_numpy(flatten_patches).half()
return image, new_size
return image, new_size, (grid_t, grid_h, grid_w), config.vision_spatial_patch_size ** 2
else:
video = torch.from_numpy(flatten_patches).half()
return video, new_size, (grid_t, grid_h, grid_w), config.vision_spatial_patch_size ** 2
@@ -149,4 +150,51 @@ def position_embeddings(
cos = cos.unsqueeze(1).repeat(1, 1, 2).contiguous()
sin = sin.unsqueeze(1).repeat(1, 1, 2).contiguous()
return sin, cos
return sin, cos
def get_window_index(grid_thw, config: ExLlamaV2Config):
window_index: list = []
cu_window_seqlens: list = [0]
window_index_id = 0
vit_merger_window_size = (
config.vision_window_size //
config.vision_spatial_merge_size //
config.vision_patch_size["height"]
)
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = (
grid_h // config.vision_spatial_merge_size,
grid_w // config.vision_spatial_merge_size,
)
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
index_padded = index_padded.reshape(
grid_t,
num_windows_h,
vit_merger_window_size,
num_windows_w,
vit_merger_window_size,
)
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
grid_t,
num_windows_h * num_windows_w,
vit_merger_window_size,
vit_merger_window_size,
)
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
index_padded = index_padded.reshape(-1)
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = seqlens.cumsum(0) * config.vision_spatial_merge_size**2 + cu_window_seqlens[-1]
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim =0)
return window_index, cu_window_seqlens

View File

@@ -4,6 +4,7 @@ import os, sys
import threading
import torch
import torch.nn.functional as F
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.conv import ExLlamaV2Conv
from exllamav2.rmsnorm import ExLlamaV2RMSNorm
@@ -44,7 +45,7 @@ class ExLlamaV2VisionTower(ExLlamaV2):
self.postprocess_func = pixtral.postprocess
self.video_preprocess_func = None
self.video_postprocess_func = None
elif cfg.vision_model_type == "qwen2":
elif cfg.vision_model_type in ["qwen2", "qwen2.5"]:
self.preprocess_func = qwen2.preprocess
self.postprocess_func = qwen2.postprocess
self.video_preprocess_func = qwen2.preprocess
@@ -76,7 +77,7 @@ class ExLlamaV2VisionTower(ExLlamaV2):
self.position_emb_func = pixtral.position_embeddings
elif cfg.vision_model_type == "qwen2":
elif cfg.vision_model_type in ["qwen2", "qwen2.5"]:
self.p_maxedge = cfg.vision_max_size
dim = cfg.vision_head_dim // 2
max_seqlen = int(math.ceil(cfg.vision_max_size / cfg.vision_spatial_patch_size))
@@ -130,7 +131,7 @@ class ExLlamaV2VisionTower(ExLlamaV2):
for layer_idx in range(self.config.vision_num_layers):
layer_key = cfg.arch.vt_prefix + km["layers"] + f".{layer_idx}"
attn = ExLlamaV2Attention(self, layer_key, layer_idx, archparams = self.archparams)
mlp = ExLlamaV2MLP(self, layer_key, layer_idx, archparams = self.archparams)
mlp = ExLlamaV2MLP(self, layer_key, layer_idx, archparams = self.archparams, pad32 = False)
self.modules += [attn, mlp]
# Multimodal projection
@@ -143,10 +144,11 @@ class ExLlamaV2VisionTower(ExLlamaV2):
archparams = cfg.arch.mmp,
in_features = cfg.vision_hidden_size * merge,
out_features = cfg.hidden_size,
interm_features = cfg.vision_intermediate_size,
interm_features = cfg.vision_merger_intermediate_size,
has_norm = True,
has_residual = False,
merge = merge,
pad32 = False,
)
self.modules += [mmp]
@@ -199,6 +201,19 @@ class ExLlamaV2VisionTower(ExLlamaV2):
)
attn_params = ExLlamaV2Attention.Params(non_causal_attn = True)
if self.config.vision_window_size:
attn_params.block_diag_layers = set([
l for l in range(self.config.vision_num_layers)
if l not in self.config.vision_fullatt_block_indexes
])
thw_grid_t = torch.tensor(list(thw_grid), dtype = torch.int).unsqueeze(0)
window_index, csl = qwen2.get_window_index(thw_grid_t, self.config)
csl = torch.tensor(csl, device = hidden_states.device, dtype = torch.int)
csl = torch.unique_consecutive(csl)
attn_params.cu_seqlens = csl
else:
window_index = None
device = self.modules[0].device_idx
for idx, module in enumerate(self.modules):
@@ -230,13 +245,33 @@ class ExLlamaV2VisionTower(ExLlamaV2):
hidden_states,
attn_params = attn_params,
**kwargs | {
"alt_rope_embedding": (cos, sin)
"alt_rope_embedding": (cos, sin),
}
)
if thw_grid is not None and isinstance(module, ExLlamaV2Attention):
hidden_states = hidden_states.view(pa_shape)
if window_index is not None and idx == 0:
sh = hidden_states.shape
unit = self.config.vision_spatial_merge_size ** 2
seq_len = hidden_states.shape[0] * hidden_states.shape[1]
hidden_states = hidden_states.reshape(seq_len // unit, unit, -1)
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape(sh)
sh = sin.shape
seq_len = sin.shape[0]
sin = sin.reshape(seq_len // unit, unit, -1)
sin = sin[window_index[:seq_len // unit], :, :]
sin = sin.reshape(sh)
cos = cos.reshape(seq_len // unit, unit, -1)
cos = cos[window_index[:seq_len // unit], :, :]
cos = cos.reshape(sh)
if window_index is not None:
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[:, reverse_indices]
return hidden_states
@@ -278,13 +313,14 @@ class ExLlamaV2VisionTower(ExLlamaV2):
assert all(s <= maxsize for s in original_size), \
f"Input image exceeds maximum size of {maxsize} x {maxsize}"
image_tensor, prep_image_size = self.preprocess_func(self.config, image)
image_tensor, prep_image_size, grid_thw, _ = self.preprocess_func(self.config, image)
features_x = prep_image_size[0] // self.config.vision_patch_size["width"]
features_y = prep_image_size[1] // self.config.vision_patch_size["height"]
embedding_tensor = self.process(
image_tensor,
(features_y, features_x)
(features_y, features_x),
thw_grid = grid_thw
)
if embeddings_cpu: