mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-26 09:29:27 +00:00
iq2_tn: TriLM specific 2.0625 bpw quantization
Quantize/dequantize/scale dot product. I get 46 t/s for the TriLM-3.9B with any SIMD! Finally a compiler doing a decent job auto-vectorizing the scalar implementation.
This commit is contained in:
@@ -393,6 +393,7 @@ extern "C" {
|
||||
GGML_TYPE_IQ3_K = 38,
|
||||
GGML_TYPE_IQ4_K = 39,
|
||||
GGML_TYPE_IQ5_K = 40,
|
||||
GGML_TYPE_IQ2_TN = 41,
|
||||
GGML_TYPE_COUNT,
|
||||
};
|
||||
|
||||
@@ -443,6 +444,7 @@ extern "C" {
|
||||
GGML_FTYPE_MOSTLY_IQ3_K = 31, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ4_K = 32, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ5_K = 33, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ2_TN = 34, // except 1d tensors
|
||||
};
|
||||
|
||||
// available tensor operations:
|
||||
|
||||
@@ -407,7 +407,7 @@ typedef struct {
|
||||
static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding");
|
||||
|
||||
//
|
||||
// Bitnet - implemented as 1.75 bpw
|
||||
// Bitnet - implemented as 1.625 bpw
|
||||
// The block scale is a waste, but it allows us to plug it in without any additional
|
||||
// changes to ggml.
|
||||
//
|
||||
@@ -418,13 +418,21 @@ typedef struct {
|
||||
} block_iq1_bn;
|
||||
static_assert(sizeof(block_iq1_bn) == 13, "wrong iq1_bn block size/padding");
|
||||
//
|
||||
// Bitnet - implemented as 2.25 bpw
|
||||
// Bitnet - implemented as 2.0 bpw
|
||||
//
|
||||
#define QK_IQ2BN 64
|
||||
typedef struct {
|
||||
uint8_t qs[QK_IQ2BN/4];
|
||||
} block_iq2_bn;
|
||||
static_assert(sizeof(block_iq2_bn) == QK_IQ2BN/4, "wrong iq2_bn block size/padding");
|
||||
//
|
||||
// TriLM - implemented as 2.0625 bpw
|
||||
//
|
||||
typedef struct {
|
||||
ggml_half d;
|
||||
uint8_t qs[QK_K/4];
|
||||
} block_iq2_tn;
|
||||
static_assert(sizeof(block_iq2_tn) == sizeof(ggml_half) + QK_K/4, "wrong iqt_bn block size/padding");
|
||||
|
||||
// Used by IQ1_M quants
|
||||
typedef union {
|
||||
|
||||
@@ -14996,6 +14996,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
||||
case GGML_TYPE_IQ3_K: break;
|
||||
case GGML_TYPE_IQ4_K: break;
|
||||
case GGML_TYPE_IQ5_K: break;
|
||||
case GGML_TYPE_IQ2_TN: break;
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
{
|
||||
|
||||
@@ -882,6 +882,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.vec_dot_type = GGML_TYPE_Q8_K64,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_IQ2_TN] = {
|
||||
.type_name = "iq2_tn",
|
||||
.blck_size = QK_K,
|
||||
.type_size = sizeof(block_iq2_tn),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_iq2_tn,
|
||||
.from_float = quantize_row_iq2_tn,
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_iq2_tn_ref,
|
||||
.vec_dot = vec_dot_iq2_tn_q8_k,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_IQ4_NL] = {
|
||||
.type_name = "iq4_nl",
|
||||
.blck_size = QK4_NL,
|
||||
@@ -3375,6 +3387,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
||||
case GGML_FTYPE_MOSTLY_IQ1_M: wtype = GGML_TYPE_IQ1_M; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ1_BN: wtype = GGML_TYPE_IQ1_BN; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ2_BN: wtype = GGML_TYPE_IQ2_BN; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ2_TN: wtype = GGML_TYPE_IQ2_TN; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break;
|
||||
@@ -9628,6 +9641,7 @@ static void ggml_compute_forward_add(
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ1_BN:
|
||||
case GGML_TYPE_IQ2_BN:
|
||||
case GGML_TYPE_IQ2_TN:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
@@ -10012,6 +10026,7 @@ static void ggml_compute_forward_add1(
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ1_BN:
|
||||
case GGML_TYPE_IQ2_BN:
|
||||
case GGML_TYPE_IQ2_TN:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
@@ -10146,6 +10161,7 @@ static void ggml_compute_forward_acc(
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ1_BN:
|
||||
case GGML_TYPE_IQ2_BN:
|
||||
case GGML_TYPE_IQ2_TN:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
@@ -13069,6 +13085,7 @@ static void ggml_compute_forward_out_prod(
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ1_BN:
|
||||
case GGML_TYPE_IQ2_BN:
|
||||
case GGML_TYPE_IQ2_TN:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
@@ -13263,6 +13280,7 @@ static void ggml_compute_forward_set(
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ1_BN:
|
||||
case GGML_TYPE_IQ2_BN:
|
||||
case GGML_TYPE_IQ2_TN:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
@@ -13531,6 +13549,7 @@ static void ggml_compute_forward_get_rows(
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ1_BN:
|
||||
case GGML_TYPE_IQ2_BN:
|
||||
case GGML_TYPE_IQ2_TN:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
@@ -14126,6 +14145,7 @@ static void ggml_compute_forward_clamp(
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ1_BN:
|
||||
case GGML_TYPE_IQ2_BN:
|
||||
case GGML_TYPE_IQ2_TN:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
@@ -20865,6 +20885,7 @@ size_t ggml_quantize_chunk(
|
||||
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ1_BN: result = quantize_iq1_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ2_BN: result = quantize_iq2_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ2_TN: result = quantize_iq2_tn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
|
||||
@@ -1514,3 +1514,110 @@ size_t quantize_iq5_k(const float * src, void * dst, int64_t nrows, int64_t n_pe
|
||||
}
|
||||
return nrows * nblock * sizeof(block_iq5_k);
|
||||
}
|
||||
|
||||
//
|
||||
// ========================== IQ2_TN
|
||||
//
|
||||
|
||||
void quantize_row_iq2_tn_ref(const float * x, block_iq2_tn * y, int64_t k) {
|
||||
GGML_ASSERT(k%QK_K == 0);
|
||||
|
||||
int nb = k/QK_K;
|
||||
|
||||
auto quantize = [] (float xmax, float x) {
|
||||
return x < -0.5f*xmax ? 0 : x < 0.5f*xmax ? 1 : 2;
|
||||
};
|
||||
|
||||
for (int ibl = 0; ibl < nb; ++ibl) {
|
||||
auto xb = x + QK_K*ibl;
|
||||
float max = xb[0];
|
||||
for (int j = 0; j < QK_K; ++j) {
|
||||
float ax = fabsf(xb[j]);
|
||||
max = std::max(ax, max);
|
||||
}
|
||||
y[ibl].d = GGML_FP32_TO_FP16(max);
|
||||
auto qs = y[ibl].qs;
|
||||
for (int l = 0; l < QK_K/128; ++l) {
|
||||
for (int j = 0; j < 32; ++j) {
|
||||
qs[j] = quantize(max, xb[j]) | (quantize(max, xb[j+32]) << 2) | (quantize(max, xb[j+64]) << 4) | (quantize(max, xb[j+96]) << 6);
|
||||
}
|
||||
xb += 128;
|
||||
qs += 32;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_row_iq2_tn(const float * x, void * y, int64_t k) {
|
||||
quantize_row_iq2_tn_ref(x, (block_iq2_tn *)y, k);
|
||||
}
|
||||
|
||||
size_t quantize_iq2_tn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * /*imatrix*/) {
|
||||
auto row_size = ggml_row_size(GGML_TYPE_IQ2_TN, n_per_row);
|
||||
char * qrow = (char *)dst;
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
quantize_row_iq2_tn_ref(src, (block_iq2_tn *)qrow, n_per_row);
|
||||
qrow += row_size;
|
||||
src += n_per_row;
|
||||
}
|
||||
return row_size*nrows;
|
||||
}
|
||||
|
||||
void dequantize_row_iq2_tn(const block_iq2_tn * x, float * y, int64_t k) {
|
||||
GGML_ASSERT(k%QK_K == 0);
|
||||
int nb = k/QK_K;
|
||||
for (int ibl = 0; ibl < nb; ++ibl) {
|
||||
float d = GGML_FP16_TO_FP32(x[ibl].d);
|
||||
auto qs = x[ibl].qs;
|
||||
for (int l = 0; l < QK_K/128; ++l) {
|
||||
for (int j = 0; j < 32; ++j) {
|
||||
y[j+ 0] = d*((qs[j] >> 0) & 3) - d;
|
||||
y[j+32] = d*((qs[j] >> 2) & 3) - d;
|
||||
y[j+64] = d*((qs[j] >> 4) & 3) - d;
|
||||
y[j+96] = d*((qs[j] >> 6) & 3) - d;
|
||||
}
|
||||
y += 128;
|
||||
qs += 32;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void vec_dot_iq2_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
assert(nrc == 1);
|
||||
GGML_UNUSED(nrc);
|
||||
GGML_UNUSED(bx);
|
||||
GGML_UNUSED(by);
|
||||
GGML_UNUSED(bs);
|
||||
|
||||
if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_TN, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
const block_iq2_tn * x = (const block_iq2_tn *)vx;
|
||||
const block_q8_K * y = (const block_q8_K *)vy;
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||
auto qs = x[i].qs;
|
||||
auto q8 = y[i].qs;
|
||||
int sumi1 = 0, sumi2 = 0, sumi3 = 0,sumi4 = 0;
|
||||
for (int j = 0; j < QK_K/16; ++j) sumi1 -= y[i].bsums[j];
|
||||
for (int l = 0; l < QK_K/128; ++l) {
|
||||
for (int j = 0; j < 32; ++j) {
|
||||
sumi1 += q8[j+ 0] * (qs[j] & 0x03);
|
||||
sumi2 += q8[j+32] * (qs[j] & 0x0c);
|
||||
sumi3 += q8[j+64] * (qs[j] & 0x30);
|
||||
sumi4 += q8[j+96] * (qs[j] & 0xc0);
|
||||
}
|
||||
q8 += 128;
|
||||
qs += 32;
|
||||
}
|
||||
sumf += d * (sumi1 + 0.25f*sumi2 + 0.0625f*sumi3 + 0.015625f*sumi4);
|
||||
}
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
|
||||
@@ -37,6 +37,12 @@ size_t quantize_iq5_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
|
||||
void dequantize_row_iq5_k(const block_iq5_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void vec_dot_iq5_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
void quantize_row_iq2_tn_ref(const float * GGML_RESTRICT x, block_iq2_tn * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_iq2_tn(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
size_t quantize_iq2_tn(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
void dequantize_row_iq2_tn(const block_iq2_tn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void vec_dot_iq2_tn_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user