Much faster fused delta-net on the CPU

It seems it is faster than the chunked implementation!
This commit is contained in:
Kawrakow
2026-02-24 12:42:06 +00:00
parent 2ef38b56df
commit b184e84480

View File

@@ -1393,10 +1393,15 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
const int h_start = ith * heads_per_thread;
const int h_end = (h_start + heads_per_thread < total_heads) ? h_start + heads_per_thread : total_heads;
#ifdef __AVX2__
static_assert(head_dim % 8 == 0);
#endif
const float eps = 1e-12f;
const float scale = 1.0f / sqrtf((float) head_dim);
float v_new_buf[head_dim];
float v_prime[head_dim], out_val[head_dim];
for (int h_idx = h_start; h_idx < h_end; ++h_idx) {
const int batch_idx = h_idx / n_heads;
@@ -1425,41 +1430,78 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
float q_norm_sq = 0.0f;
float k_norm_sq = 0.0f;
float kq_sum = 0.0f;
#ifdef __AVX2__
auto vqsum = _mm256_setzero_ps();
auto vksum = _mm256_setzero_ps();
auto vqksum = _mm256_setzero_ps();
for (int i = 0; i < head_dim; i += 8) {
auto vq = _mm256_loadu_ps(q_t + i);
auto vk = _mm256_loadu_ps(k_t + i);
vqsum = _mm256_fmadd_ps(vq, vq, vqsum);
vksum = _mm256_fmadd_ps(vk, vk, vksum);
vqksum = _mm256_fmadd_ps(vk, vq, vqksum);
}
q_norm_sq = hsum_float_8(vqsum);
k_norm_sq = hsum_float_8(vksum);
kq_sum = hsum_float_8(vqksum);
#else
for (int i = 0; i < head_dim; ++i) {
q_norm_sq += q_t[i] * q_t[i];
k_norm_sq += k_t[i] * k_t[i];
kq_sum += k_t[i] * q_t[i];
}
#endif
const float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps);
const float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps);
const float beta_val = 1.0f / (1.0f + expf(-beta_raw));
const float decay = expf(fminf(g_val, 50.0f));
float attn_score = 0.0f;
for (int i = 0; i < head_dim; ++i) {
attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale);
}
float attn_score = kq_sum * k_norm_inv * q_norm_inv * scale;
//float attn_score = 0.0f;
//for (int i = 0; i < head_dim; ++i) {
// attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale);
//}
float * out_t = out_data + out_head_offset + t * out_token_stride;
for (int row = 0; row < head_dim; ++row) {
float v_prime = 0.0f;
float out_val = 0.0f;
for (int col = 0; col < head_dim; ++col) {
const float k_col = k_t[col];
const float q_col = q_t[col];
std::memset(v_prime, 0, head_dim*sizeof(float));
std::memset(out_val, 0, head_dim*sizeof(float));
for (int col = 0; col < head_dim; ++col) {
const float k_col = k_t[col];
const float q_col = q_t[col];
for (int row = 0; row < head_dim; ++row) {
const float s = state[row + col * head_dim];
v_prime += s * k_col;
out_val += s * q_col;
v_prime[row] += s * k_col;
out_val[row] += s * q_col;
}
const float v_new = v_t[row] * beta_val - v_prime * beta_val * decay * k_norm_inv;
}
for (int row = 0; row < head_dim; ++row) {
const float v_new = v_t[row] * beta_val - v_prime[row] * beta_val * decay * k_norm_inv;
v_new_buf[row] = v_new;
out_t[row] = out_val * decay * q_norm_inv * scale + v_new * attn_score;
out_t[row] = out_val[row] * decay * q_norm_inv * scale + v_new * attn_score;
}
#ifdef __AVX2__
auto vd = _mm256_set1_ps(decay);
auto vmin = _mm256_set1_ps(-1e6f);
auto vmax = _mm256_set1_ps( 1e6f);
for (int col = 0; col < head_dim; ++col) {
auto vk = _mm256_set1_ps(k_t[col] * k_norm_inv);
for (int row = 0; row < head_dim; row += 8) {
auto vs = _mm256_loadu_ps(state + col * head_dim + row);
auto vn = _mm256_loadu_ps(v_new_buf + row);
vs = _mm256_fmadd_ps(vn, vk, _mm256_mul_ps(vs, vd));
auto mask_l = _mm256_cmp_ps(vs, vmin, _CMP_LT_OQ);
auto mask_u = _mm256_cmp_ps(vs, vmax, _CMP_GT_OQ);
vs = _mm256_or_ps(_mm256_and_ps(mask_l, vmin), _mm256_andnot_ps(mask_l, vs));
vs = _mm256_or_ps(_mm256_and_ps(mask_u, vmax), _mm256_andnot_ps(mask_u, vs));
_mm256_storeu_ps(state + col * head_dim + row, vs);
}
}
#else
for (int col = 0; col < head_dim; ++col) {
const float k_col = k_t[col] * k_norm_inv;
for (int row = 0; row < head_dim; ++row) {
@@ -1468,6 +1510,7 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f);
}
}
#endif
}
}
}