Split mode graph for Minimax-M2 (#1195)

* Split mode graph for Minimax-M2

* Cleanup

* Forgotten ffn_exp_probs_b
This commit is contained in:
Kawrakow
2026-01-29 07:27:06 +02:00
committed by GitHub
parent 68cd52e583
commit 68ed62447c
4 changed files with 238 additions and 60 deletions

View File

@@ -779,6 +779,7 @@ static bool llama_kv_cache_init(
n_mla++;
}
else {
int n_embd_head_v = hparams.n_embd_head_v;
k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size);
v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
auto k_name = std::string{"cache_k_l"} + std::to_string(i);
@@ -793,6 +794,7 @@ static bool llama_kv_cache_init(
auto K = model.layers[i].wk;
auto V = model.layers[i].wv;
if (K && V && K->extra && V->extra) {
bool use_V_for_K = model.layers[i].attn_k_norm && model.layers[i].attn_k_norm->ne[0] == K->ne[1] ? true : false;
auto extra_K = (const ggml_split_tensor_t *)K->extra;
auto extra_V = (const ggml_split_tensor_t *)V->extra;
auto & split_k_l = cache.split_k_l.emplace_back();
@@ -800,9 +802,14 @@ static bool llama_kv_cache_init(
split_k_l.tensor_splits.resize(extra_K->n_device, nullptr);
split_v_l.tensor_splits.resize(extra_V->n_device, nullptr);
for (int is = 0; is < extra_K->n_device; ++is) {
auto split = extra_K->splits[is];
auto split = use_V_for_K ? extra_V->splits[is] : extra_K->splits[is];
if (!split) continue;
split_k_l.tensor_splits[is] = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, split->ne[1]/n_embd_head_k * kv_size);
int nhead_kv = use_V_for_K ? split->ne[1] / n_embd_head_v : split->ne[1]/n_embd_head_k;
if (use_V_for_K) {
LLAMA_LOG_DEBUG("K_cache(%d, %d): using %d instead of %ld heads\n",
i, is, nhead_kv, extra_K->splits[is]->ne[1]/n_embd_head_k);
}
split_k_l.tensor_splits[is] = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, nhead_kv * kv_size);
auto split_name = k_name + '.' + std::to_string(is);
ggml_set_name(split_k_l.tensor_splits[is], split_name.c_str());
mem_split[is] += ggml_nbytes(split_k_l.tensor_splits[is]);
@@ -1745,6 +1752,7 @@ static bool is_model_split_supported(const llama_model & model) {
LLM_ARCH_HUNYUAN_MOE,
LLM_ARCH_OPENAI_MOE,
LLM_ARCH_ERNIE4_5_MOE,
LLM_ARCH_MINIMAX_M2,
};
auto it = k_supported.find(model.arch);
return it != k_supported.end();