Mimo-V2-Flash support (#1096)

* Mimo-2 support

* Fix bug for head sizes not being the same

It still does not solve the Mimo-2 quantized cache issue.

* Fix quantized cache

* Minor

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2026-01-05 08:00:01 +02:00
committed by GitHub
parent 1401326916
commit 8a6622eb4f
12 changed files with 251 additions and 54 deletions

View File

@@ -1080,6 +1080,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
__syncthreads();
}
// Finally, sum up partial KQ rowsums.
// The partial sums are spread across 8/4 threads each, does not need full reduce.
{
constexpr int offset_first = ntiles == 1 ? 16 : 2;
constexpr int offset_last = ntiles == 1 ? 4 : 1;
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
#pragma unroll
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
}
}
}
// If attention sinks are used, potentially re-scale if KQ_max is small.
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
// so it's being done unconditionally for every thread.
@@ -1088,6 +1102,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
//const int jc = cols_per_warp == 8 ? tile_C_VKQ::get_j(col) : tile_C_VKQ_16::get_i(2*col);
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
const float sink = sinks_f[jc % ncols2];
@@ -1126,20 +1141,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}
}
// Finally, sum up partial KQ rowsums.
// The partial sums are spread across 8/4 threads each, does not need full reduce.
{
constexpr int offset_first = ntiles == 1 ? 16 : 2;
constexpr int offset_last = ntiles == 1 ? 4 : 1;
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
#pragma unroll
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
}
}
}
// Combine VKQ accumulator values if np > 1.
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
@@ -1803,18 +1804,16 @@ static void launch_fattn_new_mma(
to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream);
K_data = (char *) K_f16.ptr;
nb11 = K->ne[0]*sizeof(half);
nb12 = nb11*K->ne[1];
nb13 = nb12*K->ne[2];
auto bs = ggml_blck_size(K->type);
auto ts = ggml_type_size(K->type);
// Original PR in llama.cpp. I don't think that can work when K is not contiguous (e.g., nb11 > nb12), there are
// gaps between the rows, etc., as ggml_get_to_fp16_cuda stores into contiguous memory.
//const size_t bs = ggml_blck_size(K->type);
//const size_t ts = ggml_type_size(K->type);
nb11 = nb11*bs*sizeof(half)/ts;
nb12 = nb12*bs*sizeof(half)/ts;
nb13 = nb13*bs*sizeof(half)/ts;
//nb11 = nb11*bs*sizeof(half)/ts;
//nb12 = nb12*bs*sizeof(half)/ts;
//nb13 = nb13*bs*sizeof(half)/ts;
//nb11 = K->ne[0]*sizeof(half);
//nb12 = nb11*K->ne[1];
//nb13 = nb12*K->ne[2];
}
if (need_f16_V && V->type != GGML_TYPE_F16) {
@@ -1831,17 +1830,16 @@ static void launch_fattn_new_mma(
to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream);
V_data = (char *) V_f16.ptr;
nb21 = K->ne[0]*sizeof(half);
nb22 = nb21*V->ne[1];
nb23 = nb22*V->ne[2];
auto bs = ggml_blck_size(V->type);
auto ts = ggml_type_size(V->type);
// Original PR in llama.cpp. Same comment as above for the K cache.
//const size_t bs = ggml_blck_size(V->type);
//const size_t ts = ggml_type_size(V->type);
nb21 = nb21*bs*sizeof(half)/ts;
nb22 = nb22*bs*sizeof(half)/ts;
nb23 = nb23*bs*sizeof(half)/ts;
//nb21 = nb21*bs*sizeof(half)/ts;
//nb22 = nb22*bs*sizeof(half)/ts;
//nb23 = nb23*bs*sizeof(half)/ts;
//nb21 = V->ne[0]*sizeof(half);
//nb22 = nb21*V->ne[1];
//nb23 = nb22*V->ne[2];
}
}
@@ -2145,10 +2143,10 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
//}
if (K->ne[0] == 192 && V->ne[0] == 128) {
GGML_ASSERT(Q->ne[0] == 192);
GGML_ASSERT(gqa_ratio == 1);
//ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<192, 128>(ctx, dst);
//GGML_ASSERT(gqa_ratio == 1); // Haha, this assert was for DeepSeek. But now we have Mimo2, which has GQA > 1
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<192, 128>(ctx, dst);
// Reduce compile time
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst);
//ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst);
return;
}
if (K->ne[0] == 192 && V->ne[0] == 192) {

View File

@@ -93,7 +93,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
// On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache.
//const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
//const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion;
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0);
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0 && K->ne[0] == V->ne[0]);
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0;
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1 && !ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
@@ -107,6 +107,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
// so no other implementation works.
//
if (new_mma_available(cc) && ((K->ne[0] == 576 && V->ne[0] == 512) || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) {
//printf("Using ggml_cuda_flash_attn_ext_mma_new\n");
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
return;
}
@@ -170,7 +171,7 @@ bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_te
// On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache.
//const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
//const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion;
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0);
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0 && K->ne[0] == V->ne[0]);
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0;
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1 && !ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
return ggml_cuda_fattn_vec_f32_is_supported(ctx, dst);

