iq1bn: adjust scalar dot product and some cleanup

This commit is contained in:
Iwan Kawrakow
2024-07-17 08:44:46 +02:00
parent 873a790ee2
commit ba00f23ea1
3 changed files with 31 additions and 131 deletions

View File

@@ -138,7 +138,6 @@ void iq2xs_init_impl(enum ggml_type type);
void iq2xs_free_impl(enum ggml_type type);
void iq3xs_init_impl(int grid_size);
void iq3xs_free_impl(int grid_size);
void iq1bn_init_impl(void);
#ifdef __cplusplus
}

2
ggml.c
View File

@@ -21119,8 +21119,6 @@ void ggml_quantize_init(enum ggml_type type) {
case GGML_TYPE_IQ1_M: iq2xs_init_impl(type); break;
case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
case GGML_TYPE_IQ3_S: iq3xs_init_impl(512); break;
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ1_BN: iq1bn_init_impl(); break;
default: // nothing
break;
}

View File

@@ -27,7 +27,6 @@
#include <array>
#include <algorithm>
#include <cstring>
#include <mutex>
namespace {
@@ -38,52 +37,7 @@ inline int nearest_int(float fval) {
return (i & 0x007fffff) - 0x00400000;
}
struct IQ1BNData {
IQ1BNData();
std::vector<std::pair<int16_t, bool>> map;
std::vector<uint16_t> rmap;
};
const IQ1BNData& get_iq1bn_data() {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
static IQ1BNData iq1bn;
return iq1bn;
}
IQ1BNData::IQ1BNData() {
map.resize(1 << 16, {int16_t(-1), false});
uint64_t aux64;
uint8_t * aux8 = (uint8_t *)&aux64;
std::vector<uint64_t> values;
values.reserve(6561);
rmap.reserve(6561);
for (int i = 0; i < (1 << 16); ++i) {
bool is_good = true;
for (int j = 0; j < 8; ++j) {
aux8[j] = (i >> 2*j) & 3;
if (aux8[j] == 3u) { is_good = false; break; }
}
if (!is_good) continue;
auto orig = aux64;
for (int j = 0; j < 8; ++j) aux8[j] = 2 - aux8[j];
int k = 0;
for (; k < int(values.size()); ++k) {
if (values[k] == aux64) break;
}
if (k < int(values.size())) {
map[i] = {k, true};
} else {
map[i].first = values.size();
values.push_back(orig);
rmap.push_back(i);
}
}
printf("==================== %s: initialized %d grid points\n", __func__, int(rmap.size()));
}
struct IQ1BNQuantizer {
constexpr static int block_size = QK_IQ1BN;
int8_t L[QK_IQ1BN];
void quantize_one_row_1bn(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix);
void quantize_one_row_2bn(const float * src, block_iq2_bn * y, int n_per_row, const float * imatrix);
@@ -95,27 +49,9 @@ struct IQ1BNQuantizer {
}
return max_in_row;
}
static uint16_t quantize_one_block_1bn(const IQ1BNData& iq1l, const float * xb, int8_t * L, uint8_t * ql, uint8_t * qh);
static constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1};
};
uint16_t IQ1BNQuantizer::quantize_one_block_1bn(const IQ1BNData& iq1bn, const float * xb, int8_t * L, uint8_t * ql, uint8_t * qh) {
for (int j = 0; j < QK_IQ1BN; ++j) {
L[j] = fabsf(xb[j]) < 1e-6f ? 1 : xb[j] < 0 ? 0 : 2;
}
uint16_t extra = 0;
for (int k = 0; k < QK_IQ1BN/8; ++k) {
auto Lk = L + 8*k;
uint16_t u = 0;
for (int j = 0; j < 8; ++j) u |= (Lk[j] << 2*j);
auto& val = iq1bn.map[u];
GGML_ASSERT(val.first >= 0);
ql[k] = val.first & 255;
qh[k/2] |= (val.first >> 8) << 4*(k%2);
if (val.second) extra |= (1 << k);
}
return extra;
}
void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix) {
static const int k_nb[6] = {1, 3, 9, 27, 81, 243};
@@ -152,9 +88,6 @@ void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, i
const int nblock = n_per_row/QK_IQ1BN;
//auto max_in_row = row_max(n_per_row, src);
//printf("%s: max = %g\n", __func__, max_in_row);
constexpr int Nj = QK_IQ1BN/4;
for (int ib = 0; ib < nblock; ++ib) {
@@ -170,10 +103,6 @@ void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, i
}
void iq1bn_init_impl(void) {
get_iq1bn_data();
}
size_t quantize_iq1_bn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
IQ1BNQuantizer iq1bn;
int nblock = n_per_row/QK_IQ1BN;
@@ -197,21 +126,19 @@ void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) {
assert(k%QK_IQ1BN == 0);
int nblock = k / QK_IQ1BN;
static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
for (int i = 0; i < nblock; ++i) {
uint8_t extra = x[i].extra;
auto ql = x[i].ql;
for (int i16 = 0; i16 < QK_IQ1BN/16; ++i16) {
for (int k = 0; k < 3; ++k) {
for (int j = 0; j < 5; ++j) {
uint8_t v = ql[k]*k_mult[j];
uint8_t v = ql[k]*IQ1BNQuantizer::k_mult[j];
int8_t vs = ((v + (v >> 1)) >> 7);
*y++ = vs - 1;
}
}
ql += 3;
uint8_t v = extra*k_mult[i16];
uint8_t v = extra*IQ1BNQuantizer::k_mult[i16];
int8_t vs = ((v + (v >> 1)) >> 7);
*y++ = vs - 1;
}
@@ -268,44 +195,38 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
return;
}
// TODO
const block_iq1_bn * x = (const block_iq1_bn *)vx;
//constexpr uint16_t k_magic = 0xaaaa;
const float * d8 = (const float *)vy;
const int8_t * q8 = (const int8_t *)(d8 + 4);
int nblock = n / QK_IQ1BN;
//const block_iq1_bn * x = (const block_iq1_bn *)vx;
int sumi[8] = {};
int8_t q1[16];
//const float * d8 = (const float *)vy;
//const int8_t * q8 = (const int8_t *)(d8 + 4);
//int nblock = n / QK_IQ1BN;
for (int i = 0; i < nblock; ++i) {
auto ql = x[i].ql;
auto extra = x[i].extra;
for (int i16 = 0; i16 < QK_IQ1BN/16; ++i16) {
for (int k = 0; k < 3; ++k) {
uint8_t q = *ql++;
for (int j = 0; j < 5; ++j) {
uint8_t v = IQ1BNQuantizer::k_mult[j]*q;
int8_t vs = 3*v >> 8;
q1[5*k+j] = vs - 1;
}
}
uint8_t v = IQ1BNQuantizer::k_mult[i16]*extra;
int8_t vs = 3*v >> 8;
q1[15] = vs - 1;
for (int j = 0; j < 8; ++j) sumi[j] += q8[j]*q1[j];
q8 += 8;
for (int j = 0; j < 8; ++j) sumi[j] += q8[j]*q1[8+j];
q8 += 8;
}
}
//int sumi[8] = {};
//uint32_t aux32[2];
//const int8_t * aux8 = (const int8_t *)aux32;
//for (int i = 0; i < nblock; ++i) {
// auto qh = x[i].qh;
// auto ql = x[i].ql;
// auto extra = x[i].extra;
// for (int j = 0; j < QK_IQ1BN/16; ++j) {
// uint16_t idx1 = ql[2*j+0] | ((qh[j] << 8) & 0x0f00);
// uint16_t idx2 = ql[2*j+1] | ((qh[j] << 4) & 0x0f00);
// uint16_t val1 = extra & 1 ? k_magic - iq1bn_grid_u16[idx1] : iq1bn_grid_u16[idx1];
// uint16_t val2 = extra & 2 ? k_magic - iq1bn_grid_u16[idx2] : iq1bn_grid_u16[idx2];
// extra >>= 2;
// aux32[0] = val1 | (val1 << 14);
// aux32[1] = (aux32[0] >> 4) & 0x03030303;
// aux32[0] &= 0x03030303;
// for (int k = 0; k < 8; ++k) sumi[k] += q8[k] * (aux8[k] - 1);
// q8 += 8;
// aux32[0] = val2 | (val2 << 14);
// aux32[1] = (aux32[0] >> 4) & 0x03030303;
// aux32[0] &= 0x03030303;
// for (int k = 0; k < 8; ++k) sumi[k] += q8[k] * (aux8[k] - 1);
// q8 += 8;
// }
//}
//*s = d8[0] * (sumi[0] + sumi[4]) + d8[1] * (sumi[1] + sumi[5]) + d8[2] * (sumi[2] + sumi[6]) + d8[3] * (sumi[3] + sumi[7]);
*s = d8[0] * (sumi[0] + sumi[4]) + d8[1] * (sumi[1] + sumi[5]) + d8[2] * (sumi[2] + sumi[6]) + d8[3] * (sumi[3] + sumi[7]);
}
void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
@@ -355,24 +276,6 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
}
void quantize_row_q8_K64_reference(const float * x, block_q8_K64 * y, int64_t k) {
//assert(k % 64 == 0);
//const int64_t nb = k / 64;
// Check if a row-wise scale works. It almost does, PPL is only ~0.02 higher
//float amax = 0;
//for (int j = 0; j < k; ++j) {
// float ax = fabsf(x[j]);
// amax = MAX(ax, amax);
//}
//float d = amax/127;
//float id = d ? 1/d : 0.f;
//for (int i = 0; i < nb; i++) {
// for (int j = 0; j < 64; ++j) y[i].qs[j] = nearest_int(id*x[j]);
// y[i].d = d;
// x += 64;
//}
float * dptr = (float *)y;
auto qs = (int8_t *)(dptr + 4);