mxfp4: repacked GEMM (AVX2/Zen4)

This commit is contained in:
Iwan Kawrakow
2025-08-08 10:58:11 +03:00
parent 294341a3d2
commit a5e87adfa7
2 changed files with 13 additions and 5 deletions

View File

@@ -580,7 +580,7 @@ static inline __m256i load_mxfp4_values_256() {
return MM256_SET_M128I(val128, val128);
}
struct IQ4_MXFP4_Dequantizer {
struct MXFP4_Dequantizer {
Dequantizer4bit b4;
#ifdef HAVE_FANCY_SIMD
const __m256i values = load_unsigned_mxfp4_values_256();
@@ -592,7 +592,7 @@ struct IQ4_MXFP4_Dequantizer {
}
};
struct IQ4_MXFP40_Dequantizer {
struct MXFP40_Dequantizer {
Dequantizer4bit b4;
const __m256i values = load_mxfp4_values_256();
inline __m256i dequant(const block_mxfp4 * x) const {
@@ -718,13 +718,13 @@ struct IQ4_NL_Unpacker final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0_1<
using Sum4T = Sum4TypeQ82;
inline static int block_size() { return QK4_NL; }
};
struct IQ4_MXFP4_Unpacker final : public Q_Unpacker<block_mxfp4, ScaleHelperQ_0_1_MXFP4<12>, IQ4_MXFP4_Dequantizer> {
struct IQ4_MXFP4_Unpacker final : public Q_Unpacker<block_mxfp4, ScaleHelperQ_0_1_MXFP4<12>, MXFP4_Dequantizer> {
IQ4_MXFP4_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ82;
inline static int block_size() { return QK4_NL; }
};
#else
struct IQ4_MXFP4_Unpacker final : public Q_Unpacker<block_mxfp4, ScaleHelperQ_0, IQ4_MXFP4_Dequantizer> {
struct IQ4_MXFP4_Unpacker final : public Q_Unpacker<block_mxfp4, ScaleHelperQ_0, MXFP4_Dequantizer> {
IQ4_MXFP4_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ80;
inline static int block_size() { return QK4_NL; }
@@ -1809,7 +1809,11 @@ void iqk_convert_qX_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
y[i].d[k] = x8[k][i].d;
if constexpr (std::is_same_v<Dequantizer, MXFP40_Dequantizer>) {
y[i].d[k] = GGML_FP32_TO_FP16(GGML_E8M0_TO_FP32_HALF(x8[k][i].e));
} else {
y[i].d[k] = x8[k][i].d;
}
_mm256_storeu_si256((__m256i *)block, deq.dequant(x8[k] + i));
auto qs = (uint32_t *)y[i].qs;
for (int l = 0; l < 4; ++l) {
@@ -1887,6 +1891,7 @@ bool iqk_convert_legacy_quants_q8_r8(int type, int n, const void * vx, size_t bx
case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8<block_q6_0, Q6_0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8<block_iq4_nl, IQ4_NL0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q8_0 : iqk_convert_q80_q80_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_MXFP4 : iqk_convert_qX_q80_r8<block_mxfp4, MXFP40_Dequantizer>(n, vx, bx, vy, nrc_x); break;
default: return false;
}
return true;

View File

@@ -266,6 +266,7 @@ struct MulMat {
case GGML_TYPE_Q5_1 : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
case GGML_TYPE_Q6_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_MXFP4 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ1_KT : return nrc_y >= 16 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ2_KT : return nrc_y >= 16 ? GGML_TYPE_Q8_0_R8 : type;
@@ -295,6 +296,7 @@ struct MulMat {
case GGML_TYPE_Q6_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_MXFP4 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ1_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
@@ -458,6 +460,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy,
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
//case GGML_TYPE_Q4_0_R8:
//case GGML_TYPE_Q5_0_R4:
//case GGML_TYPE_Q6_0_R4: