diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index e38e9568..90967602 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2853,6 +2853,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: + case GGML_TYPE_Q4_0_R4: return true; default: return false; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 2eba527f..4880e7f0 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -557,6 +557,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI3_S; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK4_0; + static constexpr int qr = QR4_0; + static constexpr int qi = QI4_0; +}; + ////////////////////// struct ggml_cuda_device_info { diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index b9baee1b..e9af7cc8 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -74,6 +74,35 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h #endif // __CUDA_ARCH__ >= CC_PASCAL } +template +static __global__ void dequantize_block_q4_0_r4(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row) { + + const int64_t ii = blockIdx.x; + int row4 = (256*ii)/(4*n_per_row); + const int64_t i = ii - row4*n_per_row/64; + + // assume 32 threads + const int tid = threadIdx.x; + int is = tid/16; // 0 or 1: 1st or 2nd block of 128 + int j = tid%16; // 0...15: index inside the block of 128 + int l = j/4; // 0....3: index inside a q4_0 block + int k = j%4; // 0....3: row index in the group of 4 rows + int ll = 16*(l%2) + 4*(l/2); + + dst_t * y = yy + (4*row4 + k)*n_per_row + 32*(2*i+is) + ll; + + const block_iq4_nl_x4 * x = (const block_iq4_nl_x4 *)vx + 2*ii + is; + const float d = __half2float(x->d[k]); + const float dm = -8*d; + + const uint8_t * q = x->qs + 16*l + 4*k; + + for (int n = 0; n < 4; ++n) { + y[n+0] = d * (q[n] & 0xF) + dm; + y[n+8] = d * (q[n] >> 4) + dm; + } +} + template static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { @@ -818,6 +847,13 @@ static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t n dequantize_block_q4_0<<>>(vx, y, nb32); } +template +static void dequantize_row_q4_0_r4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int nb = (k + 255) / 256; + dequantize_block_q4_0_r4<<>>(vx, y, n_per_row); +} + template static void dequantize_row_q6_0_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; @@ -1073,6 +1109,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; + case GGML_TYPE_Q4_0_R4: + return dequantize_row_q4_0_r4_cuda; case GGML_TYPE_Q4_1: return dequantize_row_q4_1_cuda; case GGML_TYPE_Q5_0: @@ -1147,6 +1185,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; + case GGML_TYPE_Q4_0_R4: + return dequantize_row_q4_0_r4_cuda; case GGML_TYPE_Q4_1: return dequantize_row_q4_1_cuda; case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 36dbb52a..17fe47d7 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -168,6 +168,11 @@ void iqk_mul_mat_vec_q_cuda( } } +__device__ __forceinline__ float vec_dot_q4_0_r4_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + return 0; +} + __device__ __forceinline__ void get_int_from_table_16_shift(const uint32_t & q4, uint16_t shift, const uint8_t * all_values, int & val1, int & val2) { @@ -728,6 +733,13 @@ __device__ __forceinline__ float vec_dot_iq2_bn_q8_1( } // namespace +void mul_mat_vec_q4_0_r4_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + void mul_mat_vec_iq2_k_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index 1693a73a..901bd314 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -39,3 +39,7 @@ void mul_mat_vec_iq1_bn_q8_1_cuda( void mul_mat_vec_iq2_bn_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + +void mul_mat_vec_q4_0_r4_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index cdf13533..d53d7511 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -455,6 +455,9 @@ void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_IQ3_S: mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_Q4_0_R4: + mul_mat_vec_q4_0_r4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; default: GGML_ABORT("fatal error"); break;