mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-28 18:32:04 +00:00
iq2_ks: scalar dot product
This commit is contained in:
@@ -1042,7 +1042,7 @@ void dequantize_row_iq2_ks(const block_iq2_ks * GGML_RESTRICT x, float * GGML_R
|
||||
|
||||
}
|
||||
|
||||
void vec_dot_iq2_ks_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
void vec_dot_iq2_ks_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
assert(nrc == 1);
|
||||
GGML_UNUSED(nrc);
|
||||
@@ -1056,7 +1056,43 @@ void vec_dot_iq2_ks_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void *
|
||||
}
|
||||
#endif
|
||||
|
||||
GGML_ABORT("not implemented");
|
||||
const ggml_half * dptr = (const ggml_half *)vx;
|
||||
const float d = GGML_FP16_TO_FP32(*dptr);
|
||||
const block_iq2_ks * x = (const block_iq2_ks *)(dptr + 1);
|
||||
const block_q8_K * y = (const block_q8_K *)vy;
|
||||
|
||||
const int nb = n / QK_K;
|
||||
float sumf = 0;
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const uint8_t * qs = x[i].qs;
|
||||
const int8_t * q8 = y[i].qs;
|
||||
uint16_t extra = x[i].extra;
|
||||
int sumi = 0;
|
||||
for (int ib128 = 0; ib128 < QK_K/128; ++ib128) {
|
||||
int d1 = (((x[i].scales[2*ib128+0] & 0xf) | ((extra >> 4) & 0x10)) - 16);
|
||||
int d2 = (((x[i].scales[2*ib128+0] >> 4) | ((extra >> 5) & 0x10)) - 16);
|
||||
int d3 = (((x[i].scales[2*ib128+1] & 0xf) | ((extra >> 6) & 0x10)) - 16);
|
||||
int d4 = (((x[i].scales[2*ib128+1] >> 4) | ((extra >> 7) & 0x10)) - 16);
|
||||
const int8_t * values1 = extra & 1 ? iq2nl_values + 4 : iq2nl_values;
|
||||
const int8_t * values2 = extra & 2 ? iq2nl_values + 4 : iq2nl_values;
|
||||
const int8_t * values3 = extra & 4 ? iq2nl_values + 4 : iq2nl_values;
|
||||
const int8_t * values4 = extra & 8 ? iq2nl_values + 4 : iq2nl_values;
|
||||
extra >>= 4;
|
||||
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
|
||||
for (int j = 0; j < 32; ++j) {
|
||||
sumi1 += q8[j+ 0] * values1[(qs[j] >> 0) & 3];
|
||||
sumi2 += q8[j+32] * values2[(qs[j] >> 2) & 3];
|
||||
sumi3 += q8[j+64] * values3[(qs[j] >> 4) & 3];
|
||||
sumi4 += q8[j+96] * values4[(qs[j] >> 6) & 3];
|
||||
}
|
||||
sumi += d1*sumi1 + d2*sumi2 + d3*sumi3 + d4*sumi4;
|
||||
q8 += 128;
|
||||
qs += 32;
|
||||
}
|
||||
sumf += y[i].d * sumi;
|
||||
}
|
||||
|
||||
*s = d * sumf;
|
||||
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user