This results in faster PP.

Now PP is faster than split mode layer for L3-70B.
This commit is contained in:
Kawrakow
2025-11-27 14:46:19 +00:00
parent e7d897e26f
commit d8d9c7bdca
2 changed files with 68 additions and 5 deletions

View File

@@ -649,7 +649,7 @@ ggml_tensor * llm_build_context::llm_build_ffn(
std::vector<ggml_tensor *> ffn;
ffn.reserve(u->n_device);
for (int id = 0; id < u->n_device; ++id) {
int il_cb = 1000*id + il;
int il_cb = 1000*(id+1) + il;
auto split_u = u->splits[id];
auto split_g = g->splits[id];
auto split_d = d->splits[id];
@@ -659,6 +659,10 @@ ggml_tensor * llm_build_context::llm_build_ffn(
if (ffn_norm && ffn_norm->extra) {
auto norm = (ggml_split_tensor_t *)ffn_norm->extra;
cur = llm_build_norm(ctx, input, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM_RMS, cb, il);
cb(cur, "ffn_inp_normed", il_cb);
}
else if (input->type != GGML_TYPE_F32) {
cur = ggml_cast(ctx, input, GGML_TYPE_F32);
}
cur = ggml_fused_up_gate(ctx, split_u, split_g, cur, unary_op);
cb(cur, "ffn_up_gate", il_cb);
@@ -668,6 +672,9 @@ ggml_tensor * llm_build_context::llm_build_ffn(
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
if (cur->ne[1] >= 32) {
cur = ggml_cast(ctx, cur, GGML_TYPE_F16);
}
if (graph) {
ggml_build_forward_expand(graph, cur);
}
@@ -676,11 +683,18 @@ ggml_tensor * llm_build_context::llm_build_ffn(
if (ffn.size() == 1) return ffn.front();
auto cur = ggml_add(ctx, ffn[0], ffn[1]);
cb(cur, "combine_ffn", il);
cur->op_params[0] = 0xff;
for (int id = 2; id < int(ffn.size()); ++id) {
cur = ggml_add(ctx, cur, ffn[id]);
cb(cur, "combine_ffn", il);
}
cur->op_params[0] = 0xff;
if (ffn.size() > 2) {
cur->op_params[0] = 0xff;
}
if (cur->type != GGML_TYPE_F32) {
cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
}
return cur;
}
@@ -688,6 +702,9 @@ ggml_tensor * llm_build_context::llm_build_ffn(
input = llm_build_norm(ctx, input, lctx.model.hparams, ffn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(input, "ffn_norm", il);
}
else if (input->type != GGML_TYPE_F32) {
input = ggml_cast(ctx, input, GGML_TYPE_F32);
}
if (lctx.cparams.fused_up_gate &&
up && gate && !up_b && !up_s && !gate_b && !gate_s && type_gate == LLM_FFN_PAR &&
@@ -1621,6 +1638,7 @@ ggml_cgraph * llm_build_context::build_llama() {
auto split = output->splits[id];
if (!split) continue;
o.push_back(llm_build_lora_mm(lctx, ctx0, split, cur));
cb(o.back(), "output", id);
}
if (o.size() == 1) cur = o.front();
cur = ggml_concat(ctx0, o[0], o[1], 0);
@@ -8968,7 +8986,7 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
GGML_ASSERT(wq->n_device == kl->n_device && wq->n_device == vl->n_device);
std::vector<ggml_tensor*> attn; attn.reserve(wq->n_device);
for (int id = 0; id < wq->n_device; ++id) {
int il_cb = 1000*id + il;
int il_cb = 1000*(id+1) + il;
auto split_wq = wq->splits[id];
auto split_wk = wk->splits[id];
auto split_wv = wv->splits[id];
@@ -9058,6 +9076,7 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
#endif
cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, KQ_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
cb(cur, "flash_attn", il_cb);
ggml_flash_attn_ext_add_sinks(cur, sinks);
if (n_swa > 0) {
((int32_t *)cur->op_params)[4] = n_swa;
@@ -9071,6 +9090,7 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
}
cur = ggml_reshape_2d(ctx0, cur, split_wo->ne[0], n_tokens);
cb(cur, "flash_attn_reshaped", il_cb);
cur = llm_build_lora_mm(lctx, ctx0, split_wo, cur);
if (lctx.model.arch == LLM_ARCH_GLM4 || lctx.model.arch == LLM_ARCH_GLM4_MOE) {
@@ -9078,6 +9098,9 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
cb(cur, "kqv_wo", il_cb);
if (cur->ne[1] >= 32) {
cur = ggml_cast(ctx0, cur, GGML_TYPE_F16);
}
ggml_build_forward_expand(gf, cur);
// TODO: wo_b
attn.push_back(cur);
@@ -9085,11 +9108,14 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
if (attn.size() == 1) return attn.front();
auto cur = ggml_add(ctx0, attn[0], attn[1]);
cb(cur, "combine_attn", il);
cur->op_params[0] = 0xff;
for (int id = 2; id < (int)attn.size(); ++id) {
cur = ggml_add(ctx0, cur, attn[id]);
cb(cur, "combine_attn", il);
}
cur->op_params[0] = 0xff;
if (attn.size() > 2) {
cur->op_params[0] = 0xff;
}
return cur;
}
}