mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-21 15:09:40 +00:00
iq4_kss: WIP
This commit is contained in:
@@ -256,6 +256,8 @@ static void analyze_iq4ks(const char * name, int nrows, int n_per_row, const flo
|
||||
float mse0 = 0, mse = 0;
|
||||
auto compute = [&mutex, &counter, &mse0, &mse, values, row_size, nblock, nrows, n_per_row, chunk] () {
|
||||
std::vector<char> Q(row_size);
|
||||
float diff[4];
|
||||
float xv[4];
|
||||
float lmse0 = 0, lmse = 0;
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
@@ -282,25 +284,41 @@ static void analyze_iq4ks(const char * name, int nrows, int n_per_row, const flo
|
||||
for (int j = 0; j < 16; j += 2) {
|
||||
uint16_t v0 = *(const uint16_t *)(qs + j);
|
||||
int non = popcount(v0);
|
||||
float diff1 = xb[j+ 0] - dl*values[qs[j+0] & 0xf];
|
||||
float diff2 = xb[j+16] - dl*values[qs[j+0] >> 4];
|
||||
float diff3 = xb[j+ 1] - dl*values[qs[j+1] & 0xf];
|
||||
float diff4 = xb[j+17] - dl*values[qs[j+1] >> 4];
|
||||
lmse0 += diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4;
|
||||
xv[0] = xb[j+ 0]; xv[1] = xb[j+16]; xv[2] = xb[j+ 1]; xv[3] = xb[j+17];
|
||||
diff[0] = xv[0] - dl*values[qs[j+0] & 0xf];
|
||||
diff[1] = xv[1] - dl*values[qs[j+0] >> 4];
|
||||
diff[2] = xv[2] - dl*values[qs[j+1] & 0xf];
|
||||
diff[3] = xv[3] - dl*values[qs[j+1] >> 4];
|
||||
float diff4 = diff[0]*diff[0] + diff[1]*diff[1] + diff[2]*diff[2] + diff[3]*diff[3];
|
||||
lmse0 += diff4;
|
||||
if (non%2 == 0) {
|
||||
lmse += diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4;
|
||||
lmse += diff4;
|
||||
} else {
|
||||
float best = std::numeric_limits<float>::max();
|
||||
for (int k = 0; k < 16; k += 4) {
|
||||
uint16_t v = v0 ^ (1 << k);
|
||||
uint8_t v1 = v;
|
||||
uint8_t v2 = v >> 8;
|
||||
diff1 = xb[j+ 0] - dl*values[v1 & 0xf];
|
||||
diff2 = xb[j+16] - dl*values[v1 >> 4];
|
||||
diff3 = xb[j+ 1] - dl*values[v2 & 0xf];
|
||||
diff4 = xb[j+17] - dl*values[v2 >> 4];
|
||||
float score = diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4;
|
||||
if (score < best) best = score;
|
||||
//for (int k = 0; k < 16; k += 4) {
|
||||
// uint16_t v = v0 ^ (1 << k);
|
||||
// uint8_t v1 = v;
|
||||
// uint8_t v2 = v >> 8;
|
||||
// diff1 = xb[j+ 0] - dl*values[v1 & 0xf];
|
||||
// diff2 = xb[j+16] - dl*values[v1 >> 4];
|
||||
// diff3 = xb[j+ 1] - dl*values[v2 & 0xf];
|
||||
// diff4 = xb[j+17] - dl*values[v2 >> 4];
|
||||
// float score = diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4;
|
||||
// if (score < best) best = score;
|
||||
//}
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
uint16_t v = (v0 >> 4*k) & 0xf;
|
||||
auto pc = popcount(v);
|
||||
if (v > 0 && popcount(v-1u) != pc) {
|
||||
float this_diff = xv[k] - dl*values[v-1u];
|
||||
float score = diff4 - diff[k]*diff[k] + this_diff*this_diff;
|
||||
if (score < best) best = score;
|
||||
}
|
||||
if (v < 15 && popcount(v + 1u) != pc) {
|
||||
float this_diff = xv[k] - dl*values[v+1u];
|
||||
float score = diff4 - diff[k]*diff[k] + this_diff*this_diff;
|
||||
if (score < best) best = score;
|
||||
}
|
||||
}
|
||||
lmse += best;
|
||||
}
|
||||
|
||||
@@ -405,6 +405,7 @@ extern "C" {
|
||||
GGML_TYPE_IQ1_TN = 143,
|
||||
GGML_TYPE_IQ4_KS = 144,
|
||||
GGML_TYPE_IQ2_KS = 145,
|
||||
GGML_TYPE_IQ4_KSS = 146,
|
||||
GGML_TYPE_COUNT,
|
||||
};
|
||||
|
||||
@@ -462,6 +463,7 @@ extern "C" {
|
||||
GGML_FTYPE_MOSTLY_IQ1_TN = 136, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ4_KS = 137, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ2_KS = 138, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ4_KSS = 139, // except 1d tensors
|
||||
};
|
||||
|
||||
// available tensor operations:
|
||||
|
||||
@@ -447,6 +447,11 @@ typedef struct {
|
||||
} block_iq4_ks;
|
||||
static_assert(sizeof(block_iq4_ks) == QK_K/32 + QK_K/2, "wrong iq4_ks block size/padding");
|
||||
|
||||
typedef struct {
|
||||
uint32_t qs[QK_K/8];
|
||||
} block_iq4_kss;
|
||||
static_assert(sizeof(block_iq4_kss) == QK_K/8*sizeof(uint32_t), "wrong iq4_kss block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_half d;
|
||||
uint16_t extra;
|
||||
|
||||
@@ -1100,6 +1100,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.nrows = 1,
|
||||
.row_meta_size = 4,
|
||||
},
|
||||
[GGML_TYPE_IQ4_KSS] = {
|
||||
.type_name = "iq4_kss",
|
||||
.blck_size = QK_K,
|
||||
.type_size = sizeof(block_iq4_kss),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_iq4_kss,
|
||||
.from_float = quantize_row_iq4_kss,
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_kss_ref,
|
||||
.vec_dot = vec_dot_iq4_kss_q8_k,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
.row_meta_size = 4,
|
||||
},
|
||||
[GGML_TYPE_Q8_K] = {
|
||||
.type_name = "q8_K",
|
||||
.blck_size = QK_K,
|
||||
@@ -3918,6 +3931,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
||||
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ4_KS: wtype = GGML_TYPE_IQ4_KS; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ4_KSS: wtype = GGML_TYPE_IQ4_KSS; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ2_KS: wtype = GGML_TYPE_IQ2_KS; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ3_K: wtype = GGML_TYPE_IQ3_K; break;
|
||||
@@ -10419,6 +10433,7 @@ static void ggml_compute_forward_add(
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
case GGML_TYPE_IQ3_K:
|
||||
@@ -10809,6 +10824,7 @@ static void ggml_compute_forward_add1(
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
case GGML_TYPE_IQ3_K:
|
||||
@@ -10949,6 +10965,7 @@ static void ggml_compute_forward_acc(
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
case GGML_TYPE_IQ3_K:
|
||||
@@ -14135,6 +14152,7 @@ static void ggml_compute_forward_out_prod(
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
case GGML_TYPE_IQ3_K:
|
||||
@@ -14515,6 +14533,7 @@ static void ggml_compute_forward_set(
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
case GGML_TYPE_IQ3_K:
|
||||
@@ -14789,6 +14808,7 @@ static void ggml_compute_forward_get_rows(
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
case GGML_TYPE_IQ3_K:
|
||||
@@ -15390,6 +15410,7 @@ static void ggml_compute_forward_clamp(
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
case GGML_TYPE_IQ3_K:
|
||||
@@ -22208,6 +22229,7 @@ size_t ggml_quantize_chunk(
|
||||
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ4_KS: result = quantize_iq4_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ4_KSS: result = quantize_iq4_kss(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ2_KS: result = quantize_iq2_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ3_K: result = quantize_iq3_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
|
||||
@@ -20,6 +20,25 @@
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <mutex>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#include <intrin.h>
|
||||
#include <ammintrin.h>
|
||||
#include <nmmintrin.h>
|
||||
#include <immintrin.h>
|
||||
#include <stdlib.h>
|
||||
inline int popcount(uint8_t x) { return __popcnt(x); }
|
||||
inline int popcount(uint16_t x) { return __popcnt(x); }
|
||||
inline int popcount(uint32_t x) { return __popcnt(x); }
|
||||
inline int popcount(uint64_t x) { return _mm_popcnt_u64(x); }
|
||||
#else
|
||||
constexpr int popcount(uint8_t x) { return __builtin_popcount(x); }
|
||||
constexpr int popcount(uint16_t x) { return __builtin_popcount(x); }
|
||||
constexpr int popcount(uint32_t x) { return __builtin_popcount(x); }
|
||||
constexpr int popcount(uint64_t x) { return __builtin_popcountll(x); }
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -2811,3 +2830,233 @@ void vec_dot_iq4_ks_q8_k(int n, float * s, size_t bs, const void * vx, size_t b
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
namespace {
|
||||
const uint16_t * scramble_table() {
|
||||
static std::mutex mutex;
|
||||
static std::vector<uint16_t> table;
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
if (table.empty()) {
|
||||
table.resize(1 << 15);
|
||||
for (int i = 0; i < int(table.size()); ++i) {
|
||||
uint16_t val = i;
|
||||
int non = popcount(val);
|
||||
if (non%2) val |= (1 << 15);
|
||||
for (int j = 0; j < int(table.size()); ++j) {
|
||||
if ((j ^ (j << 1)) == val) {
|
||||
//printf("%5d %5u %5d\n", i, val, j);
|
||||
table[i] = j; break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return table.data();
|
||||
}
|
||||
uint16_t prune_iq4ks(uint16_t v, const int8_t * values, const float * x, const float * w, float dl) {
|
||||
if (popcount(v)%2 == 0) return v;
|
||||
float best_score = std::numeric_limits<float>::max();
|
||||
uint8_t q4[4];
|
||||
int jbest = -1;
|
||||
uint8_t bestq = 0;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
uint8_t q = (v >> 4*j) & 0xf;
|
||||
q4[j] = q;
|
||||
auto pc = popcount(q);
|
||||
float diff0 = dl*iq4k_values[q] - x[j];
|
||||
if (q > 0) {
|
||||
uint8_t qm = q - 1u;
|
||||
int pcm = popcount(qm);
|
||||
if (pcm == pc-1 || pcm == pc+1) {
|
||||
float diff1 = dl*values[qm] - x[j];
|
||||
float score = w[j]*(diff1*diff1 - diff0*diff0);
|
||||
if (score < best_score) {
|
||||
best_score = score; jbest = j; bestq = qm;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (q < 15) {
|
||||
uint8_t qp = q + 1u;
|
||||
int pcp = popcount(qp);
|
||||
if (pcp == pc-1 || pcp == pc+1) {
|
||||
float diff1 = dl*values[qp] - x[j];
|
||||
float score = w[j]*(diff1*diff1 - diff0*diff0);
|
||||
if (score < best_score) {
|
||||
best_score = score; jbest = j; bestq = qp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(jbest >= 0);
|
||||
q4[jbest] = bestq;
|
||||
return (q4[0] | (q4[1] << 4) | (q4[2] << 8) | (q4[3] << 12));
|
||||
}
|
||||
void prune_iq4ks_to_iq4kss(int n_per_row, const uint16_t * table, const char * cx, const float * x, char *cy,
|
||||
const float * quant_weights, float * weight, float * all_scales) {
|
||||
constexpr int kBlockSize = 32;
|
||||
const float * dptr_ks = (const float *)cx;
|
||||
const float d_ks = *dptr_ks;
|
||||
const block_iq4_ks * iq4ks = (const block_iq4_ks *)(dptr_ks + 1);
|
||||
float * dptr = (float *)cy;
|
||||
*dptr = d_ks;
|
||||
block_iq4_kss * y = (block_iq4_kss *)(dptr + 1);
|
||||
int nblock = n_per_row/QK_K;
|
||||
float max_abs_scale = 0;
|
||||
for (int ibl = 0; ibl < nblock; ++ibl) {
|
||||
auto scales = all_scales + ibl*(QK_K/kBlockSize);
|
||||
const float * xbl = x + ibl*QK_K;
|
||||
float sigma2 = 0;
|
||||
for (int j = 0; j < QK_K; ++j) sigma2 += xbl[j]*xbl[j];
|
||||
sigma2 *= 2.f/QK_K;
|
||||
for (int ib = 0; ib < QK_K/kBlockSize; ++ib) {
|
||||
const float * xb = xbl + ib*kBlockSize;
|
||||
if (quant_weights) {
|
||||
const float * qw = quant_weights + ibl*QK_K + ib*kBlockSize;
|
||||
for (int j = 0; j < kBlockSize; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
|
||||
} else {
|
||||
for (int j = 0; j < kBlockSize; ++j) weight[j] = xb[j]*xb[j];
|
||||
}
|
||||
const int8_t * values = iq4k_values + ((iq4ks[ibl].scales[ib] & 1) << 4);
|
||||
float dl = d_ks * ((iq4ks[ibl].scales[ib] & 254) - 127);
|
||||
float sumqx = 0, sumq2 = 0;
|
||||
for (int k = 0; k < kBlockSize/8; ++k) {
|
||||
uint16_t vl = 0, vh = 0;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
uint8_t ql = iq4ks[ibl].qs[(kBlockSize/2)*ib + (kBlockSize/8)*k + j] & 0xf;
|
||||
uint8_t qh = iq4ks[ibl].qs[(kBlockSize/2)*ib + (kBlockSize/8)*k + j] >> 4;
|
||||
vl |= ql << 4*j;
|
||||
vh |= qh << 4*j;
|
||||
}
|
||||
auto vlp = prune_iq4ks(vl, values, xb + 4*k, weight + 4*k, dl);
|
||||
auto vhp = prune_iq4ks(vh, values, xb + 4*k + kBlockSize/2, weight + 4*k + kBlockSize/2, dl);
|
||||
auto vlp_s = table[vlp & 0x7fff];
|
||||
auto vhp_s = table[vhp & 0x7fff];
|
||||
y[ibl].qs[(kBlockSize/8)*ib + k] = vlp_s | (vhp_s << 15) | (((iq4ks[ibl].scales[ib] >> 2*k) & 3) << 30);
|
||||
if ((vlp_s ^ (vlp_s << 1)) != vlp) {
|
||||
printf("Oops(l): vl = %u (%d) vlp = %u (%d) vlp_s = %u check = %u\n", vl, popcount(vl), vlp, popcount(vlp), vlp_s, vlp_s ^ (vlp_s << 1));
|
||||
exit(1);
|
||||
}
|
||||
if ((vhp_s ^ (vhp_s << 1)) != vhp) {
|
||||
printf("Oops(h): vh = %u (%d) vhp = %u (%d) vhp_s = %u check = %u\n", vh, popcount(vh), vhp, popcount(vhp), vhp_s, vhp_s ^ (vhp_s << 1));
|
||||
exit(1);
|
||||
}
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
float ql = values[(vlp >> 4*j) & 0xf];
|
||||
float qh = values[(vhp >> 4*j) & 0xf];
|
||||
sumqx += weight[4*k+j]*ql*xb[4*k+j] + weight[4*k+j+16]*qh*xb[4*k+j+16];
|
||||
sumq2 += weight[4*k+j]*ql*ql + weight[4*k+j+16]*qh*qh;
|
||||
}
|
||||
//for (int j = 0; j < 4; ++j) {
|
||||
// uint8_t ql = iq4ks[ibl].qs[(kBlockSize/2)*ib + (kBlockSize/8)*k + j] & 0xf;
|
||||
// uint8_t qh = iq4ks[ibl].qs[(kBlockSize/2)*ib + (kBlockSize/8)*k + j] >> 4;
|
||||
// uint8_t qlp = (vlp >> 4*j) & 0xf;
|
||||
// uint8_t qhp = (vhp >> 4*j) & 0xf;
|
||||
// float diffl = dl*values[ql] - xb[4*k+j];
|
||||
// float diffh = dl*values[qh] - xb[4*k+j+16];
|
||||
// float difflp = dl*values[qlp] - xb[4*k+j];
|
||||
// float diffhp = dl*values[qhp] - xb[4*k+j+16];
|
||||
// printf("%d %d %d: %u %u %u %u %g %g %g %g\n", ib, k, j, ql, qh, qlp, qhp, diffl, diffh, difflp, diffhp);
|
||||
//}
|
||||
}
|
||||
scales[ib] = sumq2 > 0 ? sumqx/sumq2 : dl;
|
||||
max_abs_scale = std::max(max_abs_scale, scales[ib]);
|
||||
}
|
||||
}
|
||||
if (!max_abs_scale) return;
|
||||
float d = max_abs_scale/127;
|
||||
*dptr = d;
|
||||
float id = 1/d;
|
||||
for (int ibl = 0; ibl < nblock; ++ibl) {
|
||||
auto scales = all_scales + ibl*(QK_K/kBlockSize);
|
||||
for (int ib = 0; ib < QK_K/kBlockSize; ++ib) {
|
||||
int l = nearest_int(0.5f*(id*scales[ib]+127.f));
|
||||
l = std::max(0, std::min(127, l)) << 1;
|
||||
l |= (iq4ks[ibl].scales[ib] & 1);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
y[ibl].qs[4*ib+k] &= 0x3fffffff;
|
||||
y[ibl].qs[4*ib+k] |= (((l >> 2*k) & 3) << 30);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t quantize_iq4_kss(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
|
||||
constexpr int kBlockSize = 32; //128;
|
||||
GGML_ASSERT(n_per_row%QK_K == 0);
|
||||
auto row_size = ggml_row_size(GGML_TYPE_IQ4_KSS, n_per_row);
|
||||
auto row_size_ks = ggml_row_size(GGML_TYPE_IQ4_KS, n_per_row);
|
||||
std::vector<char> work(row_size_ks);
|
||||
std::vector<float> all_scales(n_per_row/kBlockSize);
|
||||
float weight[kBlockSize];
|
||||
auto qrow = (char *)dst;
|
||||
auto table = scramble_table();
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
quantize_row_iq4_k_impl_bs128(QK_K, kBlockSize, n_per_row, src, work.data(), all_scales.data(), weight, iq4k_values, imatrix, 7);
|
||||
prune_iq4ks_to_iq4kss(n_per_row, table, work.data(), src, qrow, imatrix, weight, all_scales.data());
|
||||
src += n_per_row;
|
||||
qrow += row_size;
|
||||
}
|
||||
return nrows * row_size;
|
||||
}
|
||||
|
||||
void quantize_row_iq4_kss_ref(const float * x, block_iq4_kss * y, int64_t k) {
|
||||
quantize_iq4_kss(x, y, 1, k, nullptr);
|
||||
}
|
||||
|
||||
void quantize_row_iq4_kss(const float * x, void * y, int64_t k) {
|
||||
quantize_iq4_kss(x, (block_iq4_kss *)y, 1, k, nullptr);
|
||||
}
|
||||
|
||||
void dequantize_row_iq4_kss(const block_iq4_kss * x, float * y, int64_t k) {
|
||||
const float * dptr = (const float *)x;
|
||||
const float d = *dptr;
|
||||
x = (const block_iq4_kss *)(dptr + 1);
|
||||
//uint32_t aux32[4];
|
||||
//const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
for (int ibl = 0; ibl < k/QK_K; ++ibl) {
|
||||
auto qs = x[ibl].qs;
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
uint8_t ls = ((qs[0] >> 30) | ((qs[1] >> 28) & 0x0c) | ((qs[2] >> 26) & 0x30) | ((qs[3] >> 24) & 0xc0));
|
||||
const int8_t * values = iq4k_values + ((ls & 1) << 4);
|
||||
const float dl = d * ((ls & 254) - 127);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
uint16_t vl = qs[k] & 0x7fff;
|
||||
vl ^= (vl << 1);
|
||||
uint16_t vh = (qs[k] >> 15) & 0x7fff;
|
||||
vh ^= (vh << 1);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[4*k + j + 0] = dl*values[(vl >> 4*j) & 0xf];
|
||||
y[4*k + j + 16] = dl*values[(vh >> 4*j) & 0xf];
|
||||
}
|
||||
}
|
||||
//int16_t ls = 0;
|
||||
//for (int k = 0; k < 4; ++k) {
|
||||
// aux32[k] = (qs[k] & 0x00007fff) | ((qs[k] << 1) & 0x7fff0000);
|
||||
// aux32[k] ^= (aux32[k] << 1);
|
||||
// ls |= (qs[k] >> 30) << 2*k;
|
||||
//}
|
||||
//const int8_t * values = iq4k_values + ((ls & 1) << 4);
|
||||
//float dl = d * ((ls & 254) - 127);
|
||||
//for (int j = 0; j < 16; ++j) {
|
||||
// y[j+ 0] = dl * values[aux8[j] & 0xf];
|
||||
// y[j+16] = dl * values[aux8[j] >> 4];
|
||||
//}
|
||||
y += 32;
|
||||
qs += 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void vec_dot_iq4_kss_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
|
||||
#if GGML_USE_IQK_MULMAT
|
||||
if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_KSS, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
GGML_ASSERT(n%QK_K == 0);
|
||||
GGML_ASSERT(nrc == 1);
|
||||
GGML_UNUSED(bs);
|
||||
GGML_UNUSED(bx);
|
||||
GGML_UNUSED(by);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -61,6 +61,12 @@ size_t quantize_iq4_ks(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst
|
||||
void dequantize_row_iq4_ks(const block_iq4_ks * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void vec_dot_iq4_ks_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
void quantize_row_iq4_kss_ref(const float * GGML_RESTRICT x, block_iq4_kss * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_iq4_kss(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
size_t quantize_iq4_kss(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
void dequantize_row_iq4_kss(const block_iq4_kss * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void vec_dot_iq4_kss_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
void quantize_row_iq2_ks_ref(const float * GGML_RESTRICT x, block_iq2_ks * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_iq2_ks(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
size_t quantize_iq2_ks(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
|
||||
Reference in New Issue
Block a user