mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-06 12:00:29 +00:00
Fused delta-net (AVX512) (#1362)
This commit is contained in:
@@ -1510,8 +1510,12 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
|
||||
const float scale = 1.0f / sqrtf((float) head_dim);
|
||||
|
||||
#ifdef __AVX512F__
|
||||
__m512 v_prime[head_dim/16], out_val[head_dim/16];
|
||||
#else
|
||||
float v_new_buf[head_dim];
|
||||
float v_prime[head_dim], out_val[head_dim];
|
||||
#endif
|
||||
|
||||
for (int h_idx = h_start; h_idx < h_end; ++h_idx) {
|
||||
const int batch_idx = h_idx / n_heads;
|
||||
@@ -1539,7 +1543,15 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
const float beta_raw = beta_data[g_head_offset + t];
|
||||
|
||||
float kq_sum = 0.0f;
|
||||
#ifdef __AVX2__
|
||||
#if defined __AVX512F__
|
||||
auto vqksum = _mm512_setzero_ps();
|
||||
for (int i = 0; i < head_dim; i += 16) {
|
||||
auto vq = _mm512_loadu_ps(q_t + i);
|
||||
auto vk = _mm512_loadu_ps(k_t + i);
|
||||
vqksum = _mm512_fmadd_ps(vk, vq, vqksum);
|
||||
}
|
||||
kq_sum = _mm512_reduce_add_ps(vqksum);
|
||||
#elif defined __AVX2__
|
||||
auto vqksum = _mm256_setzero_ps();
|
||||
for (int i = 0; i < head_dim; i += 8) {
|
||||
auto vq = _mm256_loadu_ps(q_t + i);
|
||||
@@ -1560,6 +1572,42 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
|
||||
float * out_t = out_data + out_head_offset + t * out_token_stride;
|
||||
|
||||
#ifdef __AVX512F__
|
||||
for (int j = 0; j < head_dim/16; ++j) {
|
||||
v_prime[j] = out_val[j] = _mm512_setzero_ps();
|
||||
}
|
||||
for (int col = 0; col < head_dim; ++col) {
|
||||
auto k_col = _mm512_set1_ps(k_t[col]);
|
||||
auto q_col = _mm512_set1_ps(q_t[col]);
|
||||
for (int j = 0; j < head_dim/16; ++j) {
|
||||
auto s = _mm512_loadu_ps(state + col * head_dim + 16*j);
|
||||
v_prime[j] = _mm512_fmadd_ps(s, k_col, v_prime[j]);
|
||||
out_val[j] = _mm512_fmadd_ps(s, q_col, out_val[j]);
|
||||
}
|
||||
}
|
||||
auto c1 = _mm512_set1_ps(beta_val);
|
||||
auto c2 = _mm512_set1_ps(beta_val*decay);
|
||||
auto c3 = _mm512_set1_ps(decay*scale);
|
||||
auto c4 = _mm512_set1_ps(attn_score);
|
||||
for (int j = 0; j < head_dim/16; ++j) {
|
||||
auto v = _mm512_loadu_ps(v_t + 16*j);
|
||||
v_prime[j] = _mm512_sub_ps(_mm512_mul_ps(v, c1), _mm512_mul_ps(v_prime[j], c2));
|
||||
auto oval = _mm512_fmadd_ps(v_prime[j], c4, _mm512_mul_ps(out_val[j], c3));
|
||||
_mm512_storeu_ps(out_t + 16*j, oval);
|
||||
}
|
||||
auto vmin = _mm512_set1_ps(-1e6f);
|
||||
auto vmax = _mm512_set1_ps( 1e6f);
|
||||
auto vd = _mm512_set1_ps(decay);
|
||||
for (int col = 0; col < head_dim; ++col) {
|
||||
auto vk = _mm512_set1_ps(k_t[col]);
|
||||
for (int j = 0; j < head_dim/16; ++j) {
|
||||
auto vs = _mm512_loadu_ps(state + col * head_dim + 16*j);
|
||||
vs = _mm512_fmadd_ps(v_prime[j], vk, _mm512_mul_ps(vs, vd));
|
||||
vs = _mm512_max_ps(vmin, _mm512_min_ps(vmax, vs));
|
||||
_mm512_storeu_ps(state + col * head_dim + 16*j, vs);
|
||||
}
|
||||
}
|
||||
#else
|
||||
std::memset(v_prime, 0, head_dim*sizeof(float));
|
||||
std::memset(out_val, 0, head_dim*sizeof(float));
|
||||
for (int col = 0; col < head_dim; ++col) {
|
||||
@@ -1603,6 +1651,7 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,12 +134,12 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_co
|
||||
const int64_t output_size = S_v * H_v * n_tokens * n_seqs;
|
||||
const int64_t state_size = S_v * S_v * H_v * n_seqs;
|
||||
|
||||
ggml_tensor * output_tokens = ggml_view_4d(ctx0, fused_result,
|
||||
auto output_tokens = ggml_view_4d(ctx0, fused_result,
|
||||
S_v, H_v, n_tokens, n_seqs,
|
||||
ggml_row_size(fused_result->type, S_v),
|
||||
ggml_row_size(fused_result->type, S_v * H_v),
|
||||
ggml_row_size(fused_result->type, S_v * H_v * n_tokens), 0);
|
||||
output_tokens = ggml_cont_4d(ctx0, output_tokens, S_v, H_v, n_tokens, n_seqs);
|
||||
//output_tokens = ggml_cont_4d(ctx0, output_tokens, S_v, H_v, n_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * new_state_flat = ggml_view_1d(ctx0, fused_result, state_size,
|
||||
output_size * ggml_element_size(fused_result));
|
||||
|
||||
Reference in New Issue
Block a user