Add support for Mistral 3.1 VLM

This commit is contained in:
turboderp
2025-04-18 22:47:47 +02:00
parent 68f7461985
commit 9244003a40
6 changed files with 91 additions and 13 deletions

View File

@@ -26,6 +26,8 @@ torch.set_printoptions(precision = 5, sci_mode = False, linewidth=200)
# Pixtral:
# https://huggingface.co/mistral-community/pixtral-12b/
# https://huggingface.co/turboderp/pixtral-12b-exl2
# Mistral-Small 3.1:
# https://huggingface.co/prince-canuma/Mistral-Small-3.1-24B-Instruct-2503
# Qwen2-VL:
# https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct
# https://huggingface.co/turboderp/Qwen2-VL-7B-Instruct-exl2
@@ -34,8 +36,9 @@ torch.set_printoptions(precision = 5, sci_mode = False, linewidth=200)
# https://huggingface.co/turboderp/gemma-3-27b-it-exl2
# mode = "pixtral"
mode = "mistral3"
# mode = "qwen2"
mode = "gemma3"
# mode = "gemma3"
streaming = True
greedy = True
@@ -43,9 +46,11 @@ greedy = True
if mode == "pixtral":
model_directory = "/mnt/str/models/pixtral-12b-exl2/6.0bpw"
elif mode == "qwen2":
model_directory = "/mnt/str/models/qwen2.5-vl-7b-instruct-exl2/6.0bpw"
model_directory = "/mnt/str/models/qwen2.5-vl-7b-instruct-exl2/5.0bpw"
elif mode == "gemma3":
model_directory = "/mnt/str/models/gemma3-27b-it-exl2/5.0bpw"
model_directory = "/mnt/str/models/gemma3-12b-it-exl2/6.0bpw"
elif mode == "mistral3":
model_directory = "/mnt/str/models/mistral-small-3.1-24b-instruct/exl2/4.5bpw"
images = [
# {"file": "media/test_image_1.jpg"},
@@ -62,7 +67,7 @@ instruction = "Describe the image."
# Initialize model
config = ExLlamaV2Config(model_directory)
config.max_seq_len = 16384 # Pixtral default is 1M
config.max_seq_len = 8192 # Pixtral default is 1M
# Load vision model and multimodal projector and initialize preprocessor
@@ -72,8 +77,8 @@ vision_model.load(progress = True)
# Load EXL2 model
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, lazy = True, max_seq_len = 16384)
model.load_autosplit(cache, progress = True)
cache = ExLlamaV2Cache(model, max_seq_len = 8192, lazy = True)
model.load_autosplit(progress = True, cache = cache)
tokenizer = ExLlamaV2Tokenizer(config)
# Create generator
@@ -121,7 +126,7 @@ placeholders = "\n".join([ie.text_alias for ie in image_embeddings]) + "\n"
# Image token IDs are assigned sequentially, however, so two ExLlamaV2Embedding objects created from the same
# source image will not be recognized as the same image for purposes of prompt caching etc.
if mode == "pixtral":
if mode in ["pixtral", "mistral3"]:
prompt = (
"[INST]" +
placeholders +

View File

@@ -241,6 +241,8 @@ class ExLlamaV2ArchParams:
rope_freq_half: bool = False
learned_emb: bool = False
output_norm: bool = False
mlp_merger: bool = False
mlp_patch_merger: bool = False
# Component models
self.lm_prefix = ""
@@ -340,6 +342,52 @@ class ExLlamaV2ArchParams:
self.mmp.mlp_act_func = "gelu"
self.mmp.mlp_bias = bool(read_config.get("multimodal_projector_bias", True))
# Mistral 3 multimodal
if (
arch_string == "Mistral3ForConditionalGeneration" and
"vision_config" in read_config and
read_config["vision_config"].get("model_type") == "pixtral"
):
arch_recognized = True
self.lm_prefix = "language_model."
self.lm.layer_keys += \
layer_keys_llama_norms + \
layer_keys_llama_attn + \
layer_keys_llama_mlp
self.lm.expect_keys += \
expect_keys_llama
self.vt_prefix = "vision_tower."
self.vt.keys.update({
"attn_q": ".attention.q_proj",
"attn_k": ".attention.k_proj",
"attn_v": ".attention.v_proj",
"attn_o": ".attention.o_proj",
"mlp_gate": ".feed_forward.gate_proj",
"mlp_up": ".feed_forward.up_proj",
"mlp_down": ".feed_forward.down_proj",
"norm_1": ".attention_norm",
"norm_2": ".ffn_norm",
"layers": "transformer.layers",
"ln_pre": "ln_pre",
})
self.vt.mlp_merger = True
self.vt.mlp_patch_merger = True
self.mmp_prefix = "multi_modal_projector."
self.mmp.keys.update({
"norm_2": "norm",
"mlp_gate": None,
"mlp_up": "linear_1",
"mlp_down": "linear_2",
"patch_merger": "patch_merger.merging_layer",
})
self.mmp.mlp_patch_merger = True
self.mmp.mlp_gate = False
self.mmp.mlp_act_func = "gelu"
self.mmp.mlp_bias = bool(read_config.get("multimodal_projector_bias", True))
# Yi
if arch_string == "YiForCausalLM":

View File

@@ -552,18 +552,20 @@ class ExLlamaV2Config:
self.vision_merger_intermediate_size = self.vision_intermediate_size
image_processor_type = read(read_prep_config, str, ["image_processor_type"], no_default)
assert image_processor_type == "PixtralImageProcessor", \
assert image_processor_type == "PixtralImageProcessor" or image_processor_type == "PixtralImageProcessorFast", \
f"Wrong image processor type: {image_processor_type}"
self.vision_image_mean = read(read_prep_config, list, ["image_mean"], no_default)
self.vision_image_std = read(read_prep_config, list, ["image_std"], no_default)
self.vision_patch_size = read(read_prep_config, dict, ["patch_size"], no_default)
self.vision_patch_size = read(read_prep_config, object, ["patch_size"], no_default)
if isinstance(self.vision_patch_size, int):
self.vision_patch_size = {"width": self.vision_patch_size, "height": self.vision_patch_size}
assert all(self.vision_patch_size.get(x) == patch_size for x in ["width", "height"]), \
"Patch size inconsistency between config.json and preprocessor_config.json"
self.vision_resample = read(read_prep_config, int, ["resample"], no_default)
self.vision_rescale_factor = read(read_prep_config, float, ["rescale_factor"], no_default)
self.vision_size = read(read_prep_config, dict, ["size"], no_default)
self.vision_num_channels = 3
self.vision_spatial_merge_size = 1
self.vision_spatial_merge_size = read(read_config, int, ["spatial_merge_size"], 1)
self.vision_max_size = 16384
self.vision_window_size = None

View File

@@ -115,6 +115,11 @@ class ExLlamaV2MLP(ExLlamaV2Module):
else:
self.gate_proj = None
if merge and ap.mlp_patch_merger:
self.patch_merger_proj = ExLlamaV2Linear(model, key + km["patch_merger"], in_features * merge**2, in_features, ap.mlp_bias)
self.submodules += [self.patch_merger_proj]
else:
self.patch_merger_proj = None
def numel(self) -> int:
@@ -158,6 +163,9 @@ class ExLlamaV2MLP(ExLlamaV2Module):
if self.gate_proj is not None: self.gate_proj.load(device_context = device_context, output_map = down_map)
self.up_proj.load(device_context = device_context, output_map = down_map)
if self.patch_merger_proj is not None:
self.patch_merger_proj.load()
if self.up_proj.is_quant():
assert self.gate_proj is None or self.gate_proj.is_quant()
assert self.up_proj.is_quant(), "Partially quantized MLP layer"
@@ -302,6 +310,8 @@ class ExLlamaV2MLP(ExLlamaV2Module):
if self.gate_proj is not None: self.gate_proj.set_device_idx(idx)
self.up_proj.set_device_idx(idx)
self.down_proj.set_device_idx(idx)
if self.patch_merger_proj is not None:
self.patch_merger_proj.set_device_idx(idx)
# @profile
@@ -458,9 +468,18 @@ class ExLlamaV2MLP(ExLlamaV2Module):
if self.pre_layernorm else hidden_states
if self.merge:
bd = post_norm.shape[:-2]
l, d = post_norm.shape[-2:]
post_norm = post_norm.view(*bd, l // self.merge, d * self.merge)
if self.archparams.mlp_patch_merger:
bsz = hidden_states.shape[0]
assert bsz == 1
(h, w), d = kwargs["patch_size"], hidden_states.shape[-1]
image_grid = post_norm.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
grid = F.unfold(image_grid, kernel_size = int(self.merge ** 0.5), stride = int(self.merge ** 0.5))
grid = grid.view(bsz, d * self.merge, -1).transpose(1, 2)
post_norm = self.patch_merger_proj.forward(grid)
else:
bd = post_norm.shape[:-2]
l, d = post_norm.shape[-2:]
post_norm = post_norm.view(*bd, l // self.merge, d * self.merge)
if self.gate_proj is not None:
gate = self.gate_proj.forward(post_norm, loras = loras)

View File

@@ -57,6 +57,9 @@ def postprocess(
Insert [IMG_BREAK] and [IMG_END] tokens in image feature embeddings
"""
features_x //= model.config.vision_spatial_merge_size
features_y //= model.config.vision_spatial_merge_size
assert embeddings.shape[0] == features_y * features_x, \
"Invalid shape for embeddings"

View File

@@ -309,6 +309,7 @@ class ExLlamaV2VisionTower(ExLlamaV2):
attn_params = attn_params,
**kwargs | ({
"alt_rope_embedding": (cos, sin),
"patch_size": (p_height, p_width),
} if cos is not None else {})
)