iq4_kss: CUDA works

TG-128 performance is very decent with 131 t/s for LLaMA-3.1-8B.
In comparison, we have 123 t/s for q4_0 and 128 t/s for iq4_ks.
I.e., the reduced model size more than offsets the additional
bit fiddling required for iq4_kss.
This commit is contained in:
Iwan Kawrakow
2024-10-15 15:07:30 +03:00
parent bb0e3f957e
commit 026adac30d

View File

@@ -249,22 +249,22 @@ __device__ __forceinline__ float vec_dot_iq4_kss_q8_1(
const block_iq4_kss * bq4 = (const block_iq4_kss *)((const char *)vbq + sizeof(float)) + kbx;
const uint8_t * all_values = (const uint8_t *)iq4k_values;
// TODO
return 0.f;
//// iqs is 0...28
//const int ib32 = iqs/4; // Why iqs/4 ?
//const int32_t * q8 = (const int *)bq8_1[ib32].qs;
//const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
//const float dl = scale * ((bq4->scales[ib32] & 254) - 127);
//int v1, v2;
//int sumi = 0;
//for (int j = 0; j < 4; ++j) {
// get_int_from_table_16_shift(q4[j], bq4->scales[ib32] & 1, all_values, v1, v2);
// sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi);
// sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi);
//}
//return dl * __low2float(bq8_1[ib32].ds) * sumi;
// iqs is 0...28
const int ib32 = iqs/4; // Why iqs/4 ?
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
const uint8_t ls = (q4[0] >> 30) | ((q4[1] >> 28) & 0x0c) | ((q4[2] >> 26) & 0x30) | ((q4[3] >> 24) & 0xc0);
const float dl = scale * ((ls & 254) - 127);
int v1, v2;
int sumi = 0;
for (int j = 0; j < 4; ++j) {
uint32_t aux32 = (q4[j] & 0x00007fff) | ((q4[j] << 1) & 0x7fff0000);
aux32 ^= (aux32 << 1);
get_int_from_table_16_shift(aux32, ls & 1, all_values, v1, v2);
sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi);
sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi);
}
return dl * __low2float(bq8_1[ib32].ds) * sumi;
}
#define VDR_IQ5_K_Q8_1_MMVQ 4