diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 4fac11e3..de302a1f 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1510,8 +1510,12 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs, const float scale = 1.0f / sqrtf((float) head_dim); +#ifdef __AVX512F__ + __m512 v_prime[head_dim/16], out_val[head_dim/16]; +#else float v_new_buf[head_dim]; float v_prime[head_dim], out_val[head_dim]; +#endif for (int h_idx = h_start; h_idx < h_end; ++h_idx) { const int batch_idx = h_idx / n_heads; @@ -1539,7 +1543,15 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs, const float beta_raw = beta_data[g_head_offset + t]; float kq_sum = 0.0f; -#ifdef __AVX2__ +#if defined __AVX512F__ + auto vqksum = _mm512_setzero_ps(); + for (int i = 0; i < head_dim; i += 16) { + auto vq = _mm512_loadu_ps(q_t + i); + auto vk = _mm512_loadu_ps(k_t + i); + vqksum = _mm512_fmadd_ps(vk, vq, vqksum); + } + kq_sum = _mm512_reduce_add_ps(vqksum); +#elif defined __AVX2__ auto vqksum = _mm256_setzero_ps(); for (int i = 0; i < head_dim; i += 8) { auto vq = _mm256_loadu_ps(q_t + i); @@ -1560,6 +1572,42 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs, float * out_t = out_data + out_head_offset + t * out_token_stride; +#ifdef __AVX512F__ + for (int j = 0; j < head_dim/16; ++j) { + v_prime[j] = out_val[j] = _mm512_setzero_ps(); + } + for (int col = 0; col < head_dim; ++col) { + auto k_col = _mm512_set1_ps(k_t[col]); + auto q_col = _mm512_set1_ps(q_t[col]); + for (int j = 0; j < head_dim/16; ++j) { + auto s = _mm512_loadu_ps(state + col * head_dim + 16*j); + v_prime[j] = _mm512_fmadd_ps(s, k_col, v_prime[j]); + out_val[j] = _mm512_fmadd_ps(s, q_col, out_val[j]); + } + } + auto c1 = _mm512_set1_ps(beta_val); + auto c2 = _mm512_set1_ps(beta_val*decay); + auto c3 = _mm512_set1_ps(decay*scale); + auto c4 = _mm512_set1_ps(attn_score); + for (int j = 0; j < head_dim/16; ++j) { + auto v = _mm512_loadu_ps(v_t + 16*j); + v_prime[j] = _mm512_sub_ps(_mm512_mul_ps(v, c1), _mm512_mul_ps(v_prime[j], c2)); + auto oval = _mm512_fmadd_ps(v_prime[j], c4, _mm512_mul_ps(out_val[j], c3)); + _mm512_storeu_ps(out_t + 16*j, oval); + } + auto vmin = _mm512_set1_ps(-1e6f); + auto vmax = _mm512_set1_ps( 1e6f); + auto vd = _mm512_set1_ps(decay); + for (int col = 0; col < head_dim; ++col) { + auto vk = _mm512_set1_ps(k_t[col]); + for (int j = 0; j < head_dim/16; ++j) { + auto vs = _mm512_loadu_ps(state + col * head_dim + 16*j); + vs = _mm512_fmadd_ps(v_prime[j], vk, _mm512_mul_ps(vs, vd)); + vs = _mm512_max_ps(vmin, _mm512_min_ps(vmax, vs)); + _mm512_storeu_ps(state + col * head_dim + 16*j, vs); + } + } +#else std::memset(v_prime, 0, head_dim*sizeof(float)); std::memset(out_val, 0, head_dim*sizeof(float)); for (int col = 0; col < head_dim; ++col) { @@ -1603,6 +1651,7 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs, state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f); } } +#endif #endif } } diff --git a/src/llama-delta-net.cpp b/src/llama-delta-net.cpp index bed2780a..092a2908 100644 --- a/src/llama-delta-net.cpp +++ b/src/llama-delta-net.cpp @@ -134,12 +134,12 @@ std::pair delta_net::build_fused_delta_net(ggml_co const int64_t output_size = S_v * H_v * n_tokens * n_seqs; const int64_t state_size = S_v * S_v * H_v * n_seqs; - ggml_tensor * output_tokens = ggml_view_4d(ctx0, fused_result, + auto output_tokens = ggml_view_4d(ctx0, fused_result, S_v, H_v, n_tokens, n_seqs, ggml_row_size(fused_result->type, S_v), ggml_row_size(fused_result->type, S_v * H_v), ggml_row_size(fused_result->type, S_v * H_v * n_tokens), 0); - output_tokens = ggml_cont_4d(ctx0, output_tokens, S_v, H_v, n_tokens, n_seqs); + //output_tokens = ggml_cont_4d(ctx0, output_tokens, S_v, H_v, n_tokens, n_seqs); ggml_tensor * new_state_flat = ggml_view_1d(ctx0, fused_result, state_size, output_size * ggml_element_size(fused_result));