diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index a009f1a5..a03b6429 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -8523,7 +8523,22 @@ void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& inf typename Unpacker::Sum4T sum4; ScaleHelperQ8_0 scales; __m256 result[8]; - float val[8]; + auto store = [&info, &result] (int ix0) { + if constexpr (nrc_y == 1) { + info.store(ix0, 0, hsum_float_8x8(result)); + } + else if constexpr (nrc_y == 2) { + auto value = hsum_float_8x8(result); + auto value1 = _mm256_extractf128_ps(value, 1); + info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88)); + info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd)); + } + else { + float val[8]; + _mm256_storeu_ps(val, hsum_float_8x8(result)); + for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]); + } + }; if (nb%4 == 0) { for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { for (int ix = 0; ix < 8/nrc_y; ++ix) { @@ -8531,8 +8546,7 @@ void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& inf AccumType0 accum; accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); } - _mm256_storeu_ps(val, hsum_float_8x8(result)); - for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]); + store(ix0); } } else { for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { @@ -8541,8 +8555,7 @@ void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& inf AccumType0 accum; accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); } - _mm256_storeu_ps(val, hsum_float_8x8(result)); - for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]); + store(ix0); } } } @@ -8589,7 +8602,22 @@ void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& inf typename Unpacker::Sum4T sum4; ScaleHelperQ8_2 scales; __m256 result[8]; - float val[8]; + auto store = [&info, &result] (int ix0) { + if constexpr (nrc_y == 1) { + info.store(ix0, 0, hsum_float_8x8(result)); + } + else if constexpr (nrc_y == 2) { + auto value = hsum_float_8x8(result); + auto value1 = _mm256_extractf128_ps(value, 1); + info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88)); + info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd)); + } + else { + float val[8]; + _mm256_storeu_ps(val, hsum_float_8x8(result)); + for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]); + } + }; if (nb%4 == 0) { for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { for (int ix = 0; ix < 8/nrc_y; ++ix) { @@ -8597,8 +8625,7 @@ void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& inf AccumType1 accum; accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); } - _mm256_storeu_ps(val, hsum_float_8x8(result)); - for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]); + store(ix0); } } else { for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { @@ -8607,8 +8634,7 @@ void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& inf AccumType1 accum; accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); } - _mm256_storeu_ps(val, hsum_float_8x8(result)); - for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]); + store(ix0); } } } @@ -16337,7 +16363,6 @@ struct FlashQKV { accumulate_qkv_1(vh, fms); return; } - F16::Data v[8]; for (int j = 0; j < q_step; ++j) { auto R = qkv_cache + D*j; if (fms.need_scaling[j] == 2) { @@ -16350,6 +16375,40 @@ struct FlashQKV { } } } +#ifdef __AVX512F__ + if constexpr ((D/F16::block_size)%4 == 0) { + F16::Data v[16]; + F16::Data vs[4]; + for (int i = 0; i < D/F16::block_size; i += 4) { + for (int l = 0; l < k_step; l += 4) { + for (int k = 0; k < 4; ++k) { + vh.load(l+k, i+0, v[4*k+0], v[4*k+1]); + vh.load(l+k, i+2, v[4*k+2], v[4*k+3]); + } + for (int j = 0; j < q_step; ++j) { + auto R = qkv_cache + D*j; + auto s1 = F16::load(R + F16::block_size*(i+0)); + auto s2 = F16::load(R + F16::block_size*(i+1)); + auto s3 = F16::load(R + F16::block_size*(i+2)); + auto s4 = F16::load(R + F16::block_size*(i+3)); + F16::set4(fms.cache + k_step*j + l, vs); + for (int k = 0; k < 4; ++k) { + s1 = F16::fmadd(s1, v[4*k+0], vs[k]); + s2 = F16::fmadd(s2, v[4*k+1], vs[k]); + s3 = F16::fmadd(s3, v[4*k+2], vs[k]); + s4 = F16::fmadd(s4, v[4*k+3], vs[k]); + } + F16::store(R + F16::block_size*(i+0), s1); + F16::store(R + F16::block_size*(i+1), s2); + F16::store(R + F16::block_size*(i+2), s3); + F16::store(R + F16::block_size*(i+3), s4); + } + } + } + return; + } +#endif + F16::Data v[8]; #ifdef __AVX2__ F16::Data vs[4]; #endif @@ -17035,7 +17094,7 @@ struct FlashAttn { std::is_same_v> || std::is_same_v> || std::is_same_v>) { - constexpr size_t kMaxOnStackSize = 576; + constexpr size_t kMaxOnStackSize = 18432; //576; auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8); q_size = GGML_PAD(q_size, 64); if (q_size > kMaxOnStackSize) {