mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-06 12:00:29 +00:00
RoPE cache (#887)
* Introducing rope cache When computing RoPE, the rotation angles in each layer are exactly the same, and only depend on the token positions (and other constant, model dependent parameters). So, I wonder, why don't we compute the angles just once and then reuse for the Q and K RoPE in each layer? This commit does it as a POC on the CPU, and uses it in the Qwen3-MoE compute graph. * cuda: neox works * WIP * rope_cache: norm works * Fused rope+rope * Fused rope+rope (norm) * Fused rms+rms+rope+rope (neox) - not working * WIP * Also qwen3 * Add command line arg to disable rope cache * Disable RoPE cache if rope type is not neox or norm * Add missing break after merge with main * Fused fused_rms+fused_rms+rope+rope (with -mqkv) * Fused fused_rms+fused_rms+rope+rope (without -mqkv) --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -51,6 +51,7 @@ llm_build_context::llm_build_context(
|
||||
grouped_expert_routing(cparams.grouped_expert_routing),
|
||||
fused_up_gate (cparams.fused_up_gate),
|
||||
fused_mmad (cparams.fused_mmad),
|
||||
rope_cache (cparams.rope_cache),
|
||||
min_experts (cparams.min_experts),
|
||||
thresh_experts (cparams.thresh_experts),
|
||||
pooling_type (cparams.pooling_type),
|
||||
@@ -3372,6 +3373,10 @@ ggml_cgraph * llm_build_context::build_qwen3() {
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
auto rope_cache = cparams.rope_cache && (rope_type == LLAMA_ROPE_TYPE_NEOX || rope_type == LLAMA_ROPE_TYPE_NORM) ?
|
||||
ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow) : nullptr;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
@@ -3388,14 +3393,16 @@ ggml_cgraph * llm_build_context::build_qwen3() {
|
||||
model.layers[il].wv, nullptr,
|
||||
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0, il);
|
||||
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
if (rope_cache) {
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
} else {
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
@@ -3468,6 +3475,9 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
auto rope_cache = cparams.rope_cache && (rope_type == LLAMA_ROPE_TYPE_NEOX || rope_type == LLAMA_ROPE_TYPE_NORM) ?
|
||||
ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow) : nullptr;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
@@ -3483,18 +3493,16 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
|
||||
model.layers[il].wq, nullptr, model.layers[il].wk, nullptr, model.layers[il].wv, nullptr,
|
||||
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0, il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
if (rope_cache) {
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
} else {
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
@@ -6083,6 +6091,10 @@ ggml_cgraph * llm_build_context::build_glm4_moe() {
|
||||
// output token IDs (for last layer cropping)
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
auto rope_cache = cparams.rope_cache && (rope_type == LLAMA_ROPE_TYPE_NEOX || rope_type == LLAMA_ROPE_TYPE_NORM) ?
|
||||
ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow) : nullptr;
|
||||
|
||||
// Only process up to last layer (skip final NextN layer)
|
||||
// Final layer tensors are loaded but not processed in forward pass
|
||||
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
|
||||
@@ -6103,12 +6115,15 @@ ggml_cgraph * llm_build_context::build_glm4_moe() {
|
||||
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.f, il);
|
||||
|
||||
// apply RoPE
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
if (rope_cache) {
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
} else {
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
@@ -7769,6 +7784,9 @@ ggml_cgraph * llm_build_context::build_hunyuan_moe() {
|
||||
ggml_cgraph * llm_build_context::build_openai_moe() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -7786,6 +7804,10 @@ ggml_cgraph * llm_build_context::build_openai_moe() {
|
||||
|
||||
const int sliding_window_pattern = 2;
|
||||
|
||||
auto rope_cache = cparams.rope_cache && (rope_type == LLAMA_ROPE_TYPE_NEOX || rope_type == LLAMA_ROPE_TYPE_NORM) ?
|
||||
ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow) : nullptr;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@@ -7805,30 +7827,18 @@ ggml_cgraph * llm_build_context::build_openai_moe() {
|
||||
model.layers[il].wv, model.layers[il].bv,
|
||||
nullptr, nullptr, 0.0f, il);
|
||||
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
|
||||
beta_fast, beta_slow);
|
||||
if (rope_cache) {
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
} else {
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
|
||||
attn_factor, beta_fast, beta_slow);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
//auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq,
|
||||
// model.layers[il].wk, model.layers[il].bk,
|
||||
// model.layers[il].wv, model.layers[il].bv, 0.f, il);
|
||||
|
||||
//Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr,
|
||||
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
|
||||
// beta_fast, beta_slow);
|
||||
//cb(Qcur, "Qcur", il);
|
||||
|
||||
//Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
|
||||
// attn_factor, beta_fast, beta_slow);
|
||||
//cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks,
|
||||
is_sliding ? hparams.n_swa : 0);
|
||||
@@ -7916,6 +7926,10 @@ ggml_cgraph * llm_build_context::build_bailingmoe2() {
|
||||
|
||||
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
|
||||
|
||||
auto rope_cache = cparams.rope_cache && (rope_type == LLAMA_ROPE_TYPE_NEOX || rope_type == LLAMA_ROPE_TYPE_NORM) ?
|
||||
ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow) : nullptr;
|
||||
|
||||
for (int il = 0; il < n_transformer_layers; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
@@ -7929,27 +7943,15 @@ ggml_cgraph * llm_build_context::build_bailingmoe2() {
|
||||
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
|
||||
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.0f, il);
|
||||
|
||||
//cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
|
||||
//cb(cur, "wqkv", il);
|
||||
|
||||
//ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
|
||||
//ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
|
||||
////ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
//ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
|
||||
|
||||
//Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
//cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
//Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
//cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
if (rope_cache) {
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
} else {
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
Reference in New Issue
Block a user