diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index ad563f36..72684f6f 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6789,6 +6789,7 @@ void iqk_flash_helper_3(int ne00, __m512 vk[16]; __m512 vms[q_step]; __m512 vals[k_step/16]; + float qkv_cache[128*q_step]; bool need_scaling[q_step]; auto vscale = _mm512_set1_ps(scale); for (int i1 = 0; i1 < nq1/q_step; ++i1) { @@ -6814,7 +6815,8 @@ void iqk_flash_helper_3(int ne00, } } for (int j = 0; j < q_step; ++j) { - auto R = qkv + (q_step*i1 + j)*stride_qkv; + auto R = qkv_cache + 128*j; + //auto R = qkv + (q_step*i1 + j)*stride_qkv; for (int l = 0; l < k_step/16; ++l) vals[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l)); auto smax = _mm512_reduce_max_ps(_mm512_max_ps(vals[0], vals[1])); //auto smax = _mm512_reduce_max_ps(_mm512_max_ps(_mm512_max_ps(vals[0], vals[1]), _mm512_max_ps(vals[2], vals[3]))); @@ -6854,7 +6856,8 @@ void iqk_flash_helper_3(int ne00, } for (int i = 0; i < 8; i += 2) { for (int j = 0; j < q_step; ++j) { - auto R = qkv + (q_step*i1 + j)*stride_qkv; + //auto R = qkv + (q_step*i1 + j)*stride_qkv; + auto R = qkv_cache + 128*j; vk[2*j+0] = _mm512_loadu_ps(R + 16*i); vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16); if (need_scaling[j]) { @@ -6873,83 +6876,21 @@ void iqk_flash_helper_3(int ne00, } } for (int j = 0; j < q_step; ++j) { - auto R = qkv + (q_step*i1 + j)*stride_qkv; + auto R = qkv_cache + 128*j; + //auto R = qkv + (q_step*i1 + j)*stride_qkv; _mm512_storeu_ps(R + 16*i, vk[2*j+0]); _mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]); } } - //for (int i = 0; i < 8; ++i) { - // for (int j = 0; j < q_step; ++j) { - // auto R = qkv + (q_step*i1 + j)*stride_qkv; - // vk[j] = _mm512_loadu_ps(R + 16*i); - // if (need_scaling[j]) vk[j] = _mm512_mul_ps(vk[j], vms[j]); - // } - // for (int l1 = 0; l1 < k_step; ++l1) { - // auto vr = (const ggml_half *)((const char *)v + (k_step*k1 + l1)*stride_v); - // auto v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i)); - // for (int j = 0; j < q_step; ++j) { - // vk[j] = _mm512_fmadd_ps(v, _mm512_set1_ps(cache[k_step*j + l1]), vk[j]); - // } - // } - // for (int j = 0; j < q_step; ++j) { - // auto R = qkv + (q_step*i1 + j)*stride_qkv; - // _mm512_storeu_ps(R + 16*i, vk[j]); - // } - //} - // - //for (int j = 0; j < q_step; ++j) { - // auto R = qkv + (q_step*i1 + j)*stride_qkv; - // //auto val1 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j)); - // //auto val2 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16)); - // //auto smax = _mm512_reduce_max_ps(_mm512_max_ps(val1, val2)); - // auto val = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j)); - // auto smax = _mm512_reduce_max_ps(val); - // for (int i = 0; i < 8; ++i) vk[i] = _mm512_loadu_ps(R + 16*i); - // if (smax > M[j]) { - // if (M[j] > -INFINITY) { - // float m = expf(M[j] - smax); - // auto vm = _mm512_set1_ps(m); - // for (int i = 0; i < 8; ++i) { - // vk[i] = _mm512_mul_ps(vm, vk[i]); - // } - // S[j] *= m; - // } else { - // for (int i = 0; i < 8; ++i) vk[i] = _mm512_setzero_ps(); - // S[j] = 0; - // } - // M[j] = smax; - // } - // //auto vm = _mm512_set1_ps(M[j]); - // //val1 = v_expf(_mm512_sub_ps(val1, vm)); - // //val2 = v_expf(_mm512_sub_ps(val2, vm)); - // //S[j] += _mm512_reduce_add_ps(_mm512_add_ps(val1, val2)); - // //_mm512_storeu_ps(cache + k_step*j, val1); - // //_mm512_storeu_ps(cache + k_step*j + 16, val2); - // val = v_expf(_mm512_sub_ps(val, _mm512_set1_ps(M[j]))); - // S[j] += _mm512_reduce_add_ps(val); - // _mm512_storeu_ps(cache + k_step*j, val); - // for (int l1 = 0; l1 < k_step; ++l1) { - // if (cache[k_step*j + l1] < -20.0f) continue; - // auto vr = (const ggml_half *)((const char *)v + (k_step*k1 + l1)*stride_v); - // auto vs = _mm512_set1_ps(cache[k_step*j + l1]); - // for (int i = 0; i < 8; ++i) { - // vk[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i)), vk[i]); - // } - // } - // for (int i = 0; i < 8; ++i) _mm512_storeu_ps(R + 16*i, vk[i]); - //} } for (int j = 0; j < q_step; ++j) { - auto R = qkv + (q_step*i1 + j)*stride_qkv; GGML_ASSERT(S[j] > 0); - if (S[j] > 0) { - auto norm = _mm512_set1_ps(1/S[j]); - for (int i = 0; i < 8; ++i) { - auto r = _mm512_loadu_ps(R + 16*i); - _mm512_storeu_ps(R + 16*i, _mm512_mul_ps(norm, r)); - } - } else { - std::memset(R, 0, 128*sizeof(float)); + auto R = qkv_cache + 128*j; + auto final_R = qkv + (q_step*i1 + j)*stride_qkv; + auto norm = _mm512_set1_ps(1/S[j]); + for (int i = 0; i < 8; ++i) { + auto r = _mm512_loadu_ps(R + 16*i); + _mm512_storeu_ps(final_R + 16*i, _mm512_mul_ps(norm, r)); } } }