From f2a094d92d8b528fda0c67ddbb6c7af05f9109a1 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 25 Sep 2025 20:09:46 +0300 Subject: [PATCH] Add mtmd: add Qwen2-VL --- include/llama.h | 11 ++-- src/llama-arch.h | 3 + src/llama.cpp | 150 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 158 insertions(+), 6 deletions(-) diff --git a/include/llama.h b/include/llama.h index 8e6bfad2..4f9fc9c8 100644 --- a/include/llama.h +++ b/include/llama.h @@ -117,13 +117,12 @@ extern "C" { // LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 38, //llama.cpp lists this as 37 //}; - // note: these values should be synchronized with ggml_rope - // TODO: maybe move this enum to ggml.h (ggml_rope_type) enum llama_rope_type { - LLAMA_ROPE_TYPE_NONE = -1, - LLAMA_ROPE_TYPE_NORM = 0, - LLAMA_ROPE_TYPE_NEOX = 2, - LLAMA_ROPE_TYPE_GLM = 4, + LLAMA_ROPE_TYPE_NONE = -1, + LLAMA_ROPE_TYPE_NORM = 0, + LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX, + LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE, + LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION, }; enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file diff --git a/src/llama-arch.h b/src/llama-arch.h index ab68dc5d..d5463ba0 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -23,6 +23,7 @@ enum llm_arch { LLM_ARCH_QWEN, LLM_ARCH_QWEN2, LLM_ARCH_QWEN2MOE, + LLM_ARCH_QWEN2VL, LLM_ARCH_QWEN3, LLM_ARCH_QWEN3MOE, LLM_ARCH_PHI2, @@ -126,7 +127,9 @@ enum llm_kv { LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_OUTPUT_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, + LLM_KV_ROPE_DIMENSION_COUNT, + LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_FREQ_BASE, LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, diff --git a/src/llama.cpp b/src/llama.cpp index 54b1048f..80561dd3 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -203,6 +203,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN, "qwen" }, { LLM_ARCH_QWEN2, "qwen2" }, { LLM_ARCH_QWEN2MOE, "qwen2moe" }, + { LLM_ARCH_QWEN2VL, "qwen2vl" }, { LLM_ARCH_QWEN3, "qwen3" }, { LLM_ARCH_QWEN3MOE, "qwen3moe" }, { LLM_ARCH_PHI2, "phi2" }, @@ -314,6 +315,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, @@ -755,6 +757,23 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, + { + LLM_ARCH_QWEN2VL, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_QWEN3, { @@ -1959,6 +1978,8 @@ struct llama_hparams { float yarn_beta_fast = 32.0f; float yarn_beta_slow = 1.0f; + std::array rope_sections; + // for State Space Models uint32_t ssm_d_conv = 0; uint32_t ssm_d_inner = 0; @@ -3837,6 +3858,11 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_QWEN2VL: + { + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + } + // fall through case LLM_ARCH_QWEN2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -3860,6 +3886,7 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_QWEN3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5741,6 +5768,7 @@ static bool llm_load_tensors( } } break; case LLM_ARCH_QWEN2: + case LLM_ARCH_QWEN2VL: { model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -5748,6 +5776,8 @@ static bool llm_load_tensors( { model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_b = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed if (model.output == NULL) { model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); @@ -10825,6 +10855,119 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_qwen2vl() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + //auto * inp_attn = build_attn_inp_kv(); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq, + model.layers[il].wk, model.layers[il].bk, + model.layers[il].wv, model.layers[il].bv, 0.f, il); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + //cur = ggml_get_rows(ctx0, cur, inp_out_ids); + //inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + + } + struct ggml_cgraph * build_qwen2moe() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -15899,6 +16042,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_qwen2(); } break; + case LLM_ARCH_QWEN2VL: + { + result = llm.build_qwen2vl(); + } break; case LLM_ARCH_QWEN2MOE: { result = llm.build_qwen2moe(); @@ -19873,6 +20020,9 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_OPENAI_MOE: return LLAMA_ROPE_TYPE_NEOX; + case LLM_ARCH_QWEN2VL: + return LLAMA_ROPE_TYPE_MROPE; + // all model arches should be listed explicitly here case LLM_ARCH_UNKNOWN: GGML_ABORT("unknown architecture");