Fix avx2 GEMM mess (v2) (#724)

* This fixes confusion around Q8_0 on AVX2

* This does it for iq4_nl, including FA

* This does it for iq4_nl on Zen4, but FA does not work

* Slightly more clear

* Adding forgotten q8_0_r8 to num_rows()

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-08-27 08:03:47 +03:00
committed by GitHub
parent ac4ec50f03
commit 3dc4dffed5
6 changed files with 114 additions and 60 deletions

View File

@@ -5562,11 +5562,7 @@ void ggml_vec_dot_q6_0_q8_0(int n, float * restrict s, size_t bs, const void * r
void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
#ifdef HAVE_FANCY_SIMD
enum ggml_type dot_type = GGML_TYPE_Q8_1_X4;
#else
enum ggml_type dot_type = GGML_TYPE_Q8_0_X4;
#endif
enum ggml_type dot_type = GGML_TYPE_Q8_2_X4;
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, dot_type, vy, by, s, bs, 0, 1)) {
return;
}

View File

@@ -856,10 +856,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_to_mat = quantize_mat_q8_0,
.vec_dot = ggml_vec_dot_q8_0_q8_0,
#if GGML_USE_IQK_MULMAT
#ifdef HAVE_FANCY_SIMD
// Remember: we cannot add 128 to the Q8 quants and use iblock sum in Q8_1 to subtract as we do on Zen4 for pure AVX2
// because there the result of the _mm256_maddubs_epi16() instruction may overflow the int16_t range
// (and it gets satured if it does), leading to wrong results.
#ifdef __AVX2__
.vec_dot_type = GGML_TYPE_Q8_2_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_0_X4,
@@ -1314,7 +1311,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq4_nl,
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref,
.vec_dot = ggml_vec_dot_iq4_nl_q8_0,
#if defined HAVE_FANCY_SIMD
#if __AVX2__
.vec_dot_type = GGML_TYPE_Q8_2_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_0_X4,

View File

