diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 22deb56a..19ca8684 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6371,6 +6371,35 @@ void iqk_flash_helper(int nq, // number of elements in q softmax_extended(nk, qk, qk, scale, slope, (const char *)mask, true); } +namespace { +inline void accumulate(int n, float * saux, float smax, float& M, float& S, __m512 * acc, const char * v, int stride_v) { + if (smax > M) { + if (M > -INFINITY) { + float ms = expf(M - smax); + auto vms = _mm512_set1_ps(ms); + for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(vms, acc[i]); + S *= ms; + } else { + for (int i = 0; i < 8; ++i) acc[i] = _mm512_setzero_ps(); + S = 0; + } + M = smax; + } + //auto vs_all = v_expf(_mm256_sub_ps(_mm256_loadu_ps(saux), _mm256_set1_ps(M))); + //_mm256_storeu_ps(saux, vs_all); + auto vs_all = v_expf(_mm512_sub_ps(_mm512_loadu_ps(saux), _mm512_set1_ps(M))); + _mm512_storeu_ps(saux, vs_all); + for (int j = 0; j < n; ++j) { + S += saux[j]; + auto vs = _mm512_set1_ps(saux[j]); + auto vr = (const ggml_half *)(v + stride_v*j); + for (int i = 0; i < 8; ++i) { + acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); + } + } +} +} + void iqk_flash_helper_2(int nq, // number of elements in q int nk, // number of rows in k int stride_k, // distance between rows in k (in bytes) @@ -6387,22 +6416,91 @@ void iqk_flash_helper_2(int nq, // number of elements in q //GGML_ASSERT(nq / 16 <= 16); if (nq == 128) { + constexpr int kNchunk = 16; const ggml_half * mp = mask ? (const ggml_half *)mask : nullptr; __m512 vq[8]; - __m512 acc[8] = {}; + __m512 acc[8]; // = {}; for (int i = 0; i < 8; ++i) vq[i] = _mm512_loadu_ps(q + 16*i); float M = -INFINITY; float S = 0; + float saux[kNchunk]; + float smax = -INFINITY; + int last_ik = 0; int ik = 0; - //if (nk >= 8) { + for (; ik < nk; ++ik) { + if (ik - last_ik == kNchunk) { + accumulate(kNchunk, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v); + //if (smax > M) { + // if (M > -INFINITY) { + // float ms = expf(M - smax); + // auto vms = _mm512_set1_ps(ms); + // for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(vms, acc[i]); + // S *= ms; + // } + // M = smax; + //} + //auto vs_all = v_expf(_mm256_sub_ps(_mm256_loadu_ps(saux), _mm256_set1_ps(M))); + //_mm256_storeu_ps(saux, vs_all); + //for (int j = 0; j < 8; ++j) { + // S += saux[j]; + // auto vs = _mm512_set1_ps(saux[j]); + // auto vr = (const ggml_half *)((const char *)v + stride_v*(last_ik + j)); + // for (int i = 0; i < 8; ++i) { + // acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); + // } + //} + last_ik = ik; + smax = -INFINITY; + } + const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ik]) : 0.0f; + if (mv == -INFINITY) break; //continue; + const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik); + auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr))); + for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum); + float s = scale * _mm512_reduce_add_ps(sum) + mv; + saux[ik - last_ik] = s; + smax = std::max(smax, s); + } + if (ik > last_ik) { + accumulate(ik - last_ik, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v); + //if (smax > M) { + // if (M > -INFINITY) { + // float ms = expf(M - smax); + // auto vms = _mm512_set1_ps(ms); + // for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(vms, acc[i]); + // S *= ms; + // } + // M = smax; + //} + //for (int j = last_ik; j < ik; ++j) { + // float s = expf(saux[j - last_ik] - M); + // S += s; + // auto vs = _mm512_set1_ps(s); + // auto vr = (const ggml_half *)((const char *)v + stride_v*j); + // for (int i = 0; i < 8; ++i) { + // acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); + // } + //} + } + if (S > 0) { + auto norm = _mm512_set1_ps(1/S); + for (int i = 0; i < 8; ++i) _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, acc[i])); + } else { + printf("Oops: S = %g. ik = %d, last_ik = %d, nk = %d, M = %g\n", S, ik, last_ik, nk, M); + GGML_ASSERT(false); + std::memset(qkv, 0, 128*sizeof(float)); + } + return; + //int ik = 0; + //if (false && nk >= 8) { + // bool finished = false; // float s8[8]; // for (int ik8 = 0; ik8 < nk/8; ++ik8) { // float smax = -INFINITY; - // for (int j = 0; j < 8; ++j) { + // int j = 0; + // for (; j < 8; ++j) { // const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[8*ik+j]) : 0.0f; - // if (mv == -INFINITY) { - // s8[j] = -INFINITY; continue; - // } + // if (mv == -INFINITY) break; // const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*(8*ik8 + j)); // auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr))); // for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum); @@ -6410,6 +6508,7 @@ void iqk_flash_helper_2(int nq, // number of elements in q // s8[j] = s; // smax = std::max(smax, s); // } + // if (smax == -INFINITY) { finished = true; break; } // if (smax > M) { // if (M > -INFINITY) { // float ms = expf(M - smax); @@ -6418,75 +6517,63 @@ void iqk_flash_helper_2(int nq, // number of elements in q // S *= ms; // } else { // for (int i = 0; i < 8; ++i) acc[i] = _mm512_setzero_ps(); + // S = 0; // } // M = smax; // } - // if (smax == -INFINITY) break; // _mm256_storeu_ps(s8, v_expf(_mm256_sub_ps(_mm256_loadu_ps(s8), _mm256_set1_ps(M)))); - // for (int j = 0; j < 8; ++j) { - // if (s8[j] <= 0.0f) continue; - // const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*(8*ik8+j)); - // auto vs = _mm512_set1_ps(s8[j]); - // S += s8[j]; + // for (int l = 0; l < j; ++l) { + // if (s8[l] <= 0.0f) continue; + // const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*(8*ik8+l)); + // auto vs = _mm512_set1_ps(s8[l]); + // S += s8[l]; // for (int i = 0; i < 8; ++i) { // acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); // } // } // } - // ik = 8*(nk/8); + // ik = finished ? nk : 8*(nk/8); //} - //int last_i = -1; - for (; ik < nk; ++ik) { - const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ik]) : 0.0f; - if (mv == -INFINITY) break; //continue; - //last_i = ik; - const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik); - const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*ik); - //auto sum1 = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr+0))); - //auto sum2 = _mm512_mul_ps(vq[1], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr+1))); - //for (int i = 2; i < 8; i += 2) { - // sum1 = _mm512_fmadd_ps(vq[i+0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i + 0)), sum1); - // sum2 = _mm512_fmadd_ps(vq[i+1], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i + 1)), sum2); - //} - //float s = scale * _mm512_reduce_add_ps(_mm512_add_ps(sum1, sum2)) + mv; - auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr))); - for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum); - float s = scale * _mm512_reduce_add_ps(sum) + mv; - if (s > M) { - if (M > -INFINITY) { - //if (M - s > -20.f) { - float ms = expf(M - s); - auto vms = _mm512_set1_ps(ms); - for (int i = 0; i < 8; ++i) { - acc[i] = _mm512_fmadd_ps(vms, acc[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i))); - } - S = ms*S + 1; - } else { - for (int i = 0; i < 8; ++i) { - acc[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)); - } - S = 1; - } - M = s; - } else { - //if (s - M > -20.f) { - float vs = expf(s - M); - auto vvs = _mm512_set1_ps(vs); - for (int i = 0; i < 8; ++i) { - acc[i] = _mm512_fmadd_ps(vvs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); - } - S += vs; - //} - } - } - //if (last_i < 0) { + //int last_ik = 0; + //for (; ik < nk; ++ik) { + // const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ik]) : 0.0f; + // if (mv == -INFINITY) break; //continue; + // const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik); + // const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*ik); + // auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr))); + // for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum); + // float s = scale * _mm512_reduce_add_ps(sum) + mv; + // if (s > M) { + // if (M > -INFINITY) { + // float ms = expf(M - s); + // auto vms = _mm512_set1_ps(ms); + // for (int i = 0; i < 8; ++i) { + // acc[i] = _mm512_fmadd_ps(vms, acc[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i))); + // } + // S = ms*S + 1; + // } else { + // for (int i = 0; i < 8; ++i) { + // acc[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)); + // } + // S = 1; + // } + // M = s; + // } else { + // float vs = expf(s - M); + // auto vvs = _mm512_set1_ps(vs); + // for (int i = 0; i < 8; ++i) { + // acc[i] = _mm512_fmadd_ps(vvs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); + // } + // S += vs; + // } + //} + //if (S > 0) { + // auto norm = _mm512_set1_ps(1/S); + // for (int i = 0; i < 8; ++i) _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, acc[i])); + //} else { // std::memset(qkv, 0, 128*sizeof(float)); - // return; //} - //printf("%s: nk = %d, last_i = %d\n", __func__, nk, last_i); - auto norm = _mm512_set1_ps(1/S); - for (int i = 0; i < 8; ++i) _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, acc[i])); - return; + //return; } DataInfo info{qk, (const char*)q, 0, size_t(stride_k), 0, 1, nullptr, 0};