Adapting iq2_bn: CUDA works

This commit is contained in:
Iwan Kawrakow
2024-10-23 20:07:13 +03:00
parent 0d17e8c3c7
commit fa5bbe53f1
4 changed files with 51 additions and 47 deletions

View File

@@ -702,6 +702,47 @@ static __device__ __forceinline__ float vec_dot_iq1_tn_q8_1(
return __low2float(bq8_1[iqs].ds) * scale * sumi;
}
__device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
float scale = *(const float *)vbq;
const block_iq2_bn * bq2 = (const block_iq2_bn *)((const char *)vbq + sizeof(float)) + kbx;
// iqs is 0 or 1
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
auto qs = (const uint16_t *)bq2->qs + 4*iqs;
auto q8l = (const int *)bq8_1[0].qs + 2*iqs;
auto q8h = (const int *)bq8_1[1].qs + 2*iqs;
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
for (int j = 0; j < 2; ++j) {
int vl = qs[2*j+0] | (uint32_t(qs[2*j+1]) << 16);
int vh = vl >> 4;
sumi1 = __dp4a(vl & 0x03030303, q8l[j+0], sumi1);
sumi2 = __dp4a(vl & 0x0c0c0c0c, q8l[j+4], sumi2);
sumi3 = __dp4a(vh & 0x03030303, q8h[j+0], sumi3);
sumi4 = __dp4a(vh & 0x0c0c0c0c, q8h[j+4], sumi4);
}
auto d8l = __half22float2(bq8_1[0].ds);
auto d8h = __half22float2(bq8_1[1].ds);
#else
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
auto q8l = bq8_1[0].qs + 8*iqs;
auto q8h = bq8_1[1].qs + 8*iqs;
auto qs = bq2->qs + 8*iqs;
for (int j = 0; j < 8; ++j) {
sumi1 += q8l[j+ 0] * (qs[j] & 0x03);
sumi2 += q8l[j+16] * (qs[j] & 0x0c);
sumi3 += q8h[j+ 0] * (qs[j] & 0x30);
sumi4 += q8h[j+16] * (qs[j] & 0xc0);
}
auto d8l = __half22float2(bq8_1[0].ds);
auto d8h = __half22float2(bq8_1[1].ds);
return scale * (d8l.x * (sumi1 + 0.25f*sumi2) + 0.0625f * d8h.x*(sumi3 + 0.25f*sumi4) - 0.5f*d8l.y - 0.5f*d8h.y);
#endif
return scale * (d8l.x * (sumi1 + 0.25f*sumi2) + d8h.x * (sumi3 + 0.25f * sumi4) - 0.5f*d8l.y - 0.5f*d8h.y);
}
} // namespace
void mul_mat_vec_iq2_k_q8_1_cuda(
@@ -767,6 +808,13 @@ void mul_mat_vec_iq2_tn_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_TN, VDR_IQ2_TN_Q8_1_MMVQ, vec_dot_iq2_tn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
void mul_mat_vec_iq2_bn_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
//mul_mat_vec_iq2_tn_q8_1_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_BN, 1, vec_dot_iq2_bn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
void mul_mat_vec_iq1_tn_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {

View File

@@ -40,3 +40,6 @@ void mul_mat_vec_iq2_ks_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
void mul_mat_vec_iq2_bn_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);

View File

@@ -23,7 +23,6 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
type == GGML_TYPE_IQ1_BN ? vec_dot_iq1_bn_q8_1 :
type == GGML_TYPE_IQ2_BN ? vec_dot_iq2_bn_q8_1 :
type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
@@ -332,13 +331,6 @@ static void mul_mat_vec_iq1_bn_q8_1_cuda(
mul_mat_vec_q_cuda<GGML_TYPE_IQ1_BN>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq2_bn_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_BN>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq4_nl_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {

View File

@@ -1167,45 +1167,6 @@ static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
return __low2float(bq8_1[iqs].ds) * sumi;
}
static __device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
const block_iq2_bn * bq2 = (const block_iq2_bn *) vbq + kbx;
// iqs is 0 or 1
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
auto qs = (const uint16_t *)bq2->qs + 4*iqs;
auto q8l = (const int *)bq8_1[0].qs + 2*iqs;
auto q8h = (const int *)bq8_1[1].qs + 2*iqs;
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
for (int j = 0; j < 2; ++j) {
int vl = qs[2*j+0] | (uint32_t(qs[2*j+1]) << 16);
int vh = vl >> 4;
sumi1 = __dp4a(vl & 0x03030303, q8l[j+0], sumi1);
sumi2 = __dp4a(vl & 0x0c0c0c0c, q8l[j+4], sumi2);
sumi3 = __dp4a(vh & 0x03030303, q8h[j+0], sumi3);
sumi4 = __dp4a(vh & 0x0c0c0c0c, q8h[j+4], sumi4);
}
auto d8l = __half22float2(bq8_1[0].ds);
auto d8h = __half22float2(bq8_1[1].ds);
return d8l.x * (sumi1 + 0.25f*sumi2) + d8h.x * (sumi3 + 0.25f * sumi4) - 0.5f*d8l.y - 0.5f*d8h.y;
#else
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
auto q8l = bq8_1[0].qs + 8*iqs;
auto q8h = bq8_1[1].qs + 8*iqs;
auto qs = bq2->qs + 8*iqs;
for (int j = 0; j < 8; ++j) {
sumi1 += q8l[j+ 0] * (qs[j] & 0x03);
sumi2 += q8l[j+16] * (qs[j] & 0x0c);
sumi3 += q8h[j+ 0] * (qs[j] & 0x30);
sumi4 += q8h[j+16] * (qs[j] & 0xc0);
}
auto d8l = __half22float2(bq8_1[0].ds);
auto d8h = __half22float2(bq8_1[1].ds);
return d8l.x * (sumi1 + 0.25f*sumi2) + 0.0625f * d8h.x*(sumi3 + 0.25f*sumi4) - 0.5f*d8l.y - 0.5f*d8h.y;
#endif
}
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;