Graph parallel for Mimo-V2-Flash (#1105)

* WIP

* Cleanup

* Set max_gpu to 2 for Mimo2

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2026-01-05 09:58:54 +02:00
committed by GitHub
parent 385fc14110
commit 419a397ce0
5 changed files with 45 additions and 40 deletions

View File

@@ -3043,48 +3043,51 @@ bool create_tensors_helper::create_tensors() {
if (layer.attn_norm) {
auto split = create_split(ggml_nrows(layer.attn_norm), -1, cur_splits, mem_used);
prepare_split_tensors(-1, ctx_split, layer.attn_norm, layer.split_attn_norm, split, mem_used);
if (layer.attn_sinks) {
prepare_split_tensors(-1, ctx_split, layer.attn_sinks, layer.split_attn_sinks, split, mem_used);
}
}
if (layer.rope_freqs) {
auto split = create_split(ggml_nrows(layer.rope_freqs), -1, cur_splits, mem_used);
prepare_split_tensors(-1, ctx_split, layer.rope_freqs, layer.split_rope_freqs, split, mem_used);
}
if (layer.wo && layer.wq && layer.wk && layer.wv) {
// TODO: fix this logic. It only works whe K and V head size is the same
//printf("Layer %d: q = %ld x %ld, k = %ld x %ld, v = %ld x %ld, qo = %ld x %ld\n", il, layer.wq->ne[0], layer.wq->ne[1],
// layer.wk->ne[0], layer.wk->ne[1], layer.wv->ne[0], layer.wv->ne[1], layer.wo->ne[0], layer.wo->ne[1]);
int attn_granularity = hparams.n_embd_head_v * gqa_ratio;
auto granularity_kq = hparams.n_embd_head_k * gqa_ratio;
auto granularity_vo = hparams.n_embd_head_v * gqa_ratio;
if (ggml_is_quantized(layer.wo->type)) {
auto tt = ggml_internal_get_type_traits(layer.wo->type);
if (tt.blck_size > attn_granularity) attn_granularity = tt.blck_size;
if (tt.blck_size > granularity_vo) granularity_vo = tt.blck_size;
GGML_ASSERT(granularity_vo % hparams.n_embd_head_v == 0);
}
GGML_ASSERT(attn_granularity % hparams.n_embd_head_v == 0);
auto split = create_split(layer.wo->ne[0], attn_granularity, cur_splits, mem_used);
//printf("Split:"); for (auto s : split) printf(" %d", s); printf("\n");
prepare_split_tensors(0, ctx_split, layer.wo, layer.split_wo, split, mem_used);
prepare_split_tensors(1, ctx_split, layer.wq, layer.split_wq, split, mem_used);
auto split_vo = create_split(layer.wo->ne[0], granularity_vo, cur_splits, mem_used);
auto split_kq = create_split(layer.wq->ne[1], granularity_kq, cur_splits, mem_used);
prepare_split_tensors(0, ctx_split, layer.wo, layer.split_wo, split_vo, mem_used);
prepare_split_tensors(1, ctx_split, layer.wq, layer.split_wq, split_kq, mem_used);
if (layer.bo) {
prepare_split_tensors(-1, ctx_split, layer.bo, layer.split_bo, split, mem_used);
prepare_split_tensors(-1, ctx_split, layer.bo, layer.split_bo, split_vo, mem_used);
}
if (layer.bq) {
prepare_split_tensors(0, ctx_split, layer.bq, layer.split_bq, split, mem_used);
prepare_split_tensors(0, ctx_split, layer.bq, layer.split_bq, split_kq, mem_used);
}
if (layer.attn_q_norm) {
prepare_split_tensors(-1, ctx_split, layer.attn_q_norm, layer.split_q_norm, split, mem_used);
prepare_split_tensors(-1, ctx_split, layer.attn_q_norm, layer.split_q_norm, split_kq, mem_used);
}
for (auto & s : split) s /= gqa_ratio;
prepare_split_tensors(1, ctx_split, layer.wk, layer.split_wk, split, mem_used);
prepare_split_tensors(1, ctx_split, layer.wv, layer.split_wv, split, mem_used);
if (layer.attn_sinks) {
auto split_sinks = split_kq;
for (auto & s : split_sinks) {
s /= hparams.n_embd_head_k;
}
prepare_split_tensors(0, ctx_split, layer.attn_sinks, layer.split_sinks, split_sinks, mem_used);
}
for (auto & s : split_kq) s /= gqa_ratio;
for (auto & s : split_vo) s /= gqa_ratio;
prepare_split_tensors(1, ctx_split, layer.wk, layer.split_wk, split_kq, mem_used);
prepare_split_tensors(1, ctx_split, layer.wv, layer.split_wv, split_vo, mem_used);
if (layer.bk) {
prepare_split_tensors(0, ctx_split, layer.bk, layer.split_bk, split, mem_used);
prepare_split_tensors(0, ctx_split, layer.bk, layer.split_bk, split_kq, mem_used);
}
if (layer.bv) {
prepare_split_tensors(0, ctx_split, layer.bv, layer.split_bv, split, mem_used);
prepare_split_tensors(0, ctx_split, layer.bv, layer.split_bv, split_vo, mem_used);
}
if (layer.attn_k_norm) {
prepare_split_tensors(-1, ctx_split, layer.attn_k_norm, layer.split_k_norm, split, mem_used);
prepare_split_tensors(-1, ctx_split, layer.attn_k_norm, layer.split_k_norm, split_kq, mem_used);
}
}