mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-06 03:50:08 +00:00
CPU optimizations
This commit is contained in:
@@ -22562,6 +22562,11 @@ static void ggml_compute_forward_delta_net_f32(
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
if (iqk_fused_delta_net(head_dim, n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in,
|
||||
out_data, state_out, ith, nth)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t total_heads = n_heads * n_seqs;
|
||||
const int64_t heads_per_thread = (total_heads + nth - 1) / nth;
|
||||
const int64_t h_start = ith * heads_per_thread;
|
||||
@@ -22571,9 +22576,7 @@ static void ggml_compute_forward_delta_net_f32(
|
||||
const float scale = 1.0f / sqrtf((float) head_dim);
|
||||
|
||||
float * v_new_buf = (float *) malloc(head_dim * sizeof(float));
|
||||
if (!v_new_buf) {
|
||||
return;
|
||||
}
|
||||
GGML_ASSERT(v_new_buf);
|
||||
|
||||
for (int64_t h_idx = h_start; h_idx < h_end; ++h_idx) {
|
||||
const int64_t batch_idx = h_idx / n_heads;
|
||||
@@ -22624,17 +22627,17 @@ static void ggml_compute_forward_delta_net_f32(
|
||||
float out_val = 0.0f;
|
||||
|
||||
for (int64_t col = 0; col < head_dim; ++col) {
|
||||
const float k_col = k_t[col] * k_norm_inv;
|
||||
const float q_col = q_t[col] * q_norm_inv * scale;
|
||||
const float k_col = k_t[col];
|
||||
const float q_col = q_t[col];
|
||||
const float s = state[row + col * head_dim];
|
||||
|
||||
v_prime += s * k_col * beta_val * decay;
|
||||
out_val += s * q_col * decay;
|
||||
v_prime += s * k_col;
|
||||
out_val += s * q_col;
|
||||
}
|
||||
|
||||
const float v_new = v_t[row] * beta_val - v_prime;
|
||||
const float v_new = v_t[row] * beta_val - v_prime * beta_val * decay * k_norm_inv;
|
||||
v_new_buf[row] = v_new;
|
||||
out_t[row] = out_val + v_new * attn_score;
|
||||
out_t[row] = out_val * decay * q_norm_inv * scale + v_new * attn_score;
|
||||
}
|
||||
|
||||
for (int64_t col = 0; col < head_dim; ++col) {
|
||||
|
||||
@@ -1383,6 +1383,112 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
template <int head_dim>
|
||||
void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data,
|
||||
const float * state_in, float * out_data, float * state_out, int ith, int nth) {
|
||||
const int total_heads = n_heads * n_seqs;
|
||||
const int heads_per_thread = (total_heads + nth - 1) / nth;
|
||||
const int h_start = ith * heads_per_thread;
|
||||
const int h_end = (h_start + heads_per_thread < total_heads) ? h_start + heads_per_thread : total_heads;
|
||||
|
||||
const float eps = 1e-12f;
|
||||
const float scale = 1.0f / sqrtf((float) head_dim);
|
||||
|
||||
float v_new_buf[head_dim];
|
||||
|
||||
for (int h_idx = h_start; h_idx < h_end; ++h_idx) {
|
||||
const int batch_idx = h_idx / n_heads;
|
||||
const int head_idx = h_idx % n_heads;
|
||||
|
||||
const int qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
|
||||
const int qkv_token_stride = head_dim;
|
||||
const int g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens;
|
||||
const int state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim);
|
||||
const int out_head_offset = batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim;
|
||||
const int out_token_stride = head_dim * n_heads;
|
||||
|
||||
for (int i = 0; i < head_dim * head_dim; ++i) {
|
||||
state_out[state_head_offset + i] = state_in[state_head_offset + i];
|
||||
}
|
||||
|
||||
float * state = state_out + state_head_offset;
|
||||
|
||||
for (int t = 0; t < n_tokens; ++t) {
|
||||
const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride;
|
||||
const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride;
|
||||
const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride;
|
||||
|
||||
const float g_val = g_data[g_head_offset + t];
|
||||
const float beta_raw = beta_data[g_head_offset + t];
|
||||
|
||||
float q_norm_sq = 0.0f;
|
||||
float k_norm_sq = 0.0f;
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
q_norm_sq += q_t[i] * q_t[i];
|
||||
k_norm_sq += k_t[i] * k_t[i];
|
||||
}
|
||||
const float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps);
|
||||
const float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps);
|
||||
|
||||
const float beta_val = 1.0f / (1.0f + expf(-beta_raw));
|
||||
const float decay = expf(fminf(g_val, 50.0f));
|
||||
|
||||
float attn_score = 0.0f;
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale);
|
||||
}
|
||||
|
||||
float * out_t = out_data + out_head_offset + t * out_token_stride;
|
||||
|
||||
for (int row = 0; row < head_dim; ++row) {
|
||||
float v_prime = 0.0f;
|
||||
float out_val = 0.0f;
|
||||
|
||||
for (int col = 0; col < head_dim; ++col) {
|
||||
const float k_col = k_t[col];
|
||||
const float q_col = q_t[col];
|
||||
const float s = state[row + col * head_dim];
|
||||
|
||||
v_prime += s * k_col;
|
||||
out_val += s * q_col;
|
||||
}
|
||||
|
||||
const float v_new = v_t[row] * beta_val - v_prime * beta_val * decay * k_norm_inv;
|
||||
v_new_buf[row] = v_new;
|
||||
out_t[row] = out_val * decay * q_norm_inv * scale + v_new * attn_score;
|
||||
}
|
||||
|
||||
for (int col = 0; col < head_dim; ++col) {
|
||||
const float k_col = k_t[col] * k_norm_inv;
|
||||
for (int row = 0; row < head_dim; ++row) {
|
||||
float s = state[row + col * head_dim];
|
||||
s = decay * s + v_new_buf[row] * k_col;
|
||||
state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool iqk_fused_delta_net(int head_dim, int n_heads, int n_tokens, int n_seqs,
|
||||
const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data,
|
||||
const float * state_in, float * out_data, float * state_out, int ith, int nth) {
|
||||
if (head_dim != 64 && head_dim != 128) {
|
||||
return false;
|
||||
}
|
||||
if (head_dim == 64) {
|
||||
iqk_fused_delta_net_impl<64>(n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in,
|
||||
out_data, state_out, ith, nth);
|
||||
} else {
|
||||
iqk_fused_delta_net_impl<128>(n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in,
|
||||
out_data, state_out, ith, nth);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#else // IQK_IMPLEMENT
|
||||
|
||||
#include "ggml-impl.h"
|
||||
@@ -1416,4 +1522,11 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*n
|
||||
return false;
|
||||
}
|
||||
|
||||
bool iqk_fused_delta_net(int, int, int, int,
|
||||
const float *, const float *, const float *, const float *, const float *,
|
||||
const float *, float *, float *, int, int) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
@@ -73,6 +73,10 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
|
||||
IQK_API void iqk_topk_moe(int n_experts, int n_experts_used, int nrows, const float * logits,
|
||||
float * weights, int32_t * ids, int ith, int nth);
|
||||
|
||||
IQK_API bool iqk_fused_delta_net(int head_dim, int n_heads, int n_tokens, int n_seqs,
|
||||
const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data,
|
||||
const float * state_in, float * out_data, float * state_out, int ith, int nth);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -406,7 +406,7 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_co
|
||||
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
||||
g = ggml_permute(ctx0, g, 2, 0, 3, 1);
|
||||
beta = ggml_permute(ctx0, beta, 2, 0, 1, 3);
|
||||
if (n_seqs > 1) {
|
||||
if (n_seqs > 1 || n_tokens > 1) {
|
||||
q = ggml_cont_4d(ctx0, q, S_k, n_tokens, H_k, n_seqs);
|
||||
k = ggml_cont_4d(ctx0, k, S_k, n_tokens, H_k, n_seqs);
|
||||
v = ggml_cont_4d(ctx0, v, S_v, n_tokens, H_v, n_seqs);
|
||||
@@ -680,15 +680,16 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
cb(k_conv, "k_conv_predelta", il);
|
||||
cb(v_conv, "v_conv_predelta", il);
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> attn_out;
|
||||
|
||||
GGML_ASSERT(causal_mask != nullptr);
|
||||
GGML_ASSERT(identity != nullptr);
|
||||
GGML_ASSERT(diag_mask != nullptr);
|
||||
|
||||
attn_out = n_tok == 1 ? lctx.cparams.fused_delta_net ? build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb)
|
||||
: build_delta_net_autoregressive(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb)
|
||||
: build_delta_net_chunking(ctx0, q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il, cb);
|
||||
std::pair<ggml_tensor *, ggml_tensor *> attn_out;
|
||||
// The fused delta-net implementation is only faster than chunked for n_tok <= 8, so use it only in that case
|
||||
attn_out = lctx.cparams.fused_delta_net && n_tok <= 8 ? build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb) :
|
||||
n_tok == 1 ? build_delta_net_autoregressive(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb)
|
||||
: build_delta_net_chunking(ctx0, q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il, cb);
|
||||
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
cb(output, "attn_output", il);
|
||||
|
||||
Reference in New Issue
Block a user