mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
iq2_kl: MMVQ
We get PP-128(L3-8B) = 162 t/s. Which means that this is not quite as good as it should be as (almost) same bpq q2_K is at 170 t/s.
This commit is contained in:
@@ -1020,68 +1020,45 @@ __device__ __forceinline__ void vec_dot_iq3_k_q8_1(
|
||||
__device__ __forceinline__ void vec_dot_iq2_kl_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iiqs, float * result) {
|
||||
|
||||
return;
|
||||
|
||||
float d = __half2float(*(const half *)vbq);
|
||||
const block_iq3_ks * bq3 = (const block_iq3_ks *)((const char *)vbq + sizeof(half)) + kbx;
|
||||
const block_iq2_kl * bq2 = (const block_iq2_kl *)((const char *)vbq + sizeof(half)) + kbx;
|
||||
|
||||
int iqs = iiqs/4;
|
||||
const int ib128 = iqs/4; // 0 or 1. 0 works on quants 0...127, 1 on quants 128...255
|
||||
// Each thread processes 8 quants in each of the 4 32-blocks
|
||||
const int il8 = iqs%4; // 0...3. 0 works on quants 0...7, 1 on quants 8...15, 2 on 16...23, 3 on 24...31
|
||||
const int ib64 = iqs/2; // 0...3. 0 works on quants 0...63, 1 on quants 64...127, etc.
|
||||
// Each thread processes 16 quants in each of the 2 32-blocks
|
||||
const int il16 = iqs%2; // 0...3. 0 works on quants 0...7, 1 on quants 8...15, 2 on 16...23, 3 on 24...31
|
||||
|
||||
const uint16_t * ql = (const uint16_t *)bq3->qs + 16*ib128 + 4*il8;
|
||||
const uint16_t * qh = (const uint16_t *)bq3->qh + 4*il8;
|
||||
const uint16_t * ql = (const uint16_t *)bq2->qs + 8*ib64 + 4*il16;
|
||||
const uint16_t * qh = (const uint16_t *)bq2->qh + 4*il16;
|
||||
|
||||
int32_t aux32;
|
||||
const uint8_t * aux8 = (const uint8_t *)&aux32;
|
||||
|
||||
uint16_t extra = bq3->extra >> 4*ib128;
|
||||
uint16_t extra_v = extra >> 8;
|
||||
const int * q8l = (const int *)bq8_1[2*ib64+0].qs + 4*il16;
|
||||
const int * q8h = (const int *)bq8_1[2*ib64+1].qs + 4*il16;
|
||||
|
||||
const uint16_t * values1 = iq3k_table + ((extra_v << 6) & 0x40);
|
||||
const uint16_t * values2 = iq3k_table + ((extra_v << 5) & 0x40);
|
||||
const uint16_t * values3 = iq3k_table + ((extra_v << 4) & 0x40);
|
||||
const uint16_t * values4 = iq3k_table + ((extra_v << 3) & 0x40);
|
||||
|
||||
const int * q8;
|
||||
int sumi[4] = {0, 0, 0, 0};
|
||||
int v;
|
||||
int sumi1 = 0, sumi2 = 0;
|
||||
int v1, v2;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
uint32_t vl = ql[2*i+0] | (ql[2*i+1] << 16);
|
||||
uint32_t vh = ((qh[2*i+0] | (qh[2*i+1] << 16)) >> 4*ib128) << 2;
|
||||
uint32_t vl = ql[2*i+0] | (ql[2*i+1] << 16);
|
||||
uint32_t vh = (qh[2*i+0] | (qh[2*i+1] << 16)) >> 2*ib64;
|
||||
|
||||
q8 = (const int *)bq8_1[4*ib128+0].qs + 2*il8;
|
||||
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
|
||||
v = int_from_table_2(aux8, values1);
|
||||
sumi[0] = ggml_cuda_dp4a(v, q8[i], sumi[0]);
|
||||
vl >>= 2; vh >>= 1;
|
||||
|
||||
q8 += sizeof(block_q8_1)/4;
|
||||
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
|
||||
v = int_from_table_2(aux8, values2);
|
||||
sumi[1] = ggml_cuda_dp4a(v, q8[i], sumi[1]);
|
||||
vl >>= 2; vh >>= 1;
|
||||
|
||||
q8 += sizeof(block_q8_1)/4;
|
||||
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
|
||||
v = int_from_table_2(aux8, values3);
|
||||
sumi[2] = ggml_cuda_dp4a(v, q8[i], sumi[2]);
|
||||
vl >>= 2; vh >>= 1;
|
||||
|
||||
q8 += sizeof(block_q8_1)/4;
|
||||
aux32 = (vl & 0x03030303) | (vh & 0x04040404);
|
||||
v = int_from_table_2(aux8, values4);
|
||||
sumi[3] = ggml_cuda_dp4a(v, q8[i], sumi[3]);
|
||||
aux32 = (vl & 0x0f0f0f0f) | ((vh << 4) & 0x10101010);
|
||||
v1 = iq2kl_values[aux8[0]] | (iq2kl_values[aux8[1]] << 16);
|
||||
v2 = iq2kl_values[aux8[2]] | (iq2kl_values[aux8[3]] << 16);
|
||||
sumi1 = ggml_cuda_dp4a(v1, q8l[2*i+0], ggml_cuda_dp4a(v2, q8l[2*i+1], sumi1));
|
||||
|
||||
aux32 = ((vl >> 4) & 0x0f0f0f0f) | ((vh << 3) & 0x10101010);
|
||||
v1 = iq2kl_values[aux8[0]] | (iq2kl_values[aux8[1]] << 16);
|
||||
v2 = iq2kl_values[aux8[2]] | (iq2kl_values[aux8[3]] << 16);
|
||||
sumi2 = ggml_cuda_dp4a(v1, q8h[2*i+0], ggml_cuda_dp4a(v2, q8h[2*i+1], sumi2));
|
||||
}
|
||||
const uint16_t * sl16 = (const uint16_t *)bq3->scales;
|
||||
aux32 = __vsub4(((sl16[0] | (sl16[1] << 16)) >> 4*ib128) & 0x0f0f0f0f, 0x10101010);
|
||||
const int8_t * a8 = (const int8_t *)&aux32;
|
||||
*result += d * (__low2float(bq8_1[4*ib128+0].ds) * (a8[0] + ((extra << 4) & 0x10)) * sumi[0] +
|
||||
__low2float(bq8_1[4*ib128+1].ds) * (a8[1] + ((extra << 3) & 0x10)) * sumi[1] +
|
||||
__low2float(bq8_1[4*ib128+2].ds) * (a8[2] + ((extra << 2) & 0x10)) * sumi[2] +
|
||||
__low2float(bq8_1[4*ib128+3].ds) * (a8[3] + ((extra << 1) & 0x10)) * sumi[3]);
|
||||
|
||||
auto sh = bq2->scales_h >> 4*ib64;
|
||||
int ls1 = int(((bq2->scales_l[(2*ib64+0)%4] >> 4*(ib64/2)) & 0xf) | ((sh << 4) & 0x30)) - 32;
|
||||
int ls2 = int(((bq2->scales_l[(2*ib64+1)%4] >> 4*(ib64/2)) & 0xf) | ((sh << 2) & 0x30)) - 32;
|
||||
|
||||
*result += d * (__low2float(bq8_1[2*ib64+0].ds) * ls1 * sumi1 + __low2float(bq8_1[2*ib64+1].ds) * ls2 * sumi2);
|
||||
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user