From a02e78d5c155eaacfe9237d97e9b57a2c9baccfc Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 29 Aug 2024 10:06:29 +0300 Subject: [PATCH] WIP --- ggml/src/iqk/iqk_mul_mat.cpp | 48 +++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index f7abd243..f675a987 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6391,17 +6391,24 @@ inline void accumulate(int n, float * saux, float smax, float& M, float& S, __m5 } M = smax; } - //auto vs_all = v_expf(_mm256_sub_ps(_mm256_loadu_ps(saux), _mm256_set1_ps(M))); - //_mm256_storeu_ps(saux, vs_all); auto vs_all = v_expf(_mm512_sub_ps(_mm512_loadu_ps(saux), _mm512_set1_ps(M))); _mm512_storeu_ps(saux, vs_all); + S += _mm512_reduce_add_ps(vs_all); + //for (int j = 0; j < n; ++j) { + // if (saux[j] < -18.f) continue; // ignore anything less than 1.5e-8 - it want't change the single precision result. + // auto vs = _mm512_set1_ps(saux[j]); + // auto vr = (const ggml_half *)(v + stride_v*j); + // for (int i = 0; i < nq/16; ++i) { + // acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); + // } + //} + __m512 vec_v[nq/16]; for (int j = 0; j < n; ++j) { - S += saux[j]; - auto vs = _mm512_set1_ps(saux[j]); + if (saux[j] < -18.f) continue; // ignore anything less than 1.5e-8 - it want't change the single precision result. auto vr = (const ggml_half *)(v + stride_v*j); - for (int i = 0; i < nq/16; ++i) { - acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); - } + for (int i = 0; i < nq/16; ++i) vec_v[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)); + auto vs = _mm512_set1_ps(saux[j]); + for (int i = 0; i < nq/16; ++i) acc[i] = _mm512_fmadd_ps(vs, vec_v[i], acc[i]); } } @@ -6420,8 +6427,14 @@ void flash_attn_T(int nk, // number of rows in k __m512 vq[nq/16]; __m512 acc[nq/16]; float saux[kNchunk]; + //float mv32[kNchunk]; for (int i = 0; i < nq/16; ++i) vq[i] = _mm512_loadu_ps(q + 16*i); const ggml_half * mp = mask ? (const ggml_half *)mask : nullptr; + //if (!mp) std::memset(mv32, 0, kNchunk*sizeof(float)); + //else { + // auto vmask = _mm512_mul_ps(_mm512_set1_ps(slope), _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp))); + // _mm512_storeu_ps(mv32, vmask); + //} float M = -INFINITY; float S = 0; float smax = -INFINITY; @@ -6429,12 +6442,23 @@ void flash_attn_T(int nk, // number of rows in k int ik = 0; for (; ik < nk; ++ik) { if (ik - last_ik == kNchunk) { - accumulate(kNchunk, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v); + //accumulate(kNchunk, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v); + if (smax != -INFINITY) { + accumulate(kNchunk, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v); + } last_ik = ik; smax = -INFINITY; + //if (ik + kNchunk <= nk) { + // auto vmask = _mm512_mul_ps(_mm512_set1_ps(slope), _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp))); + // _mm512_storeu_ps(mv32, vmask); + //} } + //const float mv = mv32[ik - last_ik]; const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ik]) : 0.0f; - if (mv == -INFINITY) break; + if (mv == -INFINITY) { + saux[ik - last_ik] = -INFINITY; + continue; + } const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik); auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr))); for (int i = 1; i < nq/16; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum); @@ -6442,8 +6466,10 @@ void flash_attn_T(int nk, // number of rows in k saux[ik - last_ik] = s; smax = std::max(smax, s); } - if (ik > last_ik) { - accumulate(ik - last_ik, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v); + int n_left = ik - last_ik; + if (n_left > 0 & smax != -INFINITY) { + for (int j = n_left; j < kNchunk; ++j) saux[j] = -INFINITY; + accumulate(n_left, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v); } if (S > 0) { auto norm = _mm512_set1_ps(1/S);