Adding IQ1_TN - 1.6875 bpw for TriLM ternary models (#44)

* Adding iq1_tn - 1.6875 bpw for TriLM ternary models

* iq1_tn: NEON

* iq1_tn: faster NEON

* iq2_bn: improve performance on NEON

We now get TG-128 = 100 t/s for Bitnet-3B-1.58b!

* iq1_tn: improve AVX2

PP-512 goes to 533 t/s up from 455.
TG-128 @ 2 threads goes to 16.6 t/s up from 14.2.
However, we seem to have a bottleneck somewhere as
TG saturates at 8 threads.

* iq1_tn: improve Zen4

PP-512 goes to 485 t/s up from 352. With FA we get 545 t/s up from 380.
TG-128 @ 1 thread goes to 12.4 t/s up from 10.4.
However, we seem to have a bottleneck somewhere as
TG saturates at 8 threads.

* iq2_bn: improve on Zen4

We now get PP-512 = 614 t/s up from 542 t/s

* iq2_bn: improve AVX2 implementation

We now get PP-512 = 753 t/s up from 680 t/s.

* Remove unnecessary barrier in ggml_compute_forward_mul_mat

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2024-09-09 14:56:34 +03:00
committed by GitHub
parent f2ef628e4e
commit 4a5d5e207d
10 changed files with 304 additions and 148 deletions

View File

@@ -119,6 +119,54 @@ void quantize_row_iq1_bn(const float * x, void * y, int64_t k) {
quantize_iq1_bn(x, y, 1, k, nullptr);
}
void quantize_row_iq1_tn_ref(const float * x, block_iq1_tn * y, int64_t k) {
quantize_iq1_tn(x, (void *)y, 1, k, nullptr);
}
void quantize_row_iq1_tn(const float * x, void * y, int64_t k) {
quantize_iq1_tn(x, y, 1, k, nullptr);
}
size_t quantize_iq1_tn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
GGML_ASSERT(n_per_row >= 2*QK_K); // so we have space for the scale
int nblock = n_per_row/QK_IQ1BN;
float tmp[QK_IQ1BN];
char * qrow = (char *)dst;
auto row_size = ggml_row_size(GGML_TYPE_IQ1_TN, n_per_row);
IQ1BNQuantizer iq1bn;
for (int row = 0; row < nrows; ++row) {
float max = fabsf(src[0]);
for (int j = 1; j < n_per_row; ++j) max = std::max(max, fabsf(src[j]));
if (!(max > 0)) printf("%s: found max = %g?\n", __func__, max);
//GGML_ASSERT(max > 0);
*(ggml_half *)qrow = GGML_FP32_TO_FP16(max);
block_iq1_bn * y = (block_iq1_bn *)(qrow + sizeof(ggml_half));
const float * xb = src;
for (int ib = 0; ib < nblock; ++ib) {
for (int j = 0; j < QK_IQ1BN; ++j) tmp[j] = xb[j] < -0.5f*max ? -1 : xb[j] <= 0.5f*max ? 0 : 1;
iq1bn.quantize_one_row_1bn(tmp, y, QK_IQ1BN, imatrix);
++y;
xb += QK_IQ1BN;
}
src += n_per_row;
qrow += row_size;
}
return nrows*row_size;
}
void dequantize_row_iq1_tn(const block_iq1_tn * x, float * y, int64_t k) {
float scale = GGML_FP16_TO_FP32(*(const ggml_half *)x);
const block_iq1_bn * iq1bn = (const block_iq1_bn *)((const char *)x + sizeof(ggml_half));
dequantize_row_iq1_bn(iq1bn, y, k);
for (int j = 0; j < int(k); ++j) y[j] *= scale;
}
void vec_dot_iq1_tn_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
float scale = GGML_FP16_TO_FP32(*(const ggml_half *)vx);
ggml_vec_dot_iq1_bn_q8_K64(n, s, bs, (const void *)((const char *)vx + sizeof(ggml_half)), bx, vy, by, nrc);
*s *= scale;
}
void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) {
assert(k%QK_IQ1BN == 0);
int nblock = k / QK_IQ1BN;
@@ -331,8 +379,10 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) {
GGML_ASSERT(k >= 8*QK_IQ1BN);
float * dptr = (float *)y;
auto qs = (int8_t *)(dptr + 4);
auto qs = (int8_t *)(dptr + 8);
#ifdef __ARM_NEON
static const uint8_t k_shuffle[16] = {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60};
auto shuffle = vld1q_u8(k_shuffle);
@@ -351,16 +401,22 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) {
vid[i] = vdupq_n_f32(id);
}
int8x16x4_t q;
int32x4_t qsum = {};
const int8x16_t m1 = vdupq_n_s8(1);
for (int j = 0; j < k; j += 16) {
for (int i = 0; i < 4; ++i) {
auto val = vld1q_f32(x + j + 4*i);
val = vmulq_f32(vid[i], val);
q.val[i] = vreinterpretq_s8_s32(vcvtnq_s32_f32(val));
auto ival = vcvtnq_s32_f32(val);
q.val[i] = vreinterpretq_s8_s32(ival);
}
auto qi = vqtbl4q_s8(q, shuffle);
qsum = ggml_vdotq_s32(qsum, qi, m1);
vst1q_s8(qs, qi);
qs += 16;
}
auto sumf = vmulq_f32(vld1q_f32(dptr), vcvtq_f32_s32(qsum));
vst1q_f32(dptr + 4, sumf);
#elif defined __AVX__
__m128 max[4] = {};
__m128 sign_bit = _mm_set1_ps(-0.f);
@@ -381,6 +437,9 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) {
vid[i] = _mm_set1_ps(id);
}
__m128i q[4];
__m128i sums = _mm_setzero_si128();
__m128i m1_8 = _mm_set1_epi8(1);
__m128i m1_16 = _mm_set1_epi16(1);
for (int j = 0; j < k; j += 16) {
for (int i = 0; i < 4; ++i) {
auto val = _mm_loadu_ps(x + j + 4*i);
@@ -390,9 +449,13 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) {
auto q1 = _mm_packs_epi32(q[0], q[1]);
auto q2 = _mm_packs_epi32(q[2], q[3]);
auto qi = _mm_packs_epi16(q1, q2);
auto aux = _mm_maddubs_epi16(m1_8, qi);
sums = _mm_add_epi32(sums, _mm_madd_epi16(m1_16, aux));
_mm_storeu_si128((__m128i *)qs, qi);
qs += 16;
}
auto minus = _mm_mul_ps(_mm_loadu_ps(dptr), _mm_cvtepi32_ps(sums));
_mm_storeu_ps(dptr + 4, minus);
#else
float aux[4] = {0.f, 0.f, 0.f, 0.f};
for (int j = 0; j < k; j += 16) {
@@ -407,11 +470,16 @@ void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) {
dptr[i] = aux[i]/127;
aux[i] = dptr[i] > 0 ? 1/dptr[i] : 0.f;
}
int32_t sum[4] = {};
for (int j = 0; j < k; j += 16) {
for (int i = 0; i < 4; ++i) {
for (int l = 0; l < 4; ++l) qs[j+4*i+l] = nearest_int(aux[i]*x[j+4*i+l]);
for (int l = 0; l < 4; ++l) {
qs[j+4*i+l] = nearest_int(aux[i]*x[j+4*i+l]);
sum[i] += qs[j+4*i+l];
}
}
}
for (int i = 0; i < 4; ++i) dptr[4+i] = dptr[i]*sum[i];
#endif
}