From ffb393230095f17ecc0dccabf5376e6aa60efddf Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 15 Oct 2025 17:29:10 +0300 Subject: [PATCH] Attemt at grouped topk --- ggml/include/ggml.h | 7 ++++ ggml/src/ggml.c | 62 ++++++++++++++++++++++++++++++++++-- ggml/src/iqk/iqk_cpu_ops.cpp | 56 ++++++++++++++++++++++++++++++++ src/llama-build-context.cpp | 58 +++++++++++++++++---------------- 4 files changed, 154 insertions(+), 29 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index a5fda8a4..a784461f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -650,6 +650,7 @@ extern "C" { GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, GGML_OP_ARGSORT_THRESH, + GGML_OP_GROUPED_TOPK, GGML_OP_LEAKY_RELU, GGML_OP_SOFTCAP, GGML_OP_SOFT_CAP_MAX, @@ -2265,6 +2266,12 @@ extern "C" { int k, int min_entries, float thresh); + GGML_API struct ggml_tensor * ggml_grouped_topk( + struct ggml_context * ctx, + struct ggml_tensor * a, + int num_groups, + int num_top_groups, + int nk); #define GGML_KQ_MASK_PAD 16 diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index bfd5e41e..444825f5 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4253,6 +4253,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "TIMESTEP_EMBEDDING", "ARGSORT", "ARGSORT_THRESH", + "GROUPED_TOPK", "LEAKY_RELU", "SOFTCAP", "SOFT_CAP_MAX", @@ -4288,7 +4289,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); +static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -4356,6 +4357,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", "argsort_thresh(x)", + "grouped_topk(x)", "leaky_relu(x)", "k2*tanh(k1*x)", "soft_max(k2*tanh(k1*x))", @@ -4391,7 +4393,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)," }; -static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); +static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -9439,6 +9441,35 @@ struct ggml_tensor * ggml_argsort_thresh( return result; } +struct ggml_tensor * ggml_grouped_topk( + struct ggml_context * ctx, + struct ggml_tensor * a, + int num_groups, + int num_top_groups, + int nk) { + + GGML_ASSERT(num_top_groups < num_groups); + GGML_ASSERT(a->ne[0] % num_groups == 0); + int64_t n_per_group = a->ne[0] / num_groups; + GGML_ASSERT(n_per_group >= nk); + + bool is_node = false; + + //struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne); + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, num_groups); + ggml_set_op_params_i32(result, 1, num_top_groups); + ggml_set_op_params_i32(result, 2, nk); + + result->op = GGML_OP_GROUPED_TOPK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + + // ggml_top_k struct ggml_tensor * ggml_top_k( @@ -20024,6 +20055,24 @@ static void ggml_compute_forward_argsort_thresh( } } +static void ggml_compute_forward_grouped_topk( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + iqk_grouped_top_k(dst, params->ith, params->nth); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_flash_attn_ext static void ggml_compute_forward_flash_attn_ext_f16( @@ -22521,6 +22570,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml { ggml_compute_forward_argsort_thresh(params, tensor); } break; + case GGML_OP_GROUPED_TOPK: + { + ggml_compute_forward_grouped_topk(params, tensor); + } break; case GGML_OP_LEAKY_RELU: { ggml_compute_forward_leaky_relu(params, tensor); @@ -23539,6 +23592,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: not implemented } + case GGML_OP_GROUPED_TOPK: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_LEAKY_RELU: { GGML_ABORT("fatal error"); // TODO: not implemented @@ -24281,6 +24338,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: case GGML_OP_ARGSORT_THRESH: + case GGML_OP_GROUPED_TOPK: case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: diff --git a/ggml/src/iqk/iqk_cpu_ops.cpp b/ggml/src/iqk/iqk_cpu_ops.cpp index e37cb5c2..3695e215 100644 --- a/ggml/src/iqk/iqk_cpu_ops.cpp +++ b/ggml/src/iqk/iqk_cpu_ops.cpp @@ -4,8 +4,64 @@ #include #include #include +#include void iqk_grouped_top_k([[maybe_unused]] ggml_tensor * dst, [[maybe_unused]] int ith, [[maybe_unused]] int nth) { + auto src = dst->src[0]; + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_are_same_shape(src, dst)); + + auto nrows = ggml_nrows(src); + auto npt = (nrows + nth - 1)/nth; + auto first = npt*ith; + auto last = std::min(first + npt, nrows); + if (last <= first) return; + + int n_groups = dst->op_params[0]; + int n_top_groups = dst->op_params[1]; + int nk = dst->op_params[2]; + + //if (ith == 0) printf("%s: ne00 = %ld, n_groups = %d, n_top_groups = %d, nk = %d\n", __func__, src->ne[0], n_groups, n_top_groups, nk); + + int ne00 = src->ne[0]; + GGML_ASSERT(ne00%n_groups == 0); + int n_per_group = ne00/n_groups; + GGML_ASSERT(nk <= n_per_group); + + thread_local std::vector> aux; + if ((int)aux.size() < n_per_group + n_groups) aux.resize(n_per_group + n_groups); + + auto groups = aux.data() + n_per_group; + + for (int ir = first; ir < last; ++ir) { + auto data = (const float *)((const char *)src->data + ir*src->nb[1]); + auto result = (float *)((char *)dst->data + ir*dst->nb[1]); + for (int j = 0; j < ne00; ++j) result[j] = -INFINITY; + for (int ig = 0; ig < n_groups; ++ig) { + for (int j = 0; j < n_per_group; ++j) { + int jj = ig*n_per_group + j; + aux[j] = { data[jj], jj }; + } + std::partial_sort(aux.begin(), aux.begin() + nk, aux.end(), std::greater>{}); + for (int j = 0; j < nk; ++j) result[aux[j].second] = data[aux[j].second]; + //float sum = 0; + //for (int j = 0; j < nk; ++j) sum += aux[j].first; + //groups[ig] = { sum, ig }; + } + //std::partial_sort(groups, groups + n_top_groups, groups + n_groups, std::greater>{}); + + //for (int ig = 0; ig < n_top_groups; ++ig) { + // int jg = groups[ig].second; + // for (int j = 0; j < n_per_group; ++j) result[jg*n_per_group + j] = data[jg*n_per_group + j]; + //} + //for (int ig = n_top_groups; ig < n_groups; ++ig) { + // int jg = groups[ig].second; + // for (int j = 0; j < n_per_group; ++j) result[jg*n_per_group + j] = -INFINITY; + //} + + } + } void iqk_argsort(ggml_tensor * dst, int ith, int nth) { diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index e1817bae..cbb2b8d3 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -820,36 +820,40 @@ llm_expert_gating_func_type gating_op, selection_probs = logits; } - if (false && lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) { + if (true && lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) { auto& hparams = lctx.model.hparams; - const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups; - // organize experts into n_expert_groups - ggml_tensor * selection_groups = ggml_view_2d(ctx, ggml_cont(ctx, ggml_transpose(ctx, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups, n_tokens * n_exp_per_group * sizeof(float), 0); // [n_tokens * n_exp_per_group, n_expert_groups] -#if 0 - ggml_tensor * group_scores = ggml_top_k(ctx, selection_groups, 2); // [2, n_expert_groups] - group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 2, n_expert_groups] - - // get top n_group_used expert groups - group_scores = ggml_transpose(ctx, ggml_sum_rows(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2]))); // [n_expert_groups, 1] -#else - // Replace top_k(2) with argmax due to backend limitations, ideally we should use something like argmax2 instead - ggml_tensor * group_scores = ggml_reshape_2d(ctx, ggml_argmax(ctx, selection_groups), 1, selection_groups->ne[1]); // [1, n_expert_groups] - group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 1, n_expert_groups] - - // get top n_group_used expert groups - group_scores = ggml_transpose(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2])); // [n_expert_groups, 1] -#endif - ggml_tensor * expert_groups = ggml_top_k(ctx, ggml_cont(ctx, group_scores), hparams.n_group_used); // [n_group_used, 1] - cb(expert_groups->src[0], "ffn_moe_group_argsort", il); - cb(expert_groups, "ffn_moe_group_topk", il); - - // mask out the other groups - selection_probs = ggml_get_rows(ctx, selection_groups, expert_groups); // [n_tokens * n_exp_per_group, n_group_used] - selection_probs = ggml_set_rows(ctx, ggml_scale_bias(ctx, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_tokens * n_exp_per_group, n_expert_groups] - selection_probs = ggml_view_2d(ctx, selection_probs, n_tokens, n_expert, n_tokens * sizeof(float), 0); // [n_tokens, n_expert] - selection_probs = ggml_cont(ctx, ggml_transpose(ctx, selection_probs)); // [n_expert, n_tokens] + selection_probs = ggml_grouped_topk(ctx, selection_probs, hparams.n_expert_groups, hparams.n_group_used, 2); cb(selection_probs, "ffn_moe_probs_masked", il); + +// const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups; +// +// // organize experts into n_expert_groups +// ggml_tensor * selection_groups = ggml_view_2d(ctx, ggml_cont(ctx, ggml_transpose(ctx, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups, n_tokens * n_exp_per_group * sizeof(float), 0); // [n_tokens * n_exp_per_group, n_expert_groups] +//#if 0 +// ggml_tensor * group_scores = ggml_top_k(ctx, selection_groups, 2); // [2, n_expert_groups] +// group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 2, n_expert_groups] +// +// // get top n_group_used expert groups +// group_scores = ggml_transpose(ctx, ggml_sum_rows(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2]))); // [n_expert_groups, 1] +//#else +// // Replace top_k(2) with argmax due to backend limitations, ideally we should use something like argmax2 instead +// ggml_tensor * group_scores = ggml_reshape_2d(ctx, ggml_argmax(ctx, selection_groups), 1, selection_groups->ne[1]); // [1, n_expert_groups] +// group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 1, n_expert_groups] +// +// // get top n_group_used expert groups +// group_scores = ggml_transpose(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2])); // [n_expert_groups, 1] +//#endif +// ggml_tensor * expert_groups = ggml_top_k(ctx, ggml_cont(ctx, group_scores), hparams.n_group_used); // [n_group_used, 1] +// cb(expert_groups->src[0], "ffn_moe_group_argsort", il); +// cb(expert_groups, "ffn_moe_group_topk", il); +// +// // mask out the other groups +// selection_probs = ggml_get_rows(ctx, selection_groups, expert_groups); // [n_tokens * n_exp_per_group, n_group_used] +// selection_probs = ggml_set_rows(ctx, ggml_scale_bias(ctx, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_tokens * n_exp_per_group, n_expert_groups] +// selection_probs = ggml_view_2d(ctx, selection_probs, n_tokens, n_expert, n_tokens * sizeof(float), 0); // [n_tokens, n_expert] +// selection_probs = ggml_cont(ctx, ggml_transpose(ctx, selection_probs)); // [n_expert, n_tokens] +// cb(selection_probs, "ffn_moe_probs_masked", il); } // select experts