diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 19ca8684..165fae90 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6103,7 +6103,7 @@ inline __m256 v_expf(__m256 x) { } #endif -inline float prepare_softmax(int nc, const float * sp, float scale, float slope, const char * mp, bool use_fp16, __m512 * values) { +inline std::pair prepare_softmax(int nc, const float * sp, float scale, float slope, const char * mp, bool use_fp16, __m512 * values) { __m512 vscale = _mm512_set1_ps(scale); __m512 vmax = _mm512_set1_ps(-INFINITY); if (mp) { @@ -6111,6 +6111,9 @@ inline float prepare_softmax(int nc, const float * sp, float scale, float slope, if (use_fp16) { const ggml_fp16_t * mp_f16 = (const ggml_fp16_t *)mp; for (int i = 0; i < nc/16; ++i) { + if (GGML_FP16_TO_FP32(mp_f16[16*i]) == -INFINITY) { + nc = 16*i; break; + } const __m512 m = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp_f16 + i)); const __m512 x = _mm512_loadu_ps(sp + 16*i); const __m512 y = _mm512_fmadd_ps(vslope, m, _mm512_mul_ps(vscale, x)); @@ -6120,6 +6123,9 @@ inline float prepare_softmax(int nc, const float * sp, float scale, float slope, } else { const float * mp_f32 = (const float *)mp; for (int i = 0; i < nc/16; ++i) { + if (mp_f32[16*i] == -INFINITY) { + nc = 16*i; break; + } const __m512 m = _mm512_loadu_ps(mp_f32 + 16*i); const __m512 x = _mm512_loadu_ps(sp + 16*i); const __m512 y = _mm512_fmadd_ps(vslope, m, _mm512_mul_ps(vscale, x)); @@ -6135,11 +6141,11 @@ inline float prepare_softmax(int nc, const float * sp, float scale, float slope, values[i] = y; } } - return _mm512_reduce_max_ps(vmax); + return std::make_pair(_mm512_reduce_max_ps(vmax), nc); } -inline float prepare_softmax(int nc, float * sp, float scale, float slope, const char * mp, bool use_fp16) { +inline std::pair prepare_softmax(int nc, float * sp, float scale, float slope, const char * mp, bool use_fp16) { __m512 vscale = _mm512_set1_ps(scale); __m512 vmax = _mm512_set1_ps(-INFINITY); float scalar_max = -INFINITY; @@ -6147,27 +6153,26 @@ inline float prepare_softmax(int nc, float * sp, float scale, float slope, const __m512 vslope = _mm512_set1_ps(slope); if (use_fp16) { const ggml_fp16_t * mp_f16 = (const ggml_fp16_t *)mp; - __m512 vmax1 = _mm512_set1_ps(-INFINITY); - for (int i = 0; i < nc/32; ++i) { - const __m512 m1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp_f16 + 2*i+0)); - const __m512 m2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp_f16 + 2*i+1)); - const __m512 x1 = _mm512_loadu_ps(sp + 32*i + 0); - const __m512 x2 = _mm512_loadu_ps(sp + 32*i + 16); - const __m512 y1 = _mm512_fmadd_ps(vslope, m1, _mm512_mul_ps(vscale, x1)); - const __m512 y2 = _mm512_fmadd_ps(vslope, m2, _mm512_mul_ps(vscale, x2)); - vmax = _mm512_max_ps(vmax , y1); - vmax1 = _mm512_max_ps(vmax1, y2); - _mm512_storeu_ps(sp + 32*i + 0, y1); - _mm512_storeu_ps(sp + 32*i + 16, y2); + for (int i = 0; i < nc/16; ++i) { + if (GGML_FP16_TO_FP32(mp_f16[16*i]) == -INFINITY) { + nc = 16*i; break; + } + const __m512 m = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp_f16+i)); + const __m512 x = _mm512_loadu_ps(sp + 16*i); + const __m512 y = _mm512_fmadd_ps(vslope, m, _mm512_mul_ps(vscale, x)); + vmax = _mm512_max_ps(vmax , y); + _mm512_storeu_ps(sp + 16*i, y); } - vmax = _mm512_max_ps(vmax, vmax1); - for (int i = 32*(nc/32); i < nc; ++i) { + for (int i = 16*(nc/16); i < nc; ++i) { sp[i] = scale*sp[i] + slope*GGML_FP16_TO_FP32(mp_f16[i]); scalar_max = std::max(scalar_max, sp[i]); } } else { const float * mp_f32 = (const float *)mp; for (int i = 0; i < nc/16; ++i) { + if (mp_f32[16*i] == -INFINITY) { + nc = 16*i; break; + } const __m512 m = _mm512_loadu_ps(mp_f32 + 16*i); const __m512 x = _mm512_loadu_ps(sp + 16*i); const __m512 y = _mm512_fmadd_ps(vslope, m, _mm512_mul_ps(vscale, x)); @@ -6192,7 +6197,7 @@ inline float prepare_softmax(int nc, float * sp, float scale, float slope, const } } float vector_max = _mm512_reduce_max_ps(vmax); - return std::max(scalar_max, vector_max); + return std::make_pair(std::max(scalar_max, vector_max), nc); } inline float do_soft_max(int nc, float * sp, float max) { @@ -6226,26 +6231,26 @@ inline void do_scale(int nc, const float * sp, float * dp, float scale) { } void softmax_extended(int nc, float * sp, float * dp, float scale, float slope, const char * mp, bool mask_is_fp16) { - auto max = prepare_softmax(nc, sp, scale, slope, mp, mask_is_fp16); - auto sum = do_soft_max(nc, sp, max); - do_scale(nc, sp, dp, 1.f/sum); + auto [max, ncc] = prepare_softmax(nc, sp, scale, slope, mp, mask_is_fp16); + auto sum = do_soft_max(ncc, sp, max); + do_scale(ncc, sp, dp, 1.f/sum); } -float softmax_extended_dont_scale(int nc, float * sp, float scale, float slope, const char * mp, bool mask_is_fp16) { +std::pair softmax_extended_dont_scale(int nc, float * sp, float scale, float slope, const char * mp, bool mask_is_fp16) { if (nc/16 <= 16 && 16*(nc/16) == nc) { __m512 v[16]; - auto max = prepare_softmax(nc, sp, scale, slope, mp, mask_is_fp16, v); + auto [max, ncc] = prepare_softmax(nc, sp, scale, slope, mp, mask_is_fp16, v); auto vmax = _mm512_set1_ps(-max); auto vsum = _mm512_setzero_ps(); - for (int i = 0; i < nc/16; ++i) { + for (int i = 0; i < ncc/16; ++i) { v[i] = v_expf(_mm512_add_ps(v[i], vmax)); vsum = _mm512_add_ps(vsum, v[i]); _mm512_storeu_ps(sp + 16*i, v[i]); } - return _mm512_reduce_add_ps(vsum); + return std::make_pair(_mm512_reduce_add_ps(vsum), ncc); } - auto max = prepare_softmax(nc, sp, scale, slope, mp, mask_is_fp16); - return do_soft_max(nc, sp, max); + auto [max, ncc] = prepare_softmax(nc, sp, scale, slope, mp, mask_is_fp16); + return std::make_pair(do_soft_max(ncc, sp, max), ncc); } } @@ -6372,15 +6377,16 @@ void iqk_flash_helper(int nq, // number of elements in q } namespace { +template inline void accumulate(int n, float * saux, float smax, float& M, float& S, __m512 * acc, const char * v, int stride_v) { if (smax > M) { if (M > -INFINITY) { float ms = expf(M - smax); auto vms = _mm512_set1_ps(ms); - for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(vms, acc[i]); + for (int i = 0; i < nq/16; ++i) acc[i] = _mm512_mul_ps(vms, acc[i]); S *= ms; } else { - for (int i = 0; i < 8; ++i) acc[i] = _mm512_setzero_ps(); + for (int i = 0; i < nq/16; ++i) acc[i] = _mm512_setzero_ps(); S = 0; } M = smax; @@ -6393,11 +6399,61 @@ inline void accumulate(int n, float * saux, float smax, float& M, float& S, __m5 S += saux[j]; auto vs = _mm512_set1_ps(saux[j]); auto vr = (const ggml_half *)(v + stride_v*j); - for (int i = 0; i < 8; ++i) { + for (int i = 0; i < nq/16; ++i) { acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); } } } + +template +void flash_attn_T(int nk, // number of rows in k + int stride_k, // distance between rows in k (in bytes) + int stride_v, + const float * q, // q vector + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * v, // k matrix. Assumed to be fp16, nq x nk elements + const void * mask, // mask. If not null, assumed to be fp16. nk elements + float scale, + float slope, + float * qkv) { + constexpr int kNchunk = 16; + __m512 vq[nq/16]; + __m512 acc[nq/16]; + float saux[kNchunk]; + for (int i = 0; i < nq/16; ++i) vq[i] = _mm512_loadu_ps(q + 16*i); + const ggml_half * mp = mask ? (const ggml_half *)mask : nullptr; + float M = -INFINITY; + float S = 0; + float smax = -INFINITY; + int last_ik = 0; + int ik = 0; + for (; ik < nk; ++ik) { + if (ik - last_ik == kNchunk) { + accumulate(kNchunk, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v); + last_ik = ik; + smax = -INFINITY; + } + const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ik]) : 0.0f; + if (mv == -INFINITY) break; + const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik); + auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr))); + for (int i = 1; i < nq/16; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum); + float s = scale * _mm512_reduce_add_ps(sum) + mv; + saux[ik - last_ik] = s; + smax = std::max(smax, s); + } + if (ik > last_ik) { + accumulate(ik - last_ik, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v); + } + if (S > 0) { + auto norm = _mm512_set1_ps(1/S); + for (int i = 0; i < nq/16; ++i) _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, acc[i])); + } else { + printf("Oops: S = %g. ik = %d, last_ik = %d, nk = %d, M = %g\n", S, ik, last_ik, nk, M); + GGML_ASSERT(false); + std::memset(qkv, 0, 128*sizeof(float)); + } +} } void iqk_flash_helper_2(int nq, // number of elements in q @@ -6413,167 +6469,25 @@ void iqk_flash_helper_2(int nq, // number of elements in q float * qk, float * qkv) { GGML_ASSERT(nq % 4 == 0); - //GGML_ASSERT(nq / 16 <= 16); - if (nq == 128) { - constexpr int kNchunk = 16; - const ggml_half * mp = mask ? (const ggml_half *)mask : nullptr; - __m512 vq[8]; - __m512 acc[8]; // = {}; - for (int i = 0; i < 8; ++i) vq[i] = _mm512_loadu_ps(q + 16*i); - float M = -INFINITY; - float S = 0; - float saux[kNchunk]; - float smax = -INFINITY; - int last_ik = 0; - int ik = 0; - for (; ik < nk; ++ik) { - if (ik - last_ik == kNchunk) { - accumulate(kNchunk, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v); - //if (smax > M) { - // if (M > -INFINITY) { - // float ms = expf(M - smax); - // auto vms = _mm512_set1_ps(ms); - // for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(vms, acc[i]); - // S *= ms; - // } - // M = smax; - //} - //auto vs_all = v_expf(_mm256_sub_ps(_mm256_loadu_ps(saux), _mm256_set1_ps(M))); - //_mm256_storeu_ps(saux, vs_all); - //for (int j = 0; j < 8; ++j) { - // S += saux[j]; - // auto vs = _mm512_set1_ps(saux[j]); - // auto vr = (const ggml_half *)((const char *)v + stride_v*(last_ik + j)); - // for (int i = 0; i < 8; ++i) { - // acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); - // } - //} - last_ik = ik; - smax = -INFINITY; + switch (nq) { + case 64: flash_attn_T< 64>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return; + case 80: flash_attn_T< 80>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return; + case 96: flash_attn_T< 96>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return; + case 112: flash_attn_T<112>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return; + case 128: flash_attn_T<128>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return; + case 256: flash_attn_T<256>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return; + default: break; + //default: GGML_ABORT("unhandled head size -> fatal error"); + } + + if (mask) { + const ggml_half * mp = (const ggml_half *)mask; + for (int i = 0; i < nk; ++i) { + if (GGML_FP16_TO_FP32(mp[i]) == -INFINITY) { + nk = i; break; } - const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ik]) : 0.0f; - if (mv == -INFINITY) break; //continue; - const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik); - auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr))); - for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum); - float s = scale * _mm512_reduce_add_ps(sum) + mv; - saux[ik - last_ik] = s; - smax = std::max(smax, s); } - if (ik > last_ik) { - accumulate(ik - last_ik, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v); - //if (smax > M) { - // if (M > -INFINITY) { - // float ms = expf(M - smax); - // auto vms = _mm512_set1_ps(ms); - // for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(vms, acc[i]); - // S *= ms; - // } - // M = smax; - //} - //for (int j = last_ik; j < ik; ++j) { - // float s = expf(saux[j - last_ik] - M); - // S += s; - // auto vs = _mm512_set1_ps(s); - // auto vr = (const ggml_half *)((const char *)v + stride_v*j); - // for (int i = 0; i < 8; ++i) { - // acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); - // } - //} - } - if (S > 0) { - auto norm = _mm512_set1_ps(1/S); - for (int i = 0; i < 8; ++i) _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, acc[i])); - } else { - printf("Oops: S = %g. ik = %d, last_ik = %d, nk = %d, M = %g\n", S, ik, last_ik, nk, M); - GGML_ASSERT(false); - std::memset(qkv, 0, 128*sizeof(float)); - } - return; - //int ik = 0; - //if (false && nk >= 8) { - // bool finished = false; - // float s8[8]; - // for (int ik8 = 0; ik8 < nk/8; ++ik8) { - // float smax = -INFINITY; - // int j = 0; - // for (; j < 8; ++j) { - // const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[8*ik+j]) : 0.0f; - // if (mv == -INFINITY) break; - // const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*(8*ik8 + j)); - // auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr))); - // for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum); - // float s = scale * _mm512_reduce_add_ps(sum) + mv; - // s8[j] = s; - // smax = std::max(smax, s); - // } - // if (smax == -INFINITY) { finished = true; break; } - // if (smax > M) { - // if (M > -INFINITY) { - // float ms = expf(M - smax); - // auto scale = _mm512_set1_ps(ms); - // for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(acc[i], scale); - // S *= ms; - // } else { - // for (int i = 0; i < 8; ++i) acc[i] = _mm512_setzero_ps(); - // S = 0; - // } - // M = smax; - // } - // _mm256_storeu_ps(s8, v_expf(_mm256_sub_ps(_mm256_loadu_ps(s8), _mm256_set1_ps(M)))); - // for (int l = 0; l < j; ++l) { - // if (s8[l] <= 0.0f) continue; - // const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*(8*ik8+l)); - // auto vs = _mm512_set1_ps(s8[l]); - // S += s8[l]; - // for (int i = 0; i < 8; ++i) { - // acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); - // } - // } - // } - // ik = finished ? nk : 8*(nk/8); - //} - //int last_ik = 0; - //for (; ik < nk; ++ik) { - // const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ik]) : 0.0f; - // if (mv == -INFINITY) break; //continue; - // const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik); - // const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*ik); - // auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr))); - // for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum); - // float s = scale * _mm512_reduce_add_ps(sum) + mv; - // if (s > M) { - // if (M > -INFINITY) { - // float ms = expf(M - s); - // auto vms = _mm512_set1_ps(ms); - // for (int i = 0; i < 8; ++i) { - // acc[i] = _mm512_fmadd_ps(vms, acc[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i))); - // } - // S = ms*S + 1; - // } else { - // for (int i = 0; i < 8; ++i) { - // acc[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)); - // } - // S = 1; - // } - // M = s; - // } else { - // float vs = expf(s - M); - // auto vvs = _mm512_set1_ps(vs); - // for (int i = 0; i < 8; ++i) { - // acc[i] = _mm512_fmadd_ps(vvs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); - // } - // S += vs; - // } - //} - //if (S > 0) { - // auto norm = _mm512_set1_ps(1/S); - // for (int i = 0; i < 8; ++i) _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, acc[i])); - //} else { - // std::memset(qkv, 0, 128*sizeof(float)); - //} - //return; } DataInfo info{qk, (const char*)q, 0, size_t(stride_k), 0, 1, nullptr, 0}; @@ -6581,7 +6495,8 @@ void iqk_flash_helper_2(int nq, // number of elements in q mul_mat_fX_fY_1(nq, k, stride_k, info, nk); //mul_mat_fX_fY_T<1, ggml_half, float>(nq, k, stride_k, info, nk); //softmax_extended(nk, qk, qk, scale, slope, (const char *)mask, true); - auto sum = softmax_extended_dont_scale(nk, qk, scale, slope, (const char *)mask, true); + auto [sum, ncc] = softmax_extended_dont_scale(nk, qk, scale, slope, (const char *)mask, true); + nk = ncc; GGML_ASSERT(nq%16 == 0); if (nq/16 <= 16) {