mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-01 01:24:08 +00:00
Much faster fused delta-net on the CPU
It seems it is faster than the chunked implementation!
This commit is contained in:
@@ -1393,10 +1393,15 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
const int h_start = ith * heads_per_thread;
|
||||
const int h_end = (h_start + heads_per_thread < total_heads) ? h_start + heads_per_thread : total_heads;
|
||||
|
||||
#ifdef __AVX2__
|
||||
static_assert(head_dim % 8 == 0);
|
||||
#endif
|
||||
|
||||
const float eps = 1e-12f;
|
||||
const float scale = 1.0f / sqrtf((float) head_dim);
|
||||
|
||||
float v_new_buf[head_dim];
|
||||
float v_prime[head_dim], out_val[head_dim];
|
||||
|
||||
for (int h_idx = h_start; h_idx < h_end; ++h_idx) {
|
||||
const int batch_idx = h_idx / n_heads;
|
||||
@@ -1425,41 +1430,78 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
|
||||
float q_norm_sq = 0.0f;
|
||||
float k_norm_sq = 0.0f;
|
||||
float kq_sum = 0.0f;
|
||||
#ifdef __AVX2__
|
||||
auto vqsum = _mm256_setzero_ps();
|
||||
auto vksum = _mm256_setzero_ps();
|
||||
auto vqksum = _mm256_setzero_ps();
|
||||
for (int i = 0; i < head_dim; i += 8) {
|
||||
auto vq = _mm256_loadu_ps(q_t + i);
|
||||
auto vk = _mm256_loadu_ps(k_t + i);
|
||||
vqsum = _mm256_fmadd_ps(vq, vq, vqsum);
|
||||
vksum = _mm256_fmadd_ps(vk, vk, vksum);
|
||||
vqksum = _mm256_fmadd_ps(vk, vq, vqksum);
|
||||
}
|
||||
q_norm_sq = hsum_float_8(vqsum);
|
||||
k_norm_sq = hsum_float_8(vksum);
|
||||
kq_sum = hsum_float_8(vqksum);
|
||||
#else
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
q_norm_sq += q_t[i] * q_t[i];
|
||||
k_norm_sq += k_t[i] * k_t[i];
|
||||
kq_sum += k_t[i] * q_t[i];
|
||||
}
|
||||
#endif
|
||||
const float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps);
|
||||
const float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps);
|
||||
|
||||
const float beta_val = 1.0f / (1.0f + expf(-beta_raw));
|
||||
const float decay = expf(fminf(g_val, 50.0f));
|
||||
|
||||
float attn_score = 0.0f;
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale);
|
||||
}
|
||||
float attn_score = kq_sum * k_norm_inv * q_norm_inv * scale;
|
||||
|
||||
//float attn_score = 0.0f;
|
||||
//for (int i = 0; i < head_dim; ++i) {
|
||||
// attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale);
|
||||
//}
|
||||
|
||||
float * out_t = out_data + out_head_offset + t * out_token_stride;
|
||||
|
||||
for (int row = 0; row < head_dim; ++row) {
|
||||
float v_prime = 0.0f;
|
||||
float out_val = 0.0f;
|
||||
|
||||
for (int col = 0; col < head_dim; ++col) {
|
||||
const float k_col = k_t[col];
|
||||
const float q_col = q_t[col];
|
||||
std::memset(v_prime, 0, head_dim*sizeof(float));
|
||||
std::memset(out_val, 0, head_dim*sizeof(float));
|
||||
for (int col = 0; col < head_dim; ++col) {
|
||||
const float k_col = k_t[col];
|
||||
const float q_col = q_t[col];
|
||||
for (int row = 0; row < head_dim; ++row) {
|
||||
const float s = state[row + col * head_dim];
|
||||
|
||||
v_prime += s * k_col;
|
||||
out_val += s * q_col;
|
||||
v_prime[row] += s * k_col;
|
||||
out_val[row] += s * q_col;
|
||||
}
|
||||
|
||||
const float v_new = v_t[row] * beta_val - v_prime * beta_val * decay * k_norm_inv;
|
||||
}
|
||||
for (int row = 0; row < head_dim; ++row) {
|
||||
const float v_new = v_t[row] * beta_val - v_prime[row] * beta_val * decay * k_norm_inv;
|
||||
v_new_buf[row] = v_new;
|
||||
out_t[row] = out_val * decay * q_norm_inv * scale + v_new * attn_score;
|
||||
out_t[row] = out_val[row] * decay * q_norm_inv * scale + v_new * attn_score;
|
||||
}
|
||||
|
||||
#ifdef __AVX2__
|
||||
auto vd = _mm256_set1_ps(decay);
|
||||
auto vmin = _mm256_set1_ps(-1e6f);
|
||||
auto vmax = _mm256_set1_ps( 1e6f);
|
||||
for (int col = 0; col < head_dim; ++col) {
|
||||
auto vk = _mm256_set1_ps(k_t[col] * k_norm_inv);
|
||||
for (int row = 0; row < head_dim; row += 8) {
|
||||
auto vs = _mm256_loadu_ps(state + col * head_dim + row);
|
||||
auto vn = _mm256_loadu_ps(v_new_buf + row);
|
||||
vs = _mm256_fmadd_ps(vn, vk, _mm256_mul_ps(vs, vd));
|
||||
auto mask_l = _mm256_cmp_ps(vs, vmin, _CMP_LT_OQ);
|
||||
auto mask_u = _mm256_cmp_ps(vs, vmax, _CMP_GT_OQ);
|
||||
vs = _mm256_or_ps(_mm256_and_ps(mask_l, vmin), _mm256_andnot_ps(mask_l, vs));
|
||||
vs = _mm256_or_ps(_mm256_and_ps(mask_u, vmax), _mm256_andnot_ps(mask_u, vs));
|
||||
_mm256_storeu_ps(state + col * head_dim + row, vs);
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int col = 0; col < head_dim; ++col) {
|
||||
const float k_col = k_t[col] * k_norm_inv;
|
||||
for (int row = 0; row < head_dim; ++row) {
|
||||
@@ -1468,6 +1510,7 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user