iq3_ks: CUDA works

This commit is contained in:
Iwan Kawrakow
2024-10-09 18:00:45 +03:00
parent 893ca1731c
commit 252c6b2d82

View File

@@ -541,68 +541,64 @@ __device__ __forceinline__ float vec_dot_iq3_k_q8_1(
__device__ __forceinline__ float vec_dot_iq3_ks_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iiqs) {
return 0.f;
// const block_iq3_k * bq3 = (const block_iq3_k *) vbq + kbx;
//
// int iqs = iiqs/4;
// const int ib128 = iqs/4; // 0 or 1. 0 works on quants 0...127, 1 on quants 128...255
// // Each thread processes 8 quants in each of the 4 32-blocks
// const int il8 = iqs%4; // 0...3. 0 works on quants 0...7, 1 on quants 8...15, 2 on 16...23, 3 on 24...31
// const int shift = 4*(il8/2);
//
// const uint16_t * ql = (const uint16_t *)bq3->qs + 16*ib128 + 4*il8;
// const uint16_t * qh = (const uint16_t *)bq3->qh + 4*il8;
//
// uint32_t aux32;
// const uint8_t * aux8 = (const uint8_t *)&aux32;
//
// const int hshift = 4*(1-ib128);
// const uint16_t sh = bq3->scales_h >> (8*ib128 + il8/2);
//
// const uint8_t extra = bq3->extra >> (8*ib128 + il8/2);
// const uint16_t * values1 = iq3k_table + ((extra << 6) & 0x40);
// const uint16_t * values2 = iq3k_table + ((extra << 5) & 0x40);
// const uint16_t * values3 = iq3k_table + ((extra << 4) & 0x40);
// const uint16_t * values4 = iq3k_table + ((extra << 3) & 0x40);
//
// const int * q8;
// int sumi[4] = {0, 0, 0, 0};
// int v;
// for (int i = 0; i < 2; ++i) {
// uint32_t vl = ql[2*i+0] | (ql[2*i+1] << 16);
// uint32_t vh = ((qh[2*i+0] | (qh[2*i+1] << 16)) << hshift) >> 2;
//
// q8 = (const int *)bq8_1[4*ib128+0].qs + 2*il8;
// aux32 = (vl & 0x03030303) | (vh & 0x04040404);
// v = int_from_table_2(aux8, values1);
// sumi[0] = ggml_cuda_dp4a(v, q8[i], sumi[0]);
// vl >>= 2; vh >>= 1;
//
// q8 += sizeof(block_q8_1)/4;
// aux32 = (vl & 0x03030303) | (vh & 0x04040404);
// v = int_from_table_2(aux8, values2);
// sumi[1] = ggml_cuda_dp4a(v, q8[i], sumi[1]);
// vl >>= 2; vh >>= 1;
//
// q8 += sizeof(block_q8_1)/4;
// aux32 = (vl & 0x03030303) | (vh & 0x04040404);
// v = int_from_table_2(aux8, values3);
// sumi[2] = ggml_cuda_dp4a(v, q8[i], sumi[2]);
// vl >>= 2; vh >>= 1;
//
// q8 += sizeof(block_q8_1)/4;
// aux32 = (vl & 0x03030303) | (vh & 0x04040404);
// v = int_from_table_2(aux8, values4);
// sumi[3] = ggml_cuda_dp4a(v, q8[i], sumi[3]);
//
// }
// const float d = __half2float(bq3->d);
// const uint16_t * sl16 = (const uint16_t *)bq3->scales_l + 2*ib128;
// aux32 = ((((sl16[0] | (sl16[1] << 16)) >> shift) & 0x0f0f0f0f) << 1) | 0x01010101;
// return d * (__low2float(bq8_1[4*ib128+0].ds) * aux8[0] * (sh & 0x01 ? -1 : 1) * sumi[0] +
// __low2float(bq8_1[4*ib128+1].ds) * aux8[1] * (sh & 0x04 ? -1 : 1) * sumi[1] +
// __low2float(bq8_1[4*ib128+2].ds) * aux8[2] * (sh & 0x10 ? -1 : 1) * sumi[2] +
// __low2float(bq8_1[4*ib128+3].ds) * aux8[3] * (sh & 0x40 ? -1 : 1) * sumi[3]);
const float d = *(const float *)vbq;
const block_iq3_ks * bq3 = (const block_iq3_ks *)((const char *)vbq + sizeof(float)) + kbx;
int iqs = iiqs/4;
const int ib128 = iqs/4; // 0 or 1. 0 works on quants 0...127, 1 on quants 128...255
// Each thread processes 8 quants in each of the 4 32-blocks
const int il8 = iqs%4; // 0...3. 0 works on quants 0...7, 1 on quants 8...15, 2 on 16...23, 3 on 24...31
const uint32_t * ql = (const uint32_t *)bq3->qs + 8*ib128 + 2*il8;
const uint32_t * qh = (const uint32_t *)bq3->qh + 2*il8;
uint32_t aux32;
const uint8_t * aux8 = (const uint8_t *)&aux32;
const int hshift = 4*(1-ib128);
const uint16_t * values1 = iq3k_table + ((bq3->scales[4*ib128+0] << 6) & 0x40);
const uint16_t * values2 = iq3k_table + ((bq3->scales[4*ib128+0] << 6) & 0x40);
const uint16_t * values3 = iq3k_table + ((bq3->scales[4*ib128+0] << 6) & 0x40);
const uint16_t * values4 = iq3k_table + ((bq3->scales[4*ib128+0] << 6) & 0x40);
const int * q8;
int sumi[4] = {0, 0, 0, 0};
int v;
for (int i = 0; i < 2; ++i) {
uint32_t vl = ql[i];
uint32_t vh = (qh[i] << hshift) >> 2;
q8 = (const int *)bq8_1[4*ib128+0].qs + 2*il8;
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
v = int_from_table_2(aux8, values1);
sumi[0] = ggml_cuda_dp4a(v, q8[i], sumi[0]);
vl >>= 2; vh >>= 1;
q8 += sizeof(block_q8_1)/4;
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
v = int_from_table_2(aux8, values2);
sumi[1] = ggml_cuda_dp4a(v, q8[i], sumi[1]);
vl >>= 2; vh >>= 1;
q8 += sizeof(block_q8_1)/4;
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
v = int_from_table_2(aux8, values3);
sumi[2] = ggml_cuda_dp4a(v, q8[i], sumi[2]);
vl >>= 2; vh >>= 1;
q8 += sizeof(block_q8_1)/4;
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
v = int_from_table_2(aux8, values4);
sumi[3] = ggml_cuda_dp4a(v, q8[i], sumi[3]);
}
aux32 = ((const uint32_t *)bq3->scales)[ib128] & 0xfefefefe;
return d * (__low2float(bq8_1[4*ib128+0].ds) * ((int)aux8[0] - 127) * sumi[0] +
__low2float(bq8_1[4*ib128+1].ds) * ((int)aux8[1] - 127) * sumi[1] +
__low2float(bq8_1[4*ib128+2].ds) * ((int)aux8[2] - 127) * sumi[2] +
__low2float(bq8_1[4*ib128+3].ds) * ((int)aux8[3] - 127) * sumi[3]);
}