Use Sum4q4 for q4_0

This commit is contained in:
Iwan Kawrakow
2025-04-23 15:43:59 +03:00
parent b19fd13141
commit cd44692bc0

View File

@@ -8209,6 +8209,22 @@ template <typename Q8, typename Q8x4, typename Dot, bool can_pack = true> struct
return _mm256_add_epi32(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,1,2,3, 0,1,2,3
}
}
inline __m256i compute(__m256i x, __m256i y) const { return dot.compute(x, y); }
};
template <typename Q8, typename Q8x4> struct Sum4q4 {
inline __m256i compute(const __m256i * qx, const Q8 * y) const {
const Q8x4 * y4 = (const Q8x4 *)y;
auto p0 = _mm256_maddubs_epi16(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 16x block 0
auto p1 = _mm256_maddubs_epi16(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 16x block 1
auto p2 = _mm256_maddubs_epi16(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 16x block 2
auto p3 = _mm256_maddubs_epi16(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 16x block 3
auto p01 = _mm256_add_epi16(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1, 0,0, 1,1, 0,0, 1,1
auto p23 = _mm256_add_epi16(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3, 2,2, 3,3, 2,2, 3,3
auto p0123 = _mm256_add_epi16(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3
return _mm256_madd_epi16(_mm256_set1_epi16(1), p0123);
}
inline __m256i compute(__m256i x, __m256i y) const { return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(x, y)); }
};
struct ScaleHelperQ8_0 {
@@ -8413,7 +8429,7 @@ template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {
for (int iy = 0; iy < nrc_y; ++iy) {
auto s12 = scales.prepare1(other_scales, y[iy] + i);
auto d = accm.compute(s12, iy);
const __m256i p0 = sum.dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs));
const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs));
acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]);
}
}
@@ -8443,7 +8459,7 @@ template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {
for (int iy = 0; iy < nrc_y; ++iy) {
auto s12 = scales.prepare1(other_scales, y[iy] + i);
auto d = accm.compute(s12, iy);
const __m256i p0 = sum.dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs));
const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs));
acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]);
}
}
@@ -8788,7 +8804,8 @@ struct Q4_0_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0, Q4_0_
};
struct Q4_0_1_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0_1<8>, Q4_0_1_Dequantizer> {
Q4_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ82;
//using Sum4T = Sum4TypeQ82;
using Sum4T = Sum4q4<block_q8_2, block_q8_2_x4>;
inline static int block_size() { return QK4_0; }
};
#ifdef HAVE_FANCY_SIMD