iq1bn: fix scalar dot product

The fix makes it faster on the Ryzen-7950X (10.5 t/s vs 8.2 t/s)
but slower on the M2 (6.8 t/s vs 8.6 t/s before).
This commit is contained in:
Kawrakow
2024-07-17 13:37:18 +03:00
parent 04decf3fc5
commit 02dc036187

View File

@@ -182,6 +182,15 @@ void dequantize_row_iq2_bn(const block_iq2_bn * x, float * y, int64_t k) {
}
}
namespace {
inline int8_t iq1bn_dequant(uint8_t q, int i) {
uint8_t v = IQ1BNQuantizer::k_mult[i]*q;
//int8_t vs = (v + (v << 1)) >> 8;
int8_t vs = 3*v >> 8;
return vs - 1;
}
}
void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
GGML_UNUSED(bs);
@@ -204,29 +213,29 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
int sumi[8] = {};
int8_t q1[16];
for (int i = 0; i < nblock; ++i) {
auto ql = x[i].ql;
auto extra = x[i].extra;
for (int i16 = 0; i16 < QK_IQ1BN/16; ++i16) {
for (int k = 0; k < 3; ++k) {
uint8_t q = *ql++;
for (int j = 0; j < 5; ++j) {
uint8_t v = IQ1BNQuantizer::k_mult[j]*q;
int8_t vs = 3*v >> 8;
q1[5*k+j] = vs - 1;
for (int ii = 0; ii < nblock; ii += 32) {
int16_t sum16[8] = {};
int nb = std::min(ii + 32, nblock);
for (int i = ii; i < nb; ++i) {
auto ql = x[i].ql;
auto extra = x[i].extra;
for (int i16 = 0; i16 < QK_IQ1BN/16; ++i16) {
for (int k = 0; k < 3; ++k) {
uint8_t q = *ql++;
for (int j = 0; j < 5; ++j) q1[5*k+j] = iq1bn_dequant(q, j);
}
q1[15] = iq1bn_dequant(extra, i16);
// We collect 8 q8 values per block into each element of sum16
// => 32 x 8 = 256 values in each loop over i, so this cannot overflow the int16_t range
// (q8 is in -127...127, and hence the sum is in -32512...32512
for (int j = 0; j < 8; ++j) sum16[j] += q8[2*j+0]*q1[2*j+0] + q8[2*j+1]*q1[2*j+1];
q8 += 16;
}
uint8_t v = IQ1BNQuantizer::k_mult[i16]*extra;
int8_t vs = 3*v >> 8;
q1[15] = vs - 1;
for (int j = 0; j < 8; ++j) sumi[j] += q8[j]*q1[j];
q8 += 8;
for (int j = 0; j < 8; ++j) sumi[j] += q8[j]*q1[8+j];
q8 += 8;
}
for (int j = 0; j < 8; ++j) sumi[j] += sum16[j];
}
*s = d8[0] * (sumi[0] + sumi[4]) + d8[1] * (sumi[1] + sumi[5]) + d8[2] * (sumi[2] + sumi[6]) + d8[3] * (sumi[3] + sumi[7]);
*s = d8[0] * (sumi[0] + sumi[1]) + d8[1] * (sumi[2] + sumi[3]) + d8[2] * (sumi[4] + sumi[5]) + d8[3] * (sumi[6] + sumi[7]);
}
void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {