RoPE cache (#887)

* Introducing rope cache

When computing RoPE, the rotation angles in each layer
are exactly the same, and only depend on the token positions
(and other constant, model dependent parameters).
So, I wonder, why don't we compute the angles just once
and then reuse for the Q and K RoPE in each layer?

This commit does it as a POC on the CPU, and uses it in
the Qwen3-MoE compute graph.

* cuda: neox works

* WIP

* rope_cache: norm works

* Fused rope+rope

* Fused rope+rope (norm)

* Fused rms+rms+rope+rope (neox) - not working

* WIP

* Also qwen3

* Add command line arg to disable rope cache

* Disable RoPE cache if rope type is not neox or norm

* Add missing break after merge with main

* Fused fused_rms+fused_rms+rope+rope (with -mqkv)

* Fused fused_rms+fused_rms+rope+rope (without -mqkv)

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-11-03 18:42:20 +02:00
committed by GitHub
parent 846e736e85
commit fb0d5a995c
12 changed files with 1002 additions and 72 deletions

View File

@@ -3833,6 +3833,7 @@ struct llama_context_params llama_context_default_params() {
/*.grouped_expert_routing =*/ false,
/*.fused_up_gate =*/ true,
/*.fused_mmad =*/ true,
/*.rope_cache =*/ true,
/*.min_experts =*/ -1,
/*.thtesh_experts =*/ 0.0f,
/*.only_active_experts =*/ false,
@@ -4134,6 +4135,7 @@ struct llama_context * llama_new_context_with_model(
cparams.grouped_expert_routing = params.grouped_expert_routing;
cparams.fused_up_gate = params.fused_up_gate;
cparams.fused_mmad = params.fused_mmad;
cparams.rope_cache = params.rope_cache;
cparams.min_experts = params.min_experts;
cparams.thresh_experts = params.thresh_experts;
@@ -4216,6 +4218,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: grouped er = %d\n", __func__, cparams.grouped_expert_routing);
LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate);
LLAMA_LOG_INFO("%s: fused_mmad = %d\n", __func__, cparams.fused_mmad);
LLAMA_LOG_INFO("%s: rope_cache = %d\n", __func__, cparams.rope_cache);
LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);