mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 14:44:09 +00:00
Add command line arg to disable rope cache
This commit is contained in:
@@ -1106,6 +1106,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
params.fused_mmad = false;
|
||||
return true;
|
||||
}
|
||||
if (arg == "-no-rcache" || arg == "--no-rope-cache") {
|
||||
params.rope_cache = false;
|
||||
return true;
|
||||
}
|
||||
if (arg == "-ser" || arg == "--smart-expert-reduction") {
|
||||
CHECK_ARG
|
||||
auto values = string_split_pairs<int,float>(argv[i], ',');
|
||||
@@ -1914,6 +1918,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
options.push_back({ "*", "-ger, --grouped-expert-routing", "enable grouped expert routing (default: %s)", params.grouped_expert_routing ? "enabled" : "disabled" });
|
||||
options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" });
|
||||
options.push_back({ "*", "-no-mmad, --no-fused-mul-multiadd", "disaable fused mul-multi_add (default: %s)", params.fused_mmad? "enabled" : "disabled" });
|
||||
options.push_back({ "*", "-no-rcache, --no-rope-cache", "disaable RoPE cache (default: %s)", params.rope_cache ? "enabled" : "disabled" });
|
||||
options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts});
|
||||
options.push_back({ "*", "-mqkv, --merge-qkv,", "merge Q,K,V (default: %d)", params.merge_qkv});
|
||||
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
|
||||
@@ -2887,6 +2892,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||
cparams.grouped_expert_routing = params.grouped_expert_routing;
|
||||
cparams.fused_up_gate = params.fused_up_gate;
|
||||
cparams.fused_mmad = params.fused_mmad;
|
||||
cparams.rope_cache = params.rope_cache;
|
||||
cparams.min_experts = params.min_experts;
|
||||
cparams.thresh_experts = params.thresh_experts;
|
||||
cparams.only_active_experts = params.only_active_exps;
|
||||
@@ -4005,7 +4011,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
||||
fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false");
|
||||
fprintf(stream, "grouped_expert_routing: %s # default: false\n", params.grouped_expert_routing ? "true" : "false");
|
||||
fprintf(stream, "fused_up_gate: %s # default: true\n", params.fused_up_gate ? "true" : "false");
|
||||
fprintf(stream, "fused_mmad: %s # default: true\n", params.fused_mmad? "true" : "false");
|
||||
fprintf(stream, "fused_mmad: %s # default: true\n", params.fused_mmad ? "true" : "false");
|
||||
fprintf(stream, "rope_cache: %s # default: true\n", params.rope_cache ? "true" : "false");
|
||||
fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts);
|
||||
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
|
||||
|
||||
|
||||
@@ -112,7 +112,7 @@ enum common_reasoning_format {
|
||||
enum common_webui {
|
||||
COMMON_WEBUI_NONE,
|
||||
COMMON_WEBUI_AUTO,
|
||||
COMMON_WEBUI_LLAMACPP,
|
||||
COMMON_WEBUI_LLAMACPP,
|
||||
};
|
||||
|
||||
common_webui common_webui_from_name(const std::string& format);
|
||||
@@ -249,6 +249,7 @@ struct gpt_params {
|
||||
bool fused_up_gate = true; // fused up*unary(gate) op
|
||||
bool fused_mmad = true; // fused mul+multi_add op
|
||||
bool grouped_expert_routing = false; // if to use grouped expert routing (BailingMoeV2 arch)
|
||||
bool rope_cache = true; // if to use RoPE cache (for supported models)
|
||||
int min_experts = -1;
|
||||
float thresh_experts = 0;
|
||||
|
||||
|
||||
@@ -3097,7 +3097,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_are_same_shape(dst, cgraph->nodes[i+1]->src[1]) &&
|
||||
cgraph->nodes[i+1] == cgraph->nodes[i+2]->src[0] &&
|
||||
ops_are_same_device(cgraph, i, i+2)) {
|
||||
printf("Fusing add->add->fused_rms of %s, %s, %s\n", dst->name, cgraph->nodes[i+1]->name, cgraph->nodes[i+2]->name);
|
||||
ggml_cuda_op_fused_add_add_rms_norm(ctx, dst, cgraph->nodes[i+1], cgraph->nodes[i+2]);
|
||||
i += 2;
|
||||
}
|
||||
|
||||
@@ -427,6 +427,7 @@ extern "C" {
|
||||
bool grouped_expert_routing; // whether to use grouped expert routing (BailingMoeV2 arch)
|
||||
bool fused_up_gate; // whether to use fused up/gate op [EXPERIMENTAL]
|
||||
bool fused_mmad; // whether to use fused mul+multi_add op [EXPERIMENTAL]
|
||||
bool rope_cache; // whether to use RoPE cache [EXPERIMENTAL]
|
||||
int min_experts;
|
||||
float thresh_experts;
|
||||
bool only_active_experts;
|
||||
|
||||
@@ -51,6 +51,7 @@ llm_build_context::llm_build_context(
|
||||
grouped_expert_routing(cparams.grouped_expert_routing),
|
||||
fused_up_gate (cparams.fused_up_gate),
|
||||
fused_mmad (cparams.fused_mmad),
|
||||
rope_cache (cparams.rope_cache),
|
||||
min_experts (cparams.min_experts),
|
||||
thresh_experts (cparams.thresh_experts),
|
||||
pooling_type (cparams.pooling_type),
|
||||
@@ -3372,8 +3373,9 @@ ggml_cgraph * llm_build_context::build_qwen3() {
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
auto rope_cache = ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
auto rope_cache = cparams.rope_cache ?
|
||||
ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow) : nullptr;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
@@ -3391,21 +3393,18 @@ ggml_cgraph * llm_build_context::build_qwen3() {
|
||||
model.layers[il].wv, nullptr,
|
||||
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0, il);
|
||||
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
if (rope_cache) {
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
} else {
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
//Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
|
||||
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
// ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
//cb(Qcur, "Qcur", il);
|
||||
|
||||
//Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
|
||||
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
// ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
//cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
@@ -3476,9 +3475,9 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
auto rope_cache = ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
//ggml_set_input(rope_cache);
|
||||
auto rope_cache = cparams.rope_cache ?
|
||||
ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow) : nullptr;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
@@ -3494,20 +3493,16 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
|
||||
model.layers[il].wq, nullptr, model.layers[il].wk, nullptr, model.layers[il].wv, nullptr,
|
||||
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0, il);
|
||||
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
//Qcur = ggml_rope_ext(
|
||||
// ctx0, Qcur, inp_pos, nullptr,
|
||||
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
// ext_factor, attn_factor, beta_fast, beta_slow
|
||||
// );
|
||||
if (rope_cache) {
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
} else {
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
//Kcur = ggml_rope_ext(
|
||||
// ctx0, Kcur, inp_pos, nullptr,
|
||||
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
// ext_factor, attn_factor, beta_fast, beta_slow
|
||||
// );
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
@@ -6096,8 +6091,9 @@ ggml_cgraph * llm_build_context::build_glm4_moe() {
|
||||
// output token IDs (for last layer cropping)
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
auto rope_cache = ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
auto rope_cache = cparams.rope_cache ?
|
||||
ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow) : nullptr;
|
||||
|
||||
// Only process up to last layer (skip final NextN layer)
|
||||
// Final layer tensors are loaded but not processed in forward pass
|
||||
@@ -6119,14 +6115,15 @@ ggml_cgraph * llm_build_context::build_glm4_moe() {
|
||||
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.f, il);
|
||||
|
||||
// apply RoPE
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
//Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
|
||||
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
// ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
//Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
|
||||
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
// ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
if (rope_cache) {
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
} else {
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
@@ -7807,8 +7804,9 @@ ggml_cgraph * llm_build_context::build_openai_moe() {
|
||||
|
||||
const int sliding_window_pattern = 2;
|
||||
|
||||
auto rope_cache = ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
auto rope_cache = cparams.rope_cache ?
|
||||
ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow) : nullptr;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
|
||||
@@ -7829,33 +7827,18 @@ ggml_cgraph * llm_build_context::build_openai_moe() {
|
||||
model.layers[il].wv, model.layers[il].bv,
|
||||
nullptr, nullptr, 0.0f, il);
|
||||
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
|
||||
//Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
|
||||
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
|
||||
// beta_fast, beta_slow);
|
||||
if (rope_cache) {
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
} else {
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
//Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
|
||||
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
|
||||
// attn_factor, beta_fast, beta_slow);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
//auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq,
|
||||
// model.layers[il].wk, model.layers[il].bk,
|
||||
// model.layers[il].wv, model.layers[il].bv, 0.f, il);
|
||||
|
||||
//Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr,
|
||||
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
|
||||
// beta_fast, beta_slow);
|
||||
//cb(Qcur, "Qcur", il);
|
||||
|
||||
//Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
|
||||
// attn_factor, beta_fast, beta_slow);
|
||||
//cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks,
|
||||
is_sliding ? hparams.n_swa : 0);
|
||||
@@ -7943,8 +7926,9 @@ ggml_cgraph * llm_build_context::build_bailingmoe2() {
|
||||
|
||||
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
|
||||
|
||||
auto rope_cache = ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
auto rope_cache = cparams.rope_cache ?
|
||||
ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow) : nullptr;
|
||||
|
||||
for (int il = 0; il < n_transformer_layers; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@@ -7959,30 +7943,15 @@ ggml_cgraph * llm_build_context::build_bailingmoe2() {
|
||||
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
|
||||
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.0f, il);
|
||||
|
||||
//cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
||||
//cb(cur, "wqkv", il);
|
||||
|
||||
//ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
//ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
////ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
//ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
//Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
//cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
|
||||
//Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
|
||||
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
// ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
////Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
////cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
//Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
|
||||
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
// ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
if (rope_cache) {
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
} else {
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
@@ -81,6 +81,7 @@ struct llm_build_context {
|
||||
const bool grouped_expert_routing;
|
||||
const bool fused_up_gate;
|
||||
const bool fused_mmad;
|
||||
const bool rope_cache;
|
||||
const int min_experts;
|
||||
const float thresh_experts;
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ struct llama_cparams {
|
||||
bool grouped_expert_routing;
|
||||
bool fused_up_gate;
|
||||
bool fused_mmad;
|
||||
bool rope_cache;
|
||||
int min_experts;
|
||||
float thresh_experts;
|
||||
|
||||
|
||||
@@ -3833,6 +3833,7 @@ struct llama_context_params llama_context_default_params() {
|
||||
/*.grouped_expert_routing =*/ false,
|
||||
/*.fused_up_gate =*/ true,
|
||||
/*.fused_mmad =*/ true,
|
||||
/*.rope_cache =*/ true,
|
||||
/*.min_experts =*/ -1,
|
||||
/*.thtesh_experts =*/ 0.0f,
|
||||
/*.only_active_experts =*/ false,
|
||||
@@ -4134,6 +4135,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
cparams.grouped_expert_routing = params.grouped_expert_routing;
|
||||
cparams.fused_up_gate = params.fused_up_gate;
|
||||
cparams.fused_mmad = params.fused_mmad;
|
||||
cparams.rope_cache = params.rope_cache;
|
||||
cparams.min_experts = params.min_experts;
|
||||
cparams.thresh_experts = params.thresh_experts;
|
||||
|
||||
@@ -4216,6 +4218,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
LLAMA_LOG_INFO("%s: grouped er = %d\n", __func__, cparams.grouped_expert_routing);
|
||||
LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate);
|
||||
LLAMA_LOG_INFO("%s: fused_mmad = %d\n", __func__, cparams.fused_mmad);
|
||||
LLAMA_LOG_INFO("%s: rope_cache = %d\n", __func__, cparams.rope_cache);
|
||||
LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts);
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||
|
||||
Reference in New Issue
Block a user