maxfp4: CUDA dequantize

This commit is contained in:
Iwan Kawrakow
2025-08-08 19:20:39 +03:00
parent fd8384e3aa
commit 3466dbda40
2 changed files with 33 additions and 0 deletions

View File

@@ -3498,6 +3498,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_KL:
case GGML_TYPE_IQ3_KS:

View File

@@ -736,6 +736,27 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
}
}
template<typename dst_t>
static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {
constexpr uint32_t uval[2] = { 0x00200000, 0x00400000 };
const int64_t i = blockIdx.x;
const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK4_NL);
const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
const uint8_t * q4 = x[ib].qs + 4*il;
union { float f; uint32_t u; } helper;
helper.u = x[ib].e >= 2 ? uint32_t(x[ib].e - 1) << 23u : uval[x[ib].e];
const float d = helper.f;
for (int j = 0; j < 4; ++j) {
y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf];
y[j+16] = d * kvalues_mxfp4[q4[j] >> 4];
}
}
template<typename dst_t>
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
@@ -1611,6 +1632,13 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
@@ -1943,6 +1971,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq2_bn_cuda;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_cuda;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_cuda;
case GGML_TYPE_IQ4_XS:
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ4_KS:
@@ -2044,6 +2074,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq2_bn_cuda;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_cuda;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_cuda;
case GGML_TYPE_IQ4_XS:
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ4_KS: