Hadamard transforms for K-cache - CPU only (#1033)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-12-04 06:51:11 +01:00
committed by GitHub
parent 0581f90c0f
commit 18fdd80eaf
13 changed files with 155 additions and 20 deletions

View File

@@ -52,6 +52,7 @@ llm_build_context::llm_build_context(
fused_up_gate (cparams.fused_up_gate),
fused_mmad (cparams.fused_mmad),
rope_cache (cparams.rope_cache),
k_cache_hadamard (cparams.k_cache_hadamard),
min_experts (cparams.min_experts),
thresh_experts (cparams.thresh_experts),
pooling_type (cparams.pooling_type),
@@ -1466,6 +1467,13 @@ ggml_tensor * llm_build_context::llm_build_kv(
const llama_hparams & hparams = lctx.model.hparams;
const llama_cparams & cparams = lctx.cparams;
if (cparams.k_cache_hadamard) {
q_cur = ggml_hadamard(ctx, q_cur, hparams.n_embd_head_k);
k_cur = ggml_hadamard(ctx, k_cur, hparams.n_embd_head_k);
cb(q_cur, "Qcur_hadamard", il);
cb(k_cur, "Kcur_hadamard", il);
}
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(graph, q_cur);
@@ -9375,6 +9383,12 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
cb(Qcur, "Qcur_temp_scaled", il_cb);
}
if (cparams.k_cache_hadamard) {
Qcur = ggml_hadamard(ctx0, Qcur, hparams.n_embd_head_k);
Kcur = ggml_hadamard(ctx0, Kcur, hparams.n_embd_head_k);
cb(Qcur, "Qcur_hadamard", il_cb);
cb(Kcur, "Kcur_hadamard", il_cb);
}
ggml_build_forward_expand(gf, Qcur);
ggml_build_forward_expand(gf, Kcur);
ggml_build_forward_expand(gf, Vcur);

View File

@@ -82,6 +82,7 @@ struct llm_build_context {
const bool fused_up_gate;
const bool fused_mmad;
const bool rope_cache;
const bool k_cache_hadamard;
const int min_experts;
const float thresh_experts;

View File

@@ -39,6 +39,7 @@ struct llama_cparams {
bool fused_mmad;
bool rope_cache;
bool graph_reuse;
bool k_cache_hadamard;
int min_experts;
float thresh_experts;

View File

@@ -4048,6 +4048,7 @@ struct llama_context_params llama_context_default_params() {
/*.min_experts =*/ -1,
/*.thtesh_experts =*/ 0.0f,
/*.only_active_experts =*/ false,
/*.k_cache_hadamard =*/ false,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
/*.offload_policy =*/ nullptr,
@@ -4297,6 +4298,11 @@ struct llama_context * llama_new_context_with_model(
return nullptr;
}
if (params.k_cache_hadamard && !ggml_is_quantized(params.type_k)) {
LLAMA_LOG_WARN("%s: there is no point in Hadamard transforms with not quantized K-cache. Turning Hadamard off\n", __func__);
params.k_cache_hadamard = false;
}
llama_context * ctx = new llama_context(*model);
// add devices to ctx->cparams from model
@@ -4330,6 +4336,7 @@ struct llama_context * llama_new_context_with_model(
cparams.fused_mmad = params.fused_mmad;
cparams.rope_cache = params.rope_cache;
cparams.graph_reuse = params.graph_reuse;
cparams.k_cache_hadamard = params.k_cache_hadamard;
cparams.min_experts = params.min_experts;
cparams.thresh_experts = params.thresh_experts;
cparams.cuda_params = params.cuda_params;
@@ -4417,6 +4424,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: fused_mmad = %d\n", __func__, cparams.fused_mmad);
LLAMA_LOG_INFO("%s: rope_cache = %d\n", __func__, cparams.rope_cache);
LLAMA_LOG_INFO("%s: graph_reuse = %d\n", __func__, cparams.graph_reuse);
LLAMA_LOG_INFO("%s: k_cache_hadam = %d\n", __func__, cparams.k_cache_hadamard);
LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);