iq4_kt: go to 4.0 bpw

15 bits per group of 4, plus 8 bit scales ifor blocks of 32.
This gives a slightly better PPL than iq4_kss.
This commit is contained in:
Iwan Kawrakow
2024-11-15 09:38:47 +02:00
parent 21903f19b4
commit 1be0a9e0d7
5 changed files with 148 additions and 119 deletions

View File

@@ -468,10 +468,9 @@ typedef struct {
static_assert(sizeof(block_iq3_kt) == QK_K/4 + QK_K/8 + QK_K/64, "wrong iq3_kt block size/padding");
typedef struct {
int8_t scales[QK_K/64];
uint8_t ql[QK_K/2];
uint32_t qs[QK_K/8];
} block_iq4_kt;
static_assert(sizeof(block_iq4_kt) == QK_K/2 + QK_K/64, "wrong iq4_kt block size/padding");
static_assert(sizeof(block_iq4_kt) == QK_K/2, "wrong iq4_kt block size/padding");
typedef struct {
ggml_half d;

View File

@@ -411,30 +411,37 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst
int64_t ii = blockIdx.x;
int64_t row = (QK_K * ii) / n_per_row;
const char * cx = (const char *)vx + row * row_size;
float scale = *(const float *)cx;
const block_iq4_kt * x = (const block_iq4_kt *)(cx + sizeof(float));
const float * dptr = (const float *)((const char *)vx + row * row_size);
float scale = dptr[0] * 31.75f * 1.01f;
float row_av = dptr[1];
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
const int64_t i = ii - (row*n_per_row)/QK_K;
constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
constexpr uint32_t kmask = 0x8fff8fff;
constexpr uint32_t km32 = 0x3b603b60;
constexpr int kNumGroups = 64;
const int64_t tid = threadIdx.x;
const int64_t ib = tid; // 0...31
dst_t * y = yy + ii*QK_K + 8*ib;
const uint16_t * ql = (const uint16_t *)x[i].ql;
uint32_t idx1 = ql[2*ib+0] + 4096;
uint32_t idx2 = ql[2*ib+1] + 4096;
const float dl = scale * x[i].scales[ib/8] * 31.75f * 1.01f;
const uint32_t * shb = x[i].qs;
const uint8_t * ql = (const uint8_t *)(shb + 8); //Q::kNblock;
const uint8_t * qh = ql + kNumGroups;
const int ib32 = ib/4;
const int ig = ib%4;
const int jj = ib32*8 + 2*ig;
uint32_t idx1 = ql[jj+0] + ((qh[(jj+0)%(kNumGroups/2)] << (8 - 4*((jj+0)/(kNumGroups/2)))) & 0xf00) + (((shb[ib32] >> (8 + 6*ig+0)) & 7) << 12) + 4096;
uint32_t idx2 = ql[jj+1] + ((qh[(jj+1)%(kNumGroups/2)] << (8 - 4*((jj+1)/(kNumGroups/2)))) & 0xf00) + (((shb[ib32] >> (8 + 6*ig+3)) & 7) << 12) + 4096;
const float dl = scale * ((const int8_t *)(shb + ib32))[0];
uint32_t s[2];
const half * h = (const half *)s;
for (int j = 0; j < 4; ++j) {
idx1 = ka*idx1 + kb; s[0] = (idx1 & kmask) ^ km32;
idx2 = ka*idx2 + kb; s[1] = (idx2 & kmask) ^ km32;
y[j+0] = dl * (float)(h[0] + h[1]);
y[j+4] = dl * (float)(h[2] + h[3]);
y[j+0] = dl * (float)(h[0] + h[1]) + row_av;
y[j+4] = dl * (float)(h[2] + h[3]) + row_av;
}
}

View File

@@ -195,55 +195,55 @@ static __global__ void dequantize_mul_mat_vec_iq4_kt(const void * __restrict__ v
const half * h = (const half *)s;
for (int i = ix; i < num_blocks_per_row; i += 2) {
const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it);
const uint16_t * ql = (const uint16_t *)x[i].ql;
const dfloat scale1 = x[i].scales[it/8];
const dfloat scale2 = x[i].scales[it/8 + 2];
const dfloat2 dl1 = {scale1, scale1};
const dfloat2 dl2 = {scale2, scale2};
dfloat2 bdot1 = {0, 0};
dfloat2 bdot2 = {0, 0};
uint32_t val1 = ql[2*it+ 0] + 4096;
uint32_t val2 = ql[2*it+32] + 4096;
for (int k = 0; k < 2; ++k) {
val1 = ka*val1 + kb; s[0] = (val1 & kmask) ^ km32;
val1 = ka*val1 + kb; s[1] = (val1 & kmask) ^ km32;
val2 = ka*val2 + kb; s[2] = (val2 & kmask) ^ km32;
val2 = ka*val2 + kb; s[3] = (val2 & kmask) ^ km32;
#ifdef GGML_CUDA_F16
bdot1 = __hfma2(y[k+ 0], {h[0]+h[1], h[2]+h[3]}, bdot1);
bdot2 = __hfma2(y[k+64], {h[4]+h[5], h[6]+h[7]}, bdot2);
#else
bdot1.x += y[k+ 0].x * (float)(h[0] + h[1]);
bdot1.y += y[k+ 0].y * (float)(h[2] + h[3]);
bdot2.x += y[k+64].x * (float)(h[4] + h[5]);
bdot2.y += y[k+64].y * (float)(h[6] + h[7]);
#endif
}
val1 = ql[2*it+ 1] + 4096;
val2 = ql[2*it+33] + 4096;
for (int k = 2; k < 4; ++k) {
val1 = ka*val1 + kb; s[0] = (val1 & kmask) ^ km32;
val1 = ka*val1 + kb; s[1] = (val1 & kmask) ^ km32;
val2 = ka*val2 + kb; s[2] = (val2 & kmask) ^ km32;
val2 = ka*val2 + kb; s[3] = (val2 & kmask) ^ km32;
#ifdef GGML_CUDA_F16
bdot1 = __hfma2(y[k+ 0], {h[0]+h[1], h[2]+h[3]}, bdot1);
bdot2 = __hfma2(y[k+64], {h[4]+h[5], h[6]+h[7]}, bdot2);
#else
bdot1.x += y[k+ 0].x * (float)(h[0] + h[1]);
bdot1.y += y[k+ 0].y * (float)(h[2] + h[3]);
bdot2.x += y[k+64].x * (float)(h[4] + h[5]);
bdot2.y += y[k+64].y * (float)(h[6] + h[7]);
#endif
}
#ifdef GGML_CUDA_F16
tmp = __hfma2(dl1, bdot1, tmp);
tmp = __hfma2(dl2, bdot2, tmp);
#else
tmp.x += dl1.x * bdot1.x + dl2.x * bdot2.x;
tmp.y += dl1.y * bdot1.y + dl2.y * bdot2.y;
#endif
// const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it);
// const uint16_t * ql = (const uint16_t *)x[i].ql;
// const dfloat scale1 = x[i].scales[it/8];
// const dfloat scale2 = x[i].scales[it/8 + 2];
// const dfloat2 dl1 = {scale1, scale1};
// const dfloat2 dl2 = {scale2, scale2};
// dfloat2 bdot1 = {0, 0};
// dfloat2 bdot2 = {0, 0};
// uint32_t val1 = ql[2*it+ 0] + 4096;
// uint32_t val2 = ql[2*it+32] + 4096;
// for (int k = 0; k < 2; ++k) {
// val1 = ka*val1 + kb; s[0] = (val1 & kmask) ^ km32;
// val1 = ka*val1 + kb; s[1] = (val1 & kmask) ^ km32;
// val2 = ka*val2 + kb; s[2] = (val2 & kmask) ^ km32;
// val2 = ka*val2 + kb; s[3] = (val2 & kmask) ^ km32;
//#ifdef GGML_CUDA_F16
// bdot1 = __hfma2(y[k+ 0], {h[0]+h[1], h[2]+h[3]}, bdot1);
// bdot2 = __hfma2(y[k+64], {h[4]+h[5], h[6]+h[7]}, bdot2);
//#else
// bdot1.x += y[k+ 0].x * (float)(h[0] + h[1]);
// bdot1.y += y[k+ 0].y * (float)(h[2] + h[3]);
// bdot2.x += y[k+64].x * (float)(h[4] + h[5]);
// bdot2.y += y[k+64].y * (float)(h[6] + h[7]);
//#endif
// }
// val1 = ql[2*it+ 1] + 4096;
// val2 = ql[2*it+33] + 4096;
// for (int k = 2; k < 4; ++k) {
// val1 = ka*val1 + kb; s[0] = (val1 & kmask) ^ km32;
// val1 = ka*val1 + kb; s[1] = (val1 & kmask) ^ km32;
// val2 = ka*val2 + kb; s[2] = (val2 & kmask) ^ km32;
// val2 = ka*val2 + kb; s[3] = (val2 & kmask) ^ km32;
//#ifdef GGML_CUDA_F16
// bdot1 = __hfma2(y[k+ 0], {h[0]+h[1], h[2]+h[3]}, bdot1);
// bdot2 = __hfma2(y[k+64], {h[4]+h[5], h[6]+h[7]}, bdot2);
//#else
// bdot1.x += y[k+ 0].x * (float)(h[0] + h[1]);
// bdot1.y += y[k+ 0].y * (float)(h[2] + h[3]);
// bdot2.x += y[k+64].x * (float)(h[4] + h[5]);
// bdot2.y += y[k+64].y * (float)(h[6] + h[7]);
//#endif
// }
//#ifdef GGML_CUDA_F16
// tmp = __hfma2(dl1, bdot1, tmp);
// tmp = __hfma2(dl2, bdot2, tmp);
//#else
// tmp.x += dl1.x * bdot1.x + dl2.x * bdot2.x;
// tmp.y += dl1.y * bdot1.y + dl2.y * bdot2.y;
//#endif
}
// sum up partial sums and write back result

View File

@@ -1230,7 +1230,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = vec_dot_iq4_kt_q8_k,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
.row_meta_size = 4,
.row_meta_size = 8,
},
[GGML_TYPE_IQ3_K] = {
.type_name = "iq3_k",

View File

@@ -3187,6 +3187,7 @@ public:
}
static inline int bin4(float x) { return x < -24.f ? 0 : x < 0.0f ? 1 : x < 24.f ? 2 : 3; }
static inline int bin5(float x) { return x < -48.f ? 0 : x < -16.f ? 1 : x < 16.f ? 2 : x < 48.f ? 3 : 4; }
static inline void set_weights(float sigma2_scale, int nblock, const float * x, const float * imatrix, float * row_weights) {
for (int ibl = 0; ibl < nblock; ++ibl) {
@@ -3398,10 +3399,15 @@ void QuantizerIQKT<block_size, group_size, num_bits>::find_best_match(float d, c
auto vw4 = _mm_loadu_ps(wl);
auto vw = _mm256_set_m128(vw4, vw4);
int jbest = -1;
if (ncluster == 256) {
if (ncluster == 256 || ncluster == 625) {
_mm256_storeu_ps(sx, vx);
uint8_t u = 0;
for (int k = 0; k < 4; ++k) u |= (bin4(sx[k]) << 2*k);
uint16_t u = 0;
if (ncluster == 256) {
for (int k = 0; k < 4; ++k) u |= (bin4(sx[k]) << 2*k);
} else {
int l = 1;
for (int k = 0; k < 4; ++k) { u += bin5(sx[k])*l; l *= 5; }
}
jbest = u;
} else {
auto vbest = _mm256_set1_ps(INFINITY);
@@ -3476,7 +3482,7 @@ template <int block_size, int group_size, int num_bits>
std::vector<std::vector<int>> QuantizerIQKT<block_size, group_size, num_bits>::finalize_clusters(int num_neighbours,
const std::vector<float>& values, const std::vector<float>& clusters, std::vector<std::vector<float>>& c_values) {
int ncluster = clusters.size()/kGroupSize;
GGML_ASSERT(ncluster%8 == 0);
//GGML_ASSERT(ncluster%8 == 0);
std::vector<std::vector<int>> p_in_cluster(ncluster);
std::vector<int> which_cluster(num_neighbours*kNumVal);
std::vector<int> ibest(num_neighbours);
@@ -3600,15 +3606,18 @@ std::vector<float> QuantizerIQKT<block_size, group_size, num_bits>::cluster_poin
}
return result;
}
else if (ndim == 4 && ncluster == 256) {
else if (ndim == 4 && (ncluster == 256 || ncluster == 625)) {
std::memset(sump.data(), 0, sump.size()*sizeof(float));
std::memset(counts.data(), 0, counts.size()*sizeof(int));
//printf("%s: simple with group size %d\n", __func__, group_size);
//printf("%s: midpoints = %g, %g, %g, %g\n", __func__, mid[0], mid[1], mid[2], mid[3]);
for (int ip = 0; ip < npoint; ++ip) {
auto vp = points.data() + ndim*ip;
uint8_t u = 0;
for (int k = 0; k < ndim; ++k) u |= (bin4(vp[k]) << 2*k);
uint16_t u = 0;
if (ncluster == 255) {
for (int k = 0; k < ndim; ++k) u |= (bin4(vp[k]) << 2*k);
} else {
int s = 1;
for (int k = 0; k < ndim; ++k) { u += s*bin5(vp[k]); s *= 5; }
}
++counts[u];
for (int k = 0; k < ndim; ++k) sump[ndim*u + k] += vp[k];
}
@@ -4244,13 +4253,13 @@ void vec_dot_iq3_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx
namespace{
using QuantizerIQ4KT = QuantizerIQKT<64, 4, 16>;
using QuantizerIQ4KT = QuantizerIQKT<32, 4, 15>;
const QuantizerIQ4KT& iq4kt_quantizer() {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
static std::unique_ptr<QuantizerIQ4KT> quantizer;
if (!quantizer) quantizer = std::make_unique<QuantizerIQ4KT>(512, 5);
if (!quantizer) quantizer = std::make_unique<QuantizerIQ4KT>(625, 6);
return *quantizer;
}
@@ -4263,9 +4272,7 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f
float * dptr = (float *)vy;
block_iq4_kt * y = (block_iq4_kt *)(dptr + 1);
int best_idx[2*Q::kNg];
block_iq4_kt * y = (block_iq4_kt *)(dptr + 2);
auto& quantizer = iq4kt_quantizer();
@@ -4273,14 +4280,22 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f
Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights);
float amax_row = 0;
for (int j = 0; j < n_per_row; ++j) amax_row = std::max(amax_row, std::abs(x[j]));
float amax_row = 0, row_av = 0;
for (int j = 0; j < n_per_row; ++j) {
row_av += x[j];
amax_row = std::max(amax_row, std::abs(x[j]));
}
row_av /= n_per_row;
dptr[1] = row_av;
if (!amax_row) {
*dptr = 0.f;
dptr[0] = 0.f;
std::memset(y, 0, nblock*sizeof(block_iq4_kt));
return;
}
int best_idx[2*Q::kNg];
float xaux[Q::kBlockSize];
float amax_scale = 0, max_scale = 0;
for (int ibl = 0; ibl < nblock; ++ibl) {
@@ -4290,38 +4305,32 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f
const float * xbl = x + ibl*Q::kSuperBlockSize;
auto scales = all_scales + ibl*Q::kNblock;
auto ql = (uint16_t *)y[ibl].ql;
for (int ib = 0; ib < Q::kNblock; ++ib) {
const float * xb = xbl + Q::kBlockSize*ib;
const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
float amax = 0;
for (int j = 0; j < Q::kBlockSize; ++j) {
float ax = std::abs(xb[j]);
xaux[j] = xbl[ib*Q::kBlockSize+j] - row_av;
float ax = std::abs(xaux[j]);
amax = std::max(amax, ax);
}
if (!amax) {
scales[ib] = 0;
ql += Q::kNg;
continue;
}
float best = 0;
float scale_0 = std::max(92.f, 127.f*amax/amax_row);
for (int itry = -2; itry <= 2; ++itry) {
quantizer.find_best_match( amax/(8.f*itry + scale_0), xb, weight, best_idx);
auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx);
quantizer.find_best_match( amax/(8.f*itry + scale_0), xaux, weight, best_idx);
auto [dp, score_p] = quantizer.find_best_scale(xaux, weight, best_idx);
if (score_p > best) {
best = score_p; scales[ib] = dp;
for (int j = 0; j < Q::kNg; ++j) ql[j] = best_idx[j];
}
quantizer.find_best_match(-amax/(8.f*itry + scale_0), xb, weight, best_idx);
auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx);
quantizer.find_best_match(-amax/(8.f*itry + scale_0), xaux, weight, best_idx);
auto [dm, score_m] = quantizer.find_best_scale(xaux, weight, best_idx);
if (score_m > best) {
best = score_m; scales[ib] = dm;
for (int j = 0; j < Q::kNg; ++j) ql[j] = best_idx[j];
}
}
ql += Q::kNg;
float abs_scale = std::abs(scales[ib]);
if (abs_scale > amax_scale) {
@@ -4333,54 +4342,62 @@ void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const f
}
float d = -max_scale/128;
float id = d ? 1/d : 0.f;
for (int ibl = 0; ibl < nblock; ++ibl) {
auto scales = all_scales + ibl*Q::kNblock;
for (int ib = 0; ib < Q::kNblock; ++ib) {
int ls = nearest_int(id*scales[ib]);
y[ibl].scales[ib] = std::min(ls, 127);
}
}
*dptr = d;
dptr[0] = d;
if (!d) return;
constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;
for (int iloop = 0; iloop < 1; ++iloop) {
const float id = 1/d;
float sumqx = 0, sumq2 = 0;
for (int ibl = 0; ibl < nblock; ++ibl) {
auto qs = (uint16_t *)y[ibl].ql;
// high 3 bits + scales
// each block of 32 needs 8 x 3 (high bits) + 1 x 8 (scale) = 32 bits = 1 x uint32_t
// we have 8 blocks
auto shb = y[ibl].qs; // high 3 bits + scales
auto ql = (uint8_t *)(shb + Q::kNblock);
auto qh = ql + kNumGroups;
std::memset(qh, 0, kNumGroups/2);
const float * xbl = x + ibl*Q::kSuperBlockSize;
auto scales = all_scales + ibl*Q::kNblock;
for (int ib = 0; ib < Q::kNblock; ++ib) {
const float * xb = xbl + Q::kBlockSize*ib;
const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
int ls = y[ibl].scales[ib];
for (int j = 0; j < Q::kBlockSize; ++j) xaux[j] = xbl[ib*Q::kBlockSize+j] - row_av;
int ls = nearest_int(id*scales[ib]);
ls = std::min(ls, 127);
*(int8_t *)(shb + ib) = ls;
float dl = d*ls;
quantizer.find_best_match(dl, xb, weight, best_idx);
float dnew = quantizer.find_best_scale(xb, weight, best_idx).first;
ls = std::max(-128, std::min(127, nearest_int(dnew/d)));
y[ibl].scales[ib] = ls;
quantizer.find_best_match(dl, xaux, weight, best_idx);
for (int j = 0; j < Q::kNg; ++j) {
qs[j] = best_idx[j];
auto xl = xb + Q::kGroupSize*j;
shb[ib] |= ((best_idx[j] >> 12) << (8 + 3*j));
ql[Q::kNg*ib + j] = best_idx[j] & 255;
qh[(Q::kNg*ib + j)%(kNumGroups/2)] |= ((best_idx[j] >> 8) & 0xf) << 4*((Q::kNg*ib + j)/(kNumGroups/2));
auto xl = xaux + Q::kGroupSize*j;
auto wl = weight + Q::kGroupSize*j;
auto ql = quantizer.values() + qs[j]*Q::kGroupSize;
auto ql = quantizer.values() + Q::kGroupSize*best_idx[j];
for (int k = 0; k < Q::kGroupSize; ++k) {
float q = ql[k]*ls;
sumqx += wl[k]*xl[k]*q;
sumq2 += wl[k]*q*q;
}
}
qs += Q::kNg;
//ls += 128;
//qs[2*ib+0] = uint64_t(best_idx[0]) | (uint64_t(best_idx[1]) << 15) | (uint64_t(best_idx[2]) << 30) | (uint64_t(best_idx[3]) << 45) |
// (uint64_t(ls & 0x0f) << 60);
//qs[2*ib+1] = uint64_t(best_idx[4]) | (uint64_t(best_idx[5]) << 15) | (uint64_t(best_idx[6]) << 30) | (uint64_t(best_idx[7]) << 45) |
// (uint64_t(ls & 0xf0) << 56);
}
}
if (sumq2 > 0) {
d = sumqx/sumq2;
*dptr = d;
if (!d) return;
dptr[0] = d;
if (!d) break;
} else {
break;
}
@@ -4416,20 +4433,26 @@ size_t quantize_iq4_kt(const float * src, void * dst, int64_t nrows, int64_t n_p
void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) {
using Q = QuantizerIQ4KT;
assert(k % Q::kSuperBlockSize == 0);
constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;
const int nb = k / Q::kSuperBlockSize;
const float * dptr = (const float *)x;
const float d = *dptr * Q::kScale;
x = (const block_iq4_kt *)(dptr + 1);
const float d = dptr[0] * Q::kScale;
const float row_av = dptr[1];
x = (const block_iq4_kt *)(dptr + 2);
auto& deq = iq4kt_quantizer();
for (int ibl = 0; ibl < nb; ++ibl) {
const uint16_t * ql = (const uint16_t *)x[ibl].ql;
auto shb = x[ibl].qs;
auto ql = (const uint8_t *)(shb + Q::kNblock);
auto qh = ql + kNumGroups;
for (int ib = 0; ib < Q::kNblock; ++ib) {
float sl = d * x[ibl].scales[ib];
float sl = d * ((const int8_t *)(shb + ib))[0];
for (int ig = 0; ig < Q::kNg; ++ig) {
deq.set_values(ql[ig], y, sl);
int jj = ib*Q::kNg+ig;
uint16_t idx = ql[jj] | ((qh[jj%(kNumGroups/2)] << (8 - 4*(jj/(kNumGroups/2)))) & 0xf00) | (((shb[ib] >> (8 + 3*ig)) & 7) << 12);
deq.set_values(idx, y, sl);
for (int j = 0; j < Q::kGroupSize; ++j) y[j] += row_av;
y += Q::kGroupSize;
}
ql += Q::kNg;
}
}
}