Grouped expert routing (CPU only) (#836)

* Better argsort (CPU)

* Attemt at grouped topk

* This seems to do the trick for grouped experts routing

* Cleanup

* Trying to merge, something is not right

* Working merged grouped top_k (CPU)

* Add command line option to enable grouped expert routing

* Add grouped expert routing option to llama-bench

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-10-16 14:57:02 +03:00
committed by GitHub
parent ecf8f931ea
commit dbfd151594
11 changed files with 221 additions and 44 deletions

View File

@@ -1012,6 +1012,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.fused_moe_up_gate = true; params.fused_moe_up_gate = true;
return true; return true;
} }
if (arg == "-ger" || arg == "--grouped-expert-routing") {
params.grouped_expert_routing = true;
return true;
}
if (arg == "-no-fug" || arg == "--no-fused-up-gate") { if (arg == "-no-fug" || arg == "--no-fused-up-gate") {
params.fused_up_gate = false; params.fused_up_gate = false;
return true; return true;
@@ -1800,6 +1804,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn }); options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn });
options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch}); options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch});
options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" });
options.push_back({ "*", "-ger, --grouped-expert-routing", "enable grouped expert routing (default: %s)", params.grouped_expert_routing ? "enabled" : "disabled" });
options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" });
options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts}); options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts});
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
@@ -2755,6 +2760,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.mla_attn = params.mla_attn; cparams.mla_attn = params.mla_attn;
cparams.attn_max_batch = params.attn_max_batch; cparams.attn_max_batch = params.attn_max_batch;
cparams.fused_moe_up_gate = params.fused_moe_up_gate; cparams.fused_moe_up_gate = params.fused_moe_up_gate;
cparams.grouped_expert_routing = params.grouped_expert_routing;
cparams.fused_up_gate = params.fused_up_gate; cparams.fused_up_gate = params.fused_up_gate;
cparams.min_experts = params.min_experts; cparams.min_experts = params.min_experts;
cparams.thresh_experts = params.thresh_experts; cparams.thresh_experts = params.thresh_experts;
@@ -3871,6 +3877,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn); fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn);
fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch); fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch);
fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false"); fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false");
fprintf(stream, "grouped_expert_routing: %s # default: false\n", params.grouped_expert_routing ? "true" : "false");
fprintf(stream, "fused_up_gate: %s # default: true\n", params.fused_up_gate ? "true" : "false"); fprintf(stream, "fused_up_gate: %s # default: true\n", params.fused_up_gate ? "true" : "false");
fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts); fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts);
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);

View File

@@ -235,6 +235,7 @@ struct gpt_params {
int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false) int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false)
bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models
bool fused_up_gate = true; // fused up*unary(gate) op bool fused_up_gate = true; // fused up*unary(gate) op
bool grouped_expert_routing = false; // if to use grouped expert routing (BailingMoeV2 arch)
int min_experts = -1; int min_experts = -1;
float thresh_experts = 0; float thresh_experts = 0;

View File

