This commit is contained in:
Kawrakow
2024-09-13 15:46:36 +03:00
parent 2bafb03aac
commit e23dce7a51

View File

@@ -2895,20 +2895,6 @@ template <typename Q8, typename Q8x4, typename Dot, bool can_pack = true> struct
}
}
};
// If I use this, it negatively impacts q4_1/q5_1 performance.
//template <typename Q8, typename Q8x4, typename Dot> struct Sum4 {
// Dot dot;
// inline __m256i compute(const __m256i * qx, const Q8 * y) const {
// const Q8x4 * y4 = (const Q8x4 *)y;
// const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 8x block 0
// const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 8x block 1
// const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 8x block 2
// const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 8x block 3
// auto p01 = _mm256_add_epi32(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,1, 0,1, 0,1, 0,1
// auto p23 = _mm256_add_epi32(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,3, 2,3, 2,3, 2,3
// return _mm256_add_epi32(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,1,2,3, 0,1,2,3
// }
//};
struct ScaleHelperQ8_0 {
inline __m128 prepare4(const block_q8_0 * y) {
@@ -6849,7 +6835,6 @@ void quantize_row_q8_1(const float * x, block_q8_1 * y, int k) {
template <int D, int step>
struct HelperQ80 final : public BaseHelper<step> {
static_assert(step == QK8_0);
using Base = BaseHelper<step>;
using block_q8 = block_q8_0;
HelperQ80(const char * data, int stride) : Base(data, stride) {}
@@ -6898,7 +6883,6 @@ struct HelperQ80 final : public BaseHelper<step> {
template <int D, int step>
struct HelperQ40 final : public BaseHelper<step> {
static_assert(step == QK4_0);
using Base = BaseHelper<step>;
using block_q8 = block_q8_0;
HelperQ40(const char * data, int stride) : Base(data, stride) {}
@@ -6942,7 +6926,6 @@ struct HelperQ40 final : public BaseHelper<step> {
template <int D, int step>
struct HelperQ41 final : public BaseHelper<step> {
static_assert(step == QK4_1);
using Base = BaseHelper<step>;
using block_q8 = block_q8_1;
HelperQ41(const char * data, int stride) : Base(data, stride) {}
@@ -7268,7 +7251,7 @@ struct FlashQKV {
F16::Data v1, v2;
for (int l1 = 0; l1 < k_step; ++l1) {
vh.load(l1, i, v1, v2);
for (int j = 0; j < q_step; ++j) {
for (int j = 0; j < nq1; ++j) {
auto vs = F16::set1(fms.cache[k_step*j + l1]);
vk[2*j+0] = F16::fmadd(vk[2*j+0], v1, vs);
vk[2*j+1] = F16::fmadd(vk[2*j+1], v2, vs);