mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-22 22:24:11 +00:00
Split mode graph for Minimax-M2
This commit is contained in:
@@ -3570,6 +3570,13 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
return false;
|
||||
}
|
||||
|
||||
#if 0
|
||||
if (auto err = cudaStreamSynchronize(ctx.stream()); err != cudaSuccess) {
|
||||
GGML_CUDA_LOG_ERROR("%s: %s failed\n", __func__, ggml_op_desc(dst));
|
||||
CUDA_CHECK(err);
|
||||
}
|
||||
#endif
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
GGML_CUDA_LOG_ERROR("%s: %s failed\n", __func__, ggml_op_desc(dst));
|
||||
|
||||
@@ -8612,7 +8612,157 @@ ggml_cgraph* llm_build_context::build_minimaxm2() {
|
||||
cur = inpL;
|
||||
|
||||
// self_attention
|
||||
{
|
||||
if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) {
|
||||
// Unfortunately we cannot use build_std_attention because Q and K get normed before being RoPE'd,
|
||||
// but the RMS norm is applied on the whole row, and not per head as it is normally done.
|
||||
// Hence, we need to keep a copy of wq and wk on each device, do the whole matrix multiplications
|
||||
// on each device, apply the norm, and only then take from the result the self attention portion
|
||||
// being processed on the given device. If we would split wq and wk, we would need to reassemble
|
||||
// the whole Q and K via reduce-concat to apply the RMS norm, and that would kill performance.
|
||||
// Alternatively, we would need to add an extra reduce op, which computes the squared sum on each device,
|
||||
// than does a reduce-add operation to compute the total sum (per row) of Q and K, and then
|
||||
// it performs RMS norm using that. This would be possibly better, but let's leave it for another day.
|
||||
auto wq = (ggml_split_tensor_t *)model.layers[il].wq->extra;
|
||||
auto wk = (ggml_split_tensor_t *)model.layers[il].wk->extra;
|
||||
auto wv = (ggml_split_tensor_t *)model.layers[il].wv->extra;
|
||||
auto wo = (ggml_split_tensor_t *)model.layers[il].wo->extra;
|
||||
GGML_ASSERT(wq && wk && wv && wo);
|
||||
GGML_ASSERT(wq->n_device == wk->n_device && wq->n_device == wv->n_device && wq->n_device == wo->n_device);
|
||||
auto q_norm = (ggml_split_tensor_t *)model.layers[il].attn_q_norm->extra;
|
||||
auto k_norm = (ggml_split_tensor_t *)model.layers[il].attn_k_norm->extra;
|
||||
auto attn_norm = (ggml_split_tensor_t *)model.layers[il].attn_norm->extra;
|
||||
GGML_ASSERT(attn_norm && q_norm && k_norm);
|
||||
GGML_ASSERT(wq->n_device == q_norm->n_device && wq->n_device == k_norm->n_device && wq->n_device == attn_norm->n_device);
|
||||
auto kl = (ggml_split_tensor_t *)kv_self.k_l[il]->extra;
|
||||
auto vl = (ggml_split_tensor_t *)kv_self.v_l[il]->extra;
|
||||
GGML_ASSERT(wq->n_device == kl->n_device && wq->n_device == vl->n_device);
|
||||
int head_count = 0;
|
||||
int head_count_kv = 0;
|
||||
int n_device = wq->n_device;
|
||||
std::vector<ggml_tensor *> attn(n_device, nullptr);
|
||||
bool input_added = false;
|
||||
for (int id = 0; id < n_device; ++id) {
|
||||
if (!wq->splits[id]) continue;
|
||||
int il_id = 1000*il + id;
|
||||
auto input = get_input_tensor_sm_graph(ctx0, inpL, id);
|
||||
cur = llm_build_norm(ctx0, input, hparams, attn_norm->splits[id], nullptr, LLM_NORM_RMS, cb, il_id);
|
||||
|
||||
auto Qcur = llm_build_lora_mm(lctx, ctx0, wq->splits[id], cur);
|
||||
cb(Qcur, "Qcur", il_id);
|
||||
|
||||
auto Kcur = llm_build_lora_mm(lctx, ctx0, wk->splits[id], cur);
|
||||
cb(Kcur, "Kcur", il_id);
|
||||
|
||||
auto Vcur = llm_build_lora_mm(lctx, ctx0, wv->splits[id], cur);
|
||||
cb(Vcur, "Vcur", il_id);
|
||||
|
||||
// Do this here so Q, K, V matrix multiplications may be fused
|
||||
ggml_build_forward_expand(gf, Qcur);
|
||||
ggml_build_forward_expand(gf, Kcur);
|
||||
ggml_build_forward_expand(gf, Vcur);
|
||||
|
||||
Qcur = llm_build_norm(ctx0, Qcur, hparams, q_norm->splits[id], nullptr, LLM_NORM_RMS, cb, il_id);
|
||||
cb(Qcur, "Qcur_normed", il_id);
|
||||
|
||||
Kcur = llm_build_norm(ctx0, Kcur, hparams, k_norm->splits[id], nullptr, LLM_NORM_RMS, cb, il_id);
|
||||
cb(Kcur, "Kcur_normed", il_id);
|
||||
|
||||
// reshape for multi-head
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
int gqa_ratio = n_head / n_head_kv;
|
||||
int nhead_kv_id = Vcur->ne[0] / n_embd_head_v;
|
||||
int nhead_id = nhead_kv_id * gqa_ratio;
|
||||
GGML_ASSERT(nhead_kv_id > 0 && nhead_kv_id <= n_head_kv);
|
||||
|
||||
Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head_k, nhead_id, n_tokens, Qcur->nb[1], Qcur->nb[2], head_count*Qcur->nb[1]);
|
||||
Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head_k, nhead_kv_id, n_tokens, Kcur->nb[1], Kcur->nb[2], head_count_kv*Kcur->nb[1]);
|
||||
head_count += nhead_id;
|
||||
head_count_kv += nhead_kv_id;
|
||||
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur_roped", il_id);
|
||||
cb(Kcur, "Kcur_roped", il_id);
|
||||
|
||||
if (cparams.k_cache_hadamard) {
|
||||
Qcur = ggml_hadamard(ctx0, Qcur, hparams.n_embd_head_k);
|
||||
Kcur = ggml_hadamard(ctx0, Kcur, hparams.n_embd_head_k);
|
||||
cb(Qcur, "Qcur_hadamard", il_id);
|
||||
cb(Kcur, "Kcur_hadamard", il_id);
|
||||
}
|
||||
ggml_build_forward_expand(gf, Qcur);
|
||||
ggml_build_forward_expand(gf, Kcur);
|
||||
|
||||
// Store K, V in KV cache
|
||||
auto idx = 2*wq->n_device*il + 2*id;
|
||||
GGML_ASSERT(idx+1 < (int)lctx.cache_copies.size());
|
||||
auto k_row_size = ggml_row_size(kl->splits[id]->type, n_embd_head_k);
|
||||
auto k_cache_view = ggml_view_2d(ctx0, kl->splits[id], n_embd_head_k, n_tokens*nhead_kv_id,
|
||||
k_row_size, k_row_size*nhead_kv_id*kv_head);
|
||||
lctx.cache_copies[idx+0].cpy = ggml_cpy(ctx0, Kcur, k_cache_view);
|
||||
lctx.cache_copies[idx+0].step = k_row_size*nhead_kv_id;
|
||||
|
||||
auto v_cache_view = ggml_view_1d(ctx0, vl->splits[id], n_tokens*wv->splits[id]->ne[1],
|
||||
kv_head*ggml_row_size(vl->splits[id]->type, wv->splits[id]->ne[1]));
|
||||
lctx.cache_copies[idx+1].cpy = ggml_cpy(ctx0, Vcur, v_cache_view);
|
||||
lctx.cache_copies[idx+1].step = ggml_row_size(vl->splits[id]->type, wv->splits[id]->ne[1]);
|
||||
|
||||
ggml_build_forward_expand(gf, lctx.cache_copies[idx+0].cpy);
|
||||
ggml_build_forward_expand(gf, lctx.cache_copies[idx+1].cpy);
|
||||
|
||||
auto q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
cb(q, "q", il_id);
|
||||
|
||||
auto k = ggml_view_3d(ctx0, kl->splits[id], n_embd_head_k, n_kv, nhead_kv_id,
|
||||
ggml_row_size(kl->splits[id]->type, n_embd_head_k)*nhead_kv_id,
|
||||
ggml_row_size(kl->splits[id]->type, n_embd_head_k), 0);
|
||||
cb(k, "k", il_id);
|
||||
|
||||
auto v = ggml_view_3d(ctx0, vl->splits[id], n_embd_head_v, n_kv, nhead_kv_id,
|
||||
ggml_row_size( vl->splits[id]->type, wv->splits[id]->ne[1]),
|
||||
ggml_row_size( vl->splits[id]->type, n_embd_head_v), 0);
|
||||
cb(v, "v", il_id);
|
||||
|
||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, 1.0f / sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias, 0.0f);
|
||||
cb(cur, "fa", il_id);
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, wo->splits[id]->ne[0], n_tokens);
|
||||
cb(cur, "fa_reshaped", il_id);
|
||||
|
||||
if (il == n_layer - 1 && n_tokens > 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
cb(cur, "fa_get_rows", il_id);
|
||||
if (!input_added) {
|
||||
input = ggml_get_rows(ctx0, input, inp_out_ids);
|
||||
cb(cur, "sainp_get_rows", il_id);
|
||||
}
|
||||
}
|
||||
|
||||
cur = llm_build_lora_mm(lctx, ctx0, wo->splits[id], cur);
|
||||
cb(cur, "kqv_wo", il_id);
|
||||
|
||||
if (!input_added) {
|
||||
cur = ggml_add(ctx0, cur, input);
|
||||
cb(cur, "attn_out_with_input", il);
|
||||
input_added = true;
|
||||
}
|
||||
|
||||
if (cur->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) {
|
||||
cur = ggml_cast(ctx0, cur, lctx.cparams.reduce_type);
|
||||
}
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
attn[id] = cur;
|
||||
}
|
||||
|
||||
cur = ggml_reduce(ctx0, attn.data(), n_device, GGML_OP_ADD);
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
cb(cur, "attn_combined", il);
|
||||
|
||||
} else {
|
||||
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
@@ -8627,65 +8777,54 @@ ggml_cgraph* llm_build_context::build_minimaxm2() {
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
// reshape for multi-head
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
// Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
|
||||
// apply RoPE
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask,
|
||||
n_tokens, kv_head, n_kv,
|
||||
1.0f / sqrtf(float(n_embd_head)), cb, il);
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask,
|
||||
n_tokens, kv_head, n_kv,
|
||||
1.0f / sqrtf(float(n_embd_head)), cb, il);
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, inpSA);
|
||||
cb(cur, "ffn_inp", il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// MoE branch
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS,cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = llm_build_moe_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur,
|
||||
model.layers[il].ffn_gate_inp, nullptr,
|
||||
model.layers[il].ffn_up_exps, nullptr,
|
||||
model.layers[il].ffn_gate_exps, nullptr,
|
||||
model.layers[il].ffn_down_exps, nullptr,
|
||||
nullptr,
|
||||
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, // no shared experts
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0,
|
||||
LLM_FFN_SILU, true, false, 0.0f,
|
||||
(llm_expert_gating_func_type)hparams.expert_gating_func,
|
||||
cb, il, gf, false, model.layers[il].ffn_up_gate_exps);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
LLM_FFN_SILU, cb, il, gf, true, nullptr);
|
||||
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
@@ -2604,7 +2604,6 @@ bool create_tensors_helper::create_minimaxm2_tensors(const LLM_TN & tn) {
|
||||
create_embd_output(tn, n_embd, n_vocab);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context* ctx_layer = ctx_for_layer(i);
|
||||
ggml_context* ctx_split = ctx_for_layer_split(i);
|
||||
auto& layer = model.layers[i];
|
||||
|
||||
@@ -2613,13 +2612,13 @@ bool create_tensors_helper::create_minimaxm2_tensors(const LLM_TN & tn) {
|
||||
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
|
||||
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
|
||||
|
||||
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k * n_head }, 0);
|
||||
layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_k_gqa }, 0);
|
||||
layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.attn_q_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k * n_head }, 0);
|
||||
layer.attn_k_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_k_gqa }, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.ffn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
|
||||
layer.ffn_gate_inp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
|
||||
use_mmap_buffer &= !create_std_ffn_exps(n_embd, tn, i, 0, n_ff);
|
||||
layer.ffn_exp_probs_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, 0);
|
||||
}
|
||||
@@ -3130,17 +3129,36 @@ bool create_tensors_helper::create_tensors() {
|
||||
LLAMA_LOG_DEBUG("\n");
|
||||
LLAMA_LOG_DEBUG(" split_kq:"); for ([[maybe_unused]] auto s : split_kq) LLAMA_LOG_DEBUG(" %d", s);
|
||||
LLAMA_LOG_DEBUG("\n");
|
||||
if (layer.attn_q_norm) {
|
||||
printf("Layer %2d: q_norm = %ld x %ld, wq = %ld x %ld\n", il, layer.attn_q_norm->ne[0], layer.attn_q_norm->ne[1], layer.wq->ne[0], layer.wq->ne[1]);
|
||||
}
|
||||
if (layer.attn_k_norm) {
|
||||
printf("Layer %2d: k_norm = %ld x %ld, wq = %ld x %ld\n", il, layer.attn_k_norm->ne[0], layer.attn_k_norm->ne[1], layer.wk->ne[0], layer.wk->ne[1]);
|
||||
}
|
||||
|
||||
if (layer.attn_q_norm && layer.attn_q_norm->ne[0] == layer.wq->ne[1]) {
|
||||
// If RMS norm is not applied per attention head, as it is usually the case, but is applied to the
|
||||
// entire Q tensor (e.g., MiniMax-2), we need to have a copy of the entire wq and attn_q_norm tensors
|
||||
// on each participating GPU.
|
||||
prepare_split_tensors(-1, ctx_split, layer.wq, layer.split_wq, split_vo, mem_used);
|
||||
prepare_split_tensors(-1, ctx_split, layer.attn_q_norm, layer.split_q_norm, split_vo, mem_used);
|
||||
if (layer.bq) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.bq, layer.split_bq, split_vo, mem_used);
|
||||
}
|
||||
printf(" Not splitting wq, attn_q_norm for layer %d\n", il);
|
||||
} else {
|
||||
prepare_split_tensors(1, ctx_split, layer.wq, layer.split_wq, split_kq, mem_used);
|
||||
if (layer.attn_q_norm) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.attn_q_norm, layer.split_q_norm, split_kq, mem_used);
|
||||
}
|
||||
if (layer.bq) {
|
||||
prepare_split_tensors(0, ctx_split, layer.bq, layer.split_bq, split_kq, mem_used);
|
||||
}
|
||||
}
|
||||
prepare_split_tensors(0, ctx_split, layer.wo, layer.split_wo, split_vo, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.wq, layer.split_wq, split_kq, mem_used);
|
||||
if (layer.bo) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.bo, layer.split_bo, split_vo, mem_used);
|
||||
}
|
||||
if (layer.bq) {
|
||||
prepare_split_tensors(0, ctx_split, layer.bq, layer.split_bq, split_kq, mem_used);
|
||||
}
|
||||
if (layer.attn_q_norm) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.attn_q_norm, layer.split_q_norm, split_kq, mem_used);
|
||||
}
|
||||
if (layer.attn_sinks) {
|
||||
auto split_sinks = split_kq;
|
||||
for (auto & s : split_sinks) {
|
||||
@@ -3150,17 +3168,29 @@ bool create_tensors_helper::create_tensors() {
|
||||
}
|
||||
for (auto & s : split_kq) s /= gqa_ratio;
|
||||
for (auto & s : split_vo) s /= gqa_ratio;
|
||||
prepare_split_tensors(1, ctx_split, layer.wk, layer.split_wk, split_kq, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.wv, layer.split_wv, split_vo, mem_used);
|
||||
if (layer.bk) {
|
||||
prepare_split_tensors(0, ctx_split, layer.bk, layer.split_bk, split_kq, mem_used);
|
||||
if (layer.attn_k_norm && layer.attn_k_norm->ne[0] == layer.wk->ne[1]) {
|
||||
// If RMS norm is not applied per attention head, as it is usually the case, but is applied to the
|
||||
// entire K tensor (e.g., MiniMax-2), we need to have a copy of the entire wk and attn_k_norm tensors
|
||||
// on each participating GPU.
|
||||
prepare_split_tensors(-1, ctx_split, layer.wk, layer.split_wk, split_vo, mem_used);
|
||||
prepare_split_tensors(-1, ctx_split, layer.attn_k_norm, layer.split_k_norm, split_vo, mem_used);
|
||||
if (layer.bk) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.bk, layer.split_bk, split_vo, mem_used);
|
||||
}
|
||||
printf(" Not splitting wk, attn_k_norm for layer %d\n", il);
|
||||
} else {
|
||||
prepare_split_tensors(1, ctx_split, layer.wk, layer.split_wk, split_kq, mem_used);
|
||||
if (layer.bk) {
|
||||
prepare_split_tensors(0, ctx_split, layer.bk, layer.split_bk, split_kq, mem_used);
|
||||
}
|
||||
if (layer.attn_k_norm) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.attn_k_norm, layer.split_k_norm, split_kq, mem_used);
|
||||
}
|
||||
}
|
||||
prepare_split_tensors(1, ctx_split, layer.wv, layer.split_wv, split_vo, mem_used);
|
||||
if (layer.bv) {
|
||||
prepare_split_tensors(0, ctx_split, layer.bv, layer.split_bv, split_vo, mem_used);
|
||||
}
|
||||
if (layer.attn_k_norm) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.attn_k_norm, layer.split_k_norm, split_kq, mem_used);
|
||||
}
|
||||
}
|
||||
|
||||
if (layer.ffn_norm) {
|
||||
|
||||
@@ -779,6 +779,7 @@ static bool llama_kv_cache_init(
|
||||
n_mla++;
|
||||
}
|
||||
else {
|
||||
int n_embd_head_v = hparams.n_embd_head_v;
|
||||
k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size);
|
||||
v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
|
||||
auto k_name = std::string{"cache_k_l"} + std::to_string(i);
|
||||
@@ -793,6 +794,7 @@ static bool llama_kv_cache_init(
|
||||
auto K = model.layers[i].wk;
|
||||
auto V = model.layers[i].wv;
|
||||
if (K && V && K->extra && V->extra) {
|
||||
bool use_V_for_K = model.layers[i].attn_k_norm && model.layers[i].attn_k_norm->ne[0] == K->ne[1] ? true : false;
|
||||
auto extra_K = (const ggml_split_tensor_t *)K->extra;
|
||||
auto extra_V = (const ggml_split_tensor_t *)V->extra;
|
||||
auto & split_k_l = cache.split_k_l.emplace_back();
|
||||
@@ -800,9 +802,11 @@ static bool llama_kv_cache_init(
|
||||
split_k_l.tensor_splits.resize(extra_K->n_device, nullptr);
|
||||
split_v_l.tensor_splits.resize(extra_V->n_device, nullptr);
|
||||
for (int is = 0; is < extra_K->n_device; ++is) {
|
||||
auto split = extra_K->splits[is];
|
||||
auto split = use_V_for_K ? extra_V->splits[is] : extra_K->splits[is];
|
||||
if (!split) continue;
|
||||
split_k_l.tensor_splits[is] = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, split->ne[1]/n_embd_head_k * kv_size);
|
||||
int nhead_kv = use_V_for_K ? split->ne[1] / n_embd_head_v : split->ne[1]/n_embd_head_k;
|
||||
if (use_V_for_K) printf("K_cache(%d, %d): using %d instead of %ld heads\n", i, is, nhead_kv, extra_K->splits[is]->ne[1]/n_embd_head_k);
|
||||
split_k_l.tensor_splits[is] = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, nhead_kv * kv_size);
|
||||
auto split_name = k_name + '.' + std::to_string(is);
|
||||
ggml_set_name(split_k_l.tensor_splits[is], split_name.c_str());
|
||||
mem_split[is] += ggml_nbytes(split_k_l.tensor_splits[is]);
|
||||
@@ -1745,6 +1749,7 @@ static bool is_model_split_supported(const llama_model & model) {
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
LLM_ARCH_OPENAI_MOE,
|
||||
LLM_ARCH_ERNIE4_5_MOE,
|
||||
LLM_ARCH_MINIMAX_M2,
|
||||
};
|
||||
auto it = k_supported.find(model.arch);
|
||||
return it != k_supported.end();
|
||||
|
||||
Reference in New Issue
Block a user