View File

@@ -69,6 +69,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINIMAX_M2, "minimax-m2" },
{ LLM_ARCH_SMOLLM3, "smollm3" },
{ LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_MIMO2, "mimo2" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@@ -140,6 +141,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, "%s.attention.sliding_window_pattern" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
@@ -150,6 +152,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
{ LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" },
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },

View File

@@ -68,6 +68,7 @@ enum llm_arch {
LLM_ARCH_MINIMAX_M2,
LLM_ARCH_SMOLLM3,
LLM_ARCH_MISTRAL3,
LLM_ARCH_MIMO2,
LLM_ARCH_UNKNOWN,
};
@@ -133,6 +134,7 @@ enum llm_kv {
LLM_KV_ATTENTION_KV_LORA_RANK,
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
LLM_KV_ATTENTION_SLIDING_WINDOW,
LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN,
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_OUTPUT_SCALE,
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
@@ -143,6 +145,7 @@ enum llm_kv {
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS,
LLM_KV_ROPE_FREQ_BASE,
LLM_KV_ROPE_FREQ_BASE_SWA,
LLM_KV_ROPE_SCALE_LINEAR,
LLM_KV_ROPE_SCALING_TYPE,
LLM_KV_ROPE_SCALING_FACTOR,

View File

@@ -1394,13 +1394,20 @@ static ggml_tensor * llm_build_kqv(
auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024);
if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2] || sinks) {
//if (n_swa > 0 && k->ne[1] > n_swa + q->ne[1]) {
// auto nton = n_swa + q->ne[1];
// auto first = k->ne[1] - nton;
// k = ggml_view_3d(ctx, k, k->ne[0], nton, k->ne[2], k->nb[1], k->nb[2], k->nb[1]*first);
// v = ggml_view_3d(ctx, v, v->ne[0], nton, v->ne[2], v->nb[1], v->nb[2], v->nb[1]*first);
// kq_mask = ggml_view_3d(ctx, kq_mask, nton, kq_mask->ne[1], kq_mask->ne[2], kq_mask->nb[1], kq_mask->nb[2], kq_mask->nb[0]*first);
//}
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 ||
model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || model.arch == LLM_ARCH_GLM4_MOE) {
model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || model.arch == LLM_ARCH_GLM4_MOE || model.arch == LLM_ARCH_MIMO2) {
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
@@ -1615,7 +1622,7 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
ggml_tensor * wk, ggml_tensor * bk,
ggml_tensor * wv, ggml_tensor * bv,
ggml_tensor * q_norm, ggml_tensor * k_norm, float attention_scale, int il, bool add_graph_split) const {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
if (wqkv) {
auto qkv = llm_build_lora_mm(lctx, ctx0, wqkv, cur);
@@ -1627,8 +1634,8 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
qkv = ggml_add(ctx0, qkv, bqkv);
cb(qkv, "qkv_b", il);
}
auto Qcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), qkv->nb[1], 0*sizeof(float)*(n_embd));
auto Kcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), qkv->nb[1], 1*sizeof(float)*Qcur->ne[0]*Qcur->ne[1]);
auto Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head, n_tokens, n_embd_head_k*sizeof(float), qkv->nb[1], 0*sizeof(float)*(n_embd));
auto Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k*sizeof(float), qkv->nb[1], 1*sizeof(float)*Qcur->ne[0]*Qcur->ne[1]);
auto Vcur = ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], 1*sizeof(float)*(Qcur->ne[0]*Qcur->ne[1] + Kcur->ne[0]*Kcur->ne[1]));
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
@@ -1669,8 +1676,8 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
}
ggml_build_forward_expand(gf, qk);
ggml_build_forward_expand(gf, Vcur);
auto Qcur = ggml_view_3d(ctx0, qk, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), qk->nb[1], 0*sizeof(float)*(n_embd));
auto Kcur = ggml_view_3d(ctx0, qk, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), qk->nb[1], 1*sizeof(float)*Qcur->ne[0]*Qcur->ne[1]);
auto Qcur = ggml_view_3d(ctx0, qk, n_embd_head_k, n_head, n_tokens, n_embd_head_k*sizeof(float), qk->nb[1], 0*sizeof(float)*(n_embd));
auto Kcur = ggml_view_3d(ctx0, qk, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k*sizeof(float), qk->nb[1], 1*sizeof(float)*Qcur->ne[0]*Qcur->ne[1]);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
if (q_norm) {
@@ -1689,13 +1696,13 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
}
auto [Q, K, V] = llm_build_mul_mat_qkv(gf, cur, wq, bq, wk, bk, wv, bv, attention_scale, il, add_graph_split);
auto Qcur = ggml_reshape_3d(ctx0, Q, n_embd_head, Q->ne[0]/n_embd_head, n_tokens);
auto Qcur = ggml_reshape_3d(ctx0, Q, n_embd_head_k, Q->ne[0]/n_embd_head_k, n_tokens);
if (q_norm) {
Qcur = llm_build_norm(ctx0, Qcur, hparams, q_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
}
auto Kcur = ggml_reshape_3d(ctx0, K, n_embd_head, K->ne[0]/n_embd_head, n_tokens);
auto Kcur = ggml_reshape_3d(ctx0, K, n_embd_head_k, K->ne[0]/n_embd_head_k, n_tokens);
if (k_norm) {
Kcur = llm_build_norm(ctx0, Kcur, hparams, k_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
@@ -8494,6 +8501,81 @@ ggml_cgraph * llm_build_context::build_hunyuan_moe() {
return gf;
}
ggml_cgraph * llm_build_context::build_mimo2() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
//const int64_t n_embd_head = hparams.n_embd_head_v;
//GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
//GGML_ASSERT(n_embd_head == hparams.n_rot);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
for (int il = 0; il < n_layer; ++il) {
const bool is_sliding = model.hparams.swa_layers[il];
auto KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr, KQ_mask_l, model.layers[il].attn_sinks,
nullptr, 1.0f/sqrtf(float(n_embd_head_k)), 0.0f, is_sliding ? hparams.n_swa : 0, il, true, false, true);
if (il == n_layer - 1) {
// skip computing output for unused tokens
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}
auto ffn_inp = cur;
if (model.layers[il].ffn_gate_inp == nullptr) {
cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true);
cb(cur, "ffn_out", il);
} else {
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp,
model.layers[il].ffn_gate_inp, nullptr,
model.layers[il].ffn_up_exps, nullptr,
model.layers[il].ffn_gate_exps, nullptr,
model.layers[il].ffn_down_exps, nullptr,
model.layers[il].ffn_exp_probs_b,
nullptr, nullptr, // we don't have shared expert biases?
nullptr, nullptr,
nullptr, nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, true, false, 0.0f,
LLM_EXPERT_GATING_FUNC_SIGMOID,
LLM_FFN_SILU, cb, il, gf, true);
}
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_output(lctx, ctx0, cur, model.output, model.output_norm, cb);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
ggml_cgraph * llm_build_context::build_openai_moe() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
@@ -9317,6 +9399,10 @@ ggml_cgraph * llm_build_context::llama_build_graph(
{
result = llm.build_mistral3();
} break;
case LLM_ARCH_MIMO2:
{
result = llm.build_mimo2();
} break;
default:
GGML_ABORT("fatal error");
}
@@ -9340,6 +9426,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
ggml_tensor * input, ggml_tensor * inp_pos, ggml_tensor * rope_factors_in,
ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale,
int n_swa, int il, bool do_rope, bool add_graph_split, bool add_input, bool is_norm) {
float freq_base_l = n_swa > 0 ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
float freq_scale_l = n_swa > 0 ? hparams.rope_freq_scale_train_swa : hparams.rope_freq_scale_train;
if (!model.layers[il].wqkv && !model.layers[il].wqk && cparams.flash_attn &&
model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra) {
if (kv_self.k_l[il]->extra && kv_self.v_l[il]->extra) {
@@ -9414,9 +9504,9 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
rope_factors = extra->splits[id];
}
if (do_rope) {
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
}
cb(Qcur, "Qcur", il_cb);
@@ -9550,9 +9640,9 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, f_attn_scale, il);
if (do_rope) {
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors_in, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors_in, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors_in, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors_in, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
}
cb(Qcur, "Qcur", il);

View File

@@ -276,6 +276,8 @@ struct llm_build_context {
ggml_cgraph * build_smollm3();
ggml_cgraph * build_mimo2();
//
static ggml_tensor * llm_build_lora_mm(llama_context & lctx, ggml_context * ctx0,
ggml_tensor * w, ggml_tensor * cur);

View File

@@ -1072,6 +1072,23 @@ void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_MIMO2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa);
//TODO
//hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; // which is the same as OpenAI
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer);
switch (hparams.n_layer) {
case 48: model.type = e_model::MODEL_310B_A15B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
default: (void)0;
}

