cuda q4_0_r4: dequantize works

This commit is contained in:
Iwan Kawrakow
2024-12-06 15:25:45 +02:00
parent 3e6851621c
commit 85a0730447
6 changed files with 67 additions and 0 deletions

View File

@@ -2853,6 +2853,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_Q4_0_R4:
return true;
default:
return false;

View File

@@ -557,6 +557,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
static constexpr int qi = QI3_S;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q4_0_R4> {
static constexpr int qk = QK4_0;
static constexpr int qr = QR4_0;
static constexpr int qi = QI4_0;
};
//////////////////////
struct ggml_cuda_device_info {

View File

@@ -74,6 +74,35 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
#endif // __CUDA_ARCH__ >= CC_PASCAL
}
template<typename dst_t>
static __global__ void dequantize_block_q4_0_r4(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row) {
const int64_t ii = blockIdx.x;
int row4 = (256*ii)/(4*n_per_row);
const int64_t i = ii - row4*n_per_row/64;
// assume 32 threads
const int tid = threadIdx.x;
int is = tid/16; // 0 or 1: 1st or 2nd block of 128
int j = tid%16; // 0...15: index inside the block of 128
int l = j/4; // 0....3: index inside a q4_0 block
int k = j%4; // 0....3: row index in the group of 4 rows
int ll = 16*(l%2) + 4*(l/2);
dst_t * y = yy + (4*row4 + k)*n_per_row + 32*(2*i+is) + ll;
const block_iq4_nl_x4 * x = (const block_iq4_nl_x4 *)vx + 2*ii + is;
const float d = __half2float(x->d[k]);
const float dm = -8*d;
const uint8_t * q = x->qs + 16*l + 4*k;
for (int n = 0; n < 4; ++n) {
y[n+0] = d * (q[n] & 0xF) + dm;
y[n+8] = d * (q[n] >> 4) + dm;
}
}
template<typename dst_t>
static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
@@ -818,6 +847,13 @@ static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t n
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
}
template<typename dst_t>
static void dequantize_row_q4_0_r4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = (k + 255) / 256;
dequantize_block_q4_0_r4<<<nb, 32, 0, stream>>>(vx, y, n_per_row);
}
template<typename dst_t>
static void dequantize_row_q6_0_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
@@ -1073,6 +1109,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_0_R4:
return dequantize_row_q4_0_r4_cuda;
case GGML_TYPE_Q4_1:
return dequantize_row_q4_1_cuda;
case GGML_TYPE_Q5_0:
@@ -1147,6 +1185,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_0_R4:
return dequantize_row_q4_0_r4_cuda;
case GGML_TYPE_Q4_1:
return dequantize_row_q4_1_cuda;
case GGML_TYPE_Q5_0:

View File

@@ -168,6 +168,11 @@ void iqk_mul_mat_vec_q_cuda(
}
}
__device__ __forceinline__ float vec_dot_q4_0_r4_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
return 0;
}
__device__ __forceinline__ void get_int_from_table_16_shift(const uint32_t & q4, uint16_t shift, const uint8_t * all_values,
int & val1, int & val2) {
@@ -728,6 +733,13 @@ __device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
} // namespace
void mul_mat_vec_q4_0_r4_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
iqk_mul_mat_vec_q_cuda<GGML_TYPE_Q4_0_R4, VDR_IQ4_K_Q8_1_MMVQ, vec_dot_q4_0_r4_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
void mul_mat_vec_iq2_k_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {

View File

@@ -39,3 +39,7 @@ void mul_mat_vec_iq1_bn_q8_1_cuda(
void mul_mat_vec_iq2_bn_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
void mul_mat_vec_q4_0_r4_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);

View File

@@ -455,6 +455,9 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_IQ3_S:
mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q4_0_R4:
mul_mat_vec_q4_0_r4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
default:
GGML_ABORT("fatal error");
break;