iq1_tn: adding CUDA dot product

This commit is contained in:
Iwan Kawrakow
2024-09-09 20:00:14 +03:00
parent 5b40848742
commit f808466309
2 changed files with 78 additions and 12 deletions

View File

@@ -466,6 +466,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ1_BN> {
static constexpr int qi = QI1_BN;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ1_TN> {
static constexpr int qk = QK_IQ1BN;
static constexpr int qr = QR1_BN;
static constexpr int qi = QI1_BN;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_BN> {
static constexpr int qk = QK_IQ1BN;

View File

@@ -8,6 +8,11 @@
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
// Reminder:
// constexpr int qk = ggml_cuda_type_traits<type>::qk;
// constexpr int qi = ggml_cuda_type_traits<type>::qi;
// constexpr int vdr = get_vdr_mmvq(type);
namespace {
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -16,7 +21,7 @@ __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__global__ void iqk_mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
@@ -50,7 +55,8 @@ __global__ void iqk_mul_mat_vec_q(
for (int j = 0; j < ncols_y; ++j) {
#pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) {
tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
tmp[j][i] += vec_dot_q_cuda((const void *)((const char *)vx + (row0 + i)*row_size),
&y[j*blocks_per_col_y + kby], kbx, kqs);
}
}
}
@@ -129,30 +135,32 @@ void iqk_mul_mat_vec_q_cuda(
const dim3 block_nums(nblocks, 1, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const int64_t row_size = ggml_row_size(type, ncols_x);
switch (ncols_y) {
case 1:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 2:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 3:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 4:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 5:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 6:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 7:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 8:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
default:
GGML_ASSERT(false);
@@ -540,6 +548,58 @@ static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1(
}
static __device__ __forceinline__ float vec_dot_iq1_tn_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
float scale = *(const half *)vbq;
const block_iq1_bn * bq1 = (const block_iq1_bn *)((const char *)vbq + sizeof(half)) + kbx;
static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
// iqs is 0 or 1
int sumi = 0;
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const int * q8 = (const int *)bq8_1[iqs].qs;
int val[4];
for (int l = 0; l < 2; ++l) {
int8_t * a = (int8_t *)val;
const int i16 = 2*iqs + l;
for (int k = 0; k < 3; ++k) {
uint8_t q = bq1->ql[3*i16+k];
for (int j = 0; j < 5; ++j) {
uint8_t v = k_mult[j]*q;
int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
*a++ = vs-1;
}
}
uint8_t v = k_mult[i16]*bq1->extra;
int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
*a++ = vs-1;
sumi = __dp4a(val[0], q8[4*l+0], __dp4a(val[1], q8[4*l+1], __dp4a(val[2], q8[4*l+2], __dp4a(val[3], q8[4*l+3], sumi))));
}
#else
const int8_t * q8 = bq8_1[iqs].qs;
for (int l = 0; l < 2; ++l) {
const int i16 = 2*iqs + l;
for (int k = 0; k < 3; ++k) {
uint8_t q = bq1->ql[3*i16+k];
for (int j = 0; j < 5; ++j) {
uint8_t v = k_mult[j]*q;
int8_t vs = (v + (v >> 1)) >> 7;
sumi += q8[j]*(vs - 1);
}
q8 += 5;
}
uint8_t v = k_mult[i16]*bq1->extra;
int8_t vs = (v + (v >> 1)) >> 7;
sumi += q8[0]*(vs - 1);
q8++;
}
#endif
return __low2float(bq8_1[iqs].ds) * scale * sumi;
}
} // namespace
void mul_mat_vec_iq2_k_q8_1_cuda(
@@ -588,7 +648,6 @@ void mul_mat_vec_iq2_tn_q8_1_cuda(
void mul_mat_vec_iq1_tn_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
//printf("%s\n", __func__);
//iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_TN, VDR_IQ2_TN_Q8_1_MMVQ, vec_dot_iq2_tn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_TN, 1, vec_dot_iq1_tn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}