mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-28 09:04:10 +00:00
CUDA: faster iq2_k_r4 GEMV
This commit is contained in:
@@ -527,45 +527,6 @@ __device__ __forceinline__ void vec_dot_iq3_k_r4_q8_1(
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void vec_dot_iq2_k_r4_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
|
||||
|
||||
const block_iq2_k_r4 * bq2 = (const block_iq2_k_r4 *)vbq + kbx;
|
||||
|
||||
// iqs is 0...30 in steps of 2
|
||||
const int ib16 = iqs/2;
|
||||
const float d8 = __low2float(bq8_1[ib16/2].ds);
|
||||
const int32_t * q8 = (const int *)bq8_1[ib16/2].qs + 4*(ib16%2);
|
||||
|
||||
int ib32 = ib16/2;
|
||||
int is = ib16%2;
|
||||
const int * scales_l = (const int *)bq2->scales;
|
||||
|
||||
int scales = __vsub4(((scales_l[2*(ib32%4)+is] >> 4*(ib32/4)) & 0x0f0f0f0f), 0x08080808);
|
||||
const int8_t * s8 = (const int8_t *)&scales;
|
||||
int2 val1;
|
||||
const int * q2 = (const int *)bq2->qs + 8*ib32 + 4*is;
|
||||
int aux32[2];
|
||||
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
auto values1 = iq2nl_values + (((bq2->extra[i+4*is] >> ib32) & 1) << 2);
|
||||
int sumi1 = 0;
|
||||
aux32[0] = ((q2[i] >> 0) & 0x03030303);
|
||||
aux32[1] = ((q2[i] >> 2) & 0x03030303);
|
||||
// TODO: int_from_table_4
|
||||
val1.x = int_from_table(aux8+0, (const uint8_t *)values1);
|
||||
val1.y = int_from_table(aux8+4, (const uint8_t *)values1);
|
||||
sumi1 = ggml_cuda_dp4a(val1.x, q8[0], ggml_cuda_dp4a(val1.y, q8[1], sumi1));
|
||||
aux32[0] = ((q2[i] >> 4) & 0x03030303);
|
||||
aux32[1] = ((q2[i] >> 6) & 0x03030303);
|
||||
val1.x = int_from_table(aux8+0, (const uint8_t *)values1);
|
||||
val1.y = int_from_table(aux8+4, (const uint8_t *)values1);
|
||||
sumi1 = ggml_cuda_dp4a(val1.x, q8[2], ggml_cuda_dp4a(val1.y, q8[3], sumi1));
|
||||
const float d = __half2float(bq2->d[i]) * d8;
|
||||
result[i] += d * sumi1 * s8[i];
|
||||
}
|
||||
}
|
||||
|
||||
#define VDR_IQ6_K_Q8_1_MMVQ 4
|
||||
#define VDR_IQ6_K_Q8_1_MMQ 4
|
||||
|
||||
@@ -793,6 +754,47 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
|
||||
+ __low2float(bq8_1[4*(i4/4)+3].ds) * sumi4);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void vec_dot_iq2_k_r4_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
|
||||
|
||||
const block_iq2_k_r4 * bq2 = (const block_iq2_k_r4 *)vbq + kbx;
|
||||
|
||||
// iqs is 0...30 in steps of 2
|
||||
const int ib16 = iqs/2;
|
||||
const float d8 = __low2float(bq8_1[ib16/2].ds);
|
||||
const int32_t * q8 = (const int *)bq8_1[ib16/2].qs + 4*(ib16%2);
|
||||
|
||||
int ib32 = ib16/2;
|
||||
int is = ib16%2;
|
||||
const int * scales_l = (const int *)bq2->scales;
|
||||
|
||||
const int * all_values = (const int *)iq2k_table;
|
||||
|
||||
int scales = __vsub4(((scales_l[2*(ib32%4)+is] >> 4*(ib32/4)) & 0x0f0f0f0f), 0x08080808);
|
||||
const int8_t * s8 = (const int8_t *)&scales;
|
||||
int2 val1;
|
||||
const int * q2 = (const int *)bq2->qs + 8*ib32 + 4*is;
|
||||
int aux32[2];
|
||||
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
auto values1 = all_values + (((bq2->extra[i+4*is] >> ib32) & 1) << 8);
|
||||
int sumi1 = 0;
|
||||
aux32[0] = ((q2[i] >> 0) & 0x03030303);
|
||||
aux32[1] = ((q2[i] >> 2) & 0x03030303);
|
||||
val1.x = int_from_table_4(aux8+0, values1);
|
||||
val1.y = int_from_table_4(aux8+4, values1);
|
||||
sumi1 = ggml_cuda_dp4a(val1.x, q8[0], ggml_cuda_dp4a(val1.y, q8[1], sumi1));
|
||||
aux32[0] = ((q2[i] >> 4) & 0x03030303);
|
||||
aux32[1] = ((q2[i] >> 6) & 0x03030303);
|
||||
val1.x = int_from_table_4(aux8+0, values1);
|
||||
val1.y = int_from_table_4(aux8+4, values1);
|
||||
sumi1 = ggml_cuda_dp4a(val1.x, q8[2], ggml_cuda_dp4a(val1.y, q8[3], sumi1));
|
||||
const float d = __half2float(bq2->d[i]) * d8;
|
||||
result[i] += d * sumi1 * s8[i];
|
||||
}
|
||||
}
|
||||
|
||||
#define VDR_IQ3_K_Q8_1_MMVQ 4
|
||||
#define VDR_IQ3_K_Q8_1_MMQ 4
|
||||
|
||||
|
||||
Reference in New Issue
Block a user