Fused delta-net (AVX512) (#1362)

This commit is contained in:
Kawrakow
2026-03-05 07:55:05 +01:00
committed by GitHub
parent 2add439e43
commit 8fb002207a
2 changed files with 52 additions and 3 deletions

View File

@@ -1510,8 +1510,12 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
const float scale = 1.0f / sqrtf((float) head_dim);
#ifdef __AVX512F__
__m512 v_prime[head_dim/16], out_val[head_dim/16];
#else
float v_new_buf[head_dim];
float v_prime[head_dim], out_val[head_dim];
#endif
for (int h_idx = h_start; h_idx < h_end; ++h_idx) {
const int batch_idx = h_idx / n_heads;
@@ -1539,7 +1543,15 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
const float beta_raw = beta_data[g_head_offset + t];
float kq_sum = 0.0f;
#ifdef __AVX2__
#if defined __AVX512F__
auto vqksum = _mm512_setzero_ps();
for (int i = 0; i < head_dim; i += 16) {
auto vq = _mm512_loadu_ps(q_t + i);
auto vk = _mm512_loadu_ps(k_t + i);
vqksum = _mm512_fmadd_ps(vk, vq, vqksum);
}
kq_sum = _mm512_reduce_add_ps(vqksum);
#elif defined __AVX2__
auto vqksum = _mm256_setzero_ps();
for (int i = 0; i < head_dim; i += 8) {
auto vq = _mm256_loadu_ps(q_t + i);
@@ -1560,6 +1572,42 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
float * out_t = out_data + out_head_offset + t * out_token_stride;
#ifdef __AVX512F__
for (int j = 0; j < head_dim/16; ++j) {
v_prime[j] = out_val[j] = _mm512_setzero_ps();
}
for (int col = 0; col < head_dim; ++col) {
auto k_col = _mm512_set1_ps(k_t[col]);
auto q_col = _mm512_set1_ps(q_t[col]);
for (int j = 0; j < head_dim/16; ++j) {
auto s = _mm512_loadu_ps(state + col * head_dim + 16*j);
v_prime[j] = _mm512_fmadd_ps(s, k_col, v_prime[j]);
out_val[j] = _mm512_fmadd_ps(s, q_col, out_val[j]);
}
}
auto c1 = _mm512_set1_ps(beta_val);
auto c2 = _mm512_set1_ps(beta_val*decay);
auto c3 = _mm512_set1_ps(decay*scale);
auto c4 = _mm512_set1_ps(attn_score);
for (int j = 0; j < head_dim/16; ++j) {
auto v = _mm512_loadu_ps(v_t + 16*j);
v_prime[j] = _mm512_sub_ps(_mm512_mul_ps(v, c1), _mm512_mul_ps(v_prime[j], c2));
auto oval = _mm512_fmadd_ps(v_prime[j], c4, _mm512_mul_ps(out_val[j], c3));
_mm512_storeu_ps(out_t + 16*j, oval);
}
auto vmin = _mm512_set1_ps(-1e6f);
auto vmax = _mm512_set1_ps( 1e6f);
auto vd = _mm512_set1_ps(decay);
for (int col = 0; col < head_dim; ++col) {
auto vk = _mm512_set1_ps(k_t[col]);
for (int j = 0; j < head_dim/16; ++j) {
auto vs = _mm512_loadu_ps(state + col * head_dim + 16*j);
vs = _mm512_fmadd_ps(v_prime[j], vk, _mm512_mul_ps(vs, vd));
vs = _mm512_max_ps(vmin, _mm512_min_ps(vmax, vs));
_mm512_storeu_ps(state + col * head_dim + 16*j, vs);
}
}
#else
std::memset(v_prime, 0, head_dim*sizeof(float));
std::memset(out_val, 0, head_dim*sizeof(float));
for (int col = 0; col < head_dim; ++col) {
@@ -1603,6 +1651,7 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f);
}
}
#endif
#endif
}
}

View File

@@ -134,12 +134,12 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_co
const int64_t output_size = S_v * H_v * n_tokens * n_seqs;
const int64_t state_size = S_v * S_v * H_v * n_seqs;
ggml_tensor * output_tokens = ggml_view_4d(ctx0, fused_result,
auto output_tokens = ggml_view_4d(ctx0, fused_result,
S_v, H_v, n_tokens, n_seqs,
ggml_row_size(fused_result->type, S_v),
ggml_row_size(fused_result->type, S_v * H_v),
ggml_row_size(fused_result->type, S_v * H_v * n_tokens), 0);
output_tokens = ggml_cont_4d(ctx0, output_tokens, S_v, H_v, n_tokens, n_seqs);
//output_tokens = ggml_cont_4d(ctx0, output_tokens, S_v, H_v, n_tokens, n_seqs);
ggml_tensor * new_state_flat = ggml_view_1d(ctx0, fused_result, state_size,
output_size * ggml_element_size(fused_result));