Rename X_pe to X_rope

Much easier to follow, at least for my brain, when we have
  X_rope : rotational position encoding
  X_nope :         no position encoding
instead of X_pe and X_nope, where I was wondering wtf is 'pe'
and 'nope'.
This commit is contained in:
Iwan Kawrakow
2025-02-12 07:49:16 +02:00
parent 978aaa9f68
commit 54252d0256

View File

@@ -13419,30 +13419,30 @@ struct llm_build_context {
cb(q_nope, "q_nope", il);
// and {n_head * n_embd_head_qk_rope, n_tokens}
struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
struct ggml_tensor * q_rope = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
ggml_row_size(q->type, hparams.n_embd_head_k),
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
ggml_row_size(q->type, n_embd_head_qk_nope));
cb(q_pe, "q_pe", il);
cb(q_rope, "q_rope", il);
// {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
cb(kv_pe_compresseed, "kv_pe_compresseed", il);
struct ggml_tensor * kv_rope_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
cb(kv_rope_compresseed, "kv_rope_compresseed", il);
// split into {kv_lora_rank, n_tokens}
struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
kv_pe_compresseed->nb[1],
struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_rope_compresseed, kv_lora_rank, n_tokens,
kv_rope_compresseed->nb[1],
0);
cb(kv_compressed, "kv_compressed", il);
if (lctx.cparams.mla_attn && model.layers[il].wk_b && model.layers[il].wv_b) {
// and {n_embd_head_qk_rope, n_tokens}
struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
kv_pe_compresseed->nb[1],
kv_pe_compresseed->nb[1],
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
cb(k_pe, "k_pe", il);
struct ggml_tensor * k_rope = ggml_view_3d(ctx0, kv_rope_compresseed, n_embd_head_qk_rope, 1, n_tokens,
kv_rope_compresseed->nb[1],
kv_rope_compresseed->nb[1],
ggml_row_size(kv_rope_compresseed->type, kv_lora_rank));
cb(k_rope, "k_rope", il);
//kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
@@ -13476,28 +13476,26 @@ struct llm_build_context {
0);
cb(kv_cache_trans, "kv_cache_trans", il);
//q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
q_pe = ggml_rope_ext(
ctx0, q_pe, inp_pos, nullptr,
q_rope = ggml_rope_ext(
ctx0, q_rope, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor_scaled, beta_fast, beta_slow
);
cb(q_pe, "q_pe", il);
cb(q_rope, "q_rope", il);
// shared RoPE key
//k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
k_pe = ggml_rope_ext(
ctx0, k_pe, inp_pos, nullptr,
k_rope = ggml_rope_ext(
ctx0, k_rope, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor_scaled, beta_fast, beta_slow
);
cb(k_pe, "k_pe", il);
cb(k_rope, "k_rope", il);
struct ggml_tensor * kr_cache_view = ggml_view_1d(ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head);
cb(kr_cache_view, "kr_cache_view", il);
// note: storing RoPE-ed version of K^R in the KV cache
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_rope, kr_cache_view));
struct ggml_tensor * kr_cache =
ggml_view_2d(ctx0, kv_self.kr_l[il],
@@ -13528,18 +13526,18 @@ struct llm_build_context {
}
if (pp_opt) {
q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3);
cb(q_pe, "q_pe_perm", il);
q_rope = ggml_permute(ctx0, q_rope, 0, 2, 1, 3);
cb(q_rope, "q_rope_perm", il);
}
struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe);
cb(kq_pe, "kq_pe", il);
struct ggml_tensor * kq_rope = ggml_mul_mat(ctx0, kr_cache, q_rope);
cb(kq_rope, "kq_rope", il);
if (!pp_opt) {
kq_pe = ggml_permute(ctx0, kq_pe, 0, 2, 1, 3);
cb(kq_pe, "kq_pe_perm", il);
kq_rope = ggml_permute(ctx0, kq_rope, 0, 2, 1, 3);
cb(kq_rope, "kq_rope_perm", il);
}
struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe);
struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_rope);
cb(kq, "kq", il);
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
@@ -13579,11 +13577,11 @@ struct llm_build_context {
else {
// and {n_embd_head_qk_rope, n_tokens}
struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
kv_pe_compresseed->nb[1],
kv_pe_compresseed->nb[1],
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
cb(k_pe, "k_pe", il);
struct ggml_tensor * k_rope = ggml_view_3d(ctx0, kv_rope_compresseed, n_embd_head_qk_rope, 1, n_tokens,
kv_rope_compresseed->nb[1],
kv_rope_compresseed->nb[1],
ggml_row_size(kv_rope_compresseed->type, kv_lora_rank));
cb(k_rope, "k_pe", il);
//kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
@@ -13618,26 +13616,26 @@ struct llm_build_context {
cb(v_states, "v_states", il);
//q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
q_pe = ggml_rope_ext(
ctx0, q_pe, inp_pos, nullptr,
q_rope = ggml_rope_ext(
ctx0, q_rope, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor_scaled, beta_fast, beta_slow
);
cb(q_pe, "q_pe", il);
cb(q_rope, "q_rope", il);
// shared RoPE key
//k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
k_pe = ggml_rope_ext(
ctx0, k_pe, inp_pos, nullptr,
k_rope = ggml_rope_ext(
ctx0, k_rope, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor_scaled, beta_fast, beta_slow
);
cb(k_pe, "k_pe", il);
cb(k_rope, "k_rope", il);
struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_rope, 0);
cb(q_states, "q_states", il);
struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_rope, q_rope), 0);
cb(k_states, "k_states", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,