New iq2_kt: CUDA GEMV

This commit is contained in:
Iwan Kawrakow
2025-06-08 17:51:28 +03:00
parent e095b0fa80
commit 1efb3adc9b
3 changed files with 51 additions and 0 deletions

View File

@@ -471,6 +471,41 @@ __device__ __forceinline__ void vec_dot_iq4_kt_q8_1(
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
}
__device__ __forceinline__ void vec_dot_iq2_kt_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
constexpr uint32_t km = 0x3f3f3f3f;
float scale = *(const float *)vbq;
const block_iq2_kt * bq2 = (const block_iq2_kt *)((const char *)vbq + sizeof(float)) + kbx;
// iqs is 0...28
const int ib32 = iqs/4;
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
const int ls = iq4k_values[(bq2->scales[ib32%4] >> 4*(ib32/4)) & 0xf];
const float dl = scale * ls * 1.05f;
auto ql = (const uint16_t *)bq2->ql;
int sumi = 0;
for (int j = 0; j < 4; ++j) {
uint32_t val = ql[4*ib32+j] + 4096;
int v4 = 0;
for (int k = 0; k < 4; ++k) {
val = ka*val + kb;
v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
}
sumi = ggml_cuda_dp4a(v4, q8[2*j+0], sumi);
v4 = 0;
for (int k = 0; k < 4; ++k) {
val = ka*val + kb;
v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
}
sumi = ggml_cuda_dp4a(v4, q8[2*j+1], sumi);
}
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
}
#define VDR_IQ4_KSS_Q8_1_MMVQ 4
#define VDR_IQ4_KSS_Q8_1_MMQ 4
@@ -1263,6 +1298,14 @@ void mul_mat_vec_iq4_kt_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
void mul_mat_vec_iq2_kt_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq2_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
void mul_mat_vec_iq4_kss_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,

View File

@@ -105,3 +105,8 @@ void mul_mat_vec_iq4_kt_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
void mul_mat_vec_iq2_kt_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);

View File

@@ -529,6 +529,8 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm
case GGML_TYPE_IQ4_KT:
mul_mat_vec_iq4_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ2_KT:
mul_mat_vec_iq2_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ2_KS:
mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
@@ -691,6 +693,7 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) {
case GGML_TYPE_IQ1_S_R4:
case GGML_TYPE_IQ1_M_R4:
case GGML_TYPE_IQ4_KT:
case GGML_TYPE_IQ2_KT:
return true;
default:
return false;