Checkpoint

Go to groups of 8 for iq3_kt. 2 x 8 = 16 bits for the magnitude
plus 1 bpw for the sign. It goves a visible improvement in the
PPL vs bpw plot, but that comes at the expense of much longer
quantization time (7.5 minutes for LLaMA-3.1-8B on the Ryzen-5975WX).

I also notices that the 3INST generator is not actually generating a
Gaussian distribution. But going to a better generator means
readjusting all the hyper-parameters, so leaving it for later.
This commit is contained in:
Iwan Kawrakow
2024-11-19 17:31:07 +02:00
parent 2be4cffe66
commit 3a9926b932
4 changed files with 237 additions and 129 deletions

View File

@@ -349,7 +349,12 @@ float __device__ __forceinline__ trellis_next(uint32_t& val) {
const half * h = (const half *)&s;
val = ka*val + kb;
s = (val & kmask) ^ km32;
return (float)(h[0] +h[1]);
//float r = (float)(h[0] +h[1]);
//val = ka*val + kb;
//s = (val & kmask) ^ km32;
//r += (float)(h[0]+h[1]);
//return r;
return (float)(h[0]+h[1]);
}
template<typename dst_t>
@@ -383,20 +388,42 @@ static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst
const block_iq3_kt * x = (const block_iq3_kt *)(cx + sizeof(float));
const int64_t i = ii - (row*n_per_row)/QK_K;
const int8_t * scale_values = iq4k_values + 16;
const int64_t tid = threadIdx.x;
const int64_t ib = tid; // 0...31
dst_t * y = yy + ii*QK_K + 8*ib;
uint32_t idx1 = x[i].ql[2*ib+0] + ((x[i].qh[(2*ib+0)%32] << (8-4*((2*ib+0)/32))) & 0xf00) + 4096;
uint32_t idx2 = x[i].ql[2*ib+1] + ((x[i].qh[(2*ib+1)%32] << (8-4*((2*ib+1)/32))) & 0xf00) + 4096;
const float dl = scale * scale_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.015f;
for (int j = 0; j < 4; ++j) {
y[j+0] = dl * trellis_next(idx1);
y[j+4] = dl * trellis_next(idx2);
const uint16_t * ql = (const uint16_t *)x[i].ql;
uint32_t idx = ql[ib] + 4096;
const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 31.75f * 1.01f; //1.015f;
uint8_t mask = 1 << (ib/4);
for (int j = 0; j < 8; ++j) {
y[j] = dl * std::abs(trellis_next(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f);
}
}
//template<typename dst_t>
//static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
//
// int64_t ii = blockIdx.x;
// int64_t row = (QK_K * ii) / n_per_row;
// const float * dptr = (const float *)((const char *)vx + row * row_size);
// float scale = dptr[0];
// float alpha = dptr[1];
// const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 2);
// const int64_t i = ii - (row*n_per_row)/QK_K;
//
// const int64_t tid = threadIdx.x;
// const int64_t ib = tid; // 0...31
// dst_t * y = yy + ii*QK_K + 8*ib;
// const uint16_t * ql = (const uint16_t *)x[i].ql;
// uint32_t idx = ql[ib] + 4096;
// const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 31.75f * 1.01f; //1.015f;
// uint8_t mask = 1 << (ib/4);
// for (int j = 0; j < 8; ++j) {
// float ay = std::abs(trellis_next(idx));
// y[j] = dl * ay/(1 - alpha*ay) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f);
// }
//}
template<typename dst_t>
static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {

View File

@@ -41,6 +41,54 @@ static __device__ __forceinline__ void trellis_accum(uint32_t& val1, uint32_t& v
#endif
}
//static __device__ __forceinline__ void trellis_accum(uint32_t& val1, uint32_t& val2, uint32_t* s, const dfloat2* y, dfloat2& bdot1, dfloat2& bdot2) {
// const half * h = (const half *)s;
// s[0] = trellis_next(val1);
// s[1] = trellis_next(val1);
// s[2] = trellis_next(val1);
// s[3] = trellis_next(val1);
//#ifdef GGML_CUDA_F16
// bdot1 = __hfma2(y[ 0], {h[0]+h[1]+h[2]+h[3], h[4]+h[5]+h[6]+h[7]}, bdot1);
//#else
// bdot1.x += y[ 0].x * (float)(h[0] + h[1] + h[2] + h[3]);
// bdot1.y += y[ 0].y * (float)(h[4] + h[5] + h[6] + h[7]);
//#endif
// s[0] = trellis_next(val2);
// s[1] = trellis_next(val2);
// s[2] = trellis_next(val2);
// s[3] = trellis_next(val2);
//#ifdef GGML_CUDA_F16
// bdot2 = __hfma2(y[64], {h[0]+h[1]+h[2]+h[3], h[4]+h[5]+h[6]+h[7]}, bdot2);
//#else
// bdot2.x += y[64].x * (float)(h[0] + h[1] + h[2] + h[3]);
// bdot2.y += y[64].y * (float)(h[4] + h[5] + h[6] + h[7]);
//#endif
//}
static __device__ __forceinline__ void trellis_accum_abs(uint8_t signs1, uint8_t signs2, uint8_t mask1, uint8_t mask2,
uint32_t& val1, uint32_t& val2, uint32_t* s, const dfloat2* y, dfloat2& bdot1, dfloat2& bdot2) {
const half * h = (const half *)s;
s[0] = trellis_next(val1);
s[1] = trellis_next(val1);
s[2] = trellis_next(val2);
s[3] = trellis_next(val2);
#ifdef GGML_CUDA_F16
half h00 = __habs(h[0]+h[1]), h01 = __habs(h[2]+h[3]);
half h10 = __habs(h[4]+h[5]), h11 = __habs(h[6]+h[7]);
half2 h1 = {signs1 & mask1 ? -h00 : h00, signs2 & mask1 ? -h01 : h01};
half2 h2 = {signs1 & mask2 ? -h10 : h10, signs2 & mask2 ? -h11 : h11};
//half2 h1 = __hmul2(__habs2({h[0]+h[1], h[2]+h[3]}), {signs1 & mask1 ? -1 : 1, signs2 & mask1 ? -1 : 1});
//half2 h2 = __hmul2(__habs2({h[4]+h[5], h[6]+h[7]}), {signs1 & mask2 ? -1 : 1, signs2 & mask2 ? -1 : 1});
bdot1 = __hfma2(y[ 0], h1, bdot1);
bdot2 = __hfma2(y[64], h2, bdot2);
#else
bdot1.x += y[ 0].x * fabsf((float)(h[0] + h[1])) * (signs1 & mask1 ? -1 : 1);
bdot1.y += y[ 0].y * fabsf((float)(h[2] + h[3])) * (signs2 & mask1 ? -1 : 1);
bdot2.x += y[64].x * fabsf((float)(h[4] + h[5])) * (signs1 & mask2 ? -1 : 1);
bdot2.y += y[64].y * fabsf((float)(h[6] + h[7])) * (signs2 & mask2 ? -1 : 1);
#endif
}
static __device__ __forceinline__ void trellis_accum(const dfloat2& dl1, const dfloat2& dl2, const dfloat2& bdot1, const dfloat2& bdot2, dfloat2& tmp) {
#ifdef GGML_CUDA_F16
tmp = __hfma2(dl1, bdot1, tmp);
@@ -114,25 +162,23 @@ static __global__ void dequantize_mul_mat_vec_iq3_kt(const void * __restrict__ v
uint32_t s[4];
uint8_t mask1 = 1 << (it/4);
uint8_t mask2 = mask1 << 4;
for (int i = ix; i < num_blocks_per_row; i += 2) {
const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it);
const uint8_t * ql = x[i].ql;
const uint8_t * qh = x[i].qh;
const dfloat scale1 = iq4k_values[(x[i].scales[it/4] & 0xf)+16];
const dfloat scale2 = iq4k_values[(x[i].scales[it/4] >> 4)+16];
const uint16_t * ql = (const uint16_t *)x[i].ql;
const uint8_t * qh = x[i].qh;
const dfloat scale1 = (x[i].scales[it/4] & 0xf);
const dfloat scale2 = (x[i].scales[it/4] >> 4);
const dfloat2 dl1 = {scale1, scale1};
const dfloat2 dl2 = {scale2, scale2};
dfloat2 bdot1 = {0, 0};
dfloat2 bdot2 = {0, 0};
uint32_t val1 = ql[2*it+ 0] + ((qh[2*it+0] << 8) & 0xf00) + 4096;
uint32_t val2 = ql[2*it+32] + ((qh[2*it+0] << 4) & 0xf00) + 4096;
for (int k = 0; k < 2; ++k) {
trellis_accum(val1, val2, s, y+k, bdot1, bdot2);
}
val1 = ql[2*it+ 1] + ((qh[2*it+1] << 8) & 0xf00) + 4096;
val2 = ql[2*it+33] + ((qh[2*it+1] << 4) & 0xf00) + 4096;
for (int k = 2; k < 4; ++k) {
trellis_accum(val1, val2, s, y+k, bdot1, bdot2);
uint32_t val1 = ql[it+ 0] + 4096;
uint32_t val2 = ql[it+16] + 4096;
for (int k = 0; k < 4; ++k) {
trellis_accum_abs(qh[(8*it+2*k+0)%32], qh[(8*it+2*k+1)%32], mask1, mask2, val1, val2, s, y+k, bdot1, bdot2);
}
trellis_accum(dl1, dl2, bdot1, bdot2, tmp);
}

View File

@@ -3151,7 +3151,7 @@ __m256 hsum_float_4x8(__m256 * accm) {
return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
}
#endif
template <int block_size, int group_size, int num_bits>
template <int block_size, int group_size, int num_bits, bool is_abs = false>
class QuantizerIQKT {
static_assert(group_size == 8 || group_size == 4);
static_assert(block_size >= 8 && block_size%8 == 0);
@@ -3182,12 +3182,36 @@ public:
x = ka*x + kb;
uint32_t s = (x & kmask) ^ km32;
float val = GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16);
result[k] = scale*val;
if constexpr (is_abs) result[k] = scale*std::abs(val);
else result[k] = scale*val;
}
//for (int k = 0; k < kGroupSize; ++k) {
// x = ka*x + kb;
// uint32_t s = (x & kmask) ^ km32;
// float val = GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16);
// x = ka*x + kb;
// s = (x & kmask) ^ km32;
// val += GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16);
// if constexpr (is_abs) result[k] = scale*std::abs(0.5f*val);
// else result[k] = 0.5f*scale*val;
//}
}
static inline int bin4(float x) { return x < -24.f ? 0 : x < 0.0f ? 1 : x < 24.f ? 2 : 3; }
static inline int bin5(float x) { return x < -48.f ? 0 : x < -16.f ? 1 : x < 16.f ? 2 : x < 48.f ? 3 : 4; }
static inline int bin4(float x) {
if constexpr (is_abs) {
return x < 16.f ? 0 : x < 32.f ? 1 : x < 64.f ? 2 : 3;
} else {
return x < -24.f ? 0 : x < 0.0f ? 1 : x < 24.f ? 2 : 3;
}
}
static inline int bin5(float x) {
if constexpr (is_abs) {
return x < 11.2f ? 0 : x < 24.f ? 1 : x < 39.f ? 2 : x < 58.f ? 3 : 4;
} else {
return x < -48.f ? 0 : x < -16.f ? 1 : x < 16.f ? 2 : x < 48.f ? 3 : 4;
}
}
inline int bin3(int idim, float x) const { return x < m_mid[2*idim+0] ? 0 : x < m_mid[2*idim+1] ? 1 : 2; }
static inline void set_weights(float sigma2_scale, int nblock, const float * x, const float * imatrix, float * row_weights) {
for (int ibl = 0; ibl < nblock; ++ibl) {
@@ -3215,11 +3239,11 @@ private:
std::vector<float> m_clusters;
std::vector<std::vector<int>> m_in_cluster;
std::vector<std::vector<float>> m_c_values;
float m_mid[kGroupSize];
float m_mid[4*kGroupSize];
};
template <int block_size, int group_size, int num_bits>
QuantizerIQKT<block_size, group_size, num_bits>::QuantizerIQKT(int num_clusters, int num_neighbours, int offset) {
template <int block_size, int group_size, int num_bits, bool is_abs>
QuantizerIQKT<block_size, group_size, num_bits, is_abs>::QuantizerIQKT(int num_clusters, int num_neighbours, int offset) {
m_values.resize(kNumVal*kGroupSize);
float * data = m_values.data();
for (int i = 0; i < kNumVal; ++i) {
@@ -3234,8 +3258,8 @@ QuantizerIQKT<block_size, group_size, num_bits>::QuantizerIQKT(int num_clusters,
m_in_cluster = finalize_clusters(num_neighbours, m_values, m_clusters, m_c_values);
}
template <int block_size, int group_size, int num_bits>
std::pair<float, float> QuantizerIQKT<block_size, group_size, num_bits>::find_best_scale(
template <int block_size, int group_size, int num_bits, bool is_abs>
std::pair<float, float> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_scale(
const float * xb, const float * weight, const int * best_idx) const {
float sumqx = 0, sumq2 = 0;
#ifdef __AVX2__
@@ -3267,8 +3291,8 @@ std::pair<float, float> QuantizerIQKT<block_size, group_size, num_bits>::find_be
return sumq2 > 0 ? std::make_pair(sumqx/sumq2, sumqx*sumqx/sumq2) : std::make_pair(0.f, 0.f);
}
template <int block_size, int group_size, int num_bits>
float QuantizerIQKT<block_size, group_size, num_bits>::find_best_inverse_scale(
template <int block_size, int group_size, int num_bits, bool is_abs>
float QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_inverse_scale(
const float * xb, const float * weight, const int * best_idx) const {
float sumqx = 0, sumx2 = 0;
#ifdef __AVX2__
@@ -3300,8 +3324,8 @@ float QuantizerIQKT<block_size, group_size, num_bits>::find_best_inverse_scale(
return sumx2 > 0 ? sumqx/sumx2 : 0.f;
}
template <int block_size, int group_size, int num_bits>
void QuantizerIQKT<block_size, group_size, num_bits>::find_best_match(float d, const float * xb, const float * weight, int * best_idx) const {
template <int block_size, int group_size, int num_bits, bool is_abs>
void QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_match(float d, const float * xb, const float * weight, int * best_idx) const {
if (!d) {
std::memset(best_idx, 0, kNg*sizeof(int));
return;
@@ -3322,10 +3346,15 @@ void QuantizerIQKT<block_size, group_size, num_bits>::find_best_match(float d, c
auto vx = _mm256_mul_ps(vid, _mm256_loadu_ps(xl));
auto vw = _mm256_loadu_ps(wl);
int jbest = -1;
if (kGroupSize == 8 && ncluster == 256) {
if (kGroupSize == 8 && (ncluster == 256 || ncluster == 6561)) {
_mm256_store_ps(sx, vx);
uint8_t u = 0;
for (int j = 0; j < 8; ++j) if (sx[j] > m_mid[j]) u |= (1 << j);
uint16_t u = 0;
if (ncluster == 256) {
for (int j = 0; j < 8; ++j) if (sx[j] > m_mid[j]) u |= (1 << j);
} else {
int s = 1;
for (int j = 0; j < 8; ++j) { u += s*bin3(j, sx[j]); s *= 3; }
}
jbest = u;
} else {
auto vbest = _mm256_set1_ps(INFINITY);
@@ -3352,14 +3381,16 @@ void QuantizerIQKT<block_size, group_size, num_bits>::find_best_match(float d, c
}
}
auto& points = m_in_cluster[jbest];
auto& values = m_c_values[jbest];
GGML_ASSERT(!points.empty() && points.size()%8 == 0);
auto& values = points.empty() ? m_values : m_c_values[jbest];
int npoint = values.size()/kGroupSize;
//if (points.empty() || points.size()%8 != 0) printf("Oops: %d points in cluster %d\n", int(points.size()), jbest);
GGML_ASSERT(npoint > 0 && npoint%8 == 0);
int jbest_cluster = jbest;
auto vbest = _mm256_set1_ps(INFINITY);
auto best_index = _mm256_set1_epi32(-1);
auto best = INFINITY; jbest = -1;
auto idx = add_idx;
for (int j = 0; j < int(points.size()); j += 8) {
for (int j = 0; j < npoint; j += 8) {
for (int i = 0; i < 8; ++i) {
auto vq = _mm256_loadu_ps(values.data() + kGroupSize*(j+i));
auto vdiff = _mm256_sub_ps(vq, vx);
@@ -3381,7 +3412,7 @@ void QuantizerIQKT<block_size, group_size, num_bits>::find_best_match(float d, c
fprintf(stderr, "Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size()));
GGML_ASSERT(false);
}
best_idx[l] = points[jbest];
best_idx[l] = points.empty() ? jbest : points[jbest];
}
} else {
__m256 sqx[4];
@@ -3478,8 +3509,8 @@ void QuantizerIQKT<block_size, group_size, num_bits>::find_best_match(float d, c
#endif
}
template <int block_size, int group_size, int num_bits>
std::vector<std::vector<int>> QuantizerIQKT<block_size, group_size, num_bits>::finalize_clusters(int num_neighbours,
template <int block_size, int group_size, int num_bits, bool is_abs>
std::vector<std::vector<int>> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::finalize_clusters(int num_neighbours,
const std::vector<float>& values, const std::vector<float>& clusters, std::vector<std::vector<float>>& c_values) {
int ncluster = clusters.size()/kGroupSize;
//GGML_ASSERT(ncluster%8 == 0);
@@ -3566,8 +3597,8 @@ std::vector<std::vector<int>> QuantizerIQKT<block_size, group_size, num_bits>::f
return p_in_cluster;
}
template <int block_size, int group_size, int num_bits>
std::vector<float> QuantizerIQKT<block_size, group_size, num_bits>::cluster_points(const std::vector<float>& points, int ncluster, int niter, float * mid) {
template <int block_size, int group_size, int num_bits, bool is_abs>
std::vector<float> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::cluster_points(const std::vector<float>& points, int ncluster, int niter, float * mid) {
constexpr int ndim = kGroupSize;
GGML_ASSERT(points.size() % ndim == 0);
int npoint = points.size() / ndim;
@@ -3583,17 +3614,40 @@ std::vector<float> QuantizerIQKT<block_size, group_size, num_bits>::cluster_poin
}
}
if (kVerbose) printf("%s (ndim = %d, npoint = %d): Fo = %g\n", __func__, ndim, npoint, Fo/points.size());
for (int k = 0; k < ndim; ++k) mid[k] = 0.5f*(range[k].first + range[k].second);
if constexpr (is_abs) {
std::vector<int> P(npoint);
for (int idim = 0; idim < ndim; ++idim) {
for (int ip = 0; ip < npoint; ++ip) P[ip] = points[ip*ndim+idim];
std::sort(P.begin(), P.end());
if (ndim == 8 && ncluster == 6561) {
mid[2*idim + 0] = P[npoint/3];
mid[2*idim + 1] = P[2*npoint/3];
} else {
mid[idim] = npoint%2 == 0 ? 0.5f*(P[npoint/2] + P[npoint/2-1]) : P[npoint/2];
if (kVerbose) printf("%s: mid[%d] = %g\n", __func__, idim, mid[idim]);
}
}
} else {
for (int k = 0; k < ndim; ++k) mid[k] = 0.5f*(range[k].first + range[k].second);
}
std::vector<float> sump(ncluster*ndim);
std::vector<int> counts(ncluster);
std::vector<float> result(ncluster*ndim);
if (ndim == 8 && ncluster == 256) {
if (ndim == 8 && (ncluster == 256 || ncluster == 6561)) {
std::memset(sump.data(), 0, sump.size()*sizeof(float));
std::memset(counts.data(), 0, counts.size()*sizeof(int));
for (int ip = 0; ip < npoint; ++ip) {
auto vp = points.data() + ndim*ip;
uint8_t u = 0;
for (int k = 0; k < ndim; ++k) if (vp[k] > mid[k]) u |= (1 << k);
uint16_t u = 0;
if (ncluster == 256) {
for (int k = 0; k < ndim; ++k) if (vp[k] > mid[k]) u |= (1 << k);
} else {
int s = 1;
for (int k = 0; k < ndim; ++k) {
int bin = vp[k] < mid[2*k+0] ? 0 : vp[k] < mid[2*k+1] ? 1 : 2;
u += s*bin; s *= 3;
}
}
++counts[u];
for (int k = 0; k < ndim; ++k) sump[ndim*u + k] += vp[k];
}
@@ -3695,6 +3749,11 @@ std::vector<float> QuantizerIQKT<block_size, group_size, num_bits>::cluster_poin
best = dist2; ibest = ic;
}
}
if (ibest < 0) {
printf("Oops(iteration %d) - failed to find cluster for point", iter);
for (int k = 0; k < ndim; ++k) printf(" %g", vp[k]);
printf("\nHave %d clusters\n", ncluster);
}
GGML_ASSERT(ibest >= 0);
F += best;
if (which_cluster[ip] != ibest) ++nchanged;
@@ -3714,6 +3773,11 @@ std::vector<float> QuantizerIQKT<block_size, group_size, num_bits>::cluster_poin
if (iter > 1 && Flast/F - 1 < 1e-6) break;
Flast = F;
}
int nzero = 0;
for (int ic = 0; ic < ncluster; ++ic) {
if (!counts[ic]) ++nzero;
}
if (nzero > 0) printf("%s: there are %d empty clusters\n", __func__, nzero);
return result;
}
@@ -3975,12 +4039,12 @@ void vec_dot_iq2_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx
namespace {
using QuantizerIQ3KT = QuantizerIQKT<32, 4, 12>;
using QuantizerIQ3KT = QuantizerIQKT<32, 8, 16, true>;
const QuantizerIQ3KT& iq3kt_quantizer() {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
static std::unique_ptr<QuantizerIQ3KT> quantizer;
if (!quantizer) quantizer = std::make_unique<QuantizerIQ3KT>(256, 16);
if (!quantizer) quantizer = std::make_unique<QuantizerIQ3KT>(256, 8);
return *quantizer;
}
@@ -3988,7 +4052,7 @@ void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const f
float * all_weights, float * qtmp) {
constexpr float kSigmaScale = 2.0f;
constexpr float kStep = 4.0f;
constexpr float kStep = 8.0f;
using Q = QuantizerIQ3KT;
@@ -4018,6 +4082,8 @@ void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const f
float amax_scale = 0, max_scale = 0;
float xaux[Q::kBlockSize];
for (int ibl = 0; ibl < nblock; ++ibl) {
memset(&y[ibl], 0, sizeof(block_iq3_kt));
@@ -4031,27 +4097,24 @@ void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const f
float amax = 0;
for (int j = 0; j < Q::kBlockSize; ++j) {
float ax = std::abs(xb[j]);
xaux[j] = ax;
amax = std::max(amax, ax);
}
scales[ib] = 0;
if (!amax) continue;
float scale_0 = std::max(80.f, 123.f*amax/amax_row);
//float scale_0 = 80.f;
//quantizer.find_best_match(amax/96.f, xaux, weight, best_idx+Q::kNg);
//scales[ib] = quantizer.find_best_scale(xaux, weight, best_idx+Q::kNg).first;
float scale_0 = std::max(84.f, 123.f*amax/amax_row);
//float scale_0 = std::max(64.f, 123.f*amax/amax_row);
float best = 0;
for (int itry = -5; itry <= 5; ++itry) {
quantizer.find_best_match(amax/(scale_0 + kStep*itry), xb, weight, best_idx);
auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx);
if (score_p > best) {
best = score_p;
scales[ib] = dp;
std::memcpy(best_idx+Q::kNg, best_idx, Q::kNg*sizeof(int));
}
quantizer.find_best_match(-amax/(scale_0 + kStep*itry), xb, weight, best_idx);
auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx);
if (score_m > best) {
best = score_m;
scales[ib] = dm;
for (int itry = -3; itry <= 3; ++itry) {
quantizer.find_best_match(amax/(scale_0 + kStep*itry), xaux, weight, best_idx);
auto [d, score] = quantizer.find_best_scale(xaux, weight, best_idx);
if (score > best) {
best = score;
scales[ib] = d;
std::memcpy(best_idx+Q::kNg, best_idx, Q::kNg*sizeof(int));
}
}
@@ -4071,12 +4134,11 @@ void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const f
}
auto scale_values = iq4k_values + 16;
float d = max_scale/scale_values[0];
GGML_ASSERT(max_scale >= 0);
float d = max_scale/15;
float best = 0;
for (int itry = -9; itry <= 9; ++itry) {
float id = (itry + scale_values[0])/max_scale;
float id = (itry*0.2f + 15)/max_scale;
float sumqx = 0, sumq2 = 0;
for (int ibl = 0; ibl < nblock; ++ibl) {
const float * xb = x + ibl*Q::kSuperBlockSize;
@@ -4084,11 +4146,12 @@ void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const f
const float * wb = all_weights + ibl*Q::kSuperBlockSize;
auto scales = all_scales + ibl*Q::kNblock;
for (int ib = 0; ib < Q::kNblock; ++ib) {
int ls = best_index_iq4nl(scale_values, id*scales[ib]);
float dl = scale_values[ls];
int ls = nearest_int(id*scales[ib]);
ls = std::max(0, std::min(15, ls));
float dl = ls;
for (int j = 0; j < Q::kBlockSize; ++j) {
float q = dl*qb[j];
sumqx += wb[j]*xb[j]*q;
sumqx += wb[j]*std::abs(xb[j])*q;
sumq2 += wb[j]*q*q;
}
xb += Q::kBlockSize;
@@ -4101,20 +4164,16 @@ void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const f
}
}
//float d = -max_scale/128;
float id = d ? 1/d : 0.f;
for (int ibl = 0; ibl < nblock; ++ibl) {
auto scales = all_scales + ibl*Q::kNblock;
for (int ib = 0; ib < Q::kNblock/2; ++ib) {
int ls1 = best_index_iq4nl(scale_values, id*scales[ib]);
int ls2 = best_index_iq4nl(scale_values, id*scales[ib + Q::kNblock/2]);
int ls1 = nearest_int(id*scales[ib]);
int ls2 = nearest_int(id*scales[ib + Q::kNblock/2]);
ls1 = std::max(0, std::min(15, ls1));
ls2 = std::max(0, std::min(15, ls2));
y[ibl].scales[ib] = ls1 | (ls2 << 4);
}
//int8_t * sv = (int8_t *)y[ibl].scales;
//for (int ib = 0; ib < Q::kNblock; ++ib) {
// int ls = nearest_int(id*scales[ib]);
// sv[ib] = std::min(ls, 127);
//}
}
*dptr = d;
@@ -4124,23 +4183,25 @@ void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const f
float sumqx = 0, sumq2 = 0;
for (int ibl = 0; ibl < nblock; ++ibl) {
uint16_t * ql = (uint16_t *)y[ibl].ql;
std::memset(y[ibl].qh, 0, kNumGroups/2);
const float * xbl = x + ibl*Q::kSuperBlockSize;
//int8_t * sv = (int8_t *)y[ibl].scales;
for (int ib = 0; ib < Q::kNblock; ++ib) {
const float * xb = xbl + Q::kBlockSize*ib;
const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
//int ls = sv[ib];
int ls = scale_values[((y[ibl].scales[ib%(Q::kNblock/2)] >> 4*(ib/(Q::kNblock/2))) & 0xf)];
for (int j = 0; j < Q::kBlockSize; ++j) {
xaux[j] = std::abs(xb[j]);
if (xb[j] < 0) y[ibl].qh[j] |= (1 << ib);
}
int ls = (y[ibl].scales[ib%(Q::kNblock/2)] >> 4*(ib/(Q::kNblock/2))) & 0xf;
float dl = d*ls;
quantizer.find_best_match(dl, xb, weight, best_idx);
quantizer.find_best_match(dl, xaux, weight, best_idx);
for (int j = 0; j < Q::kNg; ++j) {
int jj = ib*Q::kNg + j;
y[ibl].ql[jj] = best_idx[j] & 255;
y[ibl].qh[jj%(kNumGroups/2)] |= ((best_idx[j] >> 8) << 4*(jj/(kNumGroups/2)));
auto xl = xb + Q::kGroupSize*j;
ql[ib*Q::kNg+j] = best_idx[j];
auto xl = xaux + Q::kGroupSize*j;
auto wl = weight + Q::kGroupSize*j;
auto ql = quantizer.values() + best_idx[j]*Q::kGroupSize;
for (int k = 0; k < Q::kGroupSize; ++k) {
@@ -4159,32 +4220,6 @@ void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const f
break;
}
}
//float aux[Q::kGroupSize];
//float sumd = 0, sumw = 0, mse = 0;
//for (int ibl = 0; ibl < nblock; ++ibl) {
// const float * xbl = x + ibl*Q::kSuperBlockSize;
// for (int ib = 0; ib < Q::kNblock; ++ib) {
// const float * xb = xbl + Q::kBlockSize*ib;
// const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize;
// int ls = scale_values[((y[ibl].scales[ib%(Q::kNblock/2)] >> 4*(ib/(Q::kNblock/2))) & 0xf)];
// float dl = d*ls*Q::kScale;
// for (int j = 0; j < Q::kNg; ++j) {
// int jj = ib*Q::kNg + j;
// int idx = y[ibl].ql[jj] + ((y[ibl].qh[jj%(kNumGroups/2)] << (8 - 4*(jj/(kNumGroups/2)))) & 0xf00);
// quantizer.set_values(idx, aux, dl);
// auto xl = xb + Q::kGroupSize*j;
// auto wl = weight + Q::kGroupSize*j;
// for (int k = 0; k < Q::kGroupSize; ++k) {
// float diff = xl[k] - aux[k];
// sumw += wl[k];
// sumd += wl[k]*diff;
// mse += diff*diff;
// }
// }
// }
//}
//printf("rmse = %g, delta = %g, %g\n", sqrt(mse/n_per_row), sumd/sumw, sumd/sumw/amax_row);
}
}
@@ -4220,26 +4255,26 @@ void dequantize_row_iq3_kt(const block_iq3_kt * x, float * y, int64_t k) {
const int nb = k / Q::kSuperBlockSize;
const float * dptr = (const float *)x;
const float d = *dptr * Q::kScale;
const int8_t * scale_values = iq4k_values + 16;
x = (const block_iq3_kt *)(dptr + 1);
auto& deq = iq3kt_quantizer();
for (int ibl = 0; ibl < nb; ++ibl) {
auto yl = y + ibl*Q::kSuperBlockSize;
auto yh = yl + Q::kSuperBlockSize/2;
auto qll = x[ibl].ql;
auto qll = (const uint16_t *)x[ibl].ql;
auto qlh = qll + kNumGroups/2;
//const int8_t * sv = (const int8_t *)x[ibl].scales;
int jj = 0;
for (int ib = 0; ib < Q::kNblock/2; ++ib) {
//float sl = d * sv[ib];
//float sh = d * sv[ib+Q::kNblock/2];
float sl = d * scale_values[(x[ibl].scales[ib] & 0xf)];
float sh = d * scale_values[(x[ibl].scales[ib] >> 4)];
float sl = d * (x[ibl].scales[ib] & 0xf);
float sh = d * (x[ibl].scales[ib] >> 4);
uint8_t l_mask = 1 << ib;
uint8_t h_mask = l_mask << (Q::kNblock/2);
for (int ig = 0; ig < Q::kNg; ++ig) {
uint16_t ul = qll[jj] | ((x[ibl].qh[jj] << 8) & 0xf00);
uint16_t uh = qlh[jj] | ((x[ibl].qh[jj] << 4) & 0xf00);
deq.set_values(ul, yl, sl);
deq.set_values(uh, yh, sh);
deq.set_values(qll[jj], yl, sl);
deq.set_values(qlh[jj], yh, sh);
for (int j = 0; j < Q::kGroupSize; ++j) {
if (x[ibl].qh[ig*Q::kGroupSize+j] & l_mask) yl[j] = -yl[j];
if (x[ibl].qh[ig*Q::kGroupSize+j] & h_mask) yh[j] = -yh[j];
}
yl += Q::kGroupSize;
yh += Q::kGroupSize;
++jj;

View File

@@ -15976,7 +15976,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
} else {
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S;
//else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT ) new_type = GGML_TYPE_IQ3_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT && qs.model.hparams.n_gqa() >= 4) new_type = GGML_TYPE_IQ3_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M ) new_type = GGML_TYPE_IQ4_K;