From cfee1b68ec6ba4eeeffb4f41a194011664bb09db Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 24 Aug 2024 14:31:13 +0300 Subject: [PATCH] WIP: plugging into ggml_compute_forward_flash_attn_ext_f16 OK, if we take into account that the mask is diagonal and skip further computations once we encounter -INFINITY, we can speed it up and make it on par with no-FA. Better than nothing, but still no luck. --- ggml/src/iqk/iqk_mul_mat.cpp | 121 ++++++++++++++++++++++------------- 1 file changed, 77 insertions(+), 44 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 78f1e97e..22deb56a 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6394,63 +6394,96 @@ void iqk_flash_helper_2(int nq, // number of elements in q float M = -INFINITY; float S = 0; int ik = 0; - if (nk >= 8) { - float s8[8]; - for (int ik8 = 0; ik8 < nk/8; ++ik8) { - float smax = -INFINITY; - for (int j = 0; j < 8; ++j) { - const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*(8*ik8 + j)); - auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr))); - for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum); - float s = scale * _mm512_reduce_add_ps(sum); - if (mp) s += slope*GGML_FP16_TO_FP32(mp[8*ik8+j]); - s8[j] = s; - smax = std::max(smax, s); - } - if (smax > M) { - float ms = expf(M - smax); - auto scale = _mm512_set1_ps(ms); - for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(acc[i], scale); - S *= ms; - M = smax; - } - _mm256_storeu_ps(s8, v_expf(_mm256_sub_ps(_mm256_loadu_ps(s8), _mm256_set1_ps(M)))); - for (int j = 0; j < 8; ++j) { - const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*(8*ik8+j)); - auto vs = _mm512_set1_ps(s8[j]); - S += s8[j]; - for (int i = 0; i < 8; ++i) { - acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); - } - } - } - ik = 8*(nk/8); - } + //if (nk >= 8) { + // float s8[8]; + // for (int ik8 = 0; ik8 < nk/8; ++ik8) { + // float smax = -INFINITY; + // for (int j = 0; j < 8; ++j) { + // const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[8*ik+j]) : 0.0f; + // if (mv == -INFINITY) { + // s8[j] = -INFINITY; continue; + // } + // const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*(8*ik8 + j)); + // auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr))); + // for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum); + // float s = scale * _mm512_reduce_add_ps(sum) + mv; + // s8[j] = s; + // smax = std::max(smax, s); + // } + // if (smax > M) { + // if (M > -INFINITY) { + // float ms = expf(M - smax); + // auto scale = _mm512_set1_ps(ms); + // for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(acc[i], scale); + // S *= ms; + // } else { + // for (int i = 0; i < 8; ++i) acc[i] = _mm512_setzero_ps(); + // } + // M = smax; + // } + // if (smax == -INFINITY) break; + // _mm256_storeu_ps(s8, v_expf(_mm256_sub_ps(_mm256_loadu_ps(s8), _mm256_set1_ps(M)))); + // for (int j = 0; j < 8; ++j) { + // if (s8[j] <= 0.0f) continue; + // const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*(8*ik8+j)); + // auto vs = _mm512_set1_ps(s8[j]); + // S += s8[j]; + // for (int i = 0; i < 8; ++i) { + // acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); + // } + // } + // } + // ik = 8*(nk/8); + //} + //int last_i = -1; for (; ik < nk; ++ik) { const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ik]) : 0.0f; - if (mv == -INFINITY) continue; + if (mv == -INFINITY) break; //continue; + //last_i = ik; const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik); const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*ik); + //auto sum1 = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr+0))); + //auto sum2 = _mm512_mul_ps(vq[1], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr+1))); + //for (int i = 2; i < 8; i += 2) { + // sum1 = _mm512_fmadd_ps(vq[i+0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i + 0)), sum1); + // sum2 = _mm512_fmadd_ps(vq[i+1], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i + 1)), sum2); + //} + //float s = scale * _mm512_reduce_add_ps(_mm512_add_ps(sum1, sum2)) + mv; auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr))); for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum); float s = scale * _mm512_reduce_add_ps(sum) + mv; if (s > M) { - float ms = expf(M - s); - auto vms = _mm512_set1_ps(ms); - for (int i = 0; i < 8; ++i) { - acc[i] = _mm512_fmadd_ps(vms, acc[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i))); + if (M > -INFINITY) { + //if (M - s > -20.f) { + float ms = expf(M - s); + auto vms = _mm512_set1_ps(ms); + for (int i = 0; i < 8; ++i) { + acc[i] = _mm512_fmadd_ps(vms, acc[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i))); + } + S = ms*S + 1; + } else { + for (int i = 0; i < 8; ++i) { + acc[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)); + } + S = 1; } - S = ms*S + 1; M = s; } else { - float vs = expf(s - M); - auto vvs = _mm512_set1_ps(vs); - for (int i = 0; i < 8; ++i) { - acc[i] = _mm512_fmadd_ps(vvs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); - } - S += vs; + //if (s - M > -20.f) { + float vs = expf(s - M); + auto vvs = _mm512_set1_ps(vs); + for (int i = 0; i < 8; ++i) { + acc[i] = _mm512_fmadd_ps(vvs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); + } + S += vs; + //} } } + //if (last_i < 0) { + // std::memset(qkv, 0, 128*sizeof(float)); + // return; + //} + //printf("%s: nk = %d, last_i = %d\n", __func__, nk, last_i); auto norm = _mm512_set1_ps(1/S); for (int i = 0; i < 8; ++i) _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, acc[i])); return;