From b5df88b12052f7c690f51e5e2978fa156790475e Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 29 Aug 2024 17:52:56 +0300 Subject: [PATCH] Experimenting with flash attention on Zen4 This variant is better for long contexts but still not as good as no FA. --- ggml/src/ggml.c | 21 ++- ggml/src/iqk/iqk_mul_mat.cpp | 299 +++++++++++++++++++++++++++++++++-- ggml/src/iqk/iqk_mul_mat.h | 18 ++- 3 files changed, 325 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 1b4656af..6fea3b22 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -16217,6 +16217,25 @@ static void ggml_compute_forward_flash_attn_ext_f16( memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + if (nr%nth == 0 && max_bias <= 0.0f && q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16 && + mask && mask->type == GGML_TYPE_F16) { + int counter = 0; + for (int64_t iq3 = 0; iq3 < neq3; iq3++) { + for (int64_t iq2 = 0; iq2 < neq2; iq2++) { + if (counter++ % nth == ith) { + iqk_flash_helper_3(D, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), + (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3]), + (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), + (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), + (const void *)((const char *)mask->data), + scale, + (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2)*nb1)); // + iq1*ne1)*nb1)) + } + } + } + return; + } + const uint32_t n_head = neq2; const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); @@ -16296,7 +16315,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( // } //} - iqk_flash_helper_2(D, nek1, nbk1, nbv1, + iqk_flash_helper_2(max_bias > 0, D, nek1, nbk1, nbv1, (const float *)((char *) q->data + iq1*nbq1 + iq2*nbq2 + iq3*nbq3), (const void *)((char *) k->data + ik2*nbk2 + ik3*nbk3), (const void *)((char *) v->data + iv2*nbv2 + iv3*nbv3), diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index f675a987..2db82fa7 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6377,6 +6377,7 @@ void iqk_flash_helper(int nq, // number of elements in q } namespace { + template inline void accumulate(int n, float * saux, float smax, float& M, float& S, __m512 * acc, const char * v, int stride_v) { if (smax > M) { @@ -6412,6 +6413,91 @@ inline void accumulate(int n, float * saux, float smax, float& M, float& S, __m5 } } +template +inline void accumulate(int n, float scale, float * saux, float smax, float& M, float& S, __m512 * acc, const char * v, int stride_v) { + smax *= scale; + if (smax > M) { + if (M > -INFINITY) { + float ms = expf(M - smax); + auto vms = _mm512_set1_ps(ms); + for (int i = 0; i < nq/16; ++i) acc[i] = _mm512_mul_ps(vms, acc[i]); + S *= ms; + } else { + for (int i = 0; i < nq/16; ++i) acc[i] = _mm512_setzero_ps(); + S = 0; + } + M = smax; + } + auto vs_all = v_expf(_mm512_fmsub_ps(_mm512_set1_ps(scale), _mm512_loadu_ps(saux), _mm512_set1_ps(M))); + _mm512_storeu_ps(saux, vs_all); + S += _mm512_reduce_add_ps(vs_all); + for (int j = 0; j < n; ++j) { + if (saux[j] < -18.f) continue; // ignore anything less than 1.5e-8 - it want't change the single precision result. + auto vs = _mm512_set1_ps(saux[j]); + auto vr = (const ggml_half *)(v + stride_v*j); + for (int i = 0; i < nq/16; ++i) { + acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); + } + } +} + +template +void flash_attn_noalibi_T(int nk, // number of rows in k + int stride_k, // distance between rows in k (in bytes) + int stride_v, + const float * q, // q vector + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * v, // k matrix. Assumed to be fp16, nq x nk elements + const void * mask, // mask. If not null, assumed to be fp16. nk elements + float scale, + float * qkv) { + GGML_ASSERT(mask); + constexpr int kNchunk = 16; + __m512 vq[nq/16]; + __m512 acc[nq/16]; + float saux[kNchunk]; + for (int i = 0; i < nq/16; ++i) vq[i] = _mm512_loadu_ps(q + 16*i); + const ggml_half * mp = (const ggml_half *)mask; + ggml_half h_inf = GGML_FP32_TO_FP16(-INFINITY); + float M = -INFINITY; + float S = 0; + float smax = -INFINITY; + int last_ik = 0; + int ik = 0; + for (; ik < nk; ++ik) { + if (ik - last_ik == kNchunk) { + if (smax != -INFINITY) { + accumulate(kNchunk, scale, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v); + } + last_ik = ik; + smax = -INFINITY; + } + if (mp[ik] == h_inf) { + saux[ik - last_ik] = -INFINITY; + continue; + } + const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik); + auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr))); + for (int i = 1; i < nq/16; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum); + float s = _mm512_reduce_add_ps(sum); + saux[ik - last_ik] = s; + smax = std::max(smax, s); + } + int n_left = ik - last_ik; + if (n_left > 0 && smax != -INFINITY) { + for (int j = n_left; j < kNchunk; ++j) saux[j] = -INFINITY; + accumulate(n_left, scale, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v); + } + if (S > 0) { + auto norm = _mm512_set1_ps(1/S); + for (int i = 0; i < nq/16; ++i) _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, acc[i])); + } else { + printf("Oops: S = %g. ik = %d, last_ik = %d, nk = %d, M = %g\n", S, ik, last_ik, nk, M); + GGML_ASSERT(false); + std::memset(qkv, 0, 128*sizeof(float)); + } +} + template void flash_attn_T(int nk, // number of rows in k int stride_k, // distance between rows in k (in bytes) @@ -6467,7 +6553,7 @@ void flash_attn_T(int nk, // number of rows in k smax = std::max(smax, s); } int n_left = ik - last_ik; - if (n_left > 0 & smax != -INFINITY) { + if (n_left > 0 && smax != -INFINITY) { for (int j = n_left; j < kNchunk; ++j) saux[j] = -INFINITY; accumulate(n_left, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v); } @@ -6482,7 +6568,8 @@ void flash_attn_T(int nk, // number of rows in k } } -void iqk_flash_helper_2(int nq, // number of elements in q +void iqk_flash_helper_2(bool is_alibi, + int nq, // number of elements in q int nk, // number of rows in k int stride_k, // distance between rows in k (in bytes) int stride_v, @@ -6496,15 +6583,26 @@ void iqk_flash_helper_2(int nq, // number of elements in q float * qkv) { GGML_ASSERT(nq % 4 == 0); - switch (nq) { - case 64: flash_attn_T< 64>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return; - case 80: flash_attn_T< 80>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return; - case 96: flash_attn_T< 96>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return; - case 112: flash_attn_T<112>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return; - case 128: flash_attn_T<128>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return; - case 256: flash_attn_T<256>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return; - default: break; - //default: GGML_ABORT("unhandled head size -> fatal error"); + if (is_alibi) { + switch (nq) { + case 64: flash_attn_T< 64>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return; + case 80: flash_attn_T< 80>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return; + case 96: flash_attn_T< 96>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return; + case 112: flash_attn_T<112>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return; + case 128: flash_attn_T<128>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return; + case 256: flash_attn_T<256>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return; + default: break; + } + } else { + switch (nq) { + case 64: flash_attn_noalibi_T< 64>(nk, stride_k, stride_v, q, k, v, mask, scale, qkv); return; + case 80: flash_attn_noalibi_T< 80>(nk, stride_k, stride_v, q, k, v, mask, scale, qkv); return; + case 96: flash_attn_noalibi_T< 96>(nk, stride_k, stride_v, q, k, v, mask, scale, qkv); return; + case 112: flash_attn_noalibi_T<112>(nk, stride_k, stride_v, q, k, v, mask, scale, qkv); return; + case 128: flash_attn_noalibi_T<128>(nk, stride_k, stride_v, q, k, v, mask, scale, qkv); return; + case 256: flash_attn_noalibi_T<256>(nk, stride_k, stride_v, q, k, v, mask, scale, qkv); return; + default: break; + } } if (mask) { @@ -6658,6 +6756,185 @@ bool iqk_soft_max_noalibi(int nc, int ir0, int ir1, int ne00, int ne01, return true; } +void iqk_flash_helper_3(int ne00, + int nq1, // number of elements in q + int nk1, // number of rows in k + int stride_q, + int stride_k, // distance between rows in k (in bytes) + int stride_v, // distance between rows in v (in bytes) + int stride_m, // distance between rows in mask (in bytes) + int stride_qkv, // distance between rows in mask (in bytes) + const float * q, // q vector + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * v, + const void * mask, // mask. If not null, assumed to be fp16. nk elements + float scale, + float * qkv) { + stride_q /= sizeof(float); + // The following works + //for (int iq1 = 0; iq1 < nq1; ++iq1) { + // iqk_flash_helper_2(false, + // ne00, + // nk1, + // stride_k, + // stride_v, + // q, + // k, + // v, + // (const void *)((const char *)mask + iq1*stride_m), + // scale, + // 1.0f, + // nullptr, + // qkv); + // q += stride_q; + // qkv += stride_qkv; + //} + float cache[256]; + float S[16], M[16]; + __m512 vk[8]; + for (int i1 = 0; i1 < nq1/16; ++i1) { + for (int j = 0; j < 16; ++j) { + auto R = qkv + (16*i1 + j)*stride_qkv; + std::memset(R, 0, 128*sizeof(float)); + S[j] = 0; M[j] = -INFINITY; + } + for (int k1 = 0; k1 < nk1/16; ++k1) { + for (int l1 = 0; l1 < 16; ++l1) { + auto kr = (const ggml_half *)((const char *)k + (16*k1 + l1)*stride_k); + for (int i = 0; i < 8; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)); + for (int m1 = 0; m1 < 16; ++m1) { + // q index is 16*i1 + m1 + // k index is 16*k1 + l1 + const ggml_half * mp = (const ggml_half *)((const char *)mask + stride_m*(16*i1 + m1)) + 16*k1; + if (GGML_FP16_TO_FP32(mp[l1]) == -INFINITY) { + cache[16*m1 + l1] = -INFINITY; + continue; + } + auto qr = q + (16*i1 + m1)*stride_q; + auto vsum = _mm512_mul_ps(vk[0], _mm512_loadu_ps(qr)); + for (int i = 1; i < 8; ++i) vsum = _mm512_fmadd_ps(vk[i], _mm512_loadu_ps(qr + 16*i), vsum); + cache[16*m1 + l1] = scale*_mm512_reduce_add_ps(vsum); + } + } + for (int j = 0; j < 16; ++j) { + auto R = qkv + (16*i1 + j)*stride_qkv; + auto val = _mm512_loadu_ps(cache + 16*j); + auto smax = _mm512_reduce_max_ps(val); + for (int i = 0; i < 8; ++i) vk[i] = _mm512_loadu_ps(R + 16*i); + if (smax > M[j]) { + if (M[j] > -INFINITY) { + float m = expf(M[j] - smax); + auto vm = _mm512_set1_ps(m); + for (int i = 0; i < 8; ++i) { + vk[i] = _mm512_mul_ps(vm, vk[i]); + //auto r = _mm512_loadu_ps(R + 16*i); + //_mm512_storeu_ps(R + 16*i, _mm512_mul_ps(vm, r)); + } + S[j] *= m; + } else { + for (int i = 0; i < 8; ++i) vk[i] = _mm512_setzero_ps(); + //std::memset(R, 0, 128*sizeof(float)); + S[j] = 0; + } + M[j] = smax; + } + val = v_expf(_mm512_sub_ps(val, _mm512_set1_ps(M[j]))); + S[j] += _mm512_reduce_add_ps(val); + _mm512_storeu_ps(cache + 16*j, val); + for (int l1 = 0; l1 < 16; ++l1) { + if (cache[16*j + l1] < -20.0f) continue; + auto vr = (const ggml_half *)((const char *)v + (16*k1 + l1)*stride_v); + auto vs = _mm512_set1_ps(cache[16*j + l1]); + for (int i = 0; i < 8; ++i) { + vk[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i)), vk[i]); + } + } + for (int i = 0; i < 8; ++i) _mm512_storeu_ps(R + 16*i, vk[i]); + } + } + for (int j = 0; j < 16; ++j) { + auto R = qkv + (16*i1 + j)*stride_qkv; + GGML_ASSERT(S[j] > 0); + if (S[j] > 0) { + auto norm = _mm512_set1_ps(1/S[j]); + for (int i = 0; i < 8; ++i) { + auto r = _mm512_loadu_ps(R + 16*i); + _mm512_storeu_ps(R + 16*i, _mm512_mul_ps(norm, r)); + } + } else { + std::memset(R, 0, 128*sizeof(float)); + } + } + } + return; + if (nq1%16 != 0 || nk1%16 != 0) printf("Oops(%s): nq1 = %d, nk1 = %d\n", __func__, nq1, nk1); + //GGML_ASSERT(nq1%16 == 0 && nk1%16 == 0); + auto vinf = _mm512_set1_ps(-INFINITY); + for (int i1 = 0; i1 < nq1/16; ++i1) { + //int iq1 = 16*i1; + for (int j1 = 0; j1 < 16; ++j1) { + S[j1] = 0; M[j1] = -INFINITY; + std::memset(qkv + j1*stride_v, 0, ne00*sizeof(float)); + } + for (int ik = 0; ik < nk1; ik += 16) { + ///////////////////////////////////////////////////////////////////////////////// + const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik); + DataInfo info{cache, (const char *)q, 16*sizeof(float), size_t(stride_q)*sizeof(float), 0, 0, nullptr, 0}; + mul_mat_fX_fY_T<4, ggml_half, float>(ne00, (const void *)kr, stride_k, info, 16); + ///////////////////////////////////////////////////////////////////////////////// + float * R = qkv; + for (int j1 = 0; j1 < 16; ++j1) { + int iq1 = 16*i1 + j1; + float * C = cache + 16*j1; + auto qk = _mm512_loadu_ps(C); + const ggml_half * mp = (const ggml_half *)((const char *)mask + stride_m*iq1); + auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i*)mp), _mm256_setzero_si256()); + qk = _mm512_mask_blend_ps(m16, vinf, qk); + float smax = _mm512_reduce_max_ps(qk); + if (smax > M[j1]) { + if (M[j1] > -INFINITY) { + float m = expf(M[j1] - smax); + auto ms = _mm512_set1_ps(m); + for (int i = 0; i < ne00/16; ++i) _mm512_storeu_ps(R + 16*i, _mm512_mul_ps(ms, _mm512_loadu_ps(R + 16*i))); + S[j1] *= m; + } else { + std::memset(R, 0, ne00*sizeof(float)); + S[j1] = 0; + } + M[j1] = smax; + } + auto vs = v_expf(_mm512_sub_ps(qk, _mm512_set1_ps(M[j1]))); + S[j1] += _mm512_reduce_add_ps(vs); + _mm512_storeu_ps(C, vs); + for (int jk = 0; jk < 16; ++jk) { + vs = _mm512_set1_ps(C[jk]); + const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*(ik + jk)); + for (int i = 0; i < ne00/16; ++i) { + auto v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)); + auto r = _mm512_loadu_ps(qkv + 16*i); + _mm512_storeu_ps(qkv + 16*i, _mm512_fmadd_ps(vs, v, r)); + } + } + R += stride_qkv; + } + } + for (int j1 = 0; j1 < 16; ++j1) { + if (S[j1] > 0) { + //GGML_ASSERT(S[j1] > 0); + auto norm = _mm512_set1_ps(1/S[j1]); + for (int i = 0; i < ne00/16; ++i) { + auto r = _mm512_loadu_ps(qkv + 16*i); + _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, r)); + } + } + qkv += stride_qkv; + } + q += 16*stride_q; + } + +} + + #else // IQK_IMPLEMENT bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) { diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index e096043a..c429d3d3 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -44,7 +44,8 @@ void iqk_flash_helper(int nq, // number of elements in q float slope, float * qk); // softmax(k*q) - k elements -void iqk_flash_helper_2(int nq, // number of elements in q +void iqk_flash_helper_2(bool is_alibi, + int nq, // number of elements in q int nk, // number of rows in k int stride_k, // distance between rows in k (in bytes) int stride_v, // distance between rows in k (in bytes) @@ -57,6 +58,21 @@ void iqk_flash_helper_2(int nq, // number of elements in q float * qk, float * qkv); // softmax(k*q) - k elements +void iqk_flash_helper_3(int ne00, + int nq, // number of elements in q + int nk, // number of rows in k + int stride_q, + int stride_k, // distance between rows in k (in bytes) + int stride_v, // distance between rows in v (in bytes) + int stride_m, // distance between rows in mask (in bytes) + int stride_qkv, // distance between rows in mask (in bytes) + const float * q, // q vector + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * v, + const void * mask, // mask. If not null, assumed to be fp16. nk elements + float scale, + float * qkv); // v*softmax(k*q) + #ifdef __cplusplus } #endif