View File

@@ -113,7 +113,7 @@ struct llama_hparams {
// qwen3vl deepstack
uint32_t n_deepstack_layers = 0;
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = -1;
@@ -122,6 +122,8 @@ struct llama_hparams {
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
std::array<uint32_t, LLAMA_MAX_LAYERS> swa_layers;
bool operator!=(const llama_hparams & other) const {
if (this->vocab_only != other.vocab_only) return true;
if (this->n_vocab != other.n_vocab) return true;

View File

@@ -133,6 +133,8 @@ struct create_tensors_helper : public create_tensors_helper_interface {
bool create_smollm3_tensors(const LLM_TN & tn);
bool create_mimo2_tensors(const LLM_TN & tn);
llama_model_loader & ml;
llama_model & model;
@@ -1194,6 +1196,49 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) {
return use_mmap_buffer;
}
bool create_tensors_helper::create_mimo2_tensors(const LLM_TN & tn) {
LOADING_PRELUDE
create_embd_output(tn, n_embd, n_vocab, true, false); //true);
for (int i = 0; i < n_layer; ++i) {
uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i);
uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i);
uint32_t n_head = hparams.n_head(i);
//printf("Layer %2d: n_head = %u, n_embd_head_k = %d, n_embd_head_v = %d, n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", i, n_head, (int)n_embd_head_k, (int)n_embd_head_v, n_embd_k_gqa, n_embd_v_gqa);
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.attn_sinks = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, 0);
auto ffn_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer;
layer.ffn_norm = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
// non-MoE branch
layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
// MoE branch
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
layer.ffn_gate_inp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_gate_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_down_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_up_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_exp_probs_b = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
}
return use_mmap_buffer;
}
bool create_tensors_helper::create_phi2_tensors(const LLM_TN & tn) {
LOADING_PRELUDE
@@ -2947,6 +2992,8 @@ bool create_tensors_helper::create_tensors() {
use_mmap_buffer = create_minimaxm2_tensors(tn); break;
case LLM_ARCH_SMOLLM3:
use_mmap_buffer = create_smollm3_tensors(tn); break;
case LLM_ARCH_MIMO2:
use_mmap_buffer = create_mimo2_tensors(tn); break;
default:
throw std::runtime_error("unknown architecture");
}
@@ -2955,7 +3002,6 @@ bool create_tensors_helper::create_tensors() {
printf("================================ max_gpu = %d\n", model.max_gpu);
std::vector<size_t> mem_used(model.splits.size(), 0);
const auto & hparams = model.hparams;
int gqa_ratio = hparams.n_head() / hparams.n_head_kv();
auto cur_splits = model.splits;
int adjust_step = std::max(1, int(n_layer / (2*model.splits.size())));
if (model.max_gpu > 1 && model.max_gpu < int(cur_splits.size())) {
@@ -2976,6 +3022,7 @@ bool create_tensors_helper::create_tensors() {
}
}
for (int il = 0; il < n_layer; ++il) {
int gqa_ratio = hparams.n_head(il) / hparams.n_head_kv(il);
if (ggml_backend_buft_is_host(model.buft_layer[il].buft_matrix)) {
LLAMA_LOG_INFO("%s: not splitting layer %d because buffer type is host\n", __func__, il);
continue;
@@ -2996,19 +3043,26 @@ bool create_tensors_helper::create_tensors() {
if (layer.attn_norm) {
auto split = create_split(ggml_nrows(layer.attn_norm), -1, cur_splits, mem_used);
prepare_split_tensors(-1, ctx_split, layer.attn_norm, layer.split_attn_norm, split, mem_used);
if (layer.attn_sinks) {
prepare_split_tensors(-1, ctx_split, layer.attn_sinks, layer.split_attn_sinks, split, mem_used);
}
}
if (layer.rope_freqs) {
auto split = create_split(ggml_nrows(layer.rope_freqs), -1, cur_splits, mem_used);
prepare_split_tensors(-1, ctx_split, layer.rope_freqs, layer.split_rope_freqs, split, mem_used);
}
if (layer.wo && layer.wq && layer.wk && layer.wv) {
int attn_granularity = hparams.n_embd_head_k * gqa_ratio;
// TODO: fix this logic. It only works whe K and V head size is the same
//printf("Layer %d: q = %ld x %ld, k = %ld x %ld, v = %ld x %ld, qo = %ld x %ld\n", il, layer.wq->ne[0], layer.wq->ne[1],
// layer.wk->ne[0], layer.wk->ne[1], layer.wv->ne[0], layer.wv->ne[1], layer.wo->ne[0], layer.wo->ne[1]);
int attn_granularity = hparams.n_embd_head_v * gqa_ratio;
if (ggml_is_quantized(layer.wo->type)) {
auto tt = ggml_internal_get_type_traits(layer.wo->type);
if (tt.blck_size > attn_granularity) attn_granularity = tt.blck_size;
}
GGML_ASSERT(attn_granularity % hparams.n_embd_head_k == 0);
GGML_ASSERT(attn_granularity % hparams.n_embd_head_v == 0);
auto split = create_split(layer.wo->ne[0], attn_granularity, cur_splits, mem_used);
//printf("Split:"); for (auto s : split) printf(" %d", s); printf("\n");
prepare_split_tensors(0, ctx_split, layer.wo, layer.split_wo, split, mem_used);
prepare_split_tensors(1, ctx_split, layer.wq, layer.split_wq, split, mem_used);
if (layer.bo) {

View File

@@ -1294,6 +1294,29 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_MIMO2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_SINKS, "blk.%d.attn_sinks" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_UNKNOWN,
{
@@ -1538,6 +1561,7 @@ const char * llama_model_type_name(e_model type) {
case MODEL_106B_A12B: return "106B.A12B";
case MODEL_230B_A10B: return "230B.A10B";
case MODEL_235B_A22B: return "235B.A22B";
case MODEL_310B_A15B: return "310B.A15B";
case MODEL_300B_A47B: return "300B.A47B";
case MODEL_355B_A32B: return "355B.A32B";
case MODEL_E2B: return "E2B";

View File

@@ -112,6 +112,7 @@ enum e_model {
MODEL_106B_A12B,
MODEL_230B_A10B, // Minimax M2
MODEL_235B_A22B,
MODEL_310B_A15B,
MODEL_300B_A47B, // Ernie MoE big
MODEL_355B_A32B,
MODEL_E2B,
@@ -184,6 +185,7 @@ struct llama_layer {
struct ggml_tensor * bkv = nullptr;
llama_split_tensor split_attn_norm;
llama_split_tensor split_attn_sinks;
llama_split_tensor split_wq;
llama_split_tensor split_wk;
llama_split_tensor split_wv;

View File

@@ -4905,6 +4905,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_OPENAI_MOE:
case LLM_ARCH_BAILINGMOE2:
case LLM_ARCH_MINIMAX_M2:
case LLM_ARCH_MIMO2:
return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2VL: