diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index ef557209..8a5bd1b1 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -1080,6 +1080,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); } + // Finally, sum up partial KQ rowsums. + // The partial sums are spread across 8/4 threads each, does not need full reduce. + { + constexpr int offset_first = ntiles == 1 ? 16 : 2; + constexpr int offset_last = ntiles == 1 ? 4 : 1; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = offset_first; offset >= offset_last; offset >>= 1) { + KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); + } + } + } + // If attention sinks are used, potentially re-scale if KQ_max is small. // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum // so it's being done unconditionally for every thread. @@ -1088,6 +1102,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented"); + //const int jc = cols_per_warp == 8 ? tile_C_VKQ::get_j(col) : tile_C_VKQ_16::get_i(2*col); const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col); const float sink = sinks_f[jc % ncols2]; @@ -1126,20 +1141,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } - // Finally, sum up partial KQ rowsums. - // The partial sums are spread across 8/4 threads each, does not need full reduce. - { - constexpr int offset_first = ntiles == 1 ? 16 : 2; - constexpr int offset_last = ntiles == 1 ? 4 : 1; -#pragma unroll - for (int col = 0; col < cols_per_thread; ++col) { -#pragma unroll - for (int offset = offset_first; offset >= offset_last; offset >>= 1) { - KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); - } - } - } - // Combine VKQ accumulator values if np > 1. // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // So also write VKQ accumulators to shared memory in column-major format if np == 1. @@ -1803,18 +1804,16 @@ static void launch_fattn_new_mma( to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream); K_data = (char *) K_f16.ptr; - nb11 = K->ne[0]*sizeof(half); - nb12 = nb11*K->ne[1]; - nb13 = nb12*K->ne[2]; + auto bs = ggml_blck_size(K->type); + auto ts = ggml_type_size(K->type); - // Original PR in llama.cpp. I don't think that can work when K is not contiguous (e.g., nb11 > nb12), there are - // gaps between the rows, etc., as ggml_get_to_fp16_cuda stores into contiguous memory. - //const size_t bs = ggml_blck_size(K->type); - //const size_t ts = ggml_type_size(K->type); + nb11 = nb11*bs*sizeof(half)/ts; + nb12 = nb12*bs*sizeof(half)/ts; + nb13 = nb13*bs*sizeof(half)/ts; - //nb11 = nb11*bs*sizeof(half)/ts; - //nb12 = nb12*bs*sizeof(half)/ts; - //nb13 = nb13*bs*sizeof(half)/ts; + //nb11 = K->ne[0]*sizeof(half); + //nb12 = nb11*K->ne[1]; + //nb13 = nb12*K->ne[2]; } if (need_f16_V && V->type != GGML_TYPE_F16) { @@ -1831,17 +1830,16 @@ static void launch_fattn_new_mma( to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream); V_data = (char *) V_f16.ptr; - nb21 = K->ne[0]*sizeof(half); - nb22 = nb21*V->ne[1]; - nb23 = nb22*V->ne[2]; + auto bs = ggml_blck_size(V->type); + auto ts = ggml_type_size(V->type); - // Original PR in llama.cpp. Same comment as above for the K cache. - //const size_t bs = ggml_blck_size(V->type); - //const size_t ts = ggml_type_size(V->type); + nb21 = nb21*bs*sizeof(half)/ts; + nb22 = nb22*bs*sizeof(half)/ts; + nb23 = nb23*bs*sizeof(half)/ts; - //nb21 = nb21*bs*sizeof(half)/ts; - //nb22 = nb22*bs*sizeof(half)/ts; - //nb23 = nb23*bs*sizeof(half)/ts; + //nb21 = V->ne[0]*sizeof(half); + //nb22 = nb21*V->ne[1]; + //nb23 = nb22*V->ne[2]; } } @@ -2145,10 +2143,10 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens //} if (K->ne[0] == 192 && V->ne[0] == 128) { GGML_ASSERT(Q->ne[0] == 192); - GGML_ASSERT(gqa_ratio == 1); - //ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<192, 128>(ctx, dst); + //GGML_ASSERT(gqa_ratio == 1); // Haha, this assert was for DeepSeek. But now we have Mimo2, which has GQA > 1 + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<192, 128>(ctx, dst); // Reduce compile time - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst); + //ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst); return; } if (K->ne[0] == 192 && V->ne[0] == 192) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 83c7cf40..ae86f7d7 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -93,7 +93,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache. //const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; //const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion; - const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0); + const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0 && K->ne[0] == V->ne[0]); const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0; if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1 && !ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); @@ -107,6 +107,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // so no other implementation works. // if (new_mma_available(cc) && ((K->ne[0] == 576 && V->ne[0] == 512) || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) { + //printf("Using ggml_cuda_flash_attn_ext_mma_new\n"); ggml_cuda_flash_attn_ext_mma_new(ctx, dst); return; } @@ -170,7 +171,7 @@ bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_te // On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache. //const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; //const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion; - const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0); + const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0 && K->ne[0] == V->ne[0]); const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0; if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1 && !ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { return ggml_cuda_fattn_vec_f32_is_supported(ctx, dst); diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index d4406dfc..cc8fb624 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -69,6 +69,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MINIMAX_M2, "minimax-m2" }, { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_MISTRAL3, "mistral3" }, + { LLM_ARCH_MIMO2, "mimo2" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -140,6 +141,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, + { LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, "%s.attention.sliding_window_pattern" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, @@ -150,6 +152,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index bfde21a1..efcbb577 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -68,6 +68,7 @@ enum llm_arch { LLM_ARCH_MINIMAX_M2, LLM_ARCH_SMOLLM3, LLM_ARCH_MISTRAL3, + LLM_ARCH_MIMO2, LLM_ARCH_UNKNOWN, }; @@ -133,6 +134,7 @@ enum llm_kv { LLM_KV_ATTENTION_KV_LORA_RANK, LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_SLIDING_WINDOW, + LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_OUTPUT_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, @@ -143,6 +145,7 @@ enum llm_kv { LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_FREQ_BASE, + LLM_KV_ROPE_FREQ_BASE_SWA, LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 7173dd42..937540a2 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -1394,13 +1394,20 @@ static ggml_tensor * llm_build_kqv( auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2] || sinks) { + //if (n_swa > 0 && k->ne[1] > n_swa + q->ne[1]) { + // auto nton = n_swa + q->ne[1]; + // auto first = k->ne[1] - nton; + // k = ggml_view_3d(ctx, k, k->ne[0], nton, k->ne[2], k->nb[1], k->nb[2], k->nb[1]*first); + // v = ggml_view_3d(ctx, v, v->ne[0], nton, v->ne[2], v->nb[1], v->nb[2], v->nb[1]*first); + // kq_mask = ggml_view_3d(ctx, kq_mask, nton, kq_mask->ne[1], kq_mask->ne[2], kq_mask->nb[1], kq_mask->nb[2], kq_mask->nb[0]*first); + //} struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || - model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || model.arch == LLM_ARCH_GLM4_MOE) { + model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || model.arch == LLM_ARCH_GLM4_MOE || model.arch == LLM_ARCH_MIMO2) { // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 ggml_mul_mat_set_prec(kq, GGML_PREC_F32); @@ -1615,7 +1622,7 @@ std::tuple llm_build_context::llm_buil ggml_tensor * wk, ggml_tensor * bk, ggml_tensor * wv, ggml_tensor * bv, ggml_tensor * q_norm, ggml_tensor * k_norm, float attention_scale, int il, bool add_graph_split) const { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head_k = hparams.n_embd_head_k; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); if (wqkv) { auto qkv = llm_build_lora_mm(lctx, ctx0, wqkv, cur); @@ -1627,8 +1634,8 @@ std::tuple llm_build_context::llm_buil qkv = ggml_add(ctx0, qkv, bqkv); cb(qkv, "qkv_b", il); } - auto Qcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), qkv->nb[1], 0*sizeof(float)*(n_embd)); - auto Kcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), qkv->nb[1], 1*sizeof(float)*Qcur->ne[0]*Qcur->ne[1]); + auto Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head, n_tokens, n_embd_head_k*sizeof(float), qkv->nb[1], 0*sizeof(float)*(n_embd)); + auto Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k*sizeof(float), qkv->nb[1], 1*sizeof(float)*Qcur->ne[0]*Qcur->ne[1]); auto Vcur = ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], 1*sizeof(float)*(Qcur->ne[0]*Qcur->ne[1] + Kcur->ne[0]*Kcur->ne[1])); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -1669,8 +1676,8 @@ std::tuple llm_build_context::llm_buil } ggml_build_forward_expand(gf, qk); ggml_build_forward_expand(gf, Vcur); - auto Qcur = ggml_view_3d(ctx0, qk, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), qk->nb[1], 0*sizeof(float)*(n_embd)); - auto Kcur = ggml_view_3d(ctx0, qk, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), qk->nb[1], 1*sizeof(float)*Qcur->ne[0]*Qcur->ne[1]); + auto Qcur = ggml_view_3d(ctx0, qk, n_embd_head_k, n_head, n_tokens, n_embd_head_k*sizeof(float), qk->nb[1], 0*sizeof(float)*(n_embd)); + auto Kcur = ggml_view_3d(ctx0, qk, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k*sizeof(float), qk->nb[1], 1*sizeof(float)*Qcur->ne[0]*Qcur->ne[1]); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); if (q_norm) { @@ -1689,13 +1696,13 @@ std::tuple llm_build_context::llm_buil } auto [Q, K, V] = llm_build_mul_mat_qkv(gf, cur, wq, bq, wk, bk, wv, bv, attention_scale, il, add_graph_split); - auto Qcur = ggml_reshape_3d(ctx0, Q, n_embd_head, Q->ne[0]/n_embd_head, n_tokens); + auto Qcur = ggml_reshape_3d(ctx0, Q, n_embd_head_k, Q->ne[0]/n_embd_head_k, n_tokens); if (q_norm) { Qcur = llm_build_norm(ctx0, Qcur, hparams, q_norm, NULL, LLM_NORM_RMS, cb, il); cb(Qcur, "Qcur_normed", il); } - auto Kcur = ggml_reshape_3d(ctx0, K, n_embd_head, K->ne[0]/n_embd_head, n_tokens); + auto Kcur = ggml_reshape_3d(ctx0, K, n_embd_head_k, K->ne[0]/n_embd_head_k, n_tokens); if (k_norm) { Kcur = llm_build_norm(ctx0, Kcur, hparams, k_norm, NULL, LLM_NORM_RMS, cb, il); cb(Kcur, "Kcur_normed", il); @@ -8494,6 +8501,81 @@ ggml_cgraph * llm_build_context::build_hunyuan_moe() { return gf; } +ggml_cgraph * llm_build_context::build_mimo2() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + //const int64_t n_embd_head = hparams.n_embd_head_v; + //GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + //GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(); + + for (int il = 0; il < n_layer; ++il) { + const bool is_sliding = model.hparams.swa_layers[il]; + auto KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask; + + cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask_l, model.layers[il].attn_sinks, + nullptr, 1.0f/sqrtf(float(n_embd_head_k)), 0.0f, is_sliding ? hparams.n_swa : 0, il, true, false, true); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + + auto ffn_inp = cur; + + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true); + cb(cur, "ffn_out", il); + } else { + cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, + model.layers[il].ffn_gate_inp, nullptr, + model.layers[il].ffn_up_exps, nullptr, + model.layers[il].ffn_gate_exps, nullptr, + model.layers[il].ffn_down_exps, nullptr, + model.layers[il].ffn_exp_probs_b, + nullptr, nullptr, // we don't have shared expert biases? + nullptr, nullptr, + nullptr, nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, false, 0.0f, + LLM_EXPERT_GATING_FUNC_SIGMOID, + LLM_FFN_SILU, cb, il, gf, true); + } + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; +} + ggml_cgraph * llm_build_context::build_openai_moe() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -9317,6 +9399,10 @@ ggml_cgraph * llm_build_context::llama_build_graph( { result = llm.build_mistral3(); } break; + case LLM_ARCH_MIMO2: + { + result = llm.build_mimo2(); + } break; default: GGML_ABORT("fatal error"); } @@ -9340,6 +9426,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens ggml_tensor * input, ggml_tensor * inp_pos, ggml_tensor * rope_factors_in, ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale, int n_swa, int il, bool do_rope, bool add_graph_split, bool add_input, bool is_norm) { + + float freq_base_l = n_swa > 0 ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; + float freq_scale_l = n_swa > 0 ? hparams.rope_freq_scale_train_swa : hparams.rope_freq_scale_train; + if (!model.layers[il].wqkv && !model.layers[il].wqk && cparams.flash_attn && model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra) { if (kv_self.k_l[il]->extra && kv_self.v_l[il]->extra) { @@ -9414,9 +9504,9 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens rope_factors = extra->splits[id]; } if (do_rope) { - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); } cb(Qcur, "Qcur", il_cb); @@ -9550,9 +9640,9 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, f_attn_scale, il); if (do_rope) { - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors_in, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors_in, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors_in, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors_in, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); } cb(Qcur, "Qcur", il); diff --git a/src/llama-build-context.h b/src/llama-build-context.h index 498c3a5d..ac5577d6 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -276,6 +276,8 @@ struct llm_build_context { ggml_cgraph * build_smollm3(); + ggml_cgraph * build_mimo2(); + // static ggml_tensor * llm_build_lora_mm(llama_context & lctx, ggml_context * ctx0, ggml_tensor * w, ggml_tensor * cur); diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 81afc62b..4cef9236 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -1072,6 +1072,23 @@ void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_MIMO2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + //TODO + //hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; // which is the same as OpenAI + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + switch (hparams.n_layer) { + case 48: model.type = e_model::MODEL_310B_A15B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + + } break; default: (void)0; } diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 9ef3eefc..984db1ed 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -113,7 +113,7 @@ struct llama_hparams { // qwen3vl deepstack uint32_t n_deepstack_layers = 0; - + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = -1; @@ -122,6 +122,8 @@ struct llama_hparams { enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + std::array swa_layers; + bool operator!=(const llama_hparams & other) const { if (this->vocab_only != other.vocab_only) return true; if (this->n_vocab != other.n_vocab) return true; diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index cc50a647..f55b1d74 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -133,6 +133,8 @@ struct create_tensors_helper : public create_tensors_helper_interface { bool create_smollm3_tensors(const LLM_TN & tn); + bool create_mimo2_tensors(const LLM_TN & tn); + llama_model_loader & ml; llama_model & model; @@ -1194,6 +1196,49 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) { return use_mmap_buffer; } +bool create_tensors_helper::create_mimo2_tensors(const LLM_TN & tn) { + LOADING_PRELUDE + + create_embd_output(tn, n_embd, n_vocab, true, false); //true); + + for (int i = 0; i < n_layer; ++i) { + uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + uint32_t n_head = hparams.n_head(i); + //printf("Layer %2d: n_head = %u, n_embd_head_k = %d, n_embd_head_v = %d, n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", i, n_head, (int)n_embd_head_k, (int)n_embd_head_v, n_embd_k_gqa, n_embd_v_gqa); + + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_sinks = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, 0); + + auto ffn_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer; + layer.ffn_norm = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + // non-MoE branch + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + layer.ffn_gate_inp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + } + return use_mmap_buffer; +} + bool create_tensors_helper::create_phi2_tensors(const LLM_TN & tn) { LOADING_PRELUDE @@ -2947,6 +2992,8 @@ bool create_tensors_helper::create_tensors() { use_mmap_buffer = create_minimaxm2_tensors(tn); break; case LLM_ARCH_SMOLLM3: use_mmap_buffer = create_smollm3_tensors(tn); break; + case LLM_ARCH_MIMO2: + use_mmap_buffer = create_mimo2_tensors(tn); break; default: throw std::runtime_error("unknown architecture"); } @@ -2955,7 +3002,6 @@ bool create_tensors_helper::create_tensors() { printf("================================ max_gpu = %d\n", model.max_gpu); std::vector mem_used(model.splits.size(), 0); const auto & hparams = model.hparams; - int gqa_ratio = hparams.n_head() / hparams.n_head_kv(); auto cur_splits = model.splits; int adjust_step = std::max(1, int(n_layer / (2*model.splits.size()))); if (model.max_gpu > 1 && model.max_gpu < int(cur_splits.size())) { @@ -2976,6 +3022,7 @@ bool create_tensors_helper::create_tensors() { } } for (int il = 0; il < n_layer; ++il) { + int gqa_ratio = hparams.n_head(il) / hparams.n_head_kv(il); if (ggml_backend_buft_is_host(model.buft_layer[il].buft_matrix)) { LLAMA_LOG_INFO("%s: not splitting layer %d because buffer type is host\n", __func__, il); continue; @@ -2996,19 +3043,26 @@ bool create_tensors_helper::create_tensors() { if (layer.attn_norm) { auto split = create_split(ggml_nrows(layer.attn_norm), -1, cur_splits, mem_used); prepare_split_tensors(-1, ctx_split, layer.attn_norm, layer.split_attn_norm, split, mem_used); + if (layer.attn_sinks) { + prepare_split_tensors(-1, ctx_split, layer.attn_sinks, layer.split_attn_sinks, split, mem_used); + } } if (layer.rope_freqs) { auto split = create_split(ggml_nrows(layer.rope_freqs), -1, cur_splits, mem_used); prepare_split_tensors(-1, ctx_split, layer.rope_freqs, layer.split_rope_freqs, split, mem_used); } if (layer.wo && layer.wq && layer.wk && layer.wv) { - int attn_granularity = hparams.n_embd_head_k * gqa_ratio; + // TODO: fix this logic. It only works whe K and V head size is the same + //printf("Layer %d: q = %ld x %ld, k = %ld x %ld, v = %ld x %ld, qo = %ld x %ld\n", il, layer.wq->ne[0], layer.wq->ne[1], + // layer.wk->ne[0], layer.wk->ne[1], layer.wv->ne[0], layer.wv->ne[1], layer.wo->ne[0], layer.wo->ne[1]); + int attn_granularity = hparams.n_embd_head_v * gqa_ratio; if (ggml_is_quantized(layer.wo->type)) { auto tt = ggml_internal_get_type_traits(layer.wo->type); if (tt.blck_size > attn_granularity) attn_granularity = tt.blck_size; } - GGML_ASSERT(attn_granularity % hparams.n_embd_head_k == 0); + GGML_ASSERT(attn_granularity % hparams.n_embd_head_v == 0); auto split = create_split(layer.wo->ne[0], attn_granularity, cur_splits, mem_used); + //printf("Split:"); for (auto s : split) printf(" %d", s); printf("\n"); prepare_split_tensors(0, ctx_split, layer.wo, layer.split_wo, split, mem_used); prepare_split_tensors(1, ctx_split, layer.wq, layer.split_wq, split, mem_used); if (layer.bo) { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 94c30ee7..be4b9cef 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1294,6 +1294,29 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_MIMO2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_SINKS, "blk.%d.attn_sinks" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1538,6 +1561,7 @@ const char * llama_model_type_name(e_model type) { case MODEL_106B_A12B: return "106B.A12B"; case MODEL_230B_A10B: return "230B.A10B"; case MODEL_235B_A22B: return "235B.A22B"; + case MODEL_310B_A15B: return "310B.A15B"; case MODEL_300B_A47B: return "300B.A47B"; case MODEL_355B_A32B: return "355B.A32B"; case MODEL_E2B: return "E2B"; diff --git a/src/llama-model.h b/src/llama-model.h index a252ab5e..86aaebd8 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -112,6 +112,7 @@ enum e_model { MODEL_106B_A12B, MODEL_230B_A10B, // Minimax M2 MODEL_235B_A22B, + MODEL_310B_A15B, MODEL_300B_A47B, // Ernie MoE big MODEL_355B_A32B, MODEL_E2B, @@ -184,6 +185,7 @@ struct llama_layer { struct ggml_tensor * bkv = nullptr; llama_split_tensor split_attn_norm; + llama_split_tensor split_attn_sinks; llama_split_tensor split_wq; llama_split_tensor split_wk; llama_split_tensor split_wv; diff --git a/src/llama.cpp b/src/llama.cpp index 449d41a7..33def225 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4905,6 +4905,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_OPENAI_MOE: case LLM_ARCH_BAILINGMOE2: case LLM_ARCH_MINIMAX_M2: + case LLM_ARCH_MIMO2: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: