mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-04 19:10:03 +00:00
CUDA: slightly faster iq4_k_r4 GEMV
We are now within 3% of iq4_k
This commit is contained in:
@@ -21,6 +21,9 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ4_K_R4> {
|
||||
// constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
||||
// constexpr int vdr = get_vdr_mmvq(type);
|
||||
|
||||
// QI4_XS = 256/(4*2) = 32
|
||||
// vdr = 4, qi = 32 -> qi/vdr = 8, kqs = 4*(tid%8), blocks_per_iter = 4*1*32/32 = 4
|
||||
// vdr = 2, qi = 32 -> qi/vdr =16, kqs = 2*(tid%16), blocks_per_iter = 2*1*32/32 = 2
|
||||
namespace {
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y, int n_interleaved = 1>
|
||||
__device__ void iqk_mul_mat_vec_q(
|
||||
@@ -254,39 +257,34 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
|
||||
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
|
||||
}
|
||||
|
||||
// TODO
|
||||
__device__ __forceinline__ void vec_dot_iq4_k_r4_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
|
||||
|
||||
const block_iq4_k_r4 * bq4 = (const block_iq4_k_r4 *)vbq + kbx;
|
||||
|
||||
// iqs is 0...28
|
||||
const int ib32 = iqs/4;
|
||||
const float d8 = __low2float(bq8_1[ib32].ds);
|
||||
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
|
||||
// iqs is 0...28 in steps of 2
|
||||
const int ib16 = iqs/2;
|
||||
const float d8 = __low2float(bq8_1[ib16/2].ds);
|
||||
const int32_t * q8 = (const int *)bq8_1[ib16/2].qs + 4*(ib16%2);
|
||||
|
||||
int scales[2];
|
||||
int ib32 = ib16/2;
|
||||
int is = ib16%2;
|
||||
int scales;
|
||||
const uint32_t * scales_l = (const uint32_t *)bq4->scales_l;
|
||||
const uint32_t * scales_h = (const uint32_t *)bq4->scales_h;
|
||||
scales[0] = __vsub4(((scales_l[2*(ib32%4)+0] >> 4*(ib32/4)) & 0x0f0f0f0f) | (((scales_h[2*(ib32%2)+0] >> 2*(ib32/2)) & 0x03030303) << 4), 0x20202020);
|
||||
scales[1] = __vsub4(((scales_l[2*(ib32%4)+1] >> 4*(ib32/4)) & 0x0f0f0f0f) | (((scales_h[2*(ib32%2)+1] >> 2*(ib32/2)) & 0x03030303) << 4), 0x20202020);
|
||||
const int8_t * s8 = (const int8_t *)scales;
|
||||
scales = __vsub4(((scales_l[2*(ib32%4)+is] >> 4*(ib32/4)) & 0x0f0f0f0f) | (((scales_h[2*(ib32%2)+is] >> 2*(ib32/2)) & 0x03030303) << 4), 0x20202020);
|
||||
const int8_t * s8 = (const int8_t *)&scales;
|
||||
int2 val1, val2;
|
||||
const int * q4 = (const int *)bq4->qs + 16*ib32;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
auto values1 = iq4k_values + (((bq4->extra[i+0] >> ib32) & 1) << 4);
|
||||
auto values2 = iq4k_values + (((bq4->extra[i+4] >> ib32) & 1) << 4);
|
||||
int sumi1 = 0, sumi2 = 0;
|
||||
val1 = get_int_from_table_16(q4[i+ 0], values1);
|
||||
auto values1 = iq4k_values + (((bq4->extra[i+4*is] >> ib32) & 1) << 4);
|
||||
int sumi1 = 0;
|
||||
val1 = get_int_from_table_16(q4[i+4*is+0], values1);
|
||||
sumi1 = ggml_cuda_dp4a(val1.x, q8[0], ggml_cuda_dp4a(val1.y, q8[2], sumi1));
|
||||
val2 = get_int_from_table_16(q4[i+ 4], values2);
|
||||
sumi2 = ggml_cuda_dp4a(val2.x, q8[4], ggml_cuda_dp4a(val2.y, q8[6], sumi2));
|
||||
val1 = get_int_from_table_16(q4[i+ 8], values1);
|
||||
val1 = get_int_from_table_16(q4[i+4*is+8], values1);
|
||||
sumi1 = ggml_cuda_dp4a(val1.x, q8[1], ggml_cuda_dp4a(val1.y, q8[3], sumi1));
|
||||
val2 = get_int_from_table_16(q4[i+12], values2);
|
||||
sumi2 = ggml_cuda_dp4a(val2.x, q8[5], ggml_cuda_dp4a(val2.y, q8[7], sumi2));
|
||||
const float d = __half2float(bq4->d[i]) * d8;
|
||||
result[i] += d * (sumi1 * s8[i] + sumi2 * s8[i+4]);
|
||||
result[i] += d * sumi1 * s8[i];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -863,7 +861,7 @@ void mul_mat_vec_iq4_k_r4_q8_1_cuda(
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
|
||||
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K_R4, VDR_IQ4_K_Q8_1_MMVQ, vec_dot_iq4_k_r4_q8_1, 4>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K_R4, 2, vec_dot_iq4_k_r4_q8_1, 4>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
|
||||
}
|
||||
|
||||
void mul_mat_vec_iq4_ks_q8_1_cuda(
|
||||
|
||||
Reference in New Issue
Block a user