Graph parallel for Mimo-V2-Flash (#1105)

* WIP

* Cleanup

* Set max_gpu to 2 for Mimo2

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2026-01-05 09:58:54 +02:00
committed by GitHub
parent 385fc14110
commit 419a397ce0
5 changed files with 45 additions and 40 deletions

View File

@@ -1394,13 +1394,6 @@ static ggml_tensor * llm_build_kqv(
auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024);
if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2] || sinks) {
//if (n_swa > 0 && k->ne[1] > n_swa + q->ne[1]) {
// auto nton = n_swa + q->ne[1];
// auto first = k->ne[1] - nton;
// k = ggml_view_3d(ctx, k, k->ne[0], nton, k->ne[2], k->nb[1], k->nb[2], k->nb[1]*first);
// v = ggml_view_3d(ctx, v, v->ne[0], nton, v->ne[2], v->nb[1], v->nb[2], v->nb[1]*first);
// kq_mask = ggml_view_3d(ctx, kq_mask, nton, kq_mask->ne[1], kq_mask->ne[2], kq_mask->nb[1], kq_mask->nb[2], kq_mask->nb[0]*first);
//}
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
@@ -9433,7 +9426,6 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
if (!model.layers[il].wqkv && !model.layers[il].wqk && cparams.flash_attn &&
model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra) {
if (kv_self.k_l[il]->extra && kv_self.v_l[il]->extra) {
//printf("%s: %s\n", __func__, ggml_op_name(input->op));
ggml_split_tensor_t * attn_norm = the_attn_norm ? (ggml_split_tensor_t *)the_attn_norm->extra : nullptr;
auto wq = (ggml_split_tensor_t *)model.layers[il].wq->extra;
auto wk = (ggml_split_tensor_t *)model.layers[il].wk->extra;
@@ -9481,11 +9473,6 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM_RMS, cb, il);
}
}
//if (attn_norm) {
// auto split_norm = attn_norm->splits[id];
// cur = llm_build_norm(ctx0, cur, hparams, split_norm, NULL, is_norm ? LLM_NORM : LLM_NORM_RMS, cb, il);
// cb(cur, "attn_norm", il_cb);
//}
if (cur->type != GGML_TYPE_F32) {
cur = ggml_cast(ctx0, cur, GGML_TYPE_F32);
}
@@ -9583,15 +9570,21 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, KQ_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
cb(cur, "flash_attn", il_cb);
ggml_flash_attn_ext_add_sinks(cur, sinks);
if (model.layers[il].attn_sinks && model.layers[il].attn_sinks->extra) {
auto split = (ggml_split_tensor_t *)model.layers[il].attn_sinks->extra;
GGML_ASSERT(split->n_device == wq->n_device);
GGML_ASSERT(split->splits[id]);
ggml_flash_attn_ext_add_sinks(cur, split->splits[id]);
} else {
ggml_flash_attn_ext_add_sinks(cur, sinks);
}
if (n_swa > 0) {
((int32_t *)cur->op_params)[4] = n_swa;
}
// Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX ||
(model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 ||
model.arch == LLM_ARCH_GLM4_MOE) {
(model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 ||
model.arch == LLM_ARCH_GLM4_MOE) {
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
}