iq2_ks: scalar dot product

This commit is contained in:
Iwan Kawrakow
2024-10-13 08:48:13 +03:00
parent 1f6e498dfa
commit 18cdf624f8

View File

@@ -1042,7 +1042,7 @@ void dequantize_row_iq2_ks(const block_iq2_ks * GGML_RESTRICT x, float * GGML_R
}
void vec_dot_iq2_ks_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
void vec_dot_iq2_ks_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
GGML_UNUSED(nrc);
@@ -1056,7 +1056,43 @@ void vec_dot_iq2_ks_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void *
}
#endif
GGML_ABORT("not implemented");
const ggml_half * dptr = (const ggml_half *)vx;
const float d = GGML_FP16_TO_FP32(*dptr);
const block_iq2_ks * x = (const block_iq2_ks *)(dptr + 1);
const block_q8_K * y = (const block_q8_K *)vy;
const int nb = n / QK_K;
float sumf = 0;
for (int i = 0; i < nb; i++) {
const uint8_t * qs = x[i].qs;
const int8_t * q8 = y[i].qs;
uint16_t extra = x[i].extra;
int sumi = 0;
for (int ib128 = 0; ib128 < QK_K/128; ++ib128) {
int d1 = (((x[i].scales[2*ib128+0] & 0xf) | ((extra >> 4) & 0x10)) - 16);
int d2 = (((x[i].scales[2*ib128+0] >> 4) | ((extra >> 5) & 0x10)) - 16);
int d3 = (((x[i].scales[2*ib128+1] & 0xf) | ((extra >> 6) & 0x10)) - 16);
int d4 = (((x[i].scales[2*ib128+1] >> 4) | ((extra >> 7) & 0x10)) - 16);
const int8_t * values1 = extra & 1 ? iq2nl_values + 4 : iq2nl_values;
const int8_t * values2 = extra & 2 ? iq2nl_values + 4 : iq2nl_values;
const int8_t * values3 = extra & 4 ? iq2nl_values + 4 : iq2nl_values;
const int8_t * values4 = extra & 8 ? iq2nl_values + 4 : iq2nl_values;
extra >>= 4;
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
for (int j = 0; j < 32; ++j) {
sumi1 += q8[j+ 0] * values1[(qs[j] >> 0) & 3];
sumi2 += q8[j+32] * values2[(qs[j] >> 2) & 3];
sumi3 += q8[j+64] * values3[(qs[j] >> 4) & 3];
sumi4 += q8[j+96] * values4[(qs[j] >> 6) & 3];
}
sumi += d1*sumi1 + d2*sumi2 + d3*sumi3 + d4*sumi4;
q8 += 128;
qs += 32;
}
sumf += y[i].d * sumi;
}
*s = d * sumf;
}