mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-03-15 00:07:26 +00:00
Add support for Mistral 3.1 VLM
This commit is contained in:
@@ -26,6 +26,8 @@ torch.set_printoptions(precision = 5, sci_mode = False, linewidth=200)
|
||||
# Pixtral:
|
||||
# https://huggingface.co/mistral-community/pixtral-12b/
|
||||
# https://huggingface.co/turboderp/pixtral-12b-exl2
|
||||
# Mistral-Small 3.1:
|
||||
# https://huggingface.co/prince-canuma/Mistral-Small-3.1-24B-Instruct-2503
|
||||
# Qwen2-VL:
|
||||
# https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct
|
||||
# https://huggingface.co/turboderp/Qwen2-VL-7B-Instruct-exl2
|
||||
@@ -34,8 +36,9 @@ torch.set_printoptions(precision = 5, sci_mode = False, linewidth=200)
|
||||
# https://huggingface.co/turboderp/gemma-3-27b-it-exl2
|
||||
|
||||
# mode = "pixtral"
|
||||
mode = "mistral3"
|
||||
# mode = "qwen2"
|
||||
mode = "gemma3"
|
||||
# mode = "gemma3"
|
||||
|
||||
streaming = True
|
||||
greedy = True
|
||||
@@ -43,9 +46,11 @@ greedy = True
|
||||
if mode == "pixtral":
|
||||
model_directory = "/mnt/str/models/pixtral-12b-exl2/6.0bpw"
|
||||
elif mode == "qwen2":
|
||||
model_directory = "/mnt/str/models/qwen2.5-vl-7b-instruct-exl2/6.0bpw"
|
||||
model_directory = "/mnt/str/models/qwen2.5-vl-7b-instruct-exl2/5.0bpw"
|
||||
elif mode == "gemma3":
|
||||
model_directory = "/mnt/str/models/gemma3-27b-it-exl2/5.0bpw"
|
||||
model_directory = "/mnt/str/models/gemma3-12b-it-exl2/6.0bpw"
|
||||
elif mode == "mistral3":
|
||||
model_directory = "/mnt/str/models/mistral-small-3.1-24b-instruct/exl2/4.5bpw"
|
||||
|
||||
images = [
|
||||
# {"file": "media/test_image_1.jpg"},
|
||||
@@ -62,7 +67,7 @@ instruction = "Describe the image."
|
||||
# Initialize model
|
||||
|
||||
config = ExLlamaV2Config(model_directory)
|
||||
config.max_seq_len = 16384 # Pixtral default is 1M
|
||||
config.max_seq_len = 8192 # Pixtral default is 1M
|
||||
|
||||
# Load vision model and multimodal projector and initialize preprocessor
|
||||
|
||||
@@ -72,8 +77,8 @@ vision_model.load(progress = True)
|
||||
# Load EXL2 model
|
||||
|
||||
model = ExLlamaV2(config)
|
||||
cache = ExLlamaV2Cache(model, lazy = True, max_seq_len = 16384)
|
||||
model.load_autosplit(cache, progress = True)
|
||||
cache = ExLlamaV2Cache(model, max_seq_len = 8192, lazy = True)
|
||||
model.load_autosplit(progress = True, cache = cache)
|
||||
tokenizer = ExLlamaV2Tokenizer(config)
|
||||
|
||||
# Create generator
|
||||
@@ -121,7 +126,7 @@ placeholders = "\n".join([ie.text_alias for ie in image_embeddings]) + "\n"
|
||||
# Image token IDs are assigned sequentially, however, so two ExLlamaV2Embedding objects created from the same
|
||||
# source image will not be recognized as the same image for purposes of prompt caching etc.
|
||||
|
||||
if mode == "pixtral":
|
||||
if mode in ["pixtral", "mistral3"]:
|
||||
prompt = (
|
||||
"[INST]" +
|
||||
placeholders +
|
||||
|
||||
@@ -241,6 +241,8 @@ class ExLlamaV2ArchParams:
|
||||
rope_freq_half: bool = False
|
||||
learned_emb: bool = False
|
||||
output_norm: bool = False
|
||||
mlp_merger: bool = False
|
||||
mlp_patch_merger: bool = False
|
||||
|
||||
# Component models
|
||||
self.lm_prefix = ""
|
||||
@@ -340,6 +342,52 @@ class ExLlamaV2ArchParams:
|
||||
self.mmp.mlp_act_func = "gelu"
|
||||
self.mmp.mlp_bias = bool(read_config.get("multimodal_projector_bias", True))
|
||||
|
||||
# Mistral 3 multimodal
|
||||
|
||||
if (
|
||||
arch_string == "Mistral3ForConditionalGeneration" and
|
||||
"vision_config" in read_config and
|
||||
read_config["vision_config"].get("model_type") == "pixtral"
|
||||
):
|
||||
arch_recognized = True
|
||||
self.lm_prefix = "language_model."
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_llama_norms + \
|
||||
layer_keys_llama_attn + \
|
||||
layer_keys_llama_mlp
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_llama
|
||||
|
||||
self.vt_prefix = "vision_tower."
|
||||
self.vt.keys.update({
|
||||
"attn_q": ".attention.q_proj",
|
||||
"attn_k": ".attention.k_proj",
|
||||
"attn_v": ".attention.v_proj",
|
||||
"attn_o": ".attention.o_proj",
|
||||
"mlp_gate": ".feed_forward.gate_proj",
|
||||
"mlp_up": ".feed_forward.up_proj",
|
||||
"mlp_down": ".feed_forward.down_proj",
|
||||
"norm_1": ".attention_norm",
|
||||
"norm_2": ".ffn_norm",
|
||||
"layers": "transformer.layers",
|
||||
"ln_pre": "ln_pre",
|
||||
})
|
||||
self.vt.mlp_merger = True
|
||||
self.vt.mlp_patch_merger = True
|
||||
|
||||
self.mmp_prefix = "multi_modal_projector."
|
||||
self.mmp.keys.update({
|
||||
"norm_2": "norm",
|
||||
"mlp_gate": None,
|
||||
"mlp_up": "linear_1",
|
||||
"mlp_down": "linear_2",
|
||||
"patch_merger": "patch_merger.merging_layer",
|
||||
})
|
||||
self.mmp.mlp_patch_merger = True
|
||||
self.mmp.mlp_gate = False
|
||||
self.mmp.mlp_act_func = "gelu"
|
||||
self.mmp.mlp_bias = bool(read_config.get("multimodal_projector_bias", True))
|
||||
|
||||
# Yi
|
||||
|
||||
if arch_string == "YiForCausalLM":
|
||||
|
||||
@@ -552,18 +552,20 @@ class ExLlamaV2Config:
|
||||
self.vision_merger_intermediate_size = self.vision_intermediate_size
|
||||
|
||||
image_processor_type = read(read_prep_config, str, ["image_processor_type"], no_default)
|
||||
assert image_processor_type == "PixtralImageProcessor", \
|
||||
assert image_processor_type == "PixtralImageProcessor" or image_processor_type == "PixtralImageProcessorFast", \
|
||||
f"Wrong image processor type: {image_processor_type}"
|
||||
self.vision_image_mean = read(read_prep_config, list, ["image_mean"], no_default)
|
||||
self.vision_image_std = read(read_prep_config, list, ["image_std"], no_default)
|
||||
self.vision_patch_size = read(read_prep_config, dict, ["patch_size"], no_default)
|
||||
self.vision_patch_size = read(read_prep_config, object, ["patch_size"], no_default)
|
||||
if isinstance(self.vision_patch_size, int):
|
||||
self.vision_patch_size = {"width": self.vision_patch_size, "height": self.vision_patch_size}
|
||||
assert all(self.vision_patch_size.get(x) == patch_size for x in ["width", "height"]), \
|
||||
"Patch size inconsistency between config.json and preprocessor_config.json"
|
||||
self.vision_resample = read(read_prep_config, int, ["resample"], no_default)
|
||||
self.vision_rescale_factor = read(read_prep_config, float, ["rescale_factor"], no_default)
|
||||
self.vision_size = read(read_prep_config, dict, ["size"], no_default)
|
||||
self.vision_num_channels = 3
|
||||
self.vision_spatial_merge_size = 1
|
||||
self.vision_spatial_merge_size = read(read_config, int, ["spatial_merge_size"], 1)
|
||||
self.vision_max_size = 16384
|
||||
self.vision_window_size = None
|
||||
|
||||
|
||||
@@ -115,6 +115,11 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
else:
|
||||
self.gate_proj = None
|
||||
|
||||
if merge and ap.mlp_patch_merger:
|
||||
self.patch_merger_proj = ExLlamaV2Linear(model, key + km["patch_merger"], in_features * merge**2, in_features, ap.mlp_bias)
|
||||
self.submodules += [self.patch_merger_proj]
|
||||
else:
|
||||
self.patch_merger_proj = None
|
||||
|
||||
def numel(self) -> int:
|
||||
|
||||
@@ -158,6 +163,9 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
if self.gate_proj is not None: self.gate_proj.load(device_context = device_context, output_map = down_map)
|
||||
self.up_proj.load(device_context = device_context, output_map = down_map)
|
||||
|
||||
if self.patch_merger_proj is not None:
|
||||
self.patch_merger_proj.load()
|
||||
|
||||
if self.up_proj.is_quant():
|
||||
assert self.gate_proj is None or self.gate_proj.is_quant()
|
||||
assert self.up_proj.is_quant(), "Partially quantized MLP layer"
|
||||
@@ -302,6 +310,8 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
if self.gate_proj is not None: self.gate_proj.set_device_idx(idx)
|
||||
self.up_proj.set_device_idx(idx)
|
||||
self.down_proj.set_device_idx(idx)
|
||||
if self.patch_merger_proj is not None:
|
||||
self.patch_merger_proj.set_device_idx(idx)
|
||||
|
||||
|
||||
# @profile
|
||||
@@ -458,9 +468,18 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
if self.pre_layernorm else hidden_states
|
||||
|
||||
if self.merge:
|
||||
bd = post_norm.shape[:-2]
|
||||
l, d = post_norm.shape[-2:]
|
||||
post_norm = post_norm.view(*bd, l // self.merge, d * self.merge)
|
||||
if self.archparams.mlp_patch_merger:
|
||||
bsz = hidden_states.shape[0]
|
||||
assert bsz == 1
|
||||
(h, w), d = kwargs["patch_size"], hidden_states.shape[-1]
|
||||
image_grid = post_norm.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
|
||||
grid = F.unfold(image_grid, kernel_size = int(self.merge ** 0.5), stride = int(self.merge ** 0.5))
|
||||
grid = grid.view(bsz, d * self.merge, -1).transpose(1, 2)
|
||||
post_norm = self.patch_merger_proj.forward(grid)
|
||||
else:
|
||||
bd = post_norm.shape[:-2]
|
||||
l, d = post_norm.shape[-2:]
|
||||
post_norm = post_norm.view(*bd, l // self.merge, d * self.merge)
|
||||
|
||||
if self.gate_proj is not None:
|
||||
gate = self.gate_proj.forward(post_norm, loras = loras)
|
||||
|
||||
@@ -57,6 +57,9 @@ def postprocess(
|
||||
Insert [IMG_BREAK] and [IMG_END] tokens in image feature embeddings
|
||||
"""
|
||||
|
||||
features_x //= model.config.vision_spatial_merge_size
|
||||
features_y //= model.config.vision_spatial_merge_size
|
||||
|
||||
assert embeddings.shape[0] == features_y * features_x, \
|
||||
"Invalid shape for embeddings"
|
||||
|
||||
|
||||
@@ -309,6 +309,7 @@ class ExLlamaV2VisionTower(ExLlamaV2):
|
||||
attn_params = attn_params,
|
||||
**kwargs | ({
|
||||
"alt_rope_embedding": (cos, sin),
|
||||
"patch_size": (p_height, p_width),
|
||||
} if cos is not None else {})
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user