From e4959f9e46b85aa63ec3e5efd8720fc3ce906395 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 30 Aug 2024 08:42:34 +0300 Subject: [PATCH] Experimenting with flash attention on Zen4 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This version outperforms no-FA up to 16k tokens, but still becomes slower at 32k. Here the t/s for LLaMA-3.1-8B on a Ryzen-7950X | test | t/s no FA | Georgi FA | This commit FA | | --------: | ---------------: | -------------: | --------------: | | pp256 | 193.46 ± 2.40 | 193.19 ± 5.07 | 197.73 ± 0.72 | | pp512 | 192.23 ± 1.83 | 188.14 ± 0.63 | 194.38 ± 0.69 | | pp1024 | 189.06 ± 0.72 | 170.81 ± 4.82 | 191.12 ± 1.47 | | pp2048 | 181.92 ± 1.21 | 140.36 ± 1.77 | 184.57 ± 1.20 | | pp4096 | 165.10 ± 0.95 | 117.50 ± 0.35 | 168.79 ± 0.50 | | pp8192 | 137.48 ± 0.75 | 68.54 ± 1.00 | 148.21 ± 0.64 | | pp16384 | 100.35 ± 0.93 | | 105.14 ± 0.00 | | pp32768 | 64.44 | | 57.36 | Didn't have the patience to run Georgi's FA at 16k tokens. No error estimate on the 32k result as I only ran 1 sample. --- ggml/src/iqk/iqk_mul_mat.cpp | 111 ++++++++++++++++++----------------- 1 file changed, 58 insertions(+), 53 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index f966fb04..ad563f36 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6772,7 +6772,6 @@ void iqk_flash_helper_3(int ne00, float * qkv) { constexpr int q_step = 8; constexpr int k_step = 32; //16; - //if (nq1%q_step != 0 || nk1%k_step != 0) printf("Oops(%s): nq1 = %d, nk1 = %d\n", __func__, nq1, nk1); if (nq1%q_step != 0 || nk1%k_step != 0) { for (int iq1 = 0; iq1 < nq1; ++iq1) { iqk_flash_helper_2(false, ne00, nk1, stride_k, stride_v, @@ -6784,49 +6783,19 @@ void iqk_flash_helper_3(int ne00, return; } stride_q /= sizeof(float); - // The following works - //for (int iq1 = 0; iq1 < nq1; ++iq1) { - // iqk_flash_helper_2(false, - // ne00, - // nk1, - // stride_k, - // stride_v, - // q, - // k, - // v, - // (const void *)((const char *)mask + iq1*stride_m), - // scale, - // 1.0f, - // nullptr, - // qkv); - // q += stride_q; - // qkv += stride_qkv; - //} const ggml_half h_inf = GGML_FP32_TO_FP16(-INFINITY); float cache[q_step*k_step]; float S[q_step], M[q_step]; - __m512 vk[8]; + __m512 vk[16]; __m512 vms[q_step]; + __m512 vals[k_step/16]; bool need_scaling[q_step]; auto vscale = _mm512_set1_ps(scale); for (int i1 = 0; i1 < nq1/q_step; ++i1) { for (int j = 0; j < q_step; ++j) { - //auto R = qkv + (q_step*i1 + j)*stride_qkv; - //std::memset(R, 0, 128*sizeof(float)); S[j] = 0; M[j] = -INFINITY; } for (int k1 = 0; k1 < nk1/k_step; ++k1) { - /////////////////////////////////////////////////////////////////////////////////// - //const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*16*k1); - //DataInfo info{cache, (const char *)(q + 16*i1*stride_q), 16*sizeof(float), size_t(stride_q)*sizeof(float), 0, 0, nullptr, 0}; - //mul_mat_fX_fY_T<4, ggml_half, float>(ne00, (const void *)kr, stride_k, info, 16); - /////////////////////////////////////////////////////////////////////////////////// - //for (int l1 = 0; l1 < 16; ++l1) { - // for (int m1 = 0; m1 < 16; ++m1) { - // const ggml_half * mp = (const ggml_half *)((const char *)mask + stride_m*(16*i1 + m1)) + 16*k1; - // cache[16*m1 + l1] = GGML_FP16_TO_FP32(mp[l1]) == -INFINITY ? -INFINITY : scale*cache[16*m1 + l1]; - // } - //} for (int l1 = 0; l1 < k_step; ++l1) { auto kr = (const ggml_half *)((const char *)k + (k_step*k1 + l1)*stride_k); for (int i = 0; i < 8; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)); @@ -6844,14 +6813,16 @@ void iqk_flash_helper_3(int ne00, cache[k_step*m1 + l1] = _mm512_reduce_add_ps(vsum); } } - // This variant is much slower than the one below for (int j = 0; j < q_step; ++j) { auto R = qkv + (q_step*i1 + j)*stride_qkv; - auto val1 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j)); - auto val2 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16)); - auto smax = _mm512_reduce_max_ps(_mm512_max_ps(val1, val2)); - //auto val = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j)); - //auto smax = _mm512_reduce_max_ps(val); + for (int l = 0; l < k_step/16; ++l) vals[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l)); + auto smax = _mm512_reduce_max_ps(_mm512_max_ps(vals[0], vals[1])); + //auto smax = _mm512_reduce_max_ps(_mm512_max_ps(_mm512_max_ps(vals[0], vals[1]), _mm512_max_ps(vals[2], vals[3]))); + //auto val1 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j)); + //auto val2 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16)); + //auto smax = _mm512_reduce_max_ps(_mm512_max_ps(val1, val2)); + ////auto val = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j)); + ////auto smax = _mm512_reduce_max_ps(val); need_scaling[j] = false; if (smax > M[j]) { if (M[j] > -INFINITY) { @@ -6865,33 +6836,67 @@ void iqk_flash_helper_3(int ne00, } M[j] = smax; } - val1 = v_expf(_mm512_sub_ps(val1, _mm512_set1_ps(M[j]))); - val2 = v_expf(_mm512_sub_ps(val2, _mm512_set1_ps(M[j]))); - S[j] += _mm512_reduce_add_ps(_mm512_add_ps(val1, val2)); - _mm512_storeu_ps(cache + k_step*j, val1); - _mm512_storeu_ps(cache + k_step*j + 16, val2); - //val = v_expf(_mm512_sub_ps(val, _mm512_set1_ps(M[j]))); - //S[j] += _mm512_reduce_add_ps(val); - //_mm512_storeu_ps(cache + k_step*j, val); + auto vm = _mm512_set1_ps(M[j]); + for (int l = 0; l < k_step/16; ++l) { + vals[l] = v_expf(_mm512_sub_ps(vals[l], vm)); + _mm512_storeu_ps(cache + k_step*j + 16*l, vals[l]); + } + S[j] += _mm512_reduce_add_ps(_mm512_add_ps(vals[0], vals[1])); + //S[j] += _mm512_reduce_add_ps(_mm512_add_ps(_mm512_add_ps(vals[0], vals[1]), _mm512_add_ps(vals[2], vals[3]))); + //val1 = v_expf(_mm512_sub_ps(val1, _mm512_set1_ps(M[j]))); + //val2 = v_expf(_mm512_sub_ps(val2, _mm512_set1_ps(M[j]))); + //S[j] += _mm512_reduce_add_ps(_mm512_add_ps(val1, val2)); + //_mm512_storeu_ps(cache + k_step*j, val1); + //_mm512_storeu_ps(cache + k_step*j + 16, val2); + ////val = v_expf(_mm512_sub_ps(val, _mm512_set1_ps(M[j]))); + ////S[j] += _mm512_reduce_add_ps(val); + ////_mm512_storeu_ps(cache + k_step*j, val); } - for (int i = 0; i < 8; ++i) { + for (int i = 0; i < 8; i += 2) { for (int j = 0; j < q_step; ++j) { auto R = qkv + (q_step*i1 + j)*stride_qkv; - vk[j] = _mm512_loadu_ps(R + 16*i); - if (need_scaling[j]) vk[j] = _mm512_mul_ps(vk[j], vms[j]); + vk[2*j+0] = _mm512_loadu_ps(R + 16*i); + vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16); + if (need_scaling[j]) { + vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]); + vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]); + } } for (int l1 = 0; l1 < k_step; ++l1) { auto vr = (const ggml_half *)((const char *)v + (k_step*k1 + l1)*stride_v); - auto v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i)); + auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0)); + auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1)); for (int j = 0; j < q_step; ++j) { - vk[j] = _mm512_fmadd_ps(v, _mm512_set1_ps(cache[k_step*j + l1]), vk[j]); + auto vs = _mm512_set1_ps(cache[k_step*j + l1]); + vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]); + vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]); } } for (int j = 0; j < q_step; ++j) { auto R = qkv + (q_step*i1 + j)*stride_qkv; - _mm512_storeu_ps(R + 16*i, vk[j]); + _mm512_storeu_ps(R + 16*i, vk[2*j+0]); + _mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]); } } + //for (int i = 0; i < 8; ++i) { + // for (int j = 0; j < q_step; ++j) { + // auto R = qkv + (q_step*i1 + j)*stride_qkv; + // vk[j] = _mm512_loadu_ps(R + 16*i); + // if (need_scaling[j]) vk[j] = _mm512_mul_ps(vk[j], vms[j]); + // } + // for (int l1 = 0; l1 < k_step; ++l1) { + // auto vr = (const ggml_half *)((const char *)v + (k_step*k1 + l1)*stride_v); + // auto v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i)); + // for (int j = 0; j < q_step; ++j) { + // vk[j] = _mm512_fmadd_ps(v, _mm512_set1_ps(cache[k_step*j + l1]), vk[j]); + // } + // } + // for (int j = 0; j < q_step; ++j) { + // auto R = qkv + (q_step*i1 + j)*stride_qkv; + // _mm512_storeu_ps(R + 16*i, vk[j]); + // } + //} + // //for (int j = 0; j < q_step; ++j) { // auto R = qkv + (q_step*i1 + j)*stride_qkv; // //auto val1 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j));