From 2ef38b56df61618b99f7f4cddc7da4881b7f0a3f Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 24 Feb 2026 10:41:19 +0000 Subject: [PATCH] CPU optimizations --- ggml/src/ggml.c | 21 ++++--- ggml/src/iqk/iqk_mul_mat.cpp | 113 +++++++++++++++++++++++++++++++++++ ggml/src/iqk/iqk_mul_mat.h | 4 ++ src/llama-delta-net.cpp | 13 ++-- 4 files changed, 136 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 27211f32..059b589c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -22562,6 +22562,11 @@ static void ggml_compute_forward_delta_net_f32( const int ith = params->ith; const int nth = params->nth; + if (iqk_fused_delta_net(head_dim, n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, + out_data, state_out, ith, nth)) { + return; + } + const int64_t total_heads = n_heads * n_seqs; const int64_t heads_per_thread = (total_heads + nth - 1) / nth; const int64_t h_start = ith * heads_per_thread; @@ -22571,9 +22576,7 @@ static void ggml_compute_forward_delta_net_f32( const float scale = 1.0f / sqrtf((float) head_dim); float * v_new_buf = (float *) malloc(head_dim * sizeof(float)); - if (!v_new_buf) { - return; - } + GGML_ASSERT(v_new_buf); for (int64_t h_idx = h_start; h_idx < h_end; ++h_idx) { const int64_t batch_idx = h_idx / n_heads; @@ -22624,17 +22627,17 @@ static void ggml_compute_forward_delta_net_f32( float out_val = 0.0f; for (int64_t col = 0; col < head_dim; ++col) { - const float k_col = k_t[col] * k_norm_inv; - const float q_col = q_t[col] * q_norm_inv * scale; + const float k_col = k_t[col]; + const float q_col = q_t[col]; const float s = state[row + col * head_dim]; - v_prime += s * k_col * beta_val * decay; - out_val += s * q_col * decay; + v_prime += s * k_col; + out_val += s * q_col; } - const float v_new = v_t[row] * beta_val - v_prime; + const float v_new = v_t[row] * beta_val - v_prime * beta_val * decay * k_norm_inv; v_new_buf[row] = v_new; - out_t[row] = out_val + v_new * attn_score; + out_t[row] = out_val * decay * q_norm_inv * scale + v_new * attn_score; } for (int64_t col = 0; col < head_dim; ++col) { diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index c580edbf..0b0537f2 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1383,6 +1383,112 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k } #endif +namespace { +template +void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs, + const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data, + const float * state_in, float * out_data, float * state_out, int ith, int nth) { + const int total_heads = n_heads * n_seqs; + const int heads_per_thread = (total_heads + nth - 1) / nth; + const int h_start = ith * heads_per_thread; + const int h_end = (h_start + heads_per_thread < total_heads) ? h_start + heads_per_thread : total_heads; + + const float eps = 1e-12f; + const float scale = 1.0f / sqrtf((float) head_dim); + + float v_new_buf[head_dim]; + + for (int h_idx = h_start; h_idx < h_end; ++h_idx) { + const int batch_idx = h_idx / n_heads; + const int head_idx = h_idx % n_heads; + + const int qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens); + const int qkv_token_stride = head_dim; + const int g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens; + const int state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim); + const int out_head_offset = batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim; + const int out_token_stride = head_dim * n_heads; + + for (int i = 0; i < head_dim * head_dim; ++i) { + state_out[state_head_offset + i] = state_in[state_head_offset + i]; + } + + float * state = state_out + state_head_offset; + + for (int t = 0; t < n_tokens; ++t) { + const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride; + const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride; + const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride; + + const float g_val = g_data[g_head_offset + t]; + const float beta_raw = beta_data[g_head_offset + t]; + + float q_norm_sq = 0.0f; + float k_norm_sq = 0.0f; + for (int i = 0; i < head_dim; ++i) { + q_norm_sq += q_t[i] * q_t[i]; + k_norm_sq += k_t[i] * k_t[i]; + } + const float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps); + const float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps); + + const float beta_val = 1.0f / (1.0f + expf(-beta_raw)); + const float decay = expf(fminf(g_val, 50.0f)); + + float attn_score = 0.0f; + for (int i = 0; i < head_dim; ++i) { + attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale); + } + + float * out_t = out_data + out_head_offset + t * out_token_stride; + + for (int row = 0; row < head_dim; ++row) { + float v_prime = 0.0f; + float out_val = 0.0f; + + for (int col = 0; col < head_dim; ++col) { + const float k_col = k_t[col]; + const float q_col = q_t[col]; + const float s = state[row + col * head_dim]; + + v_prime += s * k_col; + out_val += s * q_col; + } + + const float v_new = v_t[row] * beta_val - v_prime * beta_val * decay * k_norm_inv; + v_new_buf[row] = v_new; + out_t[row] = out_val * decay * q_norm_inv * scale + v_new * attn_score; + } + + for (int col = 0; col < head_dim; ++col) { + const float k_col = k_t[col] * k_norm_inv; + for (int row = 0; row < head_dim; ++row) { + float s = state[row + col * head_dim]; + s = decay * s + v_new_buf[row] * k_col; + state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f); + } + } + } + } +} +} + +bool iqk_fused_delta_net(int head_dim, int n_heads, int n_tokens, int n_seqs, + const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data, + const float * state_in, float * out_data, float * state_out, int ith, int nth) { + if (head_dim != 64 && head_dim != 128) { + return false; + } + if (head_dim == 64) { + iqk_fused_delta_net_impl<64>(n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, + out_data, state_out, ith, nth); + } else { + iqk_fused_delta_net_impl<128>(n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, + out_data, state_out, ith, nth); + } + return true; +} + #else // IQK_IMPLEMENT #include "ggml-impl.h" @@ -1416,4 +1522,11 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*n return false; } +bool iqk_fused_delta_net(int, int, int, int, + const float *, const float *, const float *, const float *, const float *, + const float *, float *, float *, int, int) { + return false; +} + + #endif diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index 904b55ae..440bc815 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -73,6 +73,10 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, IQK_API void iqk_topk_moe(int n_experts, int n_experts_used, int nrows, const float * logits, float * weights, int32_t * ids, int ith, int nth); +IQK_API bool iqk_fused_delta_net(int head_dim, int n_heads, int n_tokens, int n_seqs, + const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data, + const float * state_in, float * out_data, float * state_out, int ith, int nth); + #ifdef __cplusplus } #endif diff --git a/src/llama-delta-net.cpp b/src/llama-delta-net.cpp index 891a949a..41c18752 100644 --- a/src/llama-delta-net.cpp +++ b/src/llama-delta-net.cpp @@ -406,7 +406,7 @@ std::pair delta_net::build_fused_delta_net(ggml_co v = ggml_permute(ctx0, v, 0, 2, 1, 3); g = ggml_permute(ctx0, g, 2, 0, 3, 1); beta = ggml_permute(ctx0, beta, 2, 0, 1, 3); - if (n_seqs > 1) { + if (n_seqs > 1 || n_tokens > 1) { q = ggml_cont_4d(ctx0, q, S_k, n_tokens, H_k, n_seqs); k = ggml_cont_4d(ctx0, k, S_k, n_tokens, H_k, n_seqs); v = ggml_cont_4d(ctx0, v, S_v, n_tokens, H_v, n_seqs); @@ -680,15 +680,16 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_ cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - std::pair attn_out; - GGML_ASSERT(causal_mask != nullptr); GGML_ASSERT(identity != nullptr); GGML_ASSERT(diag_mask != nullptr); - attn_out = n_tok == 1 ? lctx.cparams.fused_delta_net ? build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb) - : build_delta_net_autoregressive(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb) - : build_delta_net_chunking(ctx0, q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il, cb); + std::pair attn_out; + // The fused delta-net implementation is only faster than chunked for n_tok <= 8, so use it only in that case + attn_out = lctx.cparams.fused_delta_net && n_tok <= 8 ? build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb) : + n_tok == 1 ? build_delta_net_autoregressive(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb) + : build_delta_net_chunking(ctx0, q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il, cb); + ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; cb(output, "attn_output", il);