mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 06:19:00 +00:00
Initial support for Qwen2.5-VL
This commit is contained in:
@@ -356,7 +356,7 @@ class ExLlamaV2ArchParams:
|
||||
|
||||
# Qwen2-VL (2, 2.5)
|
||||
|
||||
if arch_string == "Qwen2VLForConditionalGeneration":
|
||||
if arch_string in ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]:
|
||||
arch_recognized = True
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_llama_norms + \
|
||||
@@ -368,27 +368,44 @@ class ExLlamaV2ArchParams:
|
||||
self.lm.mrope = True
|
||||
self.lm.rope_freq_half = True
|
||||
|
||||
read_config["vision_config"].update({"model_type": "qwen2"})
|
||||
self.vt_prefix = "visual."
|
||||
self.vt.keys.update({
|
||||
"fused_qkv": ".attn.qkv",
|
||||
"attn_o": ".attn.proj",
|
||||
"mlp_gate": None,
|
||||
"mlp_up": ".mlp.fc1",
|
||||
"mlp_down": ".mlp.fc2",
|
||||
"norm_1": ".norm1",
|
||||
"norm_2": ".norm2",
|
||||
"layers": "blocks",
|
||||
"patch_conv": "patch_embed.proj",
|
||||
})
|
||||
self.vt.mlp_gate = False
|
||||
if arch_string == "Qwen2VLForConditionalGeneration":
|
||||
read_config["vision_config"].update({"model_type": "qwen2"})
|
||||
self.vt.keys.update({
|
||||
"fused_qkv": ".attn.qkv",
|
||||
"attn_o": ".attn.proj",
|
||||
"mlp_gate": None,
|
||||
"mlp_up": ".mlp.fc1",
|
||||
"mlp_down": ".mlp.fc2",
|
||||
"norm_1": ".norm1",
|
||||
"norm_2": ".norm2",
|
||||
"layers": "blocks",
|
||||
"patch_conv": "patch_embed.proj",
|
||||
})
|
||||
self.vt.mlp_gate = False
|
||||
self.vt.mlp_act_func = "quickgelu"
|
||||
self.vt.norm = "layernorm"
|
||||
elif arch_string == "Qwen2_5_VLForConditionalGeneration":
|
||||
read_config["vision_config"].update({"model_type": "qwen2.5"})
|
||||
self.vt.keys.update({
|
||||
"fused_qkv": ".attn.qkv",
|
||||
"attn_o": ".attn.proj",
|
||||
"mlp_gate": ".mlp.gate_proj",
|
||||
"mlp_up": ".mlp.up_proj",
|
||||
"mlp_down": ".mlp.down_proj",
|
||||
"norm_1": ".norm1",
|
||||
"norm_2": ".norm2",
|
||||
"layers": "blocks",
|
||||
"patch_conv": "patch_embed.proj",
|
||||
})
|
||||
self.vt.mlp_gate = True
|
||||
self.vt.mlp_act_func = "silu"
|
||||
self.vt.norm = "rmsnorm"
|
||||
self.vt.mlp_bias = True
|
||||
self.vt.attention_bias_qkv = True
|
||||
self.vt.attention_bias_o = True
|
||||
self.vt.vision_input_norm = False
|
||||
self.vt.vision_conv3d = True
|
||||
self.vt.mlp_act_func = "quickgelu"
|
||||
self.vt.norm = "layernorm"
|
||||
|
||||
self.mmp_prefix = "visual.merger."
|
||||
self.mmp.keys.update({
|
||||
|
||||
@@ -41,11 +41,11 @@ if 'EXLLAMA_NO_FLASH_ATTN' not in os.environ:
|
||||
print(" ## Warning: Flash Attention is installed but unsupported GPUs were detected.")
|
||||
|
||||
if [2, 2, 1] <= flash_attn_ver < [2, 5, 7]:
|
||||
from flash_attn import flash_attn_func
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
has_flash_attn = True
|
||||
|
||||
if [2, 5, 7] <= flash_attn_ver:
|
||||
from flash_attn import flash_attn_func, flash_attn_with_kvcache
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
# import flash_attn_2_cuda as flash_attn_cuda
|
||||
|
||||
signature = list(inspect.signature(flash_attn_func).parameters)
|
||||
@@ -882,7 +882,9 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
k_states = k_states[:, :, -self.sliding_window:, :]
|
||||
v_states = v_states[:, :, -self.sliding_window:, :]
|
||||
|
||||
if attn_params.is_causal():
|
||||
if self.layer_idx in attn_params.block_diag_layers:
|
||||
attn_mask_lr = attn_params.get_block_diag_mask(q_states.device)
|
||||
elif attn_params.is_causal():
|
||||
attn_mask_lr = causal_lower_right(q_len, k_states.shape[2])
|
||||
else:
|
||||
attn_mask_lr = attn_params.get_attn_mask(q_states.device)
|
||||
@@ -904,7 +906,9 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
attn_weights = torch.matmul(q_states, k_states)
|
||||
|
||||
attn_weights *= self.scaling
|
||||
if causal:
|
||||
if self.layer_idx in attn_params.block_diag_layers:
|
||||
attn_mask = attn_params.get_block_diag_mask(attn_weights.device)
|
||||
elif causal:
|
||||
attn_mask = attn_params.get_attn_mask(attn_weights.device)
|
||||
|
||||
if cfg.attn_logit_softcapping:
|
||||
@@ -939,14 +943,30 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
if has_flash_attn_with_softcap:
|
||||
flash_kwargs["softcap"] = cfg.attn_logit_softcapping
|
||||
|
||||
attn_output = flash_attn_func(
|
||||
q_states,
|
||||
k_states,
|
||||
v_states,
|
||||
causal = causal,
|
||||
softmax_scale = self.scaling,
|
||||
**flash_kwargs
|
||||
)
|
||||
if self.layer_idx in attn_params.block_diag_layers:
|
||||
q_states = q_states.flatten(start_dim = 0, end_dim = 1)
|
||||
k_states = k_states.flatten(start_dim = 0, end_dim = 1)
|
||||
v_states = v_states.flatten(start_dim = 0, end_dim = 1)
|
||||
max_seqlen = attn_params.get_cu_seqlens_max()
|
||||
cu_seqlens = attn_params.get_cu_seqlens(self.device_idx)
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q_states,
|
||||
k_states,
|
||||
v_states,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
max_seqlen
|
||||
)
|
||||
else:
|
||||
attn_output = flash_attn_func(
|
||||
q_states,
|
||||
k_states,
|
||||
v_states,
|
||||
causal = causal,
|
||||
softmax_scale = self.scaling,
|
||||
**flash_kwargs
|
||||
)
|
||||
attn_output = attn_output.reshape((batch_size, q_len, self.num_attention_heads * self.head_dim))
|
||||
return attn_output
|
||||
|
||||
|
||||
@@ -21,6 +21,10 @@ class Params:
|
||||
alt_rope_embed_dict: dict | None
|
||||
rope_offsets: torch.Tensor | None
|
||||
non_causal_attn: bool
|
||||
block_diag_layers: set
|
||||
block_diag_mask: torch.Tensor | None
|
||||
cu_seqlens: torch.Tensor | None
|
||||
cu_seqlens_max: int | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -66,6 +70,11 @@ class Params:
|
||||
self.past_len_tp = None
|
||||
self.paged = paged
|
||||
|
||||
self.block_diag_layers = set()
|
||||
self.block_diag_mask = None
|
||||
self.cu_seqlens = None
|
||||
self.cu_seqlens_max = None
|
||||
|
||||
def is_causal(self) -> bool:
|
||||
return self.input_mask is None
|
||||
|
||||
@@ -164,6 +173,31 @@ class Params:
|
||||
self.rope_offsets = safe_move_tensor(self.rope_offsets, device_idx, non_blocking = True)
|
||||
return self.rope_offsets
|
||||
|
||||
def get_cu_seqlens(self, device: int) -> torch.Tensor | None:
|
||||
if self.cu_seqlens is None:
|
||||
return None
|
||||
if self.cu_seqlens.device.index != device:
|
||||
self.cu_seqlens = safe_move_tensor(self.cu_seqlens, device, non_blocking = True)
|
||||
return self.cu_seqlens
|
||||
|
||||
def get_cu_seqlens_max(self) -> torch.Tensor | None:
|
||||
assert self.cu_seqlens is not None
|
||||
if self.cu_seqlens_max is not None:
|
||||
return self.cu_seqlens_max
|
||||
self.cu_seqlens_max = (self.cu_seqlens[1:] - self.cu_seqlens[:-1]).max().item()
|
||||
return self.cu_seqlens_max
|
||||
|
||||
def get_block_diag_mask(self, device: int) -> torch.Tensor | None:
|
||||
if self.block_diag_mask is None:
|
||||
csl = self.get_cu_seqlens(device)
|
||||
if csl is None:
|
||||
return None
|
||||
positions = torch.arange(csl[-1], device = csl.device)
|
||||
labels = torch.searchsorted(csl[1:], positions, right = True)
|
||||
self.block_diag_mask = labels.unsqueeze(0) == labels.unsqueeze(1).repeat(self.batch_size)
|
||||
if self.block_diag_mask.device.index != device:
|
||||
self.block_diag_mask = safe_move_tensor(self.block_diag_mask, device, non_blocking = True)
|
||||
return self.block_diag_mask
|
||||
|
||||
|
||||
class PagedParams(Params):
|
||||
|
||||
@@ -135,6 +135,7 @@ class ExLlamaV2Config:
|
||||
vision_num_key_value_groups: int | None
|
||||
vision_hidden_size: int | None
|
||||
vision_intermediate_size: int | None
|
||||
vision_merger_intermediate_size: int | None
|
||||
vision_hidden_act: str | None
|
||||
vision_rope_theta: float | None
|
||||
vision_feature_layer: int | None
|
||||
@@ -152,6 +153,8 @@ class ExLlamaV2Config:
|
||||
vision_max_pixels: int | None
|
||||
vision_temporal_patch_size: int | None
|
||||
vision_max_size: int | None
|
||||
vision_fullatt_block_indexes: list | None
|
||||
vision_window_size: int | None
|
||||
|
||||
# Deprecated fields, kept for compatibiltiy
|
||||
|
||||
@@ -478,6 +481,8 @@ class ExLlamaV2Config:
|
||||
|
||||
# TODO: Cleanup & refactor
|
||||
|
||||
self.vision_fullatt_block_indexes = None
|
||||
|
||||
if self.vision_model_type is None:
|
||||
pass
|
||||
|
||||
@@ -495,6 +500,7 @@ class ExLlamaV2Config:
|
||||
self.vision_feature_layer = read(read_config, int, ["vision_feature_layer"], no_default)
|
||||
self.vision_num_layers = read(read_config, int, ["vision_config->num_hidden_layers"], 24)
|
||||
self.vision_intermediate_size = read(read_config, int, ["vision_config->intermediate_size"], self.hidden_size)
|
||||
self.vision_merger_intermediate_size = self.vision_intermediate_size
|
||||
|
||||
image_processor_type = read(read_prep_config, str, ["image_processor_type"], no_default)
|
||||
assert image_processor_type == "PixtralImageProcessor", \
|
||||
@@ -511,10 +517,27 @@ class ExLlamaV2Config:
|
||||
self.vision_spatial_merge_size = 1
|
||||
self.vision_max_size = 16384
|
||||
|
||||
elif self.vision_model_type == "qwen2":
|
||||
elif self.vision_model_type in ["qwen2", "qwen2.5"]:
|
||||
image_processor_type = read(read_prep_config, str, ["image_processor_type"], no_default)
|
||||
if self.vision_model_type == "qwen2":
|
||||
self.vision_hidden_size = read(read_config, int, ["vision_config->embed_dim"], no_default)
|
||||
mlp_ratio = read(read_config, int, ["vision_config->mlp_ratio"], None)
|
||||
self.vision_intermediate_size = self.vision_hidden_size * mlp_ratio
|
||||
self.vision_merger_intermediate_size = self.vision_intermediate_size
|
||||
assert image_processor_type == "Qwen2VLImageProcessor", \
|
||||
f"Wrong image processor type: {image_processor_type}"
|
||||
self.vision_window_size = None
|
||||
elif self.vision_model_type == "qwen2.5":
|
||||
self.vision_hidden_size = read(read_config, int, ["vision_config->hidden_size"], no_default)
|
||||
self.vision_intermediate_size = read(read_config, int, ["vision_config->intermediate_size"], no_default)
|
||||
self.vision_fullatt_block_indexes = read(read_config, list, ["vision_config->fullatt_block_indexes", None])
|
||||
self.vision_window_size = read(read_config, int, ["vision_config->window_size", None])
|
||||
assert image_processor_type == "Qwen2_5_VLImageProcessor", \
|
||||
f"Wrong image processor type: {image_processor_type}"
|
||||
self.vision_merger_intermediate_size = 5120 # TODO: This doesn't seem to appear in the config anywhere?
|
||||
|
||||
self.vision_num_attention_heads = read(read_config, int, ["vision_config->num_heads"], no_default)
|
||||
self.vision_num_key_value_heads = self.vision_num_attention_heads
|
||||
self.vision_hidden_size = read(read_config, int, ["vision_config->embed_dim"], no_default)
|
||||
self.vision_head_dim = self.vision_hidden_size // self.vision_num_attention_heads
|
||||
self.vision_num_key_value_groups = 1
|
||||
self.vision_hidden_act = "quickgelu"
|
||||
@@ -523,12 +546,7 @@ class ExLlamaV2Config:
|
||||
patch_size = read(read_config, int, ["vision_config->patch_size"], no_default)
|
||||
self.vision_rope_theta = read(read_config, int, ["vision_config->rope_theta"], 10000.0)
|
||||
self.vision_num_layers = read(read_config, int, ["vision_config->depth"], no_default)
|
||||
mlp_ratio = read(read_config, int, ["vision_config->mlp_ratio"], no_default)
|
||||
self.vision_intermediate_size = self.vision_hidden_size * mlp_ratio
|
||||
|
||||
image_processor_type = read(read_prep_config, str, ["image_processor_type"], no_default)
|
||||
assert image_processor_type == "Qwen2VLImageProcessor", \
|
||||
f"Wrong image processor type: {image_processor_type}"
|
||||
self.vision_image_mean = read(read_prep_config, list, ["image_mean"], no_default)
|
||||
self.vision_image_std = read(read_prep_config, list, ["image_std"], no_default)
|
||||
assert read(read_prep_config, int, ["patch_size"], no_default) == patch_size, \
|
||||
|
||||
@@ -51,6 +51,7 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
out_features: int | None = None,
|
||||
interm_features: int | None = None,
|
||||
merge: int | None = None,
|
||||
pad32: bool = True,
|
||||
):
|
||||
super().__init__(model, key, archparams)
|
||||
cfg = self.model.config
|
||||
@@ -98,8 +99,8 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
self.pre_layernorm = None
|
||||
self.post_layernorm = None
|
||||
|
||||
self.up_proj = ExLlamaV2Linear(model, key + km["mlp_up"], in_features, interm_features, ap.mlp_bias, f_key = f_key, f_beg = f_b, f_end = f_c)
|
||||
self.down_proj = ExLlamaV2Linear(model, key + km["mlp_down"], interm_features, out_features, ap.mlp_bias, prescale = cfg.scale_depth)
|
||||
self.up_proj = ExLlamaV2Linear(model, key + km["mlp_up"], in_features, interm_features, ap.mlp_bias, f_key = f_key, f_beg = f_b, f_end = f_c, pad32 = pad32)
|
||||
self.down_proj = ExLlamaV2Linear(model, key + km["mlp_down"], interm_features, out_features, ap.mlp_bias, prescale = cfg.scale_depth, pad32 = pad32)
|
||||
|
||||
self.submodules = [self.up_proj,
|
||||
self.down_proj]
|
||||
@@ -109,7 +110,7 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
self.submodules += [self.post_layernorm]
|
||||
|
||||
if ap.mlp_gate:
|
||||
self.gate_proj = ExLlamaV2Linear(model, key + km["mlp_gate"], in_features, interm_features, ap.mlp_bias, f_key = f_key, f_beg = f_a, f_end = f_b)
|
||||
self.gate_proj = ExLlamaV2Linear(model, key + km["mlp_gate"], in_features, interm_features, ap.mlp_bias, f_key = f_key, f_beg = f_a, f_end = f_b, pad32 = pad32)
|
||||
self.submodules += [self.gate_proj]
|
||||
else:
|
||||
self.gate_proj = None
|
||||
|
||||
@@ -44,7 +44,7 @@ def preprocess(
|
||||
|
||||
image = image.transpose(2, 0, 1)
|
||||
image = torch.from_numpy(image).half()
|
||||
return image, new_size
|
||||
return image, new_size, None, None
|
||||
|
||||
def postprocess(
|
||||
model: ExLlamaV2,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
||||
@@ -86,7 +87,7 @@ def preprocess(
|
||||
|
||||
if mode == "image":
|
||||
image = torch.from_numpy(flatten_patches).half()
|
||||
return image, new_size
|
||||
return image, new_size, (grid_t, grid_h, grid_w), config.vision_spatial_patch_size ** 2
|
||||
else:
|
||||
video = torch.from_numpy(flatten_patches).half()
|
||||
return video, new_size, (grid_t, grid_h, grid_w), config.vision_spatial_patch_size ** 2
|
||||
@@ -149,4 +150,51 @@ def position_embeddings(
|
||||
cos = cos.unsqueeze(1).repeat(1, 1, 2).contiguous()
|
||||
sin = sin.unsqueeze(1).repeat(1, 1, 2).contiguous()
|
||||
|
||||
return sin, cos
|
||||
return sin, cos
|
||||
|
||||
|
||||
def get_window_index(grid_thw, config: ExLlamaV2Config):
|
||||
|
||||
window_index: list = []
|
||||
cu_window_seqlens: list = [0]
|
||||
window_index_id = 0
|
||||
vit_merger_window_size = (
|
||||
config.vision_window_size //
|
||||
config.vision_spatial_merge_size //
|
||||
config.vision_patch_size["height"]
|
||||
)
|
||||
|
||||
for grid_t, grid_h, grid_w in grid_thw:
|
||||
llm_grid_h, llm_grid_w = (
|
||||
grid_h // config.vision_spatial_merge_size,
|
||||
grid_w // config.vision_spatial_merge_size,
|
||||
)
|
||||
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
|
||||
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
||||
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
||||
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
||||
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
||||
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
||||
index_padded = index_padded.reshape(
|
||||
grid_t,
|
||||
num_windows_h,
|
||||
vit_merger_window_size,
|
||||
num_windows_w,
|
||||
vit_merger_window_size,
|
||||
)
|
||||
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
||||
grid_t,
|
||||
num_windows_h * num_windows_w,
|
||||
vit_merger_window_size,
|
||||
vit_merger_window_size,
|
||||
)
|
||||
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
||||
index_padded = index_padded.reshape(-1)
|
||||
index_new = index_padded[index_padded != -100]
|
||||
window_index.append(index_new + window_index_id)
|
||||
cu_seqlens_tmp = seqlens.cumsum(0) * config.vision_spatial_merge_size**2 + cu_window_seqlens[-1]
|
||||
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||
|
||||
window_index = torch.cat(window_index, dim =0)
|
||||
return window_index, cu_window_seqlens
|
||||
@@ -4,6 +4,7 @@ import os, sys
|
||||
import threading
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
||||
from exllamav2.conv import ExLlamaV2Conv
|
||||
from exllamav2.rmsnorm import ExLlamaV2RMSNorm
|
||||
@@ -44,7 +45,7 @@ class ExLlamaV2VisionTower(ExLlamaV2):
|
||||
self.postprocess_func = pixtral.postprocess
|
||||
self.video_preprocess_func = None
|
||||
self.video_postprocess_func = None
|
||||
elif cfg.vision_model_type == "qwen2":
|
||||
elif cfg.vision_model_type in ["qwen2", "qwen2.5"]:
|
||||
self.preprocess_func = qwen2.preprocess
|
||||
self.postprocess_func = qwen2.postprocess
|
||||
self.video_preprocess_func = qwen2.preprocess
|
||||
@@ -76,7 +77,7 @@ class ExLlamaV2VisionTower(ExLlamaV2):
|
||||
|
||||
self.position_emb_func = pixtral.position_embeddings
|
||||
|
||||
elif cfg.vision_model_type == "qwen2":
|
||||
elif cfg.vision_model_type in ["qwen2", "qwen2.5"]:
|
||||
self.p_maxedge = cfg.vision_max_size
|
||||
dim = cfg.vision_head_dim // 2
|
||||
max_seqlen = int(math.ceil(cfg.vision_max_size / cfg.vision_spatial_patch_size))
|
||||
@@ -130,7 +131,7 @@ class ExLlamaV2VisionTower(ExLlamaV2):
|
||||
for layer_idx in range(self.config.vision_num_layers):
|
||||
layer_key = cfg.arch.vt_prefix + km["layers"] + f".{layer_idx}"
|
||||
attn = ExLlamaV2Attention(self, layer_key, layer_idx, archparams = self.archparams)
|
||||
mlp = ExLlamaV2MLP(self, layer_key, layer_idx, archparams = self.archparams)
|
||||
mlp = ExLlamaV2MLP(self, layer_key, layer_idx, archparams = self.archparams, pad32 = False)
|
||||
self.modules += [attn, mlp]
|
||||
|
||||
# Multimodal projection
|
||||
@@ -143,10 +144,11 @@ class ExLlamaV2VisionTower(ExLlamaV2):
|
||||
archparams = cfg.arch.mmp,
|
||||
in_features = cfg.vision_hidden_size * merge,
|
||||
out_features = cfg.hidden_size,
|
||||
interm_features = cfg.vision_intermediate_size,
|
||||
interm_features = cfg.vision_merger_intermediate_size,
|
||||
has_norm = True,
|
||||
has_residual = False,
|
||||
merge = merge,
|
||||
pad32 = False,
|
||||
)
|
||||
self.modules += [mmp]
|
||||
|
||||
@@ -199,6 +201,19 @@ class ExLlamaV2VisionTower(ExLlamaV2):
|
||||
)
|
||||
|
||||
attn_params = ExLlamaV2Attention.Params(non_causal_attn = True)
|
||||
if self.config.vision_window_size:
|
||||
attn_params.block_diag_layers = set([
|
||||
l for l in range(self.config.vision_num_layers)
|
||||
if l not in self.config.vision_fullatt_block_indexes
|
||||
])
|
||||
|
||||
thw_grid_t = torch.tensor(list(thw_grid), dtype = torch.int).unsqueeze(0)
|
||||
window_index, csl = qwen2.get_window_index(thw_grid_t, self.config)
|
||||
csl = torch.tensor(csl, device = hidden_states.device, dtype = torch.int)
|
||||
csl = torch.unique_consecutive(csl)
|
||||
attn_params.cu_seqlens = csl
|
||||
else:
|
||||
window_index = None
|
||||
|
||||
device = self.modules[0].device_idx
|
||||
for idx, module in enumerate(self.modules):
|
||||
@@ -230,13 +245,33 @@ class ExLlamaV2VisionTower(ExLlamaV2):
|
||||
hidden_states,
|
||||
attn_params = attn_params,
|
||||
**kwargs | {
|
||||
"alt_rope_embedding": (cos, sin)
|
||||
"alt_rope_embedding": (cos, sin),
|
||||
}
|
||||
)
|
||||
|
||||
if thw_grid is not None and isinstance(module, ExLlamaV2Attention):
|
||||
hidden_states = hidden_states.view(pa_shape)
|
||||
|
||||
if window_index is not None and idx == 0:
|
||||
sh = hidden_states.shape
|
||||
unit = self.config.vision_spatial_merge_size ** 2
|
||||
seq_len = hidden_states.shape[0] * hidden_states.shape[1]
|
||||
hidden_states = hidden_states.reshape(seq_len // unit, unit, -1)
|
||||
hidden_states = hidden_states[window_index, :, :]
|
||||
hidden_states = hidden_states.reshape(sh)
|
||||
sh = sin.shape
|
||||
seq_len = sin.shape[0]
|
||||
sin = sin.reshape(seq_len // unit, unit, -1)
|
||||
sin = sin[window_index[:seq_len // unit], :, :]
|
||||
sin = sin.reshape(sh)
|
||||
cos = cos.reshape(seq_len // unit, unit, -1)
|
||||
cos = cos[window_index[:seq_len // unit], :, :]
|
||||
cos = cos.reshape(sh)
|
||||
|
||||
if window_index is not None:
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
hidden_states = hidden_states[:, reverse_indices]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -278,13 +313,14 @@ class ExLlamaV2VisionTower(ExLlamaV2):
|
||||
assert all(s <= maxsize for s in original_size), \
|
||||
f"Input image exceeds maximum size of {maxsize} x {maxsize}"
|
||||
|
||||
image_tensor, prep_image_size = self.preprocess_func(self.config, image)
|
||||
image_tensor, prep_image_size, grid_thw, _ = self.preprocess_func(self.config, image)
|
||||
features_x = prep_image_size[0] // self.config.vision_patch_size["width"]
|
||||
features_y = prep_image_size[1] // self.config.vision_patch_size["height"]
|
||||
|
||||
embedding_tensor = self.process(
|
||||
image_tensor,
|
||||
(features_y, features_x)
|
||||
(features_y, features_x),
|
||||
thw_grid = grid_thw
|
||||
)
|
||||
|
||||
if embeddings_cpu:
|
||||
|
||||
Reference in New Issue
Block a user