Grouped expert routing (CPU only) (#836)

* Better argsort (CPU)

* Attemt at grouped topk

* This seems to do the trick for grouped experts routing

* Cleanup

* Trying to merge, something is not right

* Working merged grouped top_k (CPU)

* Add command line option to enable grouped expert routing

* Add grouped expert routing option to llama-bench

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-10-16 14:57:02 +03:00
committed by GitHub
parent e66d307e13
commit cde642e591
11 changed files with 221 additions and 44 deletions

View File

@@ -650,6 +650,7 @@ extern "C" {
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
GGML_OP_ARGSORT_THRESH,
GGML_OP_GROUPED_TOPK,
GGML_OP_LEAKY_RELU,
GGML_OP_SOFTCAP,
GGML_OP_SOFT_CAP_MAX,
@@ -2265,6 +2266,13 @@ extern "C" {
int k,
int min_entries,
float thresh);
GGML_API struct ggml_tensor * ggml_grouped_topk(
struct ggml_context * ctx,
struct ggml_tensor * a,
int num_groups,
int num_top_groups,
int nk,
int topk_experts);
#define GGML_KQ_MASK_PAD 16

View File

@@ -4253,6 +4253,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"TIMESTEP_EMBEDDING",
"ARGSORT",
"ARGSORT_THRESH",
"GROUPED_TOPK",
"LEAKY_RELU",
"SOFTCAP",
"SOFT_CAP_MAX",
@@ -4288,7 +4289,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -4356,6 +4357,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"timestep_embedding(timesteps, dim, max_period)",
"argsort(x)",
"argsort_thresh(x)",
"grouped_topk(x)",
"leaky_relu(x)",
"k2*tanh(k1*x)",
"soft_max(k2*tanh(k1*x))",
@@ -4391,7 +4393,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x),"
};
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -9439,6 +9441,39 @@ struct ggml_tensor * ggml_argsort_thresh(
return result;
}
struct ggml_tensor * ggml_grouped_topk(
struct ggml_context * ctx,
struct ggml_tensor * a,
int num_groups,
int num_top_groups,
int nk,
int topk_experts) {
GGML_ASSERT(num_top_groups <= num_groups);
GGML_ASSERT(a->ne[0] % num_groups == 0);
GGML_ASSERT(a->ne[0] >= topk_experts);
int64_t n_per_group = a->ne[0] / num_groups;
GGML_ASSERT(n_per_group >= nk);
bool is_node = false;
int64_t ne[GGML_MAX_DIMS];
for (int i = 1; i < GGML_MAX_DIMS; ++i) ne[i] = a->ne[i];
ne[0] = topk_experts;
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, ne);
ggml_set_op_params_i32(result, 0, num_groups);
ggml_set_op_params_i32(result, 1, num_top_groups);
ggml_set_op_params_i32(result, 2, nk);
result->op = GGML_OP_GROUPED_TOPK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_top_k
struct ggml_tensor * ggml_top_k(
@@ -20024,6 +20059,24 @@ static void ggml_compute_forward_argsort_thresh(
}
}
static void ggml_compute_forward_grouped_topk(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
iqk_grouped_top_k(dst, params->ith, params->nth);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_flash_attn_ext
static void ggml_compute_forward_flash_attn_ext_f16(
@@ -22521,6 +22574,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
{
ggml_compute_forward_argsort_thresh(params, tensor);
} break;
case GGML_OP_GROUPED_TOPK:
{
ggml_compute_forward_grouped_topk(params, tensor);
} break;
case GGML_OP_LEAKY_RELU:
{
ggml_compute_forward_leaky_relu(params, tensor);
@@ -23539,6 +23596,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
case GGML_OP_GROUPED_TOPK:
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
case GGML_OP_LEAKY_RELU:
{
GGML_ABORT("fatal error"); // TODO: not implemented
@@ -24281,6 +24342,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
case GGML_OP_ARGSORT_THRESH:
case GGML_OP_GROUPED_TOPK:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV:

View File

@@ -10,8 +10,97 @@
#include <cstdint>
#include <vector>
#include <algorithm>
#include <cmath>
void iqk_grouped_top_k([[maybe_unused]] ggml_tensor * dst, [[maybe_unused]] int ith, [[maybe_unused]] int nth) {
namespace {
// Playing around with group scores: use sum of probabilities in the group
inline float group_score(int n_per_group, const float * data) {
float sum = 0;
for (int j = 0; j < n_per_group; ++j) sum += data[j];
return sum;
}
// Playing around with group scores: use max of probabilities in the group
inline float group_score_max(int n_per_group, const float * data) {
float max = data[0];
for (int j = 1; j < n_per_group; ++j) max = std::max(max, data[j]);
return max;
}
// Actual top-nk group score: sum of top-nk probabilities in the group
inline float group_score(int n_per_group, int nk, const float * data, float * aux) {
for (int j = 0; j < n_per_group; ++j) aux[j] = data[j];
std::partial_sort(aux, aux + nk, aux + n_per_group, std::greater<float>{});
float sum = 0;
for (int j = 0; j < nk; ++j) sum += aux[j];
return sum;
}
inline std::vector<std::pair<float,int>> & get_work_buffer(size_t size) {
thread_local std::vector<std::pair<float,int>> buffer;
if (buffer.size() < size) buffer.resize(size);
return buffer;
}
}
void iqk_grouped_top_k(ggml_tensor * dst, int ith, int nth) {
auto src = dst->src[0];
GGML_ASSERT(dst->type == GGML_TYPE_I32);
GGML_ASSERT(src->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_nrows(src) == ggml_nrows(dst));
auto nrows = ggml_nrows(src);
auto npt = (nrows + nth - 1)/nth;
auto first = npt*ith;
auto last = std::min(first + npt, nrows);
if (last <= first) return;
int n_groups = dst->op_params[0];
int n_top_groups = dst->op_params[1];
int nk = dst->op_params[2];
int ne00 = src->ne[0];
int ne0 = dst->ne[0];
GGML_ASSERT(ne0 <= ne00);
GGML_ASSERT(ne00%n_groups == 0);
int n_per_group = ne00/n_groups;
GGML_ASSERT(nk <= n_per_group);
GGML_ASSERT(n_top_groups <= n_groups);
size_t work_size = n_groups + n_per_group*n_top_groups;
auto& aux = get_work_buffer(work_size);
auto groups = aux.data() + n_per_group*n_top_groups;
for (int ir = first; ir < last; ++ir) {
auto data = (const float *)((const char *)src->data + ir*src->nb[1]);
auto result = (int32_t *)((char *)dst->data + ir*dst->nb[1]);
if (ne0 > n_per_group*n_top_groups) {
for (int j = 0; j < ne0; ++j) result[j] = j;
continue;
}
if (n_top_groups < n_groups) {
for (int ig = 0; ig < n_groups; ++ig) {
//groups[ig] = { group_score(n_per_group, data + ig*n_per_group), ig };
//groups[ig] = { group_score_max(n_per_group, data + ig*n_per_group), ig };
groups[ig] = { group_score(n_per_group, nk, data + ig*n_per_group, (float *)aux.data()), ig };
}
std::partial_sort(groups, groups + n_top_groups, groups + n_groups, std::greater<std::pair<float,int>>{});
for (int ig = 0; ig < n_top_groups; ++ig) {
int i0 = n_per_group * ig;
int j0 = n_per_group * groups[ig].second;
for (int j = 0; j < n_per_group; ++j) aux[i0 + j] = { data[j0 + j], j0 + j };
}
} else {
for (int j = 0; j < ne00; ++j) aux[j] = { data[j], j };
}
if (ne0 < n_top_groups*n_per_group) {
std::partial_sort(aux.begin(), aux.begin() + ne0, aux.begin() + n_top_groups*n_per_group, std::greater<std::pair<float,int>>{});
} else {
std::sort(aux.begin(), aux.begin() + ne0, std::greater<std::pair<float,int>>{});
}
for (int j = 0; j < ne0; ++j) result[j] = aux[j].second;
}
}
void iqk_argsort(ggml_tensor * dst, int ith, int nth) {
@@ -30,8 +119,7 @@ void iqk_argsort(ggml_tensor * dst, int ith, int nth) {
int nk = dst->op_params[1];
int ne00 = src->ne[0];
thread_local std::vector<std::pair<float,int>> aux;
if ((int)aux.size() < ne00) aux.resize(ne00);
auto& aux = get_work_buffer(ne00);
for (int ir = first; ir < last; ++ir) {
auto data = (const float *)((const char *)src->data + ir*src->nb[1]);