Copy reduce result to other GPUs if necessary (#1156)

This commit is contained in:
Kawrakow
2026-01-19 08:40:26 +02:00
committed by GitHub
parent 6dfbef27ec
commit 0c0b6e4b8b
3 changed files with 91 additions and 49 deletions

View File

@@ -619,14 +619,19 @@ ggml_tensor * llm_build_context::llm_build_norm(
return cur;
}
static ggml_tensor * get_input_tensor_sm_graph(ggml_tensor * input, int id) {
static ggml_tensor * get_input_tensor_sm_graph(ggml_context * ctx, ggml_tensor * input, int id) {
auto cur = input;
if (input->op == GGML_OP_REDUCE) {
auto view_src = input->view_src;
GGML_ASSERT(view_src);
cur = input->src[id];
if (cur == view_src || !cur) {
//printf("%s: Setting input to %s for id = %d\n", __func__, view_src->name, id);
if (!cur) {
GGML_ASSERT((input->op_params[4] & (1u << id)) == 0);
cur = ggml_dup_tensor(ctx, input);
input->src[id] = cur;
input->op_params[4] |= (1u << id);
}
else if (cur == view_src) {
cur = input;
}
}
@@ -693,7 +698,7 @@ ggml_tensor * llm_build_context::llm_build_ffn(
auto split_d = d->splits[id];
GGML_ASSERT((!split_u && !split_g && !split_d) || (split_u && split_g && split_d));
if (!split_u) continue;
auto cur = get_input_tensor_sm_graph(input, id);
auto cur = get_input_tensor_sm_graph(ctx, input, id);
cur = do_split_norm(ctx, cur, ffn_norm, lctx.model.hparams, cb, id, il_cb, is_norm);
cur = ggml_fused_up_gate(ctx, split_u, split_g, cur, unary_op);
cb(cur, "ffn_up_gate", il_cb);
@@ -1277,17 +1282,8 @@ llm_expert_gating_func_type gating_op,
(!split_up_exps->splits[id] && !split_gate_exps->splits[id] && !split_down_exps->splits[id]));
if (!split_up_exps->splits[id]) continue;
int il_cb = 1000*(id + 1) + il;
auto cur = get_input_tensor_sm_graph(input, id);
auto cur = get_input_tensor_sm_graph(ctx, input, id);
cur = do_split_norm(ctx, cur, ffn_norm, lctx.model.hparams, cb, id, il_cb, false);
//if (ffn_norm) {
// auto split_ffn_norm = (ggml_split_tensor_t *)ffn_norm->extra;
// GGML_ASSERT(split_ffn_norm && split_ffn_norm->n_device == split_up_exps->n_device);
// cur = llm_build_norm(ctx, cur, lctx.model.hparams, split_ffn_norm->splits[id], nullptr, LLM_NORM_RMS, cb, il);
// cb(cur, "ffn_inp_normed", il_cb);
//}
//if (cur->type != GGML_TYPE_F32) {
// cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
//}
GGML_ASSERT(!split_gate_inp_b || split_gate_inp_b->splits[id]);
GGML_ASSERT(!split_exps_down_b || split_exps_down_b->splits[id]);
GGML_ASSERT(!split_exps_gate_b || split_exps_gate_b->splits[id]);
@@ -9190,7 +9186,6 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
if (!model.layers[il].wqkv && !model.layers[il].wqk && cparams.flash_attn &&
model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra) {
if (kv_self.k_l[il]->extra && kv_self.v_l[il]->extra) {
//ggml_split_tensor_t * attn_norm = the_attn_norm ? (ggml_split_tensor_t *)the_attn_norm->extra : nullptr;
auto wq = (ggml_split_tensor_t *)model.layers[il].wq->extra;
auto wk = (ggml_split_tensor_t *)model.layers[il].wk->extra;
auto wv = (ggml_split_tensor_t *)model.layers[il].wv->extra;
@@ -9230,18 +9225,8 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
GGML_ASSERT((!split_wq && !split_wk && !split_wv && !split_wo && !split_kl && !split_vl) ||
(split_wq && split_wk && split_wv && split_wo && split_kl && split_vl));
if (!split_wq) continue;
auto cur = get_input_tensor_sm_graph(input, id);
auto cur = get_input_tensor_sm_graph(ctx0, input, id);
cur = do_split_norm(ctx0, cur, the_attn_norm, lctx.model.hparams, cb, id, il_cb, is_norm);
//if (attn_norm) {
// if (is_norm) {
// cur = ggml_fused_norm(ctx0, cur, attn_norm->splits[id], lctx.model.hparams.f_norm_eps);
// } else {
// cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM_RMS, cb, il);
// }
//}
//if (cur->type != GGML_TYPE_F32) {
// cur = ggml_cast(ctx0, cur, GGML_TYPE_F32);
//}
auto the_q_norm = model.layers[il].attn_q_norm ? model.layers[il].attn_q_norm->extra ?
((ggml_split_tensor_t *)model.layers[il].attn_q_norm->extra)->splits[id] : model.layers[il].attn_q_norm : nullptr;
auto the_k_norm = model.layers[il].attn_k_norm ? model.layers[il].attn_k_norm->extra ?