Remove some unnecessary copies in the MLA attention

This commit is contained in:
Iwan Kawrakow
2025-02-08 12:08:17 +02:00
parent 37c4fbd7f9
commit 3aaf602da5

View File

@@ -13463,7 +13463,7 @@ struct llm_build_context {
0);
cb(kv_cache_trans, "kv_cache_trans", il);
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
//q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
q_pe = ggml_rope_ext(
ctx0, q_pe, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -13472,7 +13472,7 @@ struct llm_build_context {
cb(q_pe, "q_pe", il);
// shared RoPE key
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
//k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
k_pe = ggml_rope_ext(
ctx0, k_pe, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -13508,8 +13508,9 @@ struct llm_build_context {
struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm);
cb(kq_nope, "kq_nope", il);
struct ggml_tensor * q_pe_perm = ggml_permute(ctx0, q_pe, 0, 3, 2, 1);
cb(q_pe_perm, "q_pe_perm", il);
// Huh? This is not used anywhere
//struct ggml_tensor * q_pe_perm = ggml_permute(ctx0, q_pe, 0, 3, 2, 1);
//cb(q_pe_perm, "q_pe_perm", il);
struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe);
cb(kq_pe, "kq_pe", il);
@@ -13517,6 +13518,7 @@ struct llm_build_context {
struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe);
cb(kq, "kq", il);
// We need this copy because soft_max expects a contiguous tensor
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
cb(kq, "kq_perm", il);