@@ -615,13 +615,8 @@ struct HelperIQ4nl final : public BaseHelper {
constexpr static int block_size_q = QK8_0;
#else
HelperIQ4nl(const char * data, int stride) : Base(data, stride) {}
#ifdef HAVE_FANCY_SIMD
using block_q8 = block_q8_2;
constexpr static int block_size_q = QK8_2;
#else
using block_q8 = block_q8_0;
constexpr static int block_size_q = QK8_0;
#endif
#endif
// Needed for v * softmax(k * q)

View File

@@ -148,6 +148,16 @@ typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& inf
funcs[6] = kernel<Dequantizer, 7>;\
funcs[7] = kernel<Dequantizer, 8>;\
#define IQK_SET_MUL_MAT_FUNCTIONS_T2(kernel, Dequantizer, Block, funcs) \
funcs[0] = kernel<Dequantizer, 1, Block>;\
funcs[1] = kernel<Dequantizer, 2, Block>;\
funcs[2] = kernel<Dequantizer, 3, Block>;\
funcs[3] = kernel<Dequantizer, 4, Block>;\
funcs[4] = kernel<Dequantizer, 5, Block>;\
funcs[5] = kernel<Dequantizer, 6, Block>;\
funcs[6] = kernel<Dequantizer, 7, Block>;\
funcs[7] = kernel<Dequantizer, 8, Block>;\
#define IQK_SET_MUL_MAT_FUNCTIONS(kernel, funcs) \
funcs[0] = kernel<1>;\
funcs[1] = kernel<2>;\

View File

@@ -79,6 +79,16 @@ template <typename Q8, typename Q8x4> struct Sum4q4 {
inline __m256i compute(__m256i x, __m256i y) const { return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(x, y)); }
};
inline __m256 convert_scales(const uint16_t * scales) {
auto aux_d = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)scales)), 16));
auto aux_m = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(_mm_loadl_epi64((const __m128i *)(scales+4))));
return _mm256_set_m128(_mm_mul_ps(aux_d, aux_m), aux_d);
}
inline __m128 convert_scales_s(const uint16_t * scales) {
return _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)scales)), 16));
}
struct ScaleHelperQ8_0 {
inline __m128 prepare4(const block_q8_0 * y) {
const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y;
@@ -106,6 +116,20 @@ struct ScaleHelperQ_0 {
template <typename Q> inline float prepare1(float d, const Q * y) const { return d*prepare1(y); }
};
struct ScaleHelperQ8_2S {
template <typename Q>
inline __m128 prepare4(const Q * y) {
const block_q8_2_x4 * y4 = (const block_q8_2_x4 *)y;
return convert_scales_s((const uint16_t *)y4->d);
}
template <typename Q>
inline __m128 prepare4(__m128 other_scales, const Q * y) {
return _mm_mul_ps(other_scales, prepare4<Q>(y));
}
template <typename Q> static inline float prepare1(const Q * y) { return GGML_BF16_TO_FP32(ggml_bf16_t{y->d}); }
template <typename Q> static inline float prepare1(float d, const Q * y) { return d*prepare1(y); }
};
struct ScaleHelperQ_0_MXFP4 {
float scales[4];
template <typename Q>
@@ -188,12 +212,6 @@ struct ScaleHelperQ8_1 {
}
};
inline __m256 convert_scales(const uint16_t * scales) {
auto aux_d = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)scales)), 16));
auto aux_m = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(_mm_loadl_epi64((const __m128i *)(scales+4))));
return _mm256_set_m128(_mm_mul_ps(aux_d, aux_m), aux_d);
}
struct ScaleHelperQ8_2 {
template <typename Q>
inline __m256 prepare4(const Q * y) {
@@ -348,6 +366,7 @@ using AccumType1 = AccumT<MinusType1<nrc_y>, nrc_y, is_multiple_of_4>;
using Sum4TypeQ80 = Sum4<block_q8_0, block_q8_0_x4, SignedDot, false>;
using Sum4TypeQ82 = Sum4<block_q8_2, block_q8_2_x4, UnsignedDot, false>;
using Sum4TypeQ82S = Sum4<block_q8_2, block_q8_2_x4, SignedDot, false>;
template <typename Unpacker, typename AccumType, typename Scales, typename Q8, int nrc_y>
void mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) {
@@ -374,10 +393,35 @@ void mul_mat_qX_q8_Helper_x2(int nb, const void * vx, size_t bx, const DataInfo&
}
}
template <typename Unpacker, int nrc_y>
template <typename Unpacker, int nrc_y, typename Block = block_q8_0>
void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%Unpacker::block_size() == 0);
Q8<nrc_y, block_q8_0> q8(info);
Q8<nrc_y, Block> q8(info);
int nb = n/Unpacker::block_size();
if constexpr (std::is_same_v<Block, block_q8_2>) {
if (nb%4 == 0) {
mul_mat_qX_q8_Helper<Unpacker, AccumType0<nrc_y, true>, ScaleHelperQ8_2S, Block, nrc_y>(
nb, vx, bx, info, q8.y, nrc_x);
} else {
mul_mat_qX_q8_Helper<Unpacker, AccumType0<nrc_y, false>, ScaleHelperQ8_2S, Block, nrc_y>(
nb, vx, bx, info, q8.y, nrc_x);
}
}
else {
if (nb%4 == 0) {
mul_mat_qX_q8_Helper<Unpacker, AccumType0<nrc_y, true>, ScaleHelperQ8_0, Block, nrc_y>(
nb, vx, bx, info, q8.y, nrc_x);
} else {
mul_mat_qX_q8_Helper<Unpacker, AccumType0<nrc_y, false>, ScaleHelperQ8_0, Block, nrc_y>(
nb, vx, bx, info, q8.y, nrc_x);
}
}
}
template <typename Unpacker, int nrc_y>
void mul_mat_qX_0_q8_2_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%Unpacker::block_size() == 0);
Q8<nrc_y, block_q8_2> q8(info);
int nb = n/Unpacker::block_size();
if (nb%4 == 0) {
mul_mat_qX_q8_Helper<Unpacker, AccumType0<nrc_y, true>, ScaleHelperQ8_0, block_q8_0, nrc_y>(
@@ -393,11 +437,11 @@ void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info
template <typename Unpacker, int nrc_y, int nrc_x>
void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) {
static_assert(8%nrc_y == 0);
Q8<nrc_y, block_q8_0> q8(info);
Q8<nrc_y, block_q8_2> q8(info);
int nb = n/Unpacker::block_size();
Unpacker unp(vx, bx);
typename Unpacker::Sum4T sum4;
ScaleHelperQ8_0 scales;
ScaleHelperQ8_2S scales;
__m256 result[8];
auto store = [&info, &result] (int ix0) {
if constexpr (nrc_y == 1) {
@@ -549,19 +593,15 @@ struct Q4_0_1_Dequantizer {
}
};
struct IQ4_NL_Dequantizer {
struct IQ4_NL_DequantizerU {
Dequantizer4bit b4;
#ifdef HAVE_FANCY_SIMD
const __m256i values = load_iq4nl_values_256();
#else
const __m256i values = load_iq4k_values_256();
#endif
inline __m256i dequant(const block_iq4_nl * x) const {
return _mm256_shuffle_epi8(values, b4.dequant(x->qs));
}
};
struct IQ4_NL0_Dequantizer {
struct IQ4_NL_DequantizerS {
Dequantizer4bit b4;
const __m256i values = load_iq4k_values_256();
inline __m256i dequant(const block_iq4_nl * x) const {
@@ -705,7 +745,7 @@ struct Q_Unpacker {
struct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> {
Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ80;
using Sum4T = Sum4TypeQ82S;
inline static int block_size() { return QK8_0; }
};
struct Q8_0_1_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0_1<127>, Q8_0_1_Dequantizer> {
@@ -713,6 +753,11 @@ struct Q8_0_1_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0_1<12
using Sum4T = Sum4TypeQ82;
inline static int block_size() { return QK8_0; }
};
struct Q8_0_2_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> {
Q8_0_2_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ82;
inline static int block_size() { return QK8_0; }
};
struct Q4_0_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0, Q4_0_Dequantizer> {
Q4_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ80;
@@ -729,19 +774,16 @@ struct MXFP4_Unpacker final : public Q_Unpacker<block_mxfp4, ScaleHelperQ_0_1_MX
using Sum4T = Sum4TypeQ82;
inline static int block_size() { return QK4_NL; }
};
#ifdef HAVE_FANCY_SIMD
struct IQ4_NL_Unpacker final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0_1<128>, IQ4_NL_Dequantizer> {
IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
struct IQ4_NL_UnpackerU final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0_1<128>, IQ4_NL_DequantizerU> {
IQ4_NL_UnpackerU(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ82;
inline static int block_size() { return QK4_NL; }
};
#else
struct IQ4_NL_Unpacker final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0, IQ4_NL0_Dequantizer> {
IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ80;
struct IQ4_NL_UnpackerS final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0, IQ4_NL_DequantizerS> {
IQ4_NL_UnpackerS(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ82S;
inline static int block_size() { return QK4_NL; }
};
#endif
struct Q5_0_Unpacker final : public Q_Unpacker<block_q5_0, ScaleHelperQ_0, Q5_0_Dequantizer> {
Q5_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ80;
@@ -1872,19 +1914,20 @@ void iqk_convert_qX_1_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int
}
template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> ||
std::is_same_v<Dequantizer, Q8_0_Unpacker>) {
if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker>) {
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_0_q8_0_T, Dequantizer, funcs)
}
else if constexpr (std::is_same_v<Dequantizer, Q8_0_Unpacker>) {
IQK_SET_MUL_MAT_FUNCTIONS_T2(mul_mat_qX_0_q8_0_T, Dequantizer, block_q8_2, funcs)
}
else if constexpr (std::is_same_v<Dequantizer, Q4_1_Unpacker> || std::is_same_v<Dequantizer, Q5_1_Unpacker>) {
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_1_q8_2_T, Dequantizer, funcs)
}
else if constexpr (std::is_same_v<Dequantizer, IQ4_NL_Unpacker>) {
#ifdef HAVE_FANCY_SIMD
else if constexpr (std::is_same_v<Dequantizer, IQ4_NL_UnpackerU>) {
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_1_q8_2_T, Dequantizer, funcs)
#else
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_0_q8_0_T, Dequantizer, funcs)
#endif
}
else if constexpr (std::is_same_v<Dequantizer, IQ4_NL_UnpackerS>) {
IQK_SET_MUL_MAT_FUNCTIONS_T2(mul_mat_qX_0_q8_0_T, Dequantizer, block_q8_2, funcs)
}
else if constexpr (std::is_same_v<Dequantizer, Q8_0_1_Unpacker> || std::is_same_v<Dequantizer, Q4_0_1_Unpacker> ||
std::is_same_v<Dequantizer, Q5_0_1_Unpacker> || std::is_same_v<Dequantizer, Q6_0_1_Unpacker> ||
@@ -1902,7 +1945,7 @@ bool iqk_convert_legacy_quants_q8_r8(int type, int n, const void * vx, size_t bx
case GGML_TYPE_Q5_0 : iqk_convert_qX_q80_r8<block_q5_0, Q5_0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q5_1 : iqk_convert_qX_1_q8_1_r8<block_q5_1, Q5_1_Dequantizer<block_q5_1>>(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8<block_q6_0, Q6_0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8<block_iq4_nl, IQ4_NL0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8<block_iq4_nl, IQ4_NL_DequantizerS>(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q8_0 : iqk_convert_q80_q80_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_MXFP4 : iqk_convert_qX_q80_r8<block_mxfp4, MXFP40_Dequantizer>(n, vx, bx, vy, nrc_x); break;
default: return false;
@@ -1939,20 +1982,17 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu
set_functions<Q8_0_1_Unpacker>(kernels);
#else
set_functions<Q8_0_Unpacker>(kernels);
expected_typeB = GGML_TYPE_Q8_0_X4;
#endif
break;
case GGML_TYPE_IQ4_NL:
set_functions<IQ4_NL_Unpacker>(kernels);
#ifndef HAVE_FANCY_SIMD
expected_typeB = GGML_TYPE_Q8_0_X4;
#ifdef HAVE_FANCY_SIMD
set_functions<IQ4_NL_UnpackerU>(kernels);
#else
set_functions<IQ4_NL_UnpackernS>(kernels);
#endif
break;
case GGML_TYPE_MXFP4:
set_functions<MXFP4_Unpacker>(kernels);
//#ifndef HAVE_FANCY_SIMD
// expected_typeB = GGML_TYPE_Q8_0_X4;
//#endif
break;
case GGML_TYPE_Q4_0_R8:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q4_0_r8_q8_2, kernels)
@@ -3223,6 +3263,19 @@ inline std::pair<mul_mat_t, int> mul_mat_kernel(int int_typeA, int nq) {
case 7: return std::make_pair(mul_mat, 7>, 7);\
}\
}
#define MAKE_FUNCS2(mul_mat, block, n) \
if (n >= kMaxQ) return std::make_pair(mul_mat, kMaxQ, block>, kMaxQ);\
else {\
switch (n) {\
case 1: return std::make_pair(mul_mat, 1, block>, 1);\
case 2: return std::make_pair(mul_mat, 2, block>, 2);\
case 3: return std::make_pair(mul_mat, 3, block>, 3);\
case 4: return std::make_pair(mul_mat, 4, block>, 4);\
case 5: return std::make_pair(mul_mat, 5, block>, 5);\
case 6: return std::make_pair(mul_mat, 6, block>, 6);\
case 7: return std::make_pair(mul_mat, 7, block>, 7);\
}\
}
#define MAKE_FUNCS_ONLY_NRC(mul_mat, n) \
if (n >= kMaxQ) return std::make_pair(mul_mat<kMaxQ>, kMaxQ);\
else {\
@@ -3249,7 +3302,11 @@ inline std::pair<mul_mat_t, int> mul_mat_kernel(int int_typeA, int nq) {
if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 1, k_step>, 1);
if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 2, k_step>, 2);
if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 4, k_step>, 4);
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
if (nq == 3) return std::make_pair(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 3, block_q8_2>, 3);
if (nq == 5) return std::make_pair(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 5, block_q8_2>, 5);
if (nq == 6) return std::make_pair(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 6, block_q8_2>, 6);
if (nq == 7) return std::make_pair(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 7, block_q8_2>, 7);
return std::make_pair(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, kMaxQ, block_q8_2>, kMaxQ);
#endif
#endif
}
@@ -3293,9 +3350,9 @@ inline std::pair<mul_mat_t, int> mul_mat_kernel(int int_typeA, int nq) {
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerIQ4NL, nq);
#else
#ifdef HAVE_FANCY_SIMD
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<IQ4_NL_Unpacker, nq);
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<IQ4_NL_UnpackerU, nq);
#else
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<IQ4_NL_Unpacker, nq);
MAKE_FUNCS2(mul_mat_qX_0_q8_0_T<IQ4_NL_UnpackerS, block_q8_2, nq);
#endif
#endif
}

View File

@@ -266,9 +266,7 @@ struct MulMat {
case GGML_TYPE_Q5_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_Q5_1 : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
case GGML_TYPE_Q6_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
#ifdef HAVE_FANCY_SIMD
case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
#endif
case GGML_TYPE_MXFP4 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ1_KT : return nrc_y >= 16 ? GGML_TYPE_Q8_0_R8 : type;
@@ -351,6 +349,7 @@ struct MulMat {
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q8_K_R8: return 8;
case GGML_TYPE_Q4_0_R8:
case GGML_TYPE_Q8_0_R8:
case GGML_TYPE_Q8_K_R16:
case GGML_TYPE_BF16_R16: return 16;
default: return 1;