mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 19:01:47 +00:00
Bitnet changes (#106)
* Adapting iq2_bn to work without separate scale tensors
Why? It is becoming burdensome to maintain the special Bitnet
conversion in convert_hf_to_gguf.py, so I thnk it is better
to make iq1_bn and iq2_bn just work with the mainline
conversion script (which does not generate scales).
* Adapting iq1_bn to work without separate scale tensors
* Adapting iq2_bn: CUDA dequantize
* Adapting iq2_bn: CUDA works
* Adapting iq1_bn: CUDA works
* Adapting iq1_bn, iq2_bn: NEON
* Adapting iq1_bn, iq2_bn: Metal
Dequantize works, but there is still something wrong
with the dot products.
* WIP
Absoolutely don't see what is wrong with the iq1_bn and iq2_bn
vector dot product kernels.
* Remove iq1_tn and iq2_tn - Part 1
Now that iq1_bn and iq2_bn have per row scales, there is no
reason to also have iq1_tn and iq2_tn.
* Remove iq1_tn and iq2_tn - Part 2
* Bitnet: use the standard llm_build_kv to build self attention
My main motivation was to enable FA. But FA does not work anyway
because head size is 100 for the Botnet ternary models
(and I had forgotten this little detail).
* Revert "Avoid rebuild of GGML graph for each token (#98)"
This reverts commit f2d315b46f.
As far as I can tell, the commit breaks Metal TG.
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -119,6 +119,16 @@ void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, i
|
||||
|
||||
const int nblock = n_per_row/QK_IQ1BN;
|
||||
|
||||
ggml_half * dptr = (ggml_half *)y;
|
||||
y = (block_iq1_bn *)(dptr + 1);
|
||||
|
||||
float max = 0;
|
||||
for (int j = 0; j < n_per_row; ++j) max = std::max(max, fabsf(src[j]));
|
||||
ggml_half d = GGML_FP32_TO_FP16(max);
|
||||
std::memcpy(dptr, &d, sizeof(d));
|
||||
|
||||
float thresh = 0.5f*max;
|
||||
|
||||
for (int ib = 0; ib < nblock; ++ib) {
|
||||
std::memset(&y[ib], 0, sizeof(block_iq1_bn));
|
||||
auto xb = src + ib*QK_IQ1BN;
|
||||
@@ -128,14 +138,14 @@ void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, i
|
||||
int idx = 0;
|
||||
for (int j = 0; j < 5; ++j) {
|
||||
float v = xb[16*i16 + 5*k + j];
|
||||
int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2;
|
||||
int q = fabsf(v) < thresh ? 1 : v < 0 ? 0 : 2;
|
||||
idx += k_nb[j]*q;
|
||||
}
|
||||
idx = (256*idx + k_nb[5] - 1)/k_nb[5];
|
||||
y[ib].ql[3*i16 + k] = idx;
|
||||
}
|
||||
float v = xb[16*i16 + 15];
|
||||
int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2;
|
||||
int q = fabsf(v) < thresh ? 1 : v < 0 ? 0 : 2;
|
||||
v13 += k_nb[i16]*q;
|
||||
}
|
||||
y[ib].extra = (256*v13 + k_nb[5] - 1)/k_nb[5];
|
||||
@@ -150,10 +160,18 @@ void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, i
|
||||
|
||||
constexpr int Nj = QK_IQ1BN/4;
|
||||
|
||||
float max = 0;
|
||||
for (int j = 0; j < n_per_row; ++j) max = std::max(max, fabsf(src[j]));
|
||||
|
||||
float * dptr = (float *)y;
|
||||
*dptr = max;
|
||||
y = (block_iq2_bn *)(dptr + 1);
|
||||
float thresh = 0.5f*max;
|
||||
|
||||
for (int ib = 0; ib < nblock; ++ib) {
|
||||
auto xb = src + QK_IQ1BN*ib;
|
||||
for (int j = 0; j < QK_IQ1BN; ++j) {
|
||||
L[j] = fabsf(xb[j]) < 1e-6f ? 1 : xb[j] < 0 ? 0 : 2;
|
||||
L[j] = fabsf(xb[j]) < thresh ? 1 : xb[j] < 0 ? 0 : 2;
|
||||
}
|
||||
for (int j = 0; j < Nj; ++j) {
|
||||
y[ib].qs[j] = L[j] | (L[j + Nj] << 2) | (L[j + 2*Nj] << 4) | (L[j + 3*Nj] << 6);
|
||||
@@ -165,13 +183,13 @@ void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, i
|
||||
|
||||
size_t quantize_iq1_bn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
|
||||
IQ1BNQuantizer iq1bn;
|
||||
int nblock = n_per_row/QK_IQ1BN;
|
||||
block_iq1_bn * y = (block_iq1_bn *)dst;
|
||||
auto row_size = ggml_row_size(GGML_TYPE_IQ1_BN, n_per_row);
|
||||
auto qrow = (char *)dst;
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
iq1bn.quantize_one_row_1bn(src + row*n_per_row, y, n_per_row, imatrix);
|
||||
y += nblock;
|
||||
iq1bn.quantize_one_row_1bn(src + row*n_per_row, (block_iq1_bn *)qrow, n_per_row, imatrix);
|
||||
qrow += row_size;
|
||||
}
|
||||
return sizeof(block_iq1_bn)*nblock*nrows;
|
||||
return nrows*row_size;
|
||||
}
|
||||
|
||||
void quantize_row_iq1_bn_ref(const float * x, block_iq1_bn * y, int64_t k) {
|
||||
@@ -182,54 +200,6 @@ void quantize_row_iq1_bn(const float * x, void * y, int64_t k) {
|
||||
quantize_iq1_bn(x, y, 1, k, nullptr);
|
||||
}
|
||||
|
||||
void quantize_row_iq1_tn_ref(const float * x, block_iq1_tn * y, int64_t k) {
|
||||
quantize_iq1_tn(x, (void *)y, 1, k, nullptr);
|
||||
}
|
||||
|
||||
void quantize_row_iq1_tn(const float * x, void * y, int64_t k) {
|
||||
quantize_iq1_tn(x, y, 1, k, nullptr);
|
||||
}
|
||||
|
||||
size_t quantize_iq1_tn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
|
||||
GGML_ASSERT(n_per_row >= 2*QK_K); // so we have space for the scale
|
||||
int nblock = n_per_row/QK_IQ1BN;
|
||||
float tmp[QK_IQ1BN];
|
||||
char * qrow = (char *)dst;
|
||||
auto row_size = ggml_row_size(GGML_TYPE_IQ1_TN, n_per_row);
|
||||
IQ1BNQuantizer iq1bn;
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
float max = fabsf(src[0]);
|
||||
for (int j = 1; j < n_per_row; ++j) max = std::max(max, fabsf(src[j]));
|
||||
if (!(max > 0)) printf("%s: found max = %g?\n", __func__, max);
|
||||
//GGML_ASSERT(max > 0);
|
||||
*(ggml_half *)qrow = GGML_FP32_TO_FP16(max);
|
||||
block_iq1_bn * y = (block_iq1_bn *)(qrow + sizeof(ggml_half));
|
||||
const float * xb = src;
|
||||
for (int ib = 0; ib < nblock; ++ib) {
|
||||
for (int j = 0; j < QK_IQ1BN; ++j) tmp[j] = xb[j] < -0.5f*max ? -1 : xb[j] <= 0.5f*max ? 0 : 1;
|
||||
iq1bn.quantize_one_row_1bn(tmp, y, QK_IQ1BN, imatrix);
|
||||
++y;
|
||||
xb += QK_IQ1BN;
|
||||
}
|
||||
src += n_per_row;
|
||||
qrow += row_size;
|
||||
}
|
||||
return nrows*row_size;
|
||||
}
|
||||
|
||||
void dequantize_row_iq1_tn(const block_iq1_tn * x, float * y, int64_t k) {
|
||||
float scale = GGML_FP16_TO_FP32(*(const ggml_half *)x);
|
||||
const block_iq1_bn * iq1bn = (const block_iq1_bn *)((const char *)x + sizeof(ggml_half));
|
||||
dequantize_row_iq1_bn(iq1bn, y, k);
|
||||
for (int j = 0; j < int(k); ++j) y[j] *= scale;
|
||||
}
|
||||
|
||||
void vec_dot_iq1_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
|
||||
float scale = GGML_FP16_TO_FP32(*(const ggml_half *)vx);
|
||||
ggml_vec_dot_iq1_bn_q8_K64(n, s, bs, (const void *)((const char *)vx + sizeof(ggml_half)), bx, vy, by, nrc);
|
||||
*s *= scale;
|
||||
}
|
||||
|
||||
void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) {
|
||||
assert(k%QK_IQ1BN == 0);
|
||||
int nblock = k / QK_IQ1BN;
|
||||
@@ -255,13 +225,13 @@ void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) {
|
||||
|
||||
size_t quantize_iq2_bn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
|
||||
IQ1BNQuantizer iq1bn;
|
||||
int nblock = n_per_row/QK_IQ1BN;
|
||||
block_iq2_bn * y = (block_iq2_bn *)dst;
|
||||
auto row_size = ggml_row_size(GGML_TYPE_IQ2_BN, n_per_row);
|
||||
auto qrow = (char *)dst;
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
iq1bn.quantize_one_row_2bn(src + row*n_per_row, y, n_per_row, imatrix);
|
||||
y += nblock;
|
||||
iq1bn.quantize_one_row_2bn(src + row*n_per_row, (block_iq2_bn *)qrow, n_per_row, imatrix);
|
||||
qrow += row_size;
|
||||
}
|
||||
return sizeof(block_iq2_bn)*nblock*nrows;
|
||||
return nrows*row_size;
|
||||
}
|
||||
|
||||
void quantize_row_iq2_bn_ref(const float * x, block_iq2_bn * y, int64_t k) {
|
||||
@@ -2369,114 +2339,6 @@ size_t quantize_iq6_k(const float * src, void * dst, int64_t nrows, int64_t n_pe
|
||||
return nrows * nblock * sizeof(block_iq6_k);
|
||||
}
|
||||
|
||||
//
|
||||
// ========================== IQ2_TN
|
||||
//
|
||||
|
||||
void quantize_row_iq2_tn_ref(const float * x, block_iq2_tn * y, int64_t k) {
|
||||
GGML_ASSERT(k%QK_K == 0);
|
||||
|
||||
int nb = k/QK_K;
|
||||
|
||||
auto quantize = [] (float xmax, float x) {
|
||||
return x < -0.5f*xmax ? 0 : x < 0.5f*xmax ? 1 : 2;
|
||||
};
|
||||
int n = k;
|
||||
float max = x[0];
|
||||
for (int j = 1; j < n; ++j) max = std::max(max, fabsf(x[j]));
|
||||
|
||||
*(float *)y = max;
|
||||
y = (block_iq2_tn *)((float *)y + 1);
|
||||
|
||||
for (int ibl = 0; ibl < nb; ++ibl) {
|
||||
auto xb = x + QK_K*ibl;
|
||||
auto qs = y[ibl].qs;
|
||||
for (int l = 0; l < QK_K/128; ++l) {
|
||||
for (int j = 0; j < 32; ++j) {
|
||||
qs[j] = quantize(max, xb[j]) | (quantize(max, xb[j+32]) << 2) | (quantize(max, xb[j+64]) << 4) | (quantize(max, xb[j+96]) << 6);
|
||||
}
|
||||
xb += 128;
|
||||
qs += 32;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_row_iq2_tn(const float * x, void * y, int64_t k) {
|
||||
quantize_row_iq2_tn_ref(x, (block_iq2_tn *)y, k);
|
||||
}
|
||||
|
||||
size_t quantize_iq2_tn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * /*imatrix*/) {
|
||||
auto row_size = ggml_row_size(GGML_TYPE_IQ2_TN, n_per_row);
|
||||
char * qrow = (char *)dst;
|
||||
for (int row = 0; row < nrows; ++row) {
|
||||
quantize_row_iq2_tn_ref(src, (block_iq2_tn *)qrow, n_per_row);
|
||||
qrow += row_size;
|
||||
src += n_per_row;
|
||||
}
|
||||
return row_size*nrows;
|
||||
}
|
||||
|
||||
void dequantize_row_iq2_tn(const block_iq2_tn * x, float * y, int64_t k) {
|
||||
GGML_ASSERT(k%QK_K == 0);
|
||||
const float * dptr = (const float *)x;
|
||||
float d = *dptr;
|
||||
x = (const block_iq2_tn *)(dptr + 1);
|
||||
int nb = k/QK_K;
|
||||
for (int ibl = 0; ibl < nb; ++ibl) {
|
||||
auto qs = x[ibl].qs;
|
||||
for (int l = 0; l < QK_K/128; ++l) {
|
||||
for (int j = 0; j < 32; ++j) {
|
||||
y[j+ 0] = d*((qs[j] >> 0) & 3) - d;
|
||||
y[j+32] = d*((qs[j] >> 2) & 3) - d;
|
||||
y[j+64] = d*((qs[j] >> 4) & 3) - d;
|
||||
y[j+96] = d*((qs[j] >> 6) & 3) - d;
|
||||
}
|
||||
y += 128;
|
||||
qs += 32;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void vec_dot_iq2_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
|
||||
GGML_UNUSED(bs);
|
||||
GGML_UNUSED(bx);
|
||||
GGML_UNUSED(by);
|
||||
GGML_UNUSED(nrc);
|
||||
#if GGML_USE_IQK_MULMAT
|
||||
if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_TN, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
const float * dptr = (const float *)vx;
|
||||
const float d = *dptr;
|
||||
const block_iq2_tn * x = (const block_iq2_tn *)(dptr + 1);
|
||||
const block_q8_K * y = (const block_q8_K *)vy;
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
auto qs = x[i].qs;
|
||||
auto q8 = y[i].qs;
|
||||
int sumi1 = 0, sumi2 = 0, sumi3 = 0,sumi4 = 0;
|
||||
for (int j = 0; j < QK_K/16; ++j) sumi1 -= y[i].bsums[j];
|
||||
for (int l = 0; l < QK_K/128; ++l) {
|
||||
for (int j = 0; j < 32; ++j) {
|
||||
sumi1 += q8[j+ 0] * (qs[j] & 0x03);
|
||||
sumi2 += q8[j+32] * (qs[j] & 0x0c);
|
||||
sumi3 += q8[j+64] * (qs[j] & 0x30);
|
||||
sumi4 += q8[j+96] * (qs[j] & 0xc0);
|
||||
}
|
||||
q8 += 128;
|
||||
qs += 32;
|
||||
}
|
||||
sumf += d * (sumi1 + 0.25f*sumi2 + 0.0625f*sumi3 + 0.015625f*sumi4);
|
||||
}
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
#ifdef __AVX2__
|
||||
namespace {
|
||||
inline int hsum_i32_8(const __m256i a) {
|
||||
@@ -2941,7 +2803,6 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy,
|
||||
continue;
|
||||
}
|
||||
float best = 0;
|
||||
bool is_shifted = false;
|
||||
float d = -max/iq4k_values[0];
|
||||
std::memset(vs, 0, block_size);
|
||||
for (int itry = -ntry; itry <= ntry; ++itry) {
|
||||
@@ -2974,10 +2835,10 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy,
|
||||
}
|
||||
bool copy_p = false, copy_m = false;
|
||||
if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
|
||||
d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = false; copy_p = true;
|
||||
d = sumqx_p/sumq2_p; best = d * sumqx_p; copy_p = true;
|
||||
}
|
||||
if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
|
||||
d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false; copy_m = true;
|
||||
d = sumqx_m/sumq2_m; best = d * sumqx_m; copy_m = true;
|
||||
}
|
||||
if (copy_m) {
|
||||
std::memcpy(vs, vms, block_size);
|
||||
@@ -3014,10 +2875,10 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy,
|
||||
}
|
||||
copy_p = copy_m = false;
|
||||
if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) {
|
||||
d = sumqx_p/sumq2_p; best = d * sumqx_p; is_shifted = true; copy_p = true;
|
||||
d = sumqx_p/sumq2_p; best = d * sumqx_p; copy_p = true;
|
||||
}
|
||||
if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
|
||||
d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; copy_m = true;
|
||||
d = sumqx_m/sumq2_m; best = d * sumqx_m; copy_m = true;
|
||||
}
|
||||
if (copy_m) {
|
||||
std::memcpy(vs, vms, block_size);
|
||||
|
||||
Reference in New Issue
Block a user