mxfp4: Zen4 GEMM

This commit is contained in:
Iwan Kawrakow
2025-08-08 09:23:02 +03:00
parent 58c3bffff4
commit 294341a3d2
5 changed files with 90 additions and 25 deletions

View File

@@ -28,6 +28,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 4.33G, +0.0683 ppl @ LLaMA-v1-7B", },
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", },
{ "Q6_0", LLAMA_FTYPE_MOSTLY_Q6_0, " 6.5 bpw quantization", },
{ "MXFP4", LLAMA_FTYPE_MOSTLY_MXFP4, " 4.25 bpw 4-bit float quantization",},
{ "IQ2_XXS", LLAMA_FTYPE_MOSTLY_IQ2_XXS, " 2.06 bpw quantization", },
{ "IQ2_XXS_R4",LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4,"IQ2_XXS repacked", },
{ "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", },

View File

@@ -128,28 +128,28 @@ struct ScaleHelperQ_0_1 {
const __m128 min = _mm_set1_ps(float(-min_value));
};
//template <int min_value>
//struct ScaleHelperQ_0_2 {
// ggml_bf16_t scales8[4];
// template <typename Q>
// inline __m256 prepare4(const Q * y) {
// for (int j = 0; j < 4; ++j) scales8[j] = y[j].d;
// auto s4 = _mm_castsi128_ps(_mm_slli_epi16(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)scales8)), 16));
// return _mm256_set_m128(_mm_mul_ps(s4, min), s4);
// }
// template <typename Q>
// inline __m256 prepare4(__m256 other_scales, const Q * y) {
// return _mm_mul256_ps(other_scales, prepare4<Q>(y));
// }
// template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {
// float d = GGML_BF16_TO_FP32(y->d);
// return std::make_pair(d, -d*float(min_value));
// }
// std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {
// return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));
// }
// const __m128 min = _mm_set1_ps(float(-min_value));
//};
template <int min_value>
struct ScaleHelperQ_0_1_MXFP4 {
float scales[4];
template <typename Q>
inline __m256 prepare4(const Q * y) {
for (int j = 0; j < 4; ++j) scales[j] = GGML_E8M0_TO_FP32_HALF(y[j].e);
auto s4 = _mm_loadu_ps(scales);
return _mm256_set_m128(_mm_mul_ps(s4, min), s4);
}
template <typename Q>
inline __m256 prepare4(__m256 other_scales, const Q * y) {
return _mm_mul256_ps(other_scales, prepare4<Q>(y));
}
template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {
float d = GGML_E8M0_TO_FP32_HALF(y->e);
return std::make_pair(d, -d*float(min_value));
}
std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {
return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));
}
const __m128 min = _mm_set1_ps(float(-min_value));
};
struct ScaleHelperQ8_1 {
template <typename Q>
@@ -553,6 +553,53 @@ struct IQ4_NL0_Dequantizer {
}
};
//=============================
static inline __m128i load_unsigned_mxfp4_values_128() {
static const uint8_t kvalues_mxfp4_unsigned[16] = {12, 13, 14, 15, 16, 18, 20, 24, 12, 11, 10, 9, 8, 6, 4, 0};
return _mm_loadu_si128((const __m128i *)kvalues_mxfp4_unsigned);
}
static inline __m256i load_unsigned_mxfp4_values_256() {
auto val128 = load_unsigned_mxfp4_values_128();
return MM256_SET_M128I(val128, val128);
}
#ifdef HAVE_FANCY_SIMD
static inline __m512i load_unsigned_mxfp4_values_512() {
auto val256 = load_unsigned_mxfp4_values_256();
return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
}
#endif
static inline __m128i load_mxfp4_values_128() {
return _mm_loadu_si128((const __m128i *)kvalues_mxfp4);
}
static inline __m256i load_mxfp4_values_256() {
auto val128 = load_mxfp4_values_128();
return MM256_SET_M128I(val128, val128);
}
struct IQ4_MXFP4_Dequantizer {
Dequantizer4bit b4;
#ifdef HAVE_FANCY_SIMD
const __m256i values = load_unsigned_mxfp4_values_256();
#else
const __m256i values = load_mxfp4_values_256();
#endif
inline __m256i dequant(const block_mxfp4 * x) const {
return _mm256_shuffle_epi8(values, b4.dequant(x->qs));
}
};
struct IQ4_MXFP40_Dequantizer {
Dequantizer4bit b4;
const __m256i values = load_mxfp4_values_256();
inline __m256i dequant(const block_mxfp4 * x) const {
return _mm256_shuffle_epi8(values, b4.dequant(x->qs));
}
};
struct Q4_1_Dequantizer {
Dequantizer4bit b4;
inline __m256i dequant(const block_q4_1 * x) const {
@@ -671,9 +718,14 @@ struct IQ4_NL_Unpacker final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0_1<
using Sum4T = Sum4TypeQ82;
inline static int block_size() { return QK4_NL; }
};
struct IQ4_MXFP4_Unpacker final : public Q_Unpacker<block_mxfp4, ScaleHelperQ_0_1_MXFP4<12>, IQ4_MXFP4_Dequantizer> {
IQ4_MXFP4_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ82;
inline static int block_size() { return QK4_NL; }
};
#else
struct IQ4_NL_Unpacker final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0, IQ4_NL_Dequantizer> {
IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
struct IQ4_MXFP4_Unpacker final : public Q_Unpacker<block_mxfp4, ScaleHelperQ_0, IQ4_MXFP4_Dequantizer> {
IQ4_MXFP4_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ80;
inline static int block_size() { return QK4_NL; }
};
@@ -1811,7 +1863,7 @@ template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX
else if constexpr (std::is_same_v<Dequantizer, Q4_1_Unpacker> || std::is_same_v<Dequantizer, Q5_1_Unpacker>) {
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_1_q8_2_T, Dequantizer, funcs)
}
else if constexpr (std::is_same_v<Dequantizer, IQ4_NL_Unpacker>) {
else if constexpr (std::is_same_v<Dequantizer, IQ4_NL_Unpacker> || std::is_same_v<Dequantizer, IQ4_MXFP4_Unpacker>) {
#ifdef HAVE_FANCY_SIMD
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_1_q8_2_T, Dequantizer, funcs)
#else
@@ -1876,6 +1928,12 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu
set_functions<IQ4_NL_Unpacker>(kernels);
#ifndef HAVE_FANCY_SIMD
expected_typeB = GGML_TYPE_Q8_0_X4;
#endif
break;
case GGML_TYPE_MXFP4:
set_functions<IQ4_MXFP4_Unpacker>(kernels);
#ifndef HAVE_FANCY_SIMD
expected_typeB = GGML_TYPE_Q8_0_X4;
#endif
break;
case GGML_TYPE_Q4_0_R8:

View File

@@ -871,6 +871,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_Q8_0_R8:
case GGML_TYPE_IQ4_NL_R4:
case GGML_TYPE_MXFP4:
return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs, mm.func16);
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
@@ -960,6 +961,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_Q8_0_R8:
case GGML_TYPE_Q8_1:
case GGML_TYPE_IQ4_NL_R4:
case GGML_TYPE_MXFP4:
return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:

View File

@@ -186,6 +186,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
LLAMA_FTYPE_MOSTLY_MXFP4 = 38, // except 1d tensors, 38 to be compatible with mainline
//
LLAMA_FTYPE_MOSTLY_Q6_0 = 135, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ1_BN = 136, // except 1d tensors

View File

@@ -4538,6 +4538,7 @@ struct llama_model_loader {
case GGML_TYPE_Q5_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q5_0_R4; break;
case GGML_TYPE_Q6_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_0_R4; break;
case GGML_TYPE_Q8_0_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_0_R8; break;
case GGML_TYPE_MXFP4: ftype = LLAMA_FTYPE_MOSTLY_MXFP4; break;
case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break;
case GGML_TYPE_IQ4_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS; break;
case GGML_TYPE_IQ4_KS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS_R4; break;
@@ -5294,6 +5295,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q5_0_R4: return "Q5_0_R4 - 5.5 bpw";
case LLAMA_FTYPE_MOSTLY_Q6_0_R4: return "Q6_0_R4 - 6.5 bpw";
case LLAMA_FTYPE_MOSTLY_Q8_0_R8: return "Q8_0_R8 - 8.5 bpw";
case LLAMA_FTYPE_MOSTLY_MXFP4: return "MXFP4 - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_KS: return "IQ4_KS - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_KS_R4:return "IQ4_KS_R4 - 4.25 bpw";
@@ -20541,6 +20543,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_Q5_0_R4: default_type = GGML_TYPE_Q5_0_R4; break;
case LLAMA_FTYPE_MOSTLY_Q6_0_R4: default_type = GGML_TYPE_Q6_0_R4; break;
case LLAMA_FTYPE_MOSTLY_Q8_0_R8: default_type = GGML_TYPE_Q8_0_R8; break;
case LLAMA_FTYPE_MOSTLY_MXFP4: default_type = GGML_TYPE_MXFP4; break;
case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
case LLAMA_FTYPE_MOSTLY_IQ4_KS: default_type = GGML_TYPE_IQ4_KS; break;
case LLAMA_FTYPE_MOSTLY_IQ4_KS_R4:default_type = GGML_TYPE_IQ4_KS_R4;break;