diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 67d9828c..9372c05c 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3498,6 +3498,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ2_KL: case GGML_TYPE_IQ3_KS: diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 8c03ae1b..689613f5 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -736,6 +736,27 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst } } +template +static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + constexpr uint32_t uval[2] = { 0x00200000, 0x00400000 }; + const int64_t i = blockIdx.x; + const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK4_NL); + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 4*il; + const uint8_t * q4 = x[ib].qs + 4*il; + union { float f; uint32_t u; } helper; + helper.u = x[ib].e >= 2 ? uint32_t(x[ib].e - 1) << 23u : uval[x[ib].e]; + const float d = helper.f; + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]; + y[j+16] = d * kvalues_mxfp4[q4[j] >> 4]; + } +} + template static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { const int64_t i = blockIdx.x; @@ -1611,6 +1632,13 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_iq4_nl<<>>(vx, y); } +template +static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_mxfp4<<>>(vx, y); +} + template static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; @@ -1943,6 +1971,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq2_bn_cuda; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_cuda; + case GGML_TYPE_MXFP4: + return dequantize_row_mxfp4_cuda; case GGML_TYPE_IQ4_XS: return dequantize_row_iq4_xs_cuda; case GGML_TYPE_IQ4_KS: @@ -2044,6 +2074,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq2_bn_cuda; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_cuda; + case GGML_TYPE_MXFP4: + return dequantize_row_mxfp4_cuda; case GGML_TYPE_IQ4_XS: return dequantize_row_iq4_xs_cuda; case GGML_TYPE_IQ4_KS: