diff --git a/examples/multimodal.py b/examples/multimodal.py index 2ab36db..f0ff5af 100644 --- a/examples/multimodal.py +++ b/examples/multimodal.py @@ -26,6 +26,8 @@ torch.set_printoptions(precision = 5, sci_mode = False, linewidth=200) # Pixtral: # https://huggingface.co/mistral-community/pixtral-12b/ # https://huggingface.co/turboderp/pixtral-12b-exl2 +# Mistral-Small 3.1: +# https://huggingface.co/prince-canuma/Mistral-Small-3.1-24B-Instruct-2503 # Qwen2-VL: # https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct # https://huggingface.co/turboderp/Qwen2-VL-7B-Instruct-exl2 @@ -34,8 +36,9 @@ torch.set_printoptions(precision = 5, sci_mode = False, linewidth=200) # https://huggingface.co/turboderp/gemma-3-27b-it-exl2 # mode = "pixtral" +mode = "mistral3" # mode = "qwen2" -mode = "gemma3" +# mode = "gemma3" streaming = True greedy = True @@ -43,9 +46,11 @@ greedy = True if mode == "pixtral": model_directory = "/mnt/str/models/pixtral-12b-exl2/6.0bpw" elif mode == "qwen2": - model_directory = "/mnt/str/models/qwen2.5-vl-7b-instruct-exl2/6.0bpw" + model_directory = "/mnt/str/models/qwen2.5-vl-7b-instruct-exl2/5.0bpw" elif mode == "gemma3": - model_directory = "/mnt/str/models/gemma3-27b-it-exl2/5.0bpw" + model_directory = "/mnt/str/models/gemma3-12b-it-exl2/6.0bpw" +elif mode == "mistral3": + model_directory = "/mnt/str/models/mistral-small-3.1-24b-instruct/exl2/4.5bpw" images = [ # {"file": "media/test_image_1.jpg"}, @@ -62,7 +67,7 @@ instruction = "Describe the image." # Initialize model config = ExLlamaV2Config(model_directory) -config.max_seq_len = 16384 # Pixtral default is 1M +config.max_seq_len = 8192 # Pixtral default is 1M # Load vision model and multimodal projector and initialize preprocessor @@ -72,8 +77,8 @@ vision_model.load(progress = True) # Load EXL2 model model = ExLlamaV2(config) -cache = ExLlamaV2Cache(model, lazy = True, max_seq_len = 16384) -model.load_autosplit(cache, progress = True) +cache = ExLlamaV2Cache(model, max_seq_len = 8192, lazy = True) +model.load_autosplit(progress = True, cache = cache) tokenizer = ExLlamaV2Tokenizer(config) # Create generator @@ -121,7 +126,7 @@ placeholders = "\n".join([ie.text_alias for ie in image_embeddings]) + "\n" # Image token IDs are assigned sequentially, however, so two ExLlamaV2Embedding objects created from the same # source image will not be recognized as the same image for purposes of prompt caching etc. -if mode == "pixtral": +if mode in ["pixtral", "mistral3"]: prompt = ( "[INST]" + placeholders + diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index 7780f80..300fe0c 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -241,6 +241,8 @@ class ExLlamaV2ArchParams: rope_freq_half: bool = False learned_emb: bool = False output_norm: bool = False + mlp_merger: bool = False + mlp_patch_merger: bool = False # Component models self.lm_prefix = "" @@ -340,6 +342,52 @@ class ExLlamaV2ArchParams: self.mmp.mlp_act_func = "gelu" self.mmp.mlp_bias = bool(read_config.get("multimodal_projector_bias", True)) + # Mistral 3 multimodal + + if ( + arch_string == "Mistral3ForConditionalGeneration" and + "vision_config" in read_config and + read_config["vision_config"].get("model_type") == "pixtral" + ): + arch_recognized = True + self.lm_prefix = "language_model." + self.lm.layer_keys += \ + layer_keys_llama_norms + \ + layer_keys_llama_attn + \ + layer_keys_llama_mlp + self.lm.expect_keys += \ + expect_keys_llama + + self.vt_prefix = "vision_tower." + self.vt.keys.update({ + "attn_q": ".attention.q_proj", + "attn_k": ".attention.k_proj", + "attn_v": ".attention.v_proj", + "attn_o": ".attention.o_proj", + "mlp_gate": ".feed_forward.gate_proj", + "mlp_up": ".feed_forward.up_proj", + "mlp_down": ".feed_forward.down_proj", + "norm_1": ".attention_norm", + "norm_2": ".ffn_norm", + "layers": "transformer.layers", + "ln_pre": "ln_pre", + }) + self.vt.mlp_merger = True + self.vt.mlp_patch_merger = True + + self.mmp_prefix = "multi_modal_projector." + self.mmp.keys.update({ + "norm_2": "norm", + "mlp_gate": None, + "mlp_up": "linear_1", + "mlp_down": "linear_2", + "patch_merger": "patch_merger.merging_layer", + }) + self.mmp.mlp_patch_merger = True + self.mmp.mlp_gate = False + self.mmp.mlp_act_func = "gelu" + self.mmp.mlp_bias = bool(read_config.get("multimodal_projector_bias", True)) + # Yi if arch_string == "YiForCausalLM": diff --git a/exllamav2/config.py b/exllamav2/config.py index 3ae2c80..2f83e27 100644 --- a/exllamav2/config.py +++ b/exllamav2/config.py @@ -552,18 +552,20 @@ class ExLlamaV2Config: self.vision_merger_intermediate_size = self.vision_intermediate_size image_processor_type = read(read_prep_config, str, ["image_processor_type"], no_default) - assert image_processor_type == "PixtralImageProcessor", \ + assert image_processor_type == "PixtralImageProcessor" or image_processor_type == "PixtralImageProcessorFast", \ f"Wrong image processor type: {image_processor_type}" self.vision_image_mean = read(read_prep_config, list, ["image_mean"], no_default) self.vision_image_std = read(read_prep_config, list, ["image_std"], no_default) - self.vision_patch_size = read(read_prep_config, dict, ["patch_size"], no_default) + self.vision_patch_size = read(read_prep_config, object, ["patch_size"], no_default) + if isinstance(self.vision_patch_size, int): + self.vision_patch_size = {"width": self.vision_patch_size, "height": self.vision_patch_size} assert all(self.vision_patch_size.get(x) == patch_size for x in ["width", "height"]), \ "Patch size inconsistency between config.json and preprocessor_config.json" self.vision_resample = read(read_prep_config, int, ["resample"], no_default) self.vision_rescale_factor = read(read_prep_config, float, ["rescale_factor"], no_default) self.vision_size = read(read_prep_config, dict, ["size"], no_default) self.vision_num_channels = 3 - self.vision_spatial_merge_size = 1 + self.vision_spatial_merge_size = read(read_config, int, ["spatial_merge_size"], 1) self.vision_max_size = 16384 self.vision_window_size = None diff --git a/exllamav2/mlp.py b/exllamav2/mlp.py index d46672e..cf04815 100644 --- a/exllamav2/mlp.py +++ b/exllamav2/mlp.py @@ -115,6 +115,11 @@ class ExLlamaV2MLP(ExLlamaV2Module): else: self.gate_proj = None + if merge and ap.mlp_patch_merger: + self.patch_merger_proj = ExLlamaV2Linear(model, key + km["patch_merger"], in_features * merge**2, in_features, ap.mlp_bias) + self.submodules += [self.patch_merger_proj] + else: + self.patch_merger_proj = None def numel(self) -> int: @@ -158,6 +163,9 @@ class ExLlamaV2MLP(ExLlamaV2Module): if self.gate_proj is not None: self.gate_proj.load(device_context = device_context, output_map = down_map) self.up_proj.load(device_context = device_context, output_map = down_map) + if self.patch_merger_proj is not None: + self.patch_merger_proj.load() + if self.up_proj.is_quant(): assert self.gate_proj is None or self.gate_proj.is_quant() assert self.up_proj.is_quant(), "Partially quantized MLP layer" @@ -302,6 +310,8 @@ class ExLlamaV2MLP(ExLlamaV2Module): if self.gate_proj is not None: self.gate_proj.set_device_idx(idx) self.up_proj.set_device_idx(idx) self.down_proj.set_device_idx(idx) + if self.patch_merger_proj is not None: + self.patch_merger_proj.set_device_idx(idx) # @profile @@ -458,9 +468,18 @@ class ExLlamaV2MLP(ExLlamaV2Module): if self.pre_layernorm else hidden_states if self.merge: - bd = post_norm.shape[:-2] - l, d = post_norm.shape[-2:] - post_norm = post_norm.view(*bd, l // self.merge, d * self.merge) + if self.archparams.mlp_patch_merger: + bsz = hidden_states.shape[0] + assert bsz == 1 + (h, w), d = kwargs["patch_size"], hidden_states.shape[-1] + image_grid = post_norm.view(h, w, d).permute(2, 0, 1).unsqueeze(0) + grid = F.unfold(image_grid, kernel_size = int(self.merge ** 0.5), stride = int(self.merge ** 0.5)) + grid = grid.view(bsz, d * self.merge, -1).transpose(1, 2) + post_norm = self.patch_merger_proj.forward(grid) + else: + bd = post_norm.shape[:-2] + l, d = post_norm.shape[-2:] + post_norm = post_norm.view(*bd, l // self.merge, d * self.merge) if self.gate_proj is not None: gate = self.gate_proj.forward(post_norm, loras = loras) diff --git a/exllamav2/vlm/processor/pixtral.py b/exllamav2/vlm/processor/pixtral.py index 24366a6..1e8b145 100644 --- a/exllamav2/vlm/processor/pixtral.py +++ b/exllamav2/vlm/processor/pixtral.py @@ -57,6 +57,9 @@ def postprocess( Insert [IMG_BREAK] and [IMG_END] tokens in image feature embeddings """ + features_x //= model.config.vision_spatial_merge_size + features_y //= model.config.vision_spatial_merge_size + assert embeddings.shape[0] == features_y * features_x, \ "Invalid shape for embeddings" diff --git a/exllamav2/vlm/vision_tower.py b/exllamav2/vlm/vision_tower.py index 99212ca..07f40dd 100644 --- a/exllamav2/vlm/vision_tower.py +++ b/exllamav2/vlm/vision_tower.py @@ -309,6 +309,7 @@ class ExLlamaV2VisionTower(ExLlamaV2): attn_params = attn_params, **kwargs | ({ "alt_rope_embedding": (cos, sin), + "patch_size": (p_height, p_width), } if cos is not None else {}) )