@@ -261,6 +261,7 @@ struct cmd_params {
bool warmup; bool warmup;
bool repack = false; bool repack = false;
bool fmoe = false; bool fmoe = false;
bool ger = false; // ger = Grouped Expert Routing
bool no_fug = false; bool no_fug = false;
bool use_thp = false; bool use_thp = false;
output_formats output_format; output_formats output_format;
@@ -296,9 +297,10 @@ static const cmd_params cmd_params_defaults = {
/* verbose */ false, /* verbose */ false,
/* warmup */ true, /* warmup */ true,
/* repack */ false, /* repack */ false,
/* use_thp */ false,
/* fmoe */ false, /* fmoe */ false,
/* ger */ false,
/* no_fug */ false, /* no_fug */ false,
/* use_thp */ false,
/* output_format */ MARKDOWN, /* output_format */ MARKDOWN,
/* output_format_stderr */ NONE, /* output_format_stderr */ NONE,
}; };
@@ -341,6 +343,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -thp, --transparent-huge-pages <0|1> (default: %s)\n", cmd_params_defaults.use_thp? "1" : "0"); printf(" -thp, --transparent-huge-pages <0|1> (default: %s)\n", cmd_params_defaults.use_thp? "1" : "0");
printf(" -ot, --override-tensor pattern (default: none)\n"); printf(" -ot, --override-tensor pattern (default: none)\n");
printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0"); printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0");
printf(" -ger, --grouped-expert-routing <0|1>(default: %s)\n", cmd_params_defaults.ger ? "1" : "0");
printf(" -no-fug, --no-fused-up-gate <0|1> (default: %s)\n", cmd_params_defaults.no_fug? "1" : "0"); printf(" -no-fug, --no-fused-up-gate <0|1> (default: %s)\n", cmd_params_defaults.no_fug? "1" : "0");
printf("\n"); printf("\n");
printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n"); printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
@@ -739,6 +742,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
break; break;
} }
params.fmoe = std::stoi(argv[i]); params.fmoe = std::stoi(argv[i]);
} else if (arg == "-ger" || arg == "--grouped-expert-routing") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.ger = std::stoi(argv[i]);
} else if (arg == "-no-fug" || arg == "--no-fused-up-gate") { } else if (arg == "-no-fug" || arg == "--no-fused-up-gate") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@@ -829,6 +838,7 @@ struct cmd_params_instance {
bool embeddings; bool embeddings;
bool repack = false; bool repack = false;
bool fmoe = false; bool fmoe = false;
bool ger = false;
bool no_fug = false; bool no_fug = false;
bool use_thp = false; bool use_thp = false;
const llama_model_tensor_buft_override* buft_overrides; const llama_model_tensor_buft_override* buft_overrides;
@@ -876,6 +886,7 @@ struct cmd_params_instance {
cparams.mla_attn = mla_attn; cparams.mla_attn = mla_attn;
cparams.attn_max_batch = attn_max_batch; cparams.attn_max_batch = attn_max_batch;
cparams.fused_moe_up_gate = fmoe; cparams.fused_moe_up_gate = fmoe;
cparams.grouped_expert_routing = ger;
cparams.fused_up_gate = !no_fug; cparams.fused_up_gate = !no_fug;
cparams.min_experts = ser.first; cparams.min_experts = ser.first;
cparams.thresh_experts = ser.second; cparams.thresh_experts = ser.second;
@@ -935,6 +946,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .embeddings = */ embd, /* .embeddings = */ embd,
/* .repack = */ params.repack, /* .repack = */ params.repack,
/* .fmoe = */ params.fmoe, /* .fmoe = */ params.fmoe,
/* .ger = */ params.ger,
/* .no_fug = */ params.no_fug, /* .no_fug = */ params.no_fug,
/* .use_thp = */ params.use_thp, /* .use_thp = */ params.use_thp,
/* .buft_overrides=*/ params.buft_overrides.data(), /* .buft_overrides=*/ params.buft_overrides.data(),
@@ -970,6 +982,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .embeddings = */ embd, /* .embeddings = */ embd,
/* .repack = */ params.repack, /* .repack = */ params.repack,
/* .fmoe = */ params.fmoe, /* .fmoe = */ params.fmoe,
/* .ger = */ params.ger,
/* .no_fug = */ params.no_fug, /* .no_fug = */ params.no_fug,
/* .use_thp = */ params.use_thp, /* .use_thp = */ params.use_thp,
/* .buft_overrides=*/ params.buft_overrides.data(), /* .buft_overrides=*/ params.buft_overrides.data(),
@@ -1005,6 +1018,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .embeddings = */ embd, /* .embeddings = */ embd,
/* .repack = */ params.repack, /* .repack = */ params.repack,
/* .fmoe = */ params.fmoe, /* .fmoe = */ params.fmoe,
/* .ger = */ params.ger,
/* .no_fug = */ params.no_fug, /* .no_fug = */ params.no_fug,
/* .use_thp = */ params.use_thp, /* .use_thp = */ params.use_thp,
/* .buft_overrides=*/ params.buft_overrides.data(), /* .buft_overrides=*/ params.buft_overrides.data(),
@@ -1040,6 +1054,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .embeddings = */ embd, /* .embeddings = */ embd,
/* .repack = */ params.repack, /* .repack = */ params.repack,
/* .fmoe = */ params.fmoe, /* .fmoe = */ params.fmoe,
/* .ger = */ params.ger,
/* .no_fug = */ params.no_fug, /* .no_fug = */ params.no_fug,
/* .use_thp = */ params.use_thp, /* .use_thp = */ params.use_thp,
/* .buft_overrides=*/ params.buft_overrides.data(), /* .buft_overrides=*/ params.buft_overrides.data(),
@@ -1086,6 +1101,7 @@ struct test {
bool embeddings; bool embeddings;
bool repack = false; bool repack = false;
bool fmoe = false; bool fmoe = false;
bool ger = false;
bool no_fug = false; bool no_fug = false;
bool use_thp = false; bool use_thp = false;
int n_prompt; int n_prompt;
@@ -1120,6 +1136,8 @@ struct test {
use_mmap = inst.use_mmap; use_mmap = inst.use_mmap;
embeddings = inst.embeddings; embeddings = inst.embeddings;
repack = inst.repack; repack = inst.repack;
fmoe = inst.fmoe;
ger = inst.ger;
no_fug = inst.no_fug; no_fug = inst.no_fug;
use_thp = inst.use_thp; use_thp = inst.use_thp;
n_prompt = inst.n_prompt; n_prompt = inst.n_prompt;
@@ -1212,7 +1230,7 @@ struct test {
"n_threads", "type_k", "type_v", "n_threads", "type_k", "type_v",
"n_gpu_layers", "split_mode", "n_gpu_layers", "split_mode",
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser",
"tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "fused_up_gate", "use_thp", "tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "grouped_er", "fused_up_gate", "use_thp",
"n_prompt", "n_gen", "test_time", "n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns", "avg_ns", "stddev_ns",
"avg_ts", "stddev_ts", "test", "avg_ts", "stddev_ts", "test",
@@ -1234,7 +1252,7 @@ struct test {
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" || field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" ||
field == "fused_moe" || field == "fused_up_gate") { field == "fused_moe" || field == "grouped_er" || field == "fused_up_gate") {
return BOOL; return BOOL;
} }
if (field == "avg_ts" || field == "stddev_ts") { if (field == "avg_ts" || field == "stddev_ts") {
@@ -1277,7 +1295,8 @@ struct test {
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser),
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
std::to_string(repack), std::to_string(fmoe), std::to_string(no_fug), std::to_string(use_thp), std::to_string(repack), std::to_string(fmoe), std::to_string(ger),
std::to_string(no_fug), std::to_string(use_thp),
std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()), std::to_string(avg_ns()), std::to_string(stdev_ns()),
std::to_string(avg_ts()), std::to_string(stdev_ts()), std::to_string(avg_ts()), std::to_string(stdev_ts()),
@@ -1461,6 +1480,9 @@ struct markdown_printer : public printer {
if (field == "fused_moe") { if (field == "fused_moe") {
return 4; return 4;
} }
if (field == "grouped_er") {
return 3;
}
if (field == "fused_up_gate") { if (field == "fused_up_gate") {
return 6; return 6;
} }
@@ -1513,6 +1535,12 @@ struct markdown_printer : public printer {
if (field == "fused_moe") { if (field == "fused_moe") {
return "fmoe"; return "fmoe";
} }
if (field == "grouped_er") {
return "ger";
}
if (field == "grouped_er") {
return "ger";
}
if (field == "fused_up_gate") { if (field == "fused_up_gate") {
return "no-fug"; return "no-fug";
} }
@@ -1589,6 +1617,9 @@ struct markdown_printer : public printer {
if (params.fmoe != cmd_params_defaults.fmoe) { if (params.fmoe != cmd_params_defaults.fmoe) {
fields.emplace_back("fused_moe"); fields.emplace_back("fused_moe");
} }
if (params.ger != cmd_params_defaults.ger) {
fields.emplace_back("grouped_er");
}
if (params.no_fug != cmd_params_defaults.no_fug) { if (params.no_fug != cmd_params_defaults.no_fug) {
fields.emplace_back("fused_up_gate"); fields.emplace_back("fused_up_gate");
} }

View File

@@ -650,6 +650,7 @@ extern "C" {
GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT, GGML_OP_ARGSORT,
GGML_OP_ARGSORT_THRESH, GGML_OP_ARGSORT_THRESH,
GGML_OP_GROUPED_TOPK,
GGML_OP_LEAKY_RELU, GGML_OP_LEAKY_RELU,
GGML_OP_SOFTCAP, GGML_OP_SOFTCAP,
GGML_OP_SOFT_CAP_MAX, GGML_OP_SOFT_CAP_MAX,
@@ -2265,6 +2266,13 @@ extern "C" {
int k, int k,
int min_entries, int min_entries,
float thresh); float thresh);
GGML_API struct ggml_tensor * ggml_grouped_topk(
struct ggml_context * ctx,
struct ggml_tensor * a,
int num_groups,
int num_top_groups,
int nk,
int topk_experts);
#define GGML_KQ_MASK_PAD 16 #define GGML_KQ_MASK_PAD 16

View File

@@ -4253,6 +4253,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"TIMESTEP_EMBEDDING", "TIMESTEP_EMBEDDING",
"ARGSORT", "ARGSORT",
"ARGSORT_THRESH", "ARGSORT_THRESH",
"GROUPED_TOPK",
"LEAKY_RELU", "LEAKY_RELU",
"SOFTCAP", "SOFTCAP",
"SOFT_CAP_MAX", "SOFT_CAP_MAX",
@@ -4288,7 +4289,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU", "GLU",
}; };
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none", "none",
@@ -4356,6 +4357,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"timestep_embedding(timesteps, dim, max_period)", "timestep_embedding(timesteps, dim, max_period)",
"argsort(x)", "argsort(x)",
"argsort_thresh(x)", "argsort_thresh(x)",
"grouped_topk(x)",
"leaky_relu(x)", "leaky_relu(x)",
"k2*tanh(k1*x)", "k2*tanh(k1*x)",
"soft_max(k2*tanh(k1*x))", "soft_max(k2*tanh(k1*x))",
@@ -4391,7 +4393,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)," "glu(x),"
}; };
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -9439,6 +9441,39 @@ struct ggml_tensor * ggml_argsort_thresh(
return result; return result;
} }
struct ggml_tensor * ggml_grouped_topk(
struct ggml_context * ctx,
struct ggml_tensor * a,
int num_groups,
int num_top_groups,
int nk,
int topk_experts) {
GGML_ASSERT(num_top_groups <= num_groups);
GGML_ASSERT(a->ne[0] % num_groups == 0);
GGML_ASSERT(a->ne[0] >= topk_experts);
int64_t n_per_group = a->ne[0] / num_groups;
GGML_ASSERT(n_per_group >= nk);
bool is_node = false;
int64_t ne[GGML_MAX_DIMS];
for (int i = 1; i < GGML_MAX_DIMS; ++i) ne[i] = a->ne[i];
ne[0] = topk_experts;
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, ne);
ggml_set_op_params_i32(result, 0, num_groups);
ggml_set_op_params_i32(result, 1, num_top_groups);
ggml_set_op_params_i32(result, 2, nk);
result->op = GGML_OP_GROUPED_TOPK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_top_k // ggml_top_k
struct ggml_tensor * ggml_top_k( struct ggml_tensor * ggml_top_k(
@@ -20024,6 +20059,24 @@ static void ggml_compute_forward_argsort_thresh(
} }
} }
static void ggml_compute_forward_grouped_topk(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
iqk_grouped_top_k(dst, params->ith, params->nth);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_flash_attn_ext // ggml_compute_forward_flash_attn_ext
static void ggml_compute_forward_flash_attn_ext_f16( static void ggml_compute_forward_flash_attn_ext_f16(
@@ -22521,6 +22574,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
{ {
ggml_compute_forward_argsort_thresh(params, tensor); ggml_compute_forward_argsort_thresh(params, tensor);
} break; } break;
case GGML_OP_GROUPED_TOPK:
{
ggml_compute_forward_grouped_topk(params, tensor);
} break;
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
{ {
ggml_compute_forward_leaky_relu(params, tensor); ggml_compute_forward_leaky_relu(params, tensor);
@@ -23539,6 +23596,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{ {
GGML_ABORT("fatal error"); // TODO: not implemented GGML_ABORT("fatal error"); // TODO: not implemented
} }
case GGML_OP_GROUPED_TOPK:
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
{ {
GGML_ABORT("fatal error"); // TODO: not implemented GGML_ABORT("fatal error"); // TODO: not implemented
@@ -24281,6 +24342,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
case GGML_OP_ARGSORT_THRESH: case GGML_OP_ARGSORT_THRESH:
case GGML_OP_GROUPED_TOPK:
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV: case GGML_OP_SSM_CONV:

View File

@@ -10,8 +10,97 @@
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include <cmath>
void iqk_grouped_top_k([[maybe_unused]] ggml_tensor * dst, [[maybe_unused]] int ith, [[maybe_unused]] int nth) { namespace {
// Playing around with group scores: use sum of probabilities in the group
inline float group_score(int n_per_group, const float * data) {
float sum = 0;
for (int j = 0; j < n_per_group; ++j) sum += data[j];
return sum;
}
// Playing around with group scores: use max of probabilities in the group
inline float group_score_max(int n_per_group, const float * data) {
float max = data[0];
for (int j = 1; j < n_per_group; ++j) max = std::max(max, data[j]);
return max;
}
// Actual top-nk group score: sum of top-nk probabilities in the group
inline float group_score(int n_per_group, int nk, const float * data, float * aux) {
for (int j = 0; j < n_per_group; ++j) aux[j] = data[j];
std::partial_sort(aux, aux + nk, aux + n_per_group, std::greater<float>{});
float sum = 0;
for (int j = 0; j < nk; ++j) sum += aux[j];
return sum;
}
inline std::vector<std::pair<float,int>> & get_work_buffer(size_t size) {
thread_local std::vector<std::pair<float,int>> buffer;
if (buffer.size() < size) buffer.resize(size);
return buffer;
}
}
void iqk_grouped_top_k(ggml_tensor * dst, int ith, int nth) {
auto src = dst->src[0];
GGML_ASSERT(dst->type == GGML_TYPE_I32);
GGML_ASSERT(src->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_nrows(src) == ggml_nrows(dst));
auto nrows = ggml_nrows(src);
auto npt = (nrows + nth - 1)/nth;
auto first = npt*ith;
auto last = std::min(first + npt, nrows);
if (last <= first) return;
int n_groups = dst->op_params[0];
int n_top_groups = dst->op_params[1];
int nk = dst->op_params[2];
int ne00 = src->ne[0];
int ne0 = dst->ne[0];
GGML_ASSERT(ne0 <= ne00);
GGML_ASSERT(ne00%n_groups == 0);
int n_per_group = ne00/n_groups;
GGML_ASSERT(nk <= n_per_group);
GGML_ASSERT(n_top_groups <= n_groups);
size_t work_size = n_groups + n_per_group*n_top_groups;
auto& aux = get_work_buffer(work_size);
auto groups = aux.data() + n_per_group*n_top_groups;
for (int ir = first; ir < last; ++ir) {
auto data = (const float *)((const char *)src->data + ir*src->nb[1]);
auto result = (int32_t *)((char *)dst->data + ir*dst->nb[1]);
if (ne0 > n_per_group*n_top_groups) {
for (int j = 0; j < ne0; ++j) result[j] = j;
continue;
}
if (n_top_groups < n_groups) {
for (int ig = 0; ig < n_groups; ++ig) {
//groups[ig] = { group_score(n_per_group, data + ig*n_per_group), ig };
//groups[ig] = { group_score_max(n_per_group, data + ig*n_per_group), ig };
groups[ig] = { group_score(n_per_group, nk, data + ig*n_per_group, (float *)aux.data()), ig };
}
std::partial_sort(groups, groups + n_top_groups, groups + n_groups, std::greater<std::pair<float,int>>{});
for (int ig = 0; ig < n_top_groups; ++ig) {
int i0 = n_per_group * ig;
int j0 = n_per_group * groups[ig].second;
for (int j = 0; j < n_per_group; ++j) aux[i0 + j] = { data[j0 + j], j0 + j };
}
} else {
for (int j = 0; j < ne00; ++j) aux[j] = { data[j], j };
}
if (ne0 < n_top_groups*n_per_group) {
std::partial_sort(aux.begin(), aux.begin() + ne0, aux.begin() + n_top_groups*n_per_group, std::greater<std::pair<float,int>>{});
} else {
std::sort(aux.begin(), aux.begin() + ne0, std::greater<std::pair<float,int>>{});
}
for (int j = 0; j < ne0; ++j) result[j] = aux[j].second;
}
} }
void iqk_argsort(ggml_tensor * dst, int ith, int nth) { void iqk_argsort(ggml_tensor * dst, int ith, int nth) {
@@ -30,8 +119,7 @@ void iqk_argsort(ggml_tensor * dst, int ith, int nth) {
int nk = dst->op_params[1]; int nk = dst->op_params[1];
int ne00 = src->ne[0]; int ne00 = src->ne[0];
thread_local std::vector<std::pair<float,int>> aux; auto& aux = get_work_buffer(ne00);
if ((int)aux.size() < ne00) aux.resize(ne00);
for (int ir = first; ir < last; ++ir) { for (int ir = first; ir < last; ++ir) {
auto data = (const float *)((const char *)src->data + ir*src->nb[1]); auto data = (const float *)((const char *)src->data + ir*src->nb[1]);

View File

@@ -420,6 +420,7 @@ extern "C" {
int mla_attn; // whether to use MLA attention [EXPERIMENTAL] int mla_attn; // whether to use MLA attention [EXPERIMENTAL]
int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL] int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL]
bool fused_moe_up_gate; // whether to use fused MoE up/gate op bool fused_moe_up_gate; // whether to use fused MoE up/gate op
bool grouped_expert_routing; // whether to use grouped expert routing (BailingMoeV2 arch)
bool fused_up_gate; // whether to use fused up/gate op [EXPERIMENTAL] bool fused_up_gate; // whether to use fused up/gate op [EXPERIMENTAL]
int min_experts; int min_experts;
float thresh_experts; float thresh_experts;

View File

@@ -48,6 +48,7 @@ llm_build_context::llm_build_context(
mla_attn (cparams.mla_attn), mla_attn (cparams.mla_attn),
attn_max_batch (cparams.attn_max_batch), attn_max_batch (cparams.attn_max_batch),
fused_moe_up_gate(cparams.fused_moe_up_gate), fused_moe_up_gate(cparams.fused_moe_up_gate),
grouped_expert_routing(cparams.grouped_expert_routing),
fused_up_gate (cparams.fused_up_gate), fused_up_gate (cparams.fused_up_gate),
min_experts (cparams.min_experts), min_experts (cparams.min_experts),
thresh_experts (cparams.thresh_experts), thresh_experts (cparams.thresh_experts),
@@ -820,42 +821,15 @@ llm_expert_gating_func_type gating_op,
selection_probs = logits; selection_probs = logits;
} }
if (false && lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) {
auto& hparams = lctx.model.hparams;
const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
// organize experts into n_expert_groups
ggml_tensor * selection_groups = ggml_view_2d(ctx, ggml_cont(ctx, ggml_transpose(ctx, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups, n_tokens * n_exp_per_group * sizeof(float), 0); // [n_tokens * n_exp_per_group, n_expert_groups]
#if 0
ggml_tensor * group_scores = ggml_top_k(ctx, selection_groups, 2); // [2, n_expert_groups]
group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 2, n_expert_groups]
// get top n_group_used expert groups
group_scores = ggml_transpose(ctx, ggml_sum_rows(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2]))); // [n_expert_groups, 1]
#else
// Replace top_k(2) with argmax due to backend limitations, ideally we should use something like argmax2 instead
ggml_tensor * group_scores = ggml_reshape_2d(ctx, ggml_argmax(ctx, selection_groups), 1, selection_groups->ne[1]); // [1, n_expert_groups]
group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 1, n_expert_groups]
// get top n_group_used expert groups
group_scores = ggml_transpose(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2])); // [n_expert_groups, 1]
#endif
ggml_tensor * expert_groups = ggml_top_k(ctx, ggml_cont(ctx, group_scores), hparams.n_group_used); // [n_group_used, 1]
cb(expert_groups->src[0], "ffn_moe_group_argsort", il);
cb(expert_groups, "ffn_moe_group_topk", il);
// mask out the other groups
selection_probs = ggml_get_rows(ctx, selection_groups, expert_groups); // [n_tokens * n_exp_per_group, n_group_used]
selection_probs = ggml_set_rows(ctx, ggml_scale_bias(ctx, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_tokens * n_exp_per_group, n_expert_groups]
selection_probs = ggml_view_2d(ctx, selection_probs, n_tokens, n_expert, n_tokens * sizeof(float), 0); // [n_tokens, n_expert]
selection_probs = ggml_cont(ctx, ggml_transpose(ctx, selection_probs)); // [n_expert, n_tokens]
cb(selection_probs, "ffn_moe_probs_masked", il);
}
// select experts // select experts
ggml_tensor * selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used, ggml_tensor * selected_experts;
lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens] if (lctx.cparams.grouped_expert_routing && lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) {
cb(selected_experts->src[0], "ffn_moe_argsort", il); auto& hparams = lctx.model.hparams;
selected_experts = ggml_grouped_topk(ctx, selection_probs, hparams.n_expert_groups, hparams.n_group_used, 2, n_expert_used);
} else {
selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used,
lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens]
}
cb(selected_experts, "ffn_moe_topk", il); cb(selected_experts, "ffn_moe_topk", il);
ggml_tensor * weights = ggml_get_rows(ctx, ggml_tensor * weights = ggml_get_rows(ctx,
ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]

View File

@@ -78,6 +78,7 @@ struct llm_build_context {
const int mla_attn; const int mla_attn;
const int attn_max_batch; const int attn_max_batch;
const bool fused_moe_up_gate; const bool fused_moe_up_gate;
const bool grouped_expert_routing;
const bool fused_up_gate; const bool fused_up_gate;
const int min_experts; const int min_experts;
const float thresh_experts; const float thresh_experts;

View File

@@ -31,6 +31,7 @@ struct llama_cparams {
int mla_attn; int mla_attn;
int attn_max_batch; int attn_max_batch;
bool fused_moe_up_gate; bool fused_moe_up_gate;
bool grouped_expert_routing;
bool fused_up_gate; bool fused_up_gate;
int min_experts; int min_experts;
float thresh_experts; float thresh_experts;

View File

@@ -3754,6 +3754,7 @@ struct llama_context_params llama_context_default_params() {
/*.mla_attn =*/ 0, /*.mla_attn =*/ 0,
/*.attn_max_batch =*/ 0, /*.attn_max_batch =*/ 0,
/*.fused_moe_up_gate =*/ false, /*.fused_moe_up_gate =*/ false,
/*.grouped_expert_routing =*/ false,
/*.fused_up_gate =*/ true, /*.fused_up_gate =*/ true,
/*.min_experts =*/ -1, /*.min_experts =*/ -1,
/*.thtesh_experts =*/ 0.0f, /*.thtesh_experts =*/ 0.0f,
@@ -3963,6 +3964,7 @@ struct llama_context * llama_new_context_with_model(
cparams.mla_attn = params.mla_attn; cparams.mla_attn = params.mla_attn;
cparams.attn_max_batch = params.attn_max_batch; cparams.attn_max_batch = params.attn_max_batch;
cparams.fused_moe_up_gate= params.fused_moe_up_gate; cparams.fused_moe_up_gate= params.fused_moe_up_gate;
cparams.grouped_expert_routing = params.grouped_expert_routing;
cparams.fused_up_gate = params.fused_up_gate; cparams.fused_up_gate = params.fused_up_gate;
cparams.min_experts = params.min_experts; cparams.min_experts = params.min_experts;
cparams.thresh_experts = params.thresh_experts; cparams.thresh_experts = params.thresh_experts;
@@ -4043,6 +4045,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch); LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch);
LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate); LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate);
LLAMA_LOG_INFO("%s: grouped er = %d\n", __func__, cparams.grouped_expert_routing);
LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate); LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate);
LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts); LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);