mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-12 06:50:08 +00:00
ARM_NEON fused delta-net implementation (#1361)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -17935,7 +17935,10 @@ static void ggml_compute_forward_scale_f32(
|
||||
const float * src_data = (const float *)src0->data + block_size*ib;
|
||||
float * dst_data = ( float *)dst->data + block_size*ib;
|
||||
int n = MIN(block_size, nelements - block_size*ib);
|
||||
if (b == 0.0f) {
|
||||
if (s == 0.0f && b == 0.0f) {
|
||||
memset(dst_data, 0, n*sizeof(float));
|
||||
}
|
||||
else if (b == 0.0f) {
|
||||
if (dst->data != src0->data) {
|
||||
// src0 is same shape as dst => same indices
|
||||
memcpy(dst_data, src_data, n * sizeof(float));
|
||||
|
||||
@@ -1384,10 +1384,121 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
#ifdef __ARM_NEON
|
||||
template <int head_dim>
|
||||
void iqk_fused_delta_net_neon_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data,
|
||||
const float * state_in, float * out_data, float * state_out, int ith, int nth) {
|
||||
const int total_heads = n_heads * n_seqs;
|
||||
const int heads_per_thread = (total_heads + nth - 1) / nth;
|
||||
const int h_start = ith * heads_per_thread;
|
||||
const int h_end = (h_start + heads_per_thread < total_heads) ? h_start + heads_per_thread : total_heads;
|
||||
|
||||
static_assert(head_dim % 4 == 0);
|
||||
|
||||
const float scale = 1.0f / sqrtf((float) head_dim);
|
||||
|
||||
float v_new_buf[head_dim];
|
||||
float v_prime[head_dim], out_val[head_dim];
|
||||
|
||||
float32x4x4_t vs4[4];
|
||||
|
||||
for (int h_idx = h_start; h_idx < h_end; ++h_idx) {
|
||||
const int batch_idx = h_idx / n_heads;
|
||||
const int head_idx = h_idx % n_heads;
|
||||
|
||||
const int qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
|
||||
const int qkv_token_stride = head_dim;
|
||||
const int g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens;
|
||||
const int state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim);
|
||||
const int out_head_offset = batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim;
|
||||
const int out_token_stride = head_dim * n_heads;
|
||||
|
||||
for (int i = 0; i < head_dim * head_dim; ++i) {
|
||||
state_out[state_head_offset + i] = state_in[state_head_offset + i];
|
||||
}
|
||||
|
||||
float * state = state_out + state_head_offset;
|
||||
|
||||
|
||||
for (int t = 0; t < n_tokens; ++t) {
|
||||
const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride;
|
||||
const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride;
|
||||
const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride;
|
||||
|
||||
const float g_val = g_data[g_head_offset + t];
|
||||
const float beta_raw = beta_data[g_head_offset + t];
|
||||
|
||||
float kq_sum = 0.0f;
|
||||
auto vqksum = vdupq_n_f32(0.0f);
|
||||
for (int i = 0; i < head_dim; i += 4) {
|
||||
auto vq = vld1q_f32(q_t + i);
|
||||
auto vk = vld1q_f32(k_t + i);
|
||||
vqksum = vfmaq_f32(vqksum, vq, vk);
|
||||
}
|
||||
kq_sum = vaddvq_f32(vqksum);
|
||||
|
||||
const float beta_val = 1.0f / (1.0f + expf(-beta_raw));
|
||||
const float decay = expf(fminf(g_val, 50.0f));
|
||||
|
||||
float attn_score = kq_sum * scale;
|
||||
|
||||
float * out_t = out_data + out_head_offset + t * out_token_stride;
|
||||
|
||||
std::memset(v_prime, 0, head_dim*sizeof(float));
|
||||
std::memset(out_val, 0, head_dim*sizeof(float));
|
||||
for (int col = 0; col < head_dim; ++col) {
|
||||
const float k_col = k_t[col];
|
||||
const float q_col = q_t[col];
|
||||
for (int row = 0; row < head_dim; ++row) {
|
||||
const float s = state[row + col * head_dim];
|
||||
v_prime[row] += s * k_col;
|
||||
out_val[row] += s * q_col;
|
||||
}
|
||||
}
|
||||
for (int row = 0; row < head_dim; ++row) {
|
||||
const float v_new = v_t[row] * beta_val - v_prime[row] * beta_val * decay;
|
||||
v_new_buf[row] = v_new;
|
||||
out_t[row] = out_val[row] * decay * scale + v_new * attn_score;
|
||||
}
|
||||
|
||||
auto vd = vdupq_n_f32(decay);
|
||||
auto vmin = vdupq_n_f32(-1e6f);
|
||||
auto vmax = vdupq_n_f32( 1e6f);
|
||||
for (int col = 0; col < head_dim; col += 4) {
|
||||
auto vk = vld1q_f32(k_t + col);
|
||||
for (int row = 0; row < head_dim; row += 16) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
vs4[k] = vld1q_f32_x4(state + (col + k)*head_dim + row);
|
||||
for (int j = 0; j < 4; ++j) vs4[k].val[j] = vmulq_f32(vs4[k].val[j], vd);
|
||||
}
|
||||
auto vn = vld1q_f32_x4(v_new_buf + row);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
vs4[0].val[j] = vfmaq_laneq_f32(vs4[0].val[j], vn.val[j], vk, 0);
|
||||
vs4[1].val[j] = vfmaq_laneq_f32(vs4[1].val[j], vn.val[j], vk, 1);
|
||||
vs4[2].val[j] = vfmaq_laneq_f32(vs4[2].val[j], vn.val[j], vk, 2);
|
||||
vs4[3].val[j] = vfmaq_laneq_f32(vs4[3].val[j], vn.val[j], vk, 3);
|
||||
}
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
vs4[k].val[j] = vmaxq_f32(vminq_f32(vs4[k].val[j], vmax), vmin);
|
||||
}
|
||||
vst1q_f32_x4(state + (col + k)*head_dim + row, vs4[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
template <int head_dim>
|
||||
void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data,
|
||||
const float * state_in, float * out_data, float * state_out, int ith, int nth) {
|
||||
#ifdef __ARM_NEON
|
||||
iqk_fused_delta_net_neon_impl<head_dim>(n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in, out_data, state_out, ith, nth);
|
||||
return;
|
||||
#endif
|
||||
const int total_heads = n_heads * n_seqs;
|
||||
const int heads_per_thread = (total_heads + nth - 1) / nth;
|
||||
const int h_start = ith * heads_per_thread;
|
||||
|
||||
Reference in New Issue
Block a user