mxfp4: CUDA GEMV

This commit is contained in:
Iwan Kawrakow
2025-08-08 19:57:20 +03:00
parent 3466dbda40
commit c0449207cf
3 changed files with 58 additions and 0 deletions

View File

@@ -550,6 +550,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
static constexpr int qi = QI4_NL;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> {
static constexpr int qk = QK4_NL;
static constexpr int qr = QR4_NL;
static constexpr int qi = QI4_NL;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
static constexpr int qk = QK_K;

View File

@@ -31,6 +31,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
case GGML_TYPE_IQ1_S : return vec_dot_iq1_s_q8_1;
case GGML_TYPE_IQ1_M : return vec_dot_iq1_m_q8_1;
case GGML_TYPE_IQ4_NL : return vec_dot_iq4_nl_q8_1;
case GGML_TYPE_MXFP4 : return vec_dot_mxfp4_q8_1;
case GGML_TYPE_IQ4_XS : return vec_dot_iq4_xs_q8_1;
case GGML_TYPE_IQ3_S : return vec_dot_iq3_s_q8_1;
default : return nullptr;
@@ -56,6 +57,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
case GGML_TYPE_IQ3_XXS : return VDR_IQ3_XXS_Q8_1_MMVQ;
case GGML_TYPE_IQ3_S : return VDR_IQ3_S_Q8_1_MMVQ;
case GGML_TYPE_IQ4_NL : return VDR_IQ4_NL_Q8_1_MMVQ;
case GGML_TYPE_MXFP4 : return VDR_MXFP4_Q8_1_MMVQ;
case GGML_TYPE_IQ4_XS : return VDR_IQ4_XS_Q8_1_MMVQ;
default : return 1;
}
@@ -417,6 +419,14 @@ static void mul_mat_vec_iq4_nl_q8_1_cuda(
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_mxfp4_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_MXFP4>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_iq4_xs_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
@@ -509,6 +519,9 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm
case GGML_TYPE_IQ4_NL:
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_MXFP4:
mul_mat_vec_mxfp4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ4_XS:
mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
@@ -686,6 +699,7 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) {
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_KL:

View File

@@ -17,6 +17,15 @@ static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32
return x32;
}
static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {
const uint8_t * x8 = (const uint8_t *)x;
int x32 = x8[4*i32 + 0] | (x8[4*i32 + 1] << 8);
x32 |= (x8[4*i32 + 2] | (x8[4*i32 + 3] << 8)) << 16;
return x32;
}
static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32) {
return ((const int *) x)[i32]; // assume at least 4 byte alignment
}
@@ -1167,6 +1176,34 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
return d * sumi;
}
#define VDR_MXFP4_Q8_1_MMVQ 2
#define VDR_MXFP4_Q8_1_MMQ 4
static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
const int * q8 = (const int *) bq8_1->qs + iqs;
constexpr uint32_t uval[2] = { 0x00200000, 0x00400000 };
int sumi = 0;
#pragma unroll
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
}
union { float f; uint32_t u; } helper;
helper.u = bq4->e >= 2 ? uint32_t(bq4->e - 1) << 23u : uval[bq4->e];
return helper.f * __low2float(bq8_1->ds) * sumi;
}
#define VDR_IQ4_XS_Q8_1_MMVQ 4
#define VDR_IQ4_XS_Q8_1_MMQ 4