diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 0b0537f2..c6bd17b0 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1393,10 +1393,15 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs, const int h_start = ith * heads_per_thread; const int h_end = (h_start + heads_per_thread < total_heads) ? h_start + heads_per_thread : total_heads; +#ifdef __AVX2__ + static_assert(head_dim % 8 == 0); +#endif + const float eps = 1e-12f; const float scale = 1.0f / sqrtf((float) head_dim); float v_new_buf[head_dim]; + float v_prime[head_dim], out_val[head_dim]; for (int h_idx = h_start; h_idx < h_end; ++h_idx) { const int batch_idx = h_idx / n_heads; @@ -1425,41 +1430,78 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs, float q_norm_sq = 0.0f; float k_norm_sq = 0.0f; + float kq_sum = 0.0f; +#ifdef __AVX2__ + auto vqsum = _mm256_setzero_ps(); + auto vksum = _mm256_setzero_ps(); + auto vqksum = _mm256_setzero_ps(); + for (int i = 0; i < head_dim; i += 8) { + auto vq = _mm256_loadu_ps(q_t + i); + auto vk = _mm256_loadu_ps(k_t + i); + vqsum = _mm256_fmadd_ps(vq, vq, vqsum); + vksum = _mm256_fmadd_ps(vk, vk, vksum); + vqksum = _mm256_fmadd_ps(vk, vq, vqksum); + } + q_norm_sq = hsum_float_8(vqsum); + k_norm_sq = hsum_float_8(vksum); + kq_sum = hsum_float_8(vqksum); +#else for (int i = 0; i < head_dim; ++i) { q_norm_sq += q_t[i] * q_t[i]; k_norm_sq += k_t[i] * k_t[i]; + kq_sum += k_t[i] * q_t[i]; } +#endif const float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps); const float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps); const float beta_val = 1.0f / (1.0f + expf(-beta_raw)); const float decay = expf(fminf(g_val, 50.0f)); - float attn_score = 0.0f; - for (int i = 0; i < head_dim; ++i) { - attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale); - } + float attn_score = kq_sum * k_norm_inv * q_norm_inv * scale; + + //float attn_score = 0.0f; + //for (int i = 0; i < head_dim; ++i) { + // attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale); + //} float * out_t = out_data + out_head_offset + t * out_token_stride; - for (int row = 0; row < head_dim; ++row) { - float v_prime = 0.0f; - float out_val = 0.0f; - - for (int col = 0; col < head_dim; ++col) { - const float k_col = k_t[col]; - const float q_col = q_t[col]; + std::memset(v_prime, 0, head_dim*sizeof(float)); + std::memset(out_val, 0, head_dim*sizeof(float)); + for (int col = 0; col < head_dim; ++col) { + const float k_col = k_t[col]; + const float q_col = q_t[col]; + for (int row = 0; row < head_dim; ++row) { const float s = state[row + col * head_dim]; - - v_prime += s * k_col; - out_val += s * q_col; + v_prime[row] += s * k_col; + out_val[row] += s * q_col; } - - const float v_new = v_t[row] * beta_val - v_prime * beta_val * decay * k_norm_inv; + } + for (int row = 0; row < head_dim; ++row) { + const float v_new = v_t[row] * beta_val - v_prime[row] * beta_val * decay * k_norm_inv; v_new_buf[row] = v_new; - out_t[row] = out_val * decay * q_norm_inv * scale + v_new * attn_score; + out_t[row] = out_val[row] * decay * q_norm_inv * scale + v_new * attn_score; } +#ifdef __AVX2__ + auto vd = _mm256_set1_ps(decay); + auto vmin = _mm256_set1_ps(-1e6f); + auto vmax = _mm256_set1_ps( 1e6f); + for (int col = 0; col < head_dim; ++col) { + auto vk = _mm256_set1_ps(k_t[col] * k_norm_inv); + for (int row = 0; row < head_dim; row += 8) { + auto vs = _mm256_loadu_ps(state + col * head_dim + row); + auto vn = _mm256_loadu_ps(v_new_buf + row); + vs = _mm256_fmadd_ps(vn, vk, _mm256_mul_ps(vs, vd)); + auto mask_l = _mm256_cmp_ps(vs, vmin, _CMP_LT_OQ); + auto mask_u = _mm256_cmp_ps(vs, vmax, _CMP_GT_OQ); + vs = _mm256_or_ps(_mm256_and_ps(mask_l, vmin), _mm256_andnot_ps(mask_l, vs)); + vs = _mm256_or_ps(_mm256_and_ps(mask_u, vmax), _mm256_andnot_ps(mask_u, vs)); + _mm256_storeu_ps(state + col * head_dim + row, vs); + } + } +#else for (int col = 0; col < head_dim; ++col) { const float k_col = k_t[col] * k_norm_inv; for (int row = 0; row < head_dim; ++row) { @@ -1468,6 +1510,7 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs, state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f); } } +#endif } } }