mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
Grouped expert routing (CPU only) (#836)
* Better argsort (CPU) * Attemt at grouped topk * This seems to do the trick for grouped experts routing * Cleanup * Trying to merge, something is not right * Working merged grouped top_k (CPU) * Add command line option to enable grouped expert routing * Add grouped expert routing option to llama-bench --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -1012,6 +1012,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||||||
params.fused_moe_up_gate = true;
|
params.fused_moe_up_gate = true;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (arg == "-ger" || arg == "--grouped-expert-routing") {
|
||||||
|
params.grouped_expert_routing = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (arg == "-no-fug" || arg == "--no-fused-up-gate") {
|
if (arg == "-no-fug" || arg == "--no-fused-up-gate") {
|
||||||
params.fused_up_gate = false;
|
params.fused_up_gate = false;
|
||||||
return true;
|
return true;
|
||||||
@@ -1800,6 +1804,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
|||||||
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn });
|
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn });
|
||||||
options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch});
|
options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch});
|
||||||
options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" });
|
options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" });
|
||||||
|
options.push_back({ "*", "-ger, --grouped-expert-routing", "enable grouped expert routing (default: %s)", params.grouped_expert_routing ? "enabled" : "disabled" });
|
||||||
options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" });
|
options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" });
|
||||||
options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts});
|
options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts});
|
||||||
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
|
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
|
||||||
@@ -2755,6 +2760,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||||||
cparams.mla_attn = params.mla_attn;
|
cparams.mla_attn = params.mla_attn;
|
||||||
cparams.attn_max_batch = params.attn_max_batch;
|
cparams.attn_max_batch = params.attn_max_batch;
|
||||||
cparams.fused_moe_up_gate = params.fused_moe_up_gate;
|
cparams.fused_moe_up_gate = params.fused_moe_up_gate;
|
||||||
|
cparams.grouped_expert_routing = params.grouped_expert_routing;
|
||||||
cparams.fused_up_gate = params.fused_up_gate;
|
cparams.fused_up_gate = params.fused_up_gate;
|
||||||
cparams.min_experts = params.min_experts;
|
cparams.min_experts = params.min_experts;
|
||||||
cparams.thresh_experts = params.thresh_experts;
|
cparams.thresh_experts = params.thresh_experts;
|
||||||
@@ -3871,6 +3877,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
|||||||
fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn);
|
fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn);
|
||||||
fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch);
|
fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch);
|
||||||
fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false");
|
fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false");
|
||||||
|
fprintf(stream, "grouped_expert_routing: %s # default: false\n", params.grouped_expert_routing ? "true" : "false");
|
||||||
fprintf(stream, "fused_up_gate: %s # default: true\n", params.fused_up_gate ? "true" : "false");
|
fprintf(stream, "fused_up_gate: %s # default: true\n", params.fused_up_gate ? "true" : "false");
|
||||||
fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts);
|
fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts);
|
||||||
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
|
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
|
||||||
|
|||||||
@@ -235,6 +235,7 @@ struct gpt_params {
|
|||||||
int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false)
|
int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false)
|
||||||
bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models
|
bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models
|
||||||
bool fused_up_gate = true; // fused up*unary(gate) op
|
bool fused_up_gate = true; // fused up*unary(gate) op
|
||||||
|
bool grouped_expert_routing = false; // if to use grouped expert routing (BailingMoeV2 arch)
|
||||||
int min_experts = -1;
|
int min_experts = -1;
|
||||||
float thresh_experts = 0;
|
float thresh_experts = 0;
|
||||||
|
|
||||||
|
|||||||
@@ -261,6 +261,7 @@ struct cmd_params {
|
|||||||
bool warmup;
|
bool warmup;
|
||||||
bool repack = false;
|
bool repack = false;
|
||||||
bool fmoe = false;
|
bool fmoe = false;
|
||||||
|
bool ger = false; // ger = Grouped Expert Routing
|
||||||
bool no_fug = false;
|
bool no_fug = false;
|
||||||
bool use_thp = false;
|
bool use_thp = false;
|
||||||
output_formats output_format;
|
output_formats output_format;
|
||||||
@@ -296,9 +297,10 @@ static const cmd_params cmd_params_defaults = {
|
|||||||
/* verbose */ false,
|
/* verbose */ false,
|
||||||
/* warmup */ true,
|
/* warmup */ true,
|
||||||
/* repack */ false,
|
/* repack */ false,
|
||||||
/* use_thp */ false,
|
|
||||||
/* fmoe */ false,
|
/* fmoe */ false,
|
||||||
|
/* ger */ false,
|
||||||
/* no_fug */ false,
|
/* no_fug */ false,
|
||||||
|
/* use_thp */ false,
|
||||||
/* output_format */ MARKDOWN,
|
/* output_format */ MARKDOWN,
|
||||||
/* output_format_stderr */ NONE,
|
/* output_format_stderr */ NONE,
|
||||||
};
|
};
|
||||||
@@ -341,6 +343,7 @@ static void print_usage(int /* argc */, char ** argv) {
|
|||||||
printf(" -thp, --transparent-huge-pages <0|1> (default: %s)\n", cmd_params_defaults.use_thp? "1" : "0");
|
printf(" -thp, --transparent-huge-pages <0|1> (default: %s)\n", cmd_params_defaults.use_thp? "1" : "0");
|
||||||
printf(" -ot, --override-tensor pattern (default: none)\n");
|
printf(" -ot, --override-tensor pattern (default: none)\n");
|
||||||
printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0");
|
printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0");
|
||||||
|
printf(" -ger, --grouped-expert-routing <0|1>(default: %s)\n", cmd_params_defaults.ger ? "1" : "0");
|
||||||
printf(" -no-fug, --no-fused-up-gate <0|1> (default: %s)\n", cmd_params_defaults.no_fug? "1" : "0");
|
printf(" -no-fug, --no-fused-up-gate <0|1> (default: %s)\n", cmd_params_defaults.no_fug? "1" : "0");
|
||||||
printf("\n");
|
printf("\n");
|
||||||
printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
|
printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
|
||||||
@@ -739,6 +742,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.fmoe = std::stoi(argv[i]);
|
params.fmoe = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "-ger" || arg == "--grouped-expert-routing") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.ger = std::stoi(argv[i]);
|
||||||
} else if (arg == "-no-fug" || arg == "--no-fused-up-gate") {
|
} else if (arg == "-no-fug" || arg == "--no-fused-up-gate") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
@@ -829,6 +838,7 @@ struct cmd_params_instance {
|
|||||||
bool embeddings;
|
bool embeddings;
|
||||||
bool repack = false;
|
bool repack = false;
|
||||||
bool fmoe = false;
|
bool fmoe = false;
|
||||||
|
bool ger = false;
|
||||||
bool no_fug = false;
|
bool no_fug = false;
|
||||||
bool use_thp = false;
|
bool use_thp = false;
|
||||||
const llama_model_tensor_buft_override* buft_overrides;
|
const llama_model_tensor_buft_override* buft_overrides;
|
||||||
@@ -876,6 +886,7 @@ struct cmd_params_instance {
|
|||||||
cparams.mla_attn = mla_attn;
|
cparams.mla_attn = mla_attn;
|
||||||
cparams.attn_max_batch = attn_max_batch;
|
cparams.attn_max_batch = attn_max_batch;
|
||||||
cparams.fused_moe_up_gate = fmoe;
|
cparams.fused_moe_up_gate = fmoe;
|
||||||
|
cparams.grouped_expert_routing = ger;
|
||||||
cparams.fused_up_gate = !no_fug;
|
cparams.fused_up_gate = !no_fug;
|
||||||
cparams.min_experts = ser.first;
|
cparams.min_experts = ser.first;
|
||||||
cparams.thresh_experts = ser.second;
|
cparams.thresh_experts = ser.second;
|
||||||
@@ -935,6 +946,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
|||||||
/* .embeddings = */ embd,
|
/* .embeddings = */ embd,
|
||||||
/* .repack = */ params.repack,
|
/* .repack = */ params.repack,
|
||||||
/* .fmoe = */ params.fmoe,
|
/* .fmoe = */ params.fmoe,
|
||||||
|
/* .ger = */ params.ger,
|
||||||
/* .no_fug = */ params.no_fug,
|
/* .no_fug = */ params.no_fug,
|
||||||
/* .use_thp = */ params.use_thp,
|
/* .use_thp = */ params.use_thp,
|
||||||
/* .buft_overrides=*/ params.buft_overrides.data(),
|
/* .buft_overrides=*/ params.buft_overrides.data(),
|
||||||
@@ -970,6 +982,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
|||||||
/* .embeddings = */ embd,
|
/* .embeddings = */ embd,
|
||||||
/* .repack = */ params.repack,
|
/* .repack = */ params.repack,
|
||||||
/* .fmoe = */ params.fmoe,
|
/* .fmoe = */ params.fmoe,
|
||||||
|
/* .ger = */ params.ger,
|
||||||
/* .no_fug = */ params.no_fug,
|
/* .no_fug = */ params.no_fug,
|
||||||
/* .use_thp = */ params.use_thp,
|
/* .use_thp = */ params.use_thp,
|
||||||
/* .buft_overrides=*/ params.buft_overrides.data(),
|
/* .buft_overrides=*/ params.buft_overrides.data(),
|
||||||
@@ -1005,6 +1018,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
|||||||
/* .embeddings = */ embd,
|
/* .embeddings = */ embd,
|
||||||
/* .repack = */ params.repack,
|
/* .repack = */ params.repack,
|
||||||
/* .fmoe = */ params.fmoe,
|
/* .fmoe = */ params.fmoe,
|
||||||
|
/* .ger = */ params.ger,
|
||||||
/* .no_fug = */ params.no_fug,
|
/* .no_fug = */ params.no_fug,
|
||||||
/* .use_thp = */ params.use_thp,
|
/* .use_thp = */ params.use_thp,
|
||||||
/* .buft_overrides=*/ params.buft_overrides.data(),
|
/* .buft_overrides=*/ params.buft_overrides.data(),
|
||||||
@@ -1040,6 +1054,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
|||||||
/* .embeddings = */ embd,
|
/* .embeddings = */ embd,
|
||||||
/* .repack = */ params.repack,
|
/* .repack = */ params.repack,
|
||||||
/* .fmoe = */ params.fmoe,
|
/* .fmoe = */ params.fmoe,
|
||||||
|
/* .ger = */ params.ger,
|
||||||
/* .no_fug = */ params.no_fug,
|
/* .no_fug = */ params.no_fug,
|
||||||
/* .use_thp = */ params.use_thp,
|
/* .use_thp = */ params.use_thp,
|
||||||
/* .buft_overrides=*/ params.buft_overrides.data(),
|
/* .buft_overrides=*/ params.buft_overrides.data(),
|
||||||
@@ -1086,6 +1101,7 @@ struct test {
|
|||||||
bool embeddings;
|
bool embeddings;
|
||||||
bool repack = false;
|
bool repack = false;
|
||||||
bool fmoe = false;
|
bool fmoe = false;
|
||||||
|
bool ger = false;
|
||||||
bool no_fug = false;
|
bool no_fug = false;
|
||||||
bool use_thp = false;
|
bool use_thp = false;
|
||||||
int n_prompt;
|
int n_prompt;
|
||||||
@@ -1120,6 +1136,8 @@ struct test {
|
|||||||
use_mmap = inst.use_mmap;
|
use_mmap = inst.use_mmap;
|
||||||
embeddings = inst.embeddings;
|
embeddings = inst.embeddings;
|
||||||
repack = inst.repack;
|
repack = inst.repack;
|
||||||
|
fmoe = inst.fmoe;
|
||||||
|
ger = inst.ger;
|
||||||
no_fug = inst.no_fug;
|
no_fug = inst.no_fug;
|
||||||
use_thp = inst.use_thp;
|
use_thp = inst.use_thp;
|
||||||
n_prompt = inst.n_prompt;
|
n_prompt = inst.n_prompt;
|
||||||
@@ -1212,7 +1230,7 @@ struct test {
|
|||||||
"n_threads", "type_k", "type_v",
|
"n_threads", "type_k", "type_v",
|
||||||
"n_gpu_layers", "split_mode",
|
"n_gpu_layers", "split_mode",
|
||||||
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser",
|
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser",
|
||||||
"tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "fused_up_gate", "use_thp",
|
"tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "grouped_er", "fused_up_gate", "use_thp",
|
||||||
"n_prompt", "n_gen", "test_time",
|
"n_prompt", "n_gen", "test_time",
|
||||||
"avg_ns", "stddev_ns",
|
"avg_ns", "stddev_ns",
|
||||||
"avg_ts", "stddev_ts", "test",
|
"avg_ts", "stddev_ts", "test",
|
||||||
@@ -1234,7 +1252,7 @@ struct test {
|
|||||||
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
|
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
|
||||||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
|
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
|
||||||
field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" ||
|
field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" ||
|
||||||
field == "fused_moe" || field == "fused_up_gate") {
|
field == "fused_moe" || field == "grouped_er" || field == "fused_up_gate") {
|
||||||
return BOOL;
|
return BOOL;
|
||||||
}
|
}
|
||||||
if (field == "avg_ts" || field == "stddev_ts") {
|
if (field == "avg_ts" || field == "stddev_ts") {
|
||||||
@@ -1277,7 +1295,8 @@ struct test {
|
|||||||
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
|
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
|
||||||
std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser),
|
std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser),
|
||||||
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
|
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
|
||||||
std::to_string(repack), std::to_string(fmoe), std::to_string(no_fug), std::to_string(use_thp),
|
std::to_string(repack), std::to_string(fmoe), std::to_string(ger),
|
||||||
|
std::to_string(no_fug), std::to_string(use_thp),
|
||||||
std::to_string(n_prompt), std::to_string(n_gen), test_time,
|
std::to_string(n_prompt), std::to_string(n_gen), test_time,
|
||||||
std::to_string(avg_ns()), std::to_string(stdev_ns()),
|
std::to_string(avg_ns()), std::to_string(stdev_ns()),
|
||||||
std::to_string(avg_ts()), std::to_string(stdev_ts()),
|
std::to_string(avg_ts()), std::to_string(stdev_ts()),
|
||||||
@@ -1461,6 +1480,9 @@ struct markdown_printer : public printer {
|
|||||||
if (field == "fused_moe") {
|
if (field == "fused_moe") {
|
||||||
return 4;
|
return 4;
|
||||||
}
|
}
|
||||||
|
if (field == "grouped_er") {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
if (field == "fused_up_gate") {
|
if (field == "fused_up_gate") {
|
||||||
return 6;
|
return 6;
|
||||||
}
|
}
|
||||||
@@ -1513,6 +1535,12 @@ struct markdown_printer : public printer {
|
|||||||
if (field == "fused_moe") {
|
if (field == "fused_moe") {
|
||||||
return "fmoe";
|
return "fmoe";
|
||||||
}
|
}
|
||||||
|
if (field == "grouped_er") {
|
||||||
|
return "ger";
|
||||||
|
}
|
||||||
|
if (field == "grouped_er") {
|
||||||
|
return "ger";
|
||||||
|
}
|
||||||
if (field == "fused_up_gate") {
|
if (field == "fused_up_gate") {
|
||||||
return "no-fug";
|
return "no-fug";
|
||||||
}
|
}
|
||||||
@@ -1589,6 +1617,9 @@ struct markdown_printer : public printer {
|
|||||||
if (params.fmoe != cmd_params_defaults.fmoe) {
|
if (params.fmoe != cmd_params_defaults.fmoe) {
|
||||||
fields.emplace_back("fused_moe");
|
fields.emplace_back("fused_moe");
|
||||||
}
|
}
|
||||||
|
if (params.ger != cmd_params_defaults.ger) {
|
||||||
|
fields.emplace_back("grouped_er");
|
||||||
|
}
|
||||||
if (params.no_fug != cmd_params_defaults.no_fug) {
|
if (params.no_fug != cmd_params_defaults.no_fug) {
|
||||||
fields.emplace_back("fused_up_gate");
|
fields.emplace_back("fused_up_gate");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -650,6 +650,7 @@ extern "C" {
|
|||||||
GGML_OP_TIMESTEP_EMBEDDING,
|
GGML_OP_TIMESTEP_EMBEDDING,
|
||||||
GGML_OP_ARGSORT,
|
GGML_OP_ARGSORT,
|
||||||
GGML_OP_ARGSORT_THRESH,
|
GGML_OP_ARGSORT_THRESH,
|
||||||
|
GGML_OP_GROUPED_TOPK,
|
||||||
GGML_OP_LEAKY_RELU,
|
GGML_OP_LEAKY_RELU,
|
||||||
GGML_OP_SOFTCAP,
|
GGML_OP_SOFTCAP,
|
||||||
GGML_OP_SOFT_CAP_MAX,
|
GGML_OP_SOFT_CAP_MAX,
|
||||||
@@ -2265,6 +2266,13 @@ extern "C" {
|
|||||||
int k,
|
int k,
|
||||||
int min_entries,
|
int min_entries,
|
||||||
float thresh);
|
float thresh);
|
||||||
|
GGML_API struct ggml_tensor * ggml_grouped_topk(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int num_groups,
|
||||||
|
int num_top_groups,
|
||||||
|
int nk,
|
||||||
|
int topk_experts);
|
||||||
|
|
||||||
#define GGML_KQ_MASK_PAD 16
|
#define GGML_KQ_MASK_PAD 16
|
||||||
|
|
||||||
|
|||||||
@@ -4253,6 +4253,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||||||
"TIMESTEP_EMBEDDING",
|
"TIMESTEP_EMBEDDING",
|
||||||
"ARGSORT",
|
"ARGSORT",
|
||||||
"ARGSORT_THRESH",
|
"ARGSORT_THRESH",
|
||||||
|
"GROUPED_TOPK",
|
||||||
"LEAKY_RELU",
|
"LEAKY_RELU",
|
||||||
"SOFTCAP",
|
"SOFTCAP",
|
||||||
"SOFT_CAP_MAX",
|
"SOFT_CAP_MAX",
|
||||||
@@ -4288,7 +4289,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||||||
"GLU",
|
"GLU",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
|
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
@@ -4356,6 +4357,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||||||
"timestep_embedding(timesteps, dim, max_period)",
|
"timestep_embedding(timesteps, dim, max_period)",
|
||||||
"argsort(x)",
|
"argsort(x)",
|
||||||
"argsort_thresh(x)",
|
"argsort_thresh(x)",
|
||||||
|
"grouped_topk(x)",
|
||||||
"leaky_relu(x)",
|
"leaky_relu(x)",
|
||||||
"k2*tanh(k1*x)",
|
"k2*tanh(k1*x)",
|
||||||
"soft_max(k2*tanh(k1*x))",
|
"soft_max(k2*tanh(k1*x))",
|
||||||
@@ -4391,7 +4393,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||||||
"glu(x),"
|
"glu(x),"
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
|
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
|
||||||
|
|
||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||||
|
|
||||||
@@ -9439,6 +9441,39 @@ struct ggml_tensor * ggml_argsort_thresh(
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_grouped_topk(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int num_groups,
|
||||||
|
int num_top_groups,
|
||||||
|
int nk,
|
||||||
|
int topk_experts) {
|
||||||
|
|
||||||
|
GGML_ASSERT(num_top_groups <= num_groups);
|
||||||
|
GGML_ASSERT(a->ne[0] % num_groups == 0);
|
||||||
|
GGML_ASSERT(a->ne[0] >= topk_experts);
|
||||||
|
int64_t n_per_group = a->ne[0] / num_groups;
|
||||||
|
GGML_ASSERT(n_per_group >= nk);
|
||||||
|
|
||||||
|
bool is_node = false;
|
||||||
|
|
||||||
|
int64_t ne[GGML_MAX_DIMS];
|
||||||
|
for (int i = 1; i < GGML_MAX_DIMS; ++i) ne[i] = a->ne[i];
|
||||||
|
ne[0] = topk_experts;
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, ne);
|
||||||
|
|
||||||
|
ggml_set_op_params_i32(result, 0, num_groups);
|
||||||
|
ggml_set_op_params_i32(result, 1, num_top_groups);
|
||||||
|
ggml_set_op_params_i32(result, 2, nk);
|
||||||
|
|
||||||
|
result->op = GGML_OP_GROUPED_TOPK;
|
||||||
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
result->src[0] = a;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// ggml_top_k
|
// ggml_top_k
|
||||||
|
|
||||||
struct ggml_tensor * ggml_top_k(
|
struct ggml_tensor * ggml_top_k(
|
||||||
@@ -20024,6 +20059,24 @@ static void ggml_compute_forward_argsort_thresh(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_grouped_topk(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
iqk_grouped_top_k(dst, params->ith, params->nth);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_flash_attn_ext
|
// ggml_compute_forward_flash_attn_ext
|
||||||
|
|
||||||
static void ggml_compute_forward_flash_attn_ext_f16(
|
static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
@@ -22521,6 +22574,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
|
|||||||
{
|
{
|
||||||
ggml_compute_forward_argsort_thresh(params, tensor);
|
ggml_compute_forward_argsort_thresh(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_GROUPED_TOPK:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_grouped_topk(params, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_leaky_relu(params, tensor);
|
ggml_compute_forward_leaky_relu(params, tensor);
|
||||||
@@ -23539,6 +23596,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||||||
{
|
{
|
||||||
GGML_ABORT("fatal error"); // TODO: not implemented
|
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||||
}
|
}
|
||||||
|
case GGML_OP_GROUPED_TOPK:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||||
|
}
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
{
|
{
|
||||||
GGML_ABORT("fatal error"); // TODO: not implemented
|
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||||
@@ -24281,6 +24342,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
case GGML_OP_ARGSORT_THRESH:
|
case GGML_OP_ARGSORT_THRESH:
|
||||||
|
case GGML_OP_GROUPED_TOPK:
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
case GGML_OP_FLASH_ATTN_BACK:
|
case GGML_OP_FLASH_ATTN_BACK:
|
||||||
case GGML_OP_SSM_CONV:
|
case GGML_OP_SSM_CONV:
|
||||||
|
|||||||
@@ -10,8 +10,97 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
void iqk_grouped_top_k([[maybe_unused]] ggml_tensor * dst, [[maybe_unused]] int ith, [[maybe_unused]] int nth) {
|
namespace {
|
||||||
|
// Playing around with group scores: use sum of probabilities in the group
|
||||||
|
inline float group_score(int n_per_group, const float * data) {
|
||||||
|
float sum = 0;
|
||||||
|
for (int j = 0; j < n_per_group; ++j) sum += data[j];
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
// Playing around with group scores: use max of probabilities in the group
|
||||||
|
inline float group_score_max(int n_per_group, const float * data) {
|
||||||
|
float max = data[0];
|
||||||
|
for (int j = 1; j < n_per_group; ++j) max = std::max(max, data[j]);
|
||||||
|
return max;
|
||||||
|
}
|
||||||
|
// Actual top-nk group score: sum of top-nk probabilities in the group
|
||||||
|
inline float group_score(int n_per_group, int nk, const float * data, float * aux) {
|
||||||
|
for (int j = 0; j < n_per_group; ++j) aux[j] = data[j];
|
||||||
|
std::partial_sort(aux, aux + nk, aux + n_per_group, std::greater<float>{});
|
||||||
|
float sum = 0;
|
||||||
|
for (int j = 0; j < nk; ++j) sum += aux[j];
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
inline std::vector<std::pair<float,int>> & get_work_buffer(size_t size) {
|
||||||
|
thread_local std::vector<std::pair<float,int>> buffer;
|
||||||
|
if (buffer.size() < size) buffer.resize(size);
|
||||||
|
return buffer;
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void iqk_grouped_top_k(ggml_tensor * dst, int ith, int nth) {
|
||||||
|
auto src = dst->src[0];
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_I32);
|
||||||
|
GGML_ASSERT(src->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(ggml_nrows(src) == ggml_nrows(dst));
|
||||||
|
|
||||||
|
auto nrows = ggml_nrows(src);
|
||||||
|
auto npt = (nrows + nth - 1)/nth;
|
||||||
|
auto first = npt*ith;
|
||||||
|
auto last = std::min(first + npt, nrows);
|
||||||
|
if (last <= first) return;
|
||||||
|
|
||||||
|
int n_groups = dst->op_params[0];
|
||||||
|
int n_top_groups = dst->op_params[1];
|
||||||
|
int nk = dst->op_params[2];
|
||||||
|
|
||||||
|
int ne00 = src->ne[0];
|
||||||
|
int ne0 = dst->ne[0];
|
||||||
|
GGML_ASSERT(ne0 <= ne00);
|
||||||
|
GGML_ASSERT(ne00%n_groups == 0);
|
||||||
|
int n_per_group = ne00/n_groups;
|
||||||
|
GGML_ASSERT(nk <= n_per_group);
|
||||||
|
GGML_ASSERT(n_top_groups <= n_groups);
|
||||||
|
|
||||||
|
size_t work_size = n_groups + n_per_group*n_top_groups;
|
||||||
|
auto& aux = get_work_buffer(work_size);
|
||||||
|
|
||||||
|
auto groups = aux.data() + n_per_group*n_top_groups;
|
||||||
|
|
||||||
|
for (int ir = first; ir < last; ++ir) {
|
||||||
|
auto data = (const float *)((const char *)src->data + ir*src->nb[1]);
|
||||||
|
auto result = (int32_t *)((char *)dst->data + ir*dst->nb[1]);
|
||||||
|
if (ne0 > n_per_group*n_top_groups) {
|
||||||
|
for (int j = 0; j < ne0; ++j) result[j] = j;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (n_top_groups < n_groups) {
|
||||||
|
for (int ig = 0; ig < n_groups; ++ig) {
|
||||||
|
//groups[ig] = { group_score(n_per_group, data + ig*n_per_group), ig };
|
||||||
|
//groups[ig] = { group_score_max(n_per_group, data + ig*n_per_group), ig };
|
||||||
|
groups[ig] = { group_score(n_per_group, nk, data + ig*n_per_group, (float *)aux.data()), ig };
|
||||||
|
}
|
||||||
|
std::partial_sort(groups, groups + n_top_groups, groups + n_groups, std::greater<std::pair<float,int>>{});
|
||||||
|
|
||||||
|
for (int ig = 0; ig < n_top_groups; ++ig) {
|
||||||
|
int i0 = n_per_group * ig;
|
||||||
|
int j0 = n_per_group * groups[ig].second;
|
||||||
|
for (int j = 0; j < n_per_group; ++j) aux[i0 + j] = { data[j0 + j], j0 + j };
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int j = 0; j < ne00; ++j) aux[j] = { data[j], j };
|
||||||
|
}
|
||||||
|
if (ne0 < n_top_groups*n_per_group) {
|
||||||
|
std::partial_sort(aux.begin(), aux.begin() + ne0, aux.begin() + n_top_groups*n_per_group, std::greater<std::pair<float,int>>{});
|
||||||
|
} else {
|
||||||
|
std::sort(aux.begin(), aux.begin() + ne0, std::greater<std::pair<float,int>>{});
|
||||||
|
}
|
||||||
|
for (int j = 0; j < ne0; ++j) result[j] = aux[j].second;
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void iqk_argsort(ggml_tensor * dst, int ith, int nth) {
|
void iqk_argsort(ggml_tensor * dst, int ith, int nth) {
|
||||||
@@ -30,8 +119,7 @@ void iqk_argsort(ggml_tensor * dst, int ith, int nth) {
|
|||||||
int nk = dst->op_params[1];
|
int nk = dst->op_params[1];
|
||||||
|
|
||||||
int ne00 = src->ne[0];
|
int ne00 = src->ne[0];
|
||||||
thread_local std::vector<std::pair<float,int>> aux;
|
auto& aux = get_work_buffer(ne00);
|
||||||
if ((int)aux.size() < ne00) aux.resize(ne00);
|
|
||||||
|
|
||||||
for (int ir = first; ir < last; ++ir) {
|
for (int ir = first; ir < last; ++ir) {
|
||||||
auto data = (const float *)((const char *)src->data + ir*src->nb[1]);
|
auto data = (const float *)((const char *)src->data + ir*src->nb[1]);
|
||||||
|
|||||||
@@ -420,6 +420,7 @@ extern "C" {
|
|||||||
int mla_attn; // whether to use MLA attention [EXPERIMENTAL]
|
int mla_attn; // whether to use MLA attention [EXPERIMENTAL]
|
||||||
int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL]
|
int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL]
|
||||||
bool fused_moe_up_gate; // whether to use fused MoE up/gate op
|
bool fused_moe_up_gate; // whether to use fused MoE up/gate op
|
||||||
|
bool grouped_expert_routing; // whether to use grouped expert routing (BailingMoeV2 arch)
|
||||||
bool fused_up_gate; // whether to use fused up/gate op [EXPERIMENTAL]
|
bool fused_up_gate; // whether to use fused up/gate op [EXPERIMENTAL]
|
||||||
int min_experts;
|
int min_experts;
|
||||||
float thresh_experts;
|
float thresh_experts;
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ llm_build_context::llm_build_context(
|
|||||||
mla_attn (cparams.mla_attn),
|
mla_attn (cparams.mla_attn),
|
||||||
attn_max_batch (cparams.attn_max_batch),
|
attn_max_batch (cparams.attn_max_batch),
|
||||||
fused_moe_up_gate(cparams.fused_moe_up_gate),
|
fused_moe_up_gate(cparams.fused_moe_up_gate),
|
||||||
|
grouped_expert_routing(cparams.grouped_expert_routing),
|
||||||
fused_up_gate (cparams.fused_up_gate),
|
fused_up_gate (cparams.fused_up_gate),
|
||||||
min_experts (cparams.min_experts),
|
min_experts (cparams.min_experts),
|
||||||
thresh_experts (cparams.thresh_experts),
|
thresh_experts (cparams.thresh_experts),
|
||||||
@@ -820,42 +821,15 @@ llm_expert_gating_func_type gating_op,
|
|||||||
selection_probs = logits;
|
selection_probs = logits;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (false && lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) {
|
|
||||||
auto& hparams = lctx.model.hparams;
|
|
||||||
const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
|
|
||||||
|
|
||||||
// organize experts into n_expert_groups
|
|
||||||
ggml_tensor * selection_groups = ggml_view_2d(ctx, ggml_cont(ctx, ggml_transpose(ctx, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups, n_tokens * n_exp_per_group * sizeof(float), 0); // [n_tokens * n_exp_per_group, n_expert_groups]
|
|
||||||
#if 0
|
|
||||||
ggml_tensor * group_scores = ggml_top_k(ctx, selection_groups, 2); // [2, n_expert_groups]
|
|
||||||
group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 2, n_expert_groups]
|
|
||||||
|
|
||||||
// get top n_group_used expert groups
|
|
||||||
group_scores = ggml_transpose(ctx, ggml_sum_rows(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2]))); // [n_expert_groups, 1]
|
|
||||||
#else
|
|
||||||
// Replace top_k(2) with argmax due to backend limitations, ideally we should use something like argmax2 instead
|
|
||||||
ggml_tensor * group_scores = ggml_reshape_2d(ctx, ggml_argmax(ctx, selection_groups), 1, selection_groups->ne[1]); // [1, n_expert_groups]
|
|
||||||
group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 1, n_expert_groups]
|
|
||||||
|
|
||||||
// get top n_group_used expert groups
|
|
||||||
group_scores = ggml_transpose(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2])); // [n_expert_groups, 1]
|
|
||||||
#endif
|
|
||||||
ggml_tensor * expert_groups = ggml_top_k(ctx, ggml_cont(ctx, group_scores), hparams.n_group_used); // [n_group_used, 1]
|
|
||||||
cb(expert_groups->src[0], "ffn_moe_group_argsort", il);
|
|
||||||
cb(expert_groups, "ffn_moe_group_topk", il);
|
|
||||||
|
|
||||||
// mask out the other groups
|
|
||||||
selection_probs = ggml_get_rows(ctx, selection_groups, expert_groups); // [n_tokens * n_exp_per_group, n_group_used]
|
|
||||||
selection_probs = ggml_set_rows(ctx, ggml_scale_bias(ctx, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_tokens * n_exp_per_group, n_expert_groups]
|
|
||||||
selection_probs = ggml_view_2d(ctx, selection_probs, n_tokens, n_expert, n_tokens * sizeof(float), 0); // [n_tokens, n_expert]
|
|
||||||
selection_probs = ggml_cont(ctx, ggml_transpose(ctx, selection_probs)); // [n_expert, n_tokens]
|
|
||||||
cb(selection_probs, "ffn_moe_probs_masked", il);
|
|
||||||
}
|
|
||||||
|
|
||||||
// select experts
|
// select experts
|
||||||
ggml_tensor * selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used,
|
ggml_tensor * selected_experts;
|
||||||
lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens]
|
if (lctx.cparams.grouped_expert_routing && lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) {
|
||||||
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
auto& hparams = lctx.model.hparams;
|
||||||
|
selected_experts = ggml_grouped_topk(ctx, selection_probs, hparams.n_expert_groups, hparams.n_group_used, 2, n_expert_used);
|
||||||
|
} else {
|
||||||
|
selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used,
|
||||||
|
lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens]
|
||||||
|
}
|
||||||
cb(selected_experts, "ffn_moe_topk", il);
|
cb(selected_experts, "ffn_moe_topk", il);
|
||||||
ggml_tensor * weights = ggml_get_rows(ctx,
|
ggml_tensor * weights = ggml_get_rows(ctx,
|
||||||
ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ struct llm_build_context {
|
|||||||
const int mla_attn;
|
const int mla_attn;
|
||||||
const int attn_max_batch;
|
const int attn_max_batch;
|
||||||
const bool fused_moe_up_gate;
|
const bool fused_moe_up_gate;
|
||||||
|
const bool grouped_expert_routing;
|
||||||
const bool fused_up_gate;
|
const bool fused_up_gate;
|
||||||
const int min_experts;
|
const int min_experts;
|
||||||
const float thresh_experts;
|
const float thresh_experts;
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ struct llama_cparams {
|
|||||||
int mla_attn;
|
int mla_attn;
|
||||||
int attn_max_batch;
|
int attn_max_batch;
|
||||||
bool fused_moe_up_gate;
|
bool fused_moe_up_gate;
|
||||||
|
bool grouped_expert_routing;
|
||||||
bool fused_up_gate;
|
bool fused_up_gate;
|
||||||
int min_experts;
|
int min_experts;
|
||||||
float thresh_experts;
|
float thresh_experts;
|
||||||
|
|||||||
@@ -3754,6 +3754,7 @@ struct llama_context_params llama_context_default_params() {
|
|||||||
/*.mla_attn =*/ 0,
|
/*.mla_attn =*/ 0,
|
||||||
/*.attn_max_batch =*/ 0,
|
/*.attn_max_batch =*/ 0,
|
||||||
/*.fused_moe_up_gate =*/ false,
|
/*.fused_moe_up_gate =*/ false,
|
||||||
|
/*.grouped_expert_routing =*/ false,
|
||||||
/*.fused_up_gate =*/ true,
|
/*.fused_up_gate =*/ true,
|
||||||
/*.min_experts =*/ -1,
|
/*.min_experts =*/ -1,
|
||||||
/*.thtesh_experts =*/ 0.0f,
|
/*.thtesh_experts =*/ 0.0f,
|
||||||
@@ -3963,6 +3964,7 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
cparams.mla_attn = params.mla_attn;
|
cparams.mla_attn = params.mla_attn;
|
||||||
cparams.attn_max_batch = params.attn_max_batch;
|
cparams.attn_max_batch = params.attn_max_batch;
|
||||||
cparams.fused_moe_up_gate= params.fused_moe_up_gate;
|
cparams.fused_moe_up_gate= params.fused_moe_up_gate;
|
||||||
|
cparams.grouped_expert_routing = params.grouped_expert_routing;
|
||||||
cparams.fused_up_gate = params.fused_up_gate;
|
cparams.fused_up_gate = params.fused_up_gate;
|
||||||
cparams.min_experts = params.min_experts;
|
cparams.min_experts = params.min_experts;
|
||||||
cparams.thresh_experts = params.thresh_experts;
|
cparams.thresh_experts = params.thresh_experts;
|
||||||
@@ -4043,6 +4045,7 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
|
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
|
||||||
LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch);
|
LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch);
|
||||||
LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate);
|
LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate);
|
||||||
|
LLAMA_LOG_INFO("%s: grouped er = %d\n", __func__, cparams.grouped_expert_routing);
|
||||||
LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate);
|
LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate);
|
||||||
LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts);
|
LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts);
|
||||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||||
|
|||||||
Reference in New Issue
Block a user