Be able to set reduce op data type for split mode "graph"

This commit is contained in:
Iwan Kawrakow
2025-12-24 10:57:41 +00:00
parent 1d7d0225a0
commit c6a3903571
7 changed files with 23 additions and 6 deletions

View File

@@ -697,7 +697,7 @@ ggml_tensor * llm_build_context::llm_build_ffn(
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
if (cur->ne[1] >= 32) {
if (cur->ne[1] > 32 && lctx.cparams.split_mode_f16) {
cur = ggml_cast(ctx, cur, GGML_TYPE_F16);
}
if (graph) {
@@ -1185,7 +1185,7 @@ llm_expert_gating_func_type gating_op,
split_down_shexp->splits[id], split_down_b_shexp ? split_down_b_shexp->splits[id] : nullptr, nullptr,
nullptr, type_op_shexp, LLM_FFN_PAR, cb, il);
cb(shared_out, "ffn_shexp_out", il_cb);
if (shared_out->ne[1] > 32) {
if (shared_out->ne[1] > 32 && lctx.cparams.split_mode_f16) {
shared_out = ggml_cast(ctx, shared_out, GGML_TYPE_F16);
}
results.push_back(shared_out);
@@ -1202,7 +1202,7 @@ llm_expert_gating_func_type gating_op,
cb(cur, "ffn_shared_combined", il);
}
}
if (routed_out->ne[1] > 32) {
if (routed_out->ne[1] > 32 && lctx.cparams.split_mode_f16) {
auto routed_out_f16 = ggml_cast(ctx, routed_out, GGML_TYPE_F16);
cur = ggml_add(ctx, routed_out_f16, cur);
} else {
@@ -1279,7 +1279,7 @@ llm_expert_gating_func_type gating_op,
} else {
cur = routed_out;
}
if (cur->ne[1] >= 32) {
if (cur->ne[1] > 32 && lctx.cparams.split_mode_f16) {
cur = ggml_cast(ctx, cur, GGML_TYPE_F16);
cb(cur, "ffn_out_f16", il_cb);
}
@@ -9513,7 +9513,7 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
cur = ggml_add(ctx0, cur, bo->splits[id]);
cb(cur, "kqv_wo_biased", il_cb);
}
if (cur->ne[1] >= 32) {
if (cur->ne[1] > 32 && lctx.cparams.split_mode_f16) {
cur = ggml_cast(ctx0, cur, GGML_TYPE_F16);
}
ggml_build_forward_expand(gf, cur);