mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
iq1bn(no lookup): better version
We have 4 groups of 16 in a block of 64 quants. For each group of 16 we have 3 groups of 5, each using 8 bits. The remaining 16'th quants of the 4 groups of 16 are encoded with 8 bits using the same encoding as the groups of 5. The only kernel where we have complications is the CUDA dequantize kernel (because we are dequantizing 8 quants there, and we have different encoding for the 1st and 2nd group of 8 in a group of 16). Ths achieves better performance on all tested platforms than any previous 1.625 bpw attempt. We have: | model | size | params | backend | threads | test | t/s | | ---------------- | ---------: | ---------: | ---------- | ------: | ------------: | ---------------: | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | CUDA | 8 | pp512 | 9613.02 ± 24.54 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | CUDA | 8 | tg128 | 229.85 ± 0.33 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 16 | pp512 | 322.59 ± 1.00 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 16 | tg128 | 59.79 ± 0.03 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 8 | tg128 | 57.62 ± 0.21 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 4 | tg128 | 33.66 ± 0.29 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 2 | tg128 | 18.30 ± 0.01 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | Metal | 8 | pp512 | 698.13 ± 0.21 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | Metal | 8 | tg128 | 68.88 ± 0.24 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 8 | pp512 | 196.80 ± 0.50 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 8 | tg128 | 51.58 ± 0.41 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 4 | tg128 | 30.80 ± 0.03 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 2 | tg128 | 16.89 ± 0.01 | It is still slower than 2 bpw Bitnet, but the difference now is not as dramatic.
This commit is contained in:
@@ -380,11 +380,10 @@ static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m bl
|
||||
//
|
||||
#define QK_IQ1BN 64
|
||||
typedef struct {
|
||||
uint8_t ql[12];
|
||||
uint8_t extra;
|
||||
uint8_t ql[QK_IQ1BN/8];
|
||||
uint8_t qh[QK_IQ1BN/16];
|
||||
} block_iq1_bn;
|
||||
static_assert(sizeof(block_iq1_bn) == sizeof(uint8_t) + QK_IQ1BN/8 + QK_IQ1BN/16, "wrong iq1_bn block size/padding");
|
||||
static_assert(sizeof(block_iq1_bn) == 13, "wrong iq1_bn block size/padding");
|
||||
//
|
||||
// Bitnet - implemented as 2.25 bpw
|
||||
//
|
||||
|
||||
@@ -425,7 +425,10 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst
|
||||
const int64_t ii = blockIdx.x;
|
||||
const block_iq1_bn * x = (const block_iq1_bn *) vx;
|
||||
|
||||
static const uint16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
|
||||
static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
|
||||
|
||||
//#define COMPUTE_VS(v) 3*v >> 8
|
||||
#define COMPUTE_VS(v) (v + (v >> 1)) >> 7
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/4; // 0...7
|
||||
@@ -433,11 +436,24 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst
|
||||
dst_t * y = yy + ii*QK_K + 64*ib + 8*il;
|
||||
int64_t i = QK_K/QK_IQ1BN * ii + ib;
|
||||
if (i >= nb64) return;
|
||||
uint16_t val = x[i].ql[il] | ((x[i].qh[il%4] << (8 - 4*(il/4))) & 0x0f00) | ((x[i].extra << (12 - il)) & 4096);
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
uint16_t v = (val*k_mult[j] & 0x1fff)*3 >> 13;
|
||||
y[j] = v - 1;
|
||||
const int i16 = il/2;
|
||||
uint8_t q = x[i].ql[3*i16+2*(il%2)];
|
||||
for (int j = 0; j < 5; ++j) {
|
||||
uint8_t v = k_mult[j]*q;
|
||||
int8_t vs = COMPUTE_VS(v);
|
||||
y[2*(il%2)+j] = vs - 1;
|
||||
}
|
||||
q = x[i].ql[3*i16+1];
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
uint8_t v = k_mult[3*(il%2)+j]*q;
|
||||
int8_t vs = COMPUTE_VS(v);
|
||||
y[5*(1-(il%2))+j] = vs-1;
|
||||
}
|
||||
uint8_t v = (il%2) ? k_mult[i16]*x[i].extra : k_mult[2]*q;
|
||||
int8_t vs = COMPUTE_VS(v);
|
||||
y[7] = vs - 1;
|
||||
|
||||
#undef COMPUTE_VS
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
|
||||
@@ -1078,67 +1078,47 @@ static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
const block_iq1_bn * bq1 = (const block_iq1_bn *) vbq + kbx;
|
||||
|
||||
static const int16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
|
||||
static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
|
||||
|
||||
// iqs is 0 or 1
|
||||
|
||||
uint8_t extra = bq1->extra >> 4*iqs;
|
||||
int sumi = 0;
|
||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
const int * q8 = (const int *)bq8_1[iqs].qs;
|
||||
//int v[2];
|
||||
//int8_t * a = (int8_t *)v;
|
||||
//for (int l = 0; l < 2; ++l) {
|
||||
// int16_t val1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*l+0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096);
|
||||
// int16_t val2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*l+1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096);
|
||||
// for (int k = 0; k < 8; ++k) a[k] = ((val1*k_mult[k] & 0x1fff)*3 >> 13) - 1;
|
||||
// sumi = __dp4a(v[0], q8[4*l+0], __dp4a(v[1], q8[4*l+1], sumi));
|
||||
// for (int k = 0; k < 8; ++k) a[k] = ((val2*k_mult[k] & 0x1fff)*3 >> 13) - 1;
|
||||
// sumi = __dp4a(v[0], q8[4*l+2], __dp4a(v[1], q8[4*l+3], sumi));
|
||||
// extra >>= 2;
|
||||
//}
|
||||
|
||||
int v[4];
|
||||
int8_t * a = (int8_t *)v;
|
||||
int val[4];
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
int16_t val1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*l+0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096);
|
||||
int16_t val2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*l+1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096);
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
a[k+0] = ((val1*k_mult[k] & 0x1fff)*3 >> 13) - 1;
|
||||
a[k+8] = ((val2*k_mult[k] & 0x1fff)*3 >> 13) - 1;
|
||||
int8_t * a = (int8_t *)val;
|
||||
const int i16 = 2*iqs + l;
|
||||
for (int k = 0; k < 3; ++k) {
|
||||
uint8_t q = bq1->ql[3*i16+k];
|
||||
for (int j = 0; j < 5; ++j) {
|
||||
uint8_t v = k_mult[j]*q;
|
||||
int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
|
||||
*a++ = vs-1;
|
||||
}
|
||||
}
|
||||
sumi = __dp4a(v[0], q8[4*l+0], __dp4a(v[1], q8[4*l+1], __dp4a(v[2], q8[4*l+2], __dp4a(v[3], q8[4*l+3], sumi))));
|
||||
extra >>= 2;
|
||||
uint8_t v = k_mult[i16]*bq1->extra;
|
||||
int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
|
||||
*a++ = vs-1;
|
||||
sumi = __dp4a(val[0], q8[4*l+0], __dp4a(val[1], q8[4*l+1], __dp4a(val[2], q8[4*l+2], __dp4a(val[3], q8[4*l+3], sumi))));
|
||||
}
|
||||
|
||||
//int v[8];
|
||||
//int8_t * a = (int8_t *)v;
|
||||
//int16_t val1 = bq1->ql[4*iqs + 0] | ((bq1->qh[0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096);
|
||||
//int16_t val2 = bq1->ql[4*iqs + 1] | ((bq1->qh[1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096);
|
||||
//int16_t val3 = bq1->ql[4*iqs + 2] | ((bq1->qh[2] << (8-4*iqs)) & 0x0f00) | ((extra << 10) & 4096);
|
||||
//int16_t val4 = bq1->ql[4*iqs + 3] | ((bq1->qh[3] << (8-4*iqs)) & 0x0f00) | ((extra << 9) & 4096);
|
||||
//for (int k = 0; k < 8; ++k) {
|
||||
// a[k+ 0] = ((val1*k_mult[k] & 0x1fff)*3 >> 13) - 1;
|
||||
// a[k+ 8] = ((val2*k_mult[k] & 0x1fff)*3 >> 13) - 1;
|
||||
// a[k+16] = ((val3*k_mult[k] & 0x1fff)*3 >> 13) - 1;
|
||||
// a[k+24] = ((val4*k_mult[k] & 0x1fff)*3 >> 13) - 1;
|
||||
//}
|
||||
//sumi = __dp4a(v[0], q8[0], __dp4a(v[1], q8[1], __dp4a(v[2], q8[2], __dp4a(v[3], q8[3], sumi))));
|
||||
//sumi = __dp4a(v[4], q8[4], __dp4a(v[5], q8[5], __dp4a(v[6], q8[6], __dp4a(v[7], q8[7], sumi))));
|
||||
#else
|
||||
const int8_t * q8 = bq8_1[iqs].qs;
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
int val1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*l+0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096);
|
||||
int val2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*l+1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
int v1 = (val1*k_mult[k+0] & 0x1fff)*3 >> 13;
|
||||
int v2 = (val1*k_mult[k+4] & 0x1fff)*3 >> 13;
|
||||
int v3 = (val2*k_mult[k+0] & 0x1fff)*3 >> 13;
|
||||
int v4 = (val2*k_mult[k+4] & 0x1fff)*3 >> 13;
|
||||
sumi += (v1 - 1)*q8[k+0] + (v2 - 1)*q8[k+4] + (v3 - 1)*q8[k+8] + (v4 - 1)*q8[k+12];
|
||||
const int i16 = 2*iqs + l;
|
||||
for (int k = 0; k < 3; ++k) {
|
||||
uint8_t q = bq1->ql[3*i16+k];
|
||||
for (int j = 0; j < 5; ++j) {
|
||||
uint8_t v = k_mult[j]*q;
|
||||
int8_t vs = (v + (v >> 1)) >> 7;
|
||||
sumi += q8[j]*(vs - 1);
|
||||
}
|
||||
q8 += 5;
|
||||
}
|
||||
q8 += 16;
|
||||
extra >>= 2;
|
||||
uint8_t v = k_mult[i16]*bq1->extra;
|
||||
int8_t vs = (v + (v >> 1)) >> 7;
|
||||
sumi += q8[0]*(vs - 1);
|
||||
q8++;
|
||||
}
|
||||
#endif
|
||||
return __low2float(bq8_1[iqs].ds) * sumi;
|
||||
|
||||
115
ggml-metal.metal
115
ggml-metal.metal
@@ -5054,6 +5054,49 @@ static inline float iq1bn_fp8_to_float(uint8_t fp8) {
|
||||
return s.f;
|
||||
}
|
||||
|
||||
//static constant int8_t iq1bn_values[256*5] = {
|
||||
// -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 0, -1, -1, -1, 0, 0, -1, -1, -1, 1, 0,
|
||||
// -1, -1, -1, -1, 1, -1, -1, -1, 0, 1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, 0, -1, -1, 0, -1, 0, -1, -1, 1, -1, 0, -1,
|
||||
// -1, -1, 0, 0, -1, -1, 0, 0, 0, -1, -1, 1, 0, 0, -1, -1, -1, 1, 0, -1, -1, 0, 1, 0, -1, -1, 1, 1, 0, -1, -1, -1,
|
||||
// -1, 1, -1, -1, 0, 0, 0, 0, 0, 0, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, 0, 1, -1, -1, 0, 0, 1, -1, -1, 1, 0, 1,
|
||||
// -1, -1, -1, 1, 1, -1, -1, 0, 1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 0, -1, 0, -1, -1, 0, -1, 1, -1, -1, 0, -1,
|
||||
// -1, 0, -1, 0, -1, 0, 0, -1, 0, -1, 1, 0, -1, 0, -1, -1, 1, -1, 0, -1, 0, 1, -1, 0, -1, 1, 1, -1, 0, -1, -1, -1,
|
||||
// 0, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 0, -1, -1, 0, 0, 0, -1, 0, 0, 0, 0, -1, 1, 0, 0, 0,
|
||||
// -1, -1, 1, 0, 0, -1, 0, 1, 0, 0, -1, 1, 1, 0, 0, -1, -1, -1, 1, 0, -1, 0, -1, 1, 0, -1, 1, -1, 1, 0, -1, -1,
|
||||
// 0, 1, 0, -1, 0, 0, 1, 0, -1, 1, 0, 1, 0, -1, -1, 1, 1, 0, -1, 0, 1, 1, 0, -1, 1, 1, 1, 0, -1, -1, -1, -1,
|
||||
// 1, -1, 0, -1, -1, 1, -1, 1, -1, -1, 1, -1, 0, 0, 0, 0, 0, -1, 0, -1, 1, -1, 0, 0, -1, 1, -1, 1, 0, -1, 1, -1,
|
||||
// -1, 1, -1, 1, -1, 0, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 0, 1, -1, 0, -1, 0, 1, -1, 1, -1, 0, 1, -1, -1, 0,
|
||||
// 0, 1, -1, 0, 0, 0, 1, -1, 1, 0, 0, 1, -1, -1, 1, 0, 1, -1, 0, 1, 0, 1, -1, 1, 1, 0, 1, -1, -1, -1, 1, 1,
|
||||
// -1, 0, -1, 1, 1, -1, 1, -1, 1, 1, -1, 0, 0, 0, 0, 0, -1, 0, 1, 1, -1, 0, 0, 1, 1, -1, 1, 0, 1, 1, -1, -1,
|
||||
// 1, 1, 1, -1, 0, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, -1, -1, -1, 0, 0, -1, -1, -1, 0, 1, -1, -1, -1, 0, -1, 0, -1,
|
||||
// -1, 0, 0, 0, -1, -1, 0, 1, 0, -1, -1, 0, -1, 1, -1, -1, 0, 0, 1, -1, -1, 0, 1, 1, -1, -1, 0, -1, -1, 0, -1, 0,
|
||||
// 0, -1, 0, -1, 0, 1, -1, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 1, 0, 0, -1, 0, -1, 1,
|
||||
// 0, -1, 0, 0, 1, 0, -1, 0, 1, 1, 0, -1, 0, -1, -1, 1, -1, 0, 0, -1, 1, -1, 0, 1, -1, 1, -1, 0, -1, 0, 1, -1,
|
||||
// 0, 0, 0, 1, -1, 0, 1, 0, 1, -1, 0, -1, 1, 1, -1, 0, 0, 1, 1, -1, 0, 1, 1, 1, -1, 0, -1, -1, -1, 0, 0, 0,
|
||||
// -1, -1, 0, 0, 1, -1, -1, 0, 0, -1, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, -1, 1, -1,
|
||||
// 0, 0, 0, 1, -1, 0, 0, 1, 1, -1, 0, 0, -1, -1, 0, 0, 0, 0, -1, 0, 0, 0, 1, -1, 0, 0, 0, -1, 0, 0, 0, 0,
|
||||
// 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, -1, -1, 1, 0, 0, 0, -1,
|
||||
// 1, 0, 0, 1, -1, 1, 0, 0, -1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, -1, 1, 1, 0,
|
||||
// 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, -1, -1, -1, 1, 0, 0, -1, -1, 1, 0, 1, -1, -1, 1, 0, -1, 0, -1, 1, 0, 0,
|
||||
// 0, -1, 1, 0, 1, 0, -1, 1, 0, -1, 1, -1, 1, 0, 0, 1, -1, 1, 0, 1, 1, -1, 1, 0, -1, -1, 0, 1, 0, 0, -1, 0,
|
||||
// 1, 0, 1, -1, 0, 1, 0, -1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, -1, 1, 0, 1, 0,
|
||||
// 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, -1, -1, 1, 1, 0, 0, -1, 1, 1, 0, 1, -1, 1, 1, 0, -1, 0, 1, 1, 0, 0, 0,
|
||||
// 1, 1, 0, 1, 0, 1, 1, 0, -1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, -1, -1, -1, -1, 1, 0, -1, -1, -1,
|
||||
// 1, 1, -1, -1, -1, 1, -1, 0, -1, -1, 1, 0, 0, -1, -1, 1, 1, 0, -1, -1, 1, -1, 1, -1, -1, 1, 0, 0, 0, 0, 0, 0,
|
||||
// 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 0, -1, 1, 0, -1, 0, -1, 1, 1, -1, 0, -1, 1, -1, 0, 0, -1, 1, 0, 0, 0,
|
||||
// -1, 1, 1, 0, 0, -1, 1, -1, 1, 0, -1, 1, 0, 1, 0, -1, 1, 1, 1, 0, -1, 1, -1, -1, 1, -1, 1, 0, -1, 1, -1, 1,
|
||||
// 1, -1, 1, -1, 1, -1, 0, 1, -1, 1, 0, 0, 1, -1, 1, 1, 0, 1, -1, 1, -1, 1, 1, -1, 1, 0, 0, 0, 0, 0, 0, 1,
|
||||
// 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, -1, 0, 1, 0, -1, -1, 0, 1, 1, -1, -1, 0, 1, -1, 0, -1, 0, 1, 0, 0, -1, 0,
|
||||
// 1, 1, 0, -1, 0, 1, -1, 1, -1, 0, 1, 0, 1, -1, 0, 1, 1, 1, -1, 0, 1, -1, -1, 0, 0, 1, 0, -1, 0, 0, 1, 1,
|
||||
// -1, 0, 0, 1, -1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, -1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0,
|
||||
// 0, 0, 1, 1, 0, 0, 1, -1, -1, 1, 0, 1, 0, -1, 1, 0, 1, 1, -1, 1, 0, 1, -1, 0, 1, 0, 1, 0, 0, 1, 0, 1,
|
||||
// 1, 0, 1, 0, 1, -1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, -1, -1, -1, 1, 1, 0, -1, -1, 1, 1, 1, -1,
|
||||
// -1, 1, 1, -1, 0, -1, 1, 1, 0, 0, -1, 1, 1, 1, 0, -1, 1, 1, -1, 1, -1, 1, 1, 0, 1, -1, 1, 1, 1, 1, -1, 1,
|
||||
// 1, 0, 0, 0, 0, 0, -1, -1, 0, 1, 1, 0, -1, 0, 1, 1, 1, -1, 0, 1, 1, -1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1,
|
||||
// 0, 0, 1, 1, -1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, -1, -1, 1, 1, 1, 0, -1, 1, 1, 1, 1, -1, 1,
|
||||
// 1, 1, -1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, -1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
//};
|
||||
|
||||
void kernel_mul_mv_iq1_bn_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
@@ -5087,53 +5130,62 @@ void kernel_mul_mv_iq1_bn_f32_impl(
|
||||
device const block_iq1_bn * x = (device const block_iq1_bn *) src0 + ib_row + offset0;
|
||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float4 yl[2];
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
float yl[16];
|
||||
float sumf[N_DST]={0.f};
|
||||
|
||||
const int nb32 = nb * (QK_IQ1BN / 32);
|
||||
|
||||
const int ix = tiisg/4;
|
||||
const int ir = tiisg%4;
|
||||
const int ix = tiisg/2;
|
||||
const int ir = tiisg%2;
|
||||
|
||||
device const float4 * y4 = (device const float4 *)y + 8 * ix + 2 * ir;
|
||||
device const float * y4 = (device const float *)y + 32 * ix + 16 * ir;
|
||||
|
||||
uint32_t aux32[2];
|
||||
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
|
||||
|
||||
const float values[3] = {-1.f, 0.f, 1.f};
|
||||
|
||||
constexpr int16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
|
||||
constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1};
|
||||
|
||||
for (int ib32 = ix; ib32 < nb32; ib32 += 8) {
|
||||
for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
|
||||
|
||||
yl[0] = y4[0]; yl[1] = y4[1];
|
||||
for (int j = 0; j < 16; ++j) yl[j] = y4[j];
|
||||
|
||||
const int ibl = ib32 / (QK_IQ1BN / 32);
|
||||
const int ib = ib32 % (QK_IQ1BN / 32);
|
||||
const int il = 4*ib + ir;
|
||||
const int i16 = 2*ib + ir;
|
||||
|
||||
device const block_iq1_bn * xr = x + ibl;
|
||||
device const uint8_t * ql = xr->ql + 3*i16;
|
||||
device const uint8_t * extra = (device const uint8_t *)&xr->extra;
|
||||
device const uint8_t * ql = xr->ql + il;
|
||||
device const uint8_t * qh = xr->qh + il%4;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
|
||||
uint8_t h = extra[0] >> il;
|
||||
float acc = 0;
|
||||
int i = 0;
|
||||
for (int k = 0; k < 3; ++k) {
|
||||
//constant int8_t * vs = iq1bn_values + 5*ql[k];
|
||||
//for (int j = 0; j < 5; ++j) acc += yl[5*k+j]*vs[j];
|
||||
uint8_t q = ql[k];
|
||||
for (int j = 0; j < 5; ++j) {
|
||||
uint8_t v = k_mult[j]*q;
|
||||
v = 3*v >> 8; //(v + (v >> 1)) >> 7;
|
||||
acc += yl[i++] * values[v];
|
||||
}
|
||||
}
|
||||
//constant int8_t * vs = iq1bn_values + 5*extra[0];
|
||||
//acc += yl[15] * vs[i16];
|
||||
uint8_t v = k_mult[i16]*extra[0];
|
||||
v = 3*v >> 8; //(v + (v >> 1)) >> 7;
|
||||
acc += yl[15] * values[v];
|
||||
|
||||
int16_t val = ql[0] | ((qh[0] << (8 - 4*(il/4))) & 0x0f00) | ((extra[0] << (12 - il)) & 4096);
|
||||
float4 acc4 = yl[0] * float4{values[(val*k_mult[0] & 0x1fff)*3 >> 13], values[(val*k_mult[1] & 0x1fff)*3 >> 13],
|
||||
values[(val*k_mult[2] & 0x1fff)*3 >> 13], values[(val*k_mult[3] & 0x1fff)*3 >> 13]}
|
||||
+ yl[1] * float4{values[(val*k_mult[4] & 0x1fff)*3 >> 13], values[(val*k_mult[5] & 0x1fff)*3 >> 13],
|
||||
values[(val*k_mult[6] & 0x1fff)*3 >> 13], values[(val*k_mult[7] & 0x1fff)*3 >> 13]};
|
||||
sumf[row] += acc4[0] + acc4[1] + acc4[2] + acc4[3];
|
||||
sumf[row] += acc;
|
||||
|
||||
extra += nb*sizeof(block_iq1_bn);
|
||||
ql += nb*sizeof(block_iq1_bn);
|
||||
qh += nb*sizeof(block_iq1_bn);
|
||||
}
|
||||
|
||||
y4 += 32 * 2;
|
||||
y4 += 32 * 16;
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; row += 2) {
|
||||
@@ -5990,18 +6042,23 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
|
||||
template <typename type4x4>
|
||||
void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) {
|
||||
// il is in 0...3
|
||||
uint8_t gs = xb->extra >> 2*il;
|
||||
|
||||
constexpr int16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
|
||||
constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1};
|
||||
|
||||
short il1 = 2*il+0, il2 = 2*il+1;
|
||||
int16_t v1 = xb->ql[il1] | ((xb->qh[il1%4] << (8 - 4*(il1/4))) & 0x0f00) | ((gs << 12) & 4096);
|
||||
int16_t v2 = xb->ql[il2] | ((xb->qh[il2%4] << (8 - 4*(il2/4))) & 0x0f00) | ((gs << 11) & 4096);
|
||||
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
reg[i/4+0][i%4] = ((v1*k_mult[i] & 0x1fff)*3 >> 13) - 1;
|
||||
reg[i/4+2][i%4] = ((v2*k_mult[i] & 0x1fff)*3 >> 13) - 1;
|
||||
int i = 0;
|
||||
for (int k = 0; k < 3; ++k) {
|
||||
uint8_t q = xb->ql[3*il + k];
|
||||
for (int j = 0; j < 5; ++j) {
|
||||
uint8_t v = k_mult[j]*q;
|
||||
int8_t vs = 3*v >> 8;
|
||||
//int8_t vs = (v + (v >> 1)) >> 7;
|
||||
reg[i/4][i%4] = vs - 1;
|
||||
++i;
|
||||
}
|
||||
}
|
||||
uint8_t v = k_mult[il]*xb->extra;
|
||||
int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
|
||||
reg[3][3] = vs - 1;
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
|
||||
112
iqk-quantize.cpp
112
iqk-quantize.cpp
@@ -118,7 +118,7 @@ uint16_t IQ1BNQuantizer::quantize_one_block_1bn(const IQ1BNData& iq1bn, const fl
|
||||
|
||||
void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix) {
|
||||
|
||||
static const int k_nb[8] = {1, 3, 9, 27, 81, 243, 729, 2187};
|
||||
static const int k_nb[6] = {1, 3, 9, 27, 81, 243};
|
||||
(void)imatrix;
|
||||
|
||||
const int nblock = n_per_row/QK_IQ1BN;
|
||||
@@ -126,21 +126,24 @@ void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, i
|
||||
for (int ib = 0; ib < nblock; ++ib) {
|
||||
std::memset(&y[ib], 0, sizeof(block_iq1_bn));
|
||||
auto xb = src + ib*QK_IQ1BN;
|
||||
for (int i = 0; i < QK_IQ1BN/8; ++i) {
|
||||
int idx = 0;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
float v = xb[8*i + j];
|
||||
int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2;
|
||||
idx += k_nb[j]*q;
|
||||
int v13 = 0;
|
||||
for (int i16 = 0; i16 < QK_IQ1BN/16; ++i16) {
|
||||
for (int k = 0; k < 3; ++k) {
|
||||
int idx = 0;
|
||||
for (int j = 0; j < 5; ++j) {
|
||||
float v = xb[16*i16 + 5*k + j];
|
||||
int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2;
|
||||
idx += k_nb[j]*q;
|
||||
}
|
||||
idx = (256*idx + k_nb[5] - 1)/k_nb[5];
|
||||
y[ib].ql[3*i16 + k] = idx;
|
||||
}
|
||||
idx = (8192*idx + 6560)/6561;
|
||||
y[ib].ql[i] = idx & 255;
|
||||
y[ib].qh[i%4] |= ((idx >> 8) & 0xf) << 4*(i/4);
|
||||
y[ib].extra |= (idx >> 12) << i;
|
||||
|
||||
float v = xb[16*i16 + 15];
|
||||
int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2;
|
||||
v13 += k_nb[i16]*q;
|
||||
}
|
||||
y[ib].extra = (256*v13 + k_nb[5] - 1)/k_nb[5];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, int n_per_row, const float * imatrix) {
|
||||
@@ -194,18 +197,23 @@ void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) {
|
||||
assert(k%QK_IQ1BN == 0);
|
||||
int nblock = k / QK_IQ1BN;
|
||||
|
||||
static const int k_mult[8] = {17496, 5832, 1944, 648, 216, 72, 24, 8};
|
||||
static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
|
||||
|
||||
for (int i = 0; i < nblock; ++i) {
|
||||
uint8_t extra = x[i].extra;
|
||||
auto qh = x[i].qh;
|
||||
auto ql = x[i].ql;
|
||||
for (int k = 0; k < QK_IQ1BN/8; ++k) {
|
||||
uint16_t idx = ql[k] | ((qh[k%4] << (8 - 4*(k/4))) & 0x0f00) | ((extra << (12 - k)) & 4096);
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
int v = (idx*k_mult[j] & 0xffff)*3 >> 16;
|
||||
*y++ = v - 1;
|
||||
for (int i16 = 0; i16 < QK_IQ1BN/16; ++i16) {
|
||||
for (int k = 0; k < 3; ++k) {
|
||||
for (int j = 0; j < 5; ++j) {
|
||||
uint8_t v = ql[k]*k_mult[j];
|
||||
int8_t vs = ((v + (v >> 1)) >> 7);
|
||||
*y++ = vs - 1;
|
||||
}
|
||||
}
|
||||
ql += 3;
|
||||
uint8_t v = extra*k_mult[i16];
|
||||
int8_t vs = ((v + (v >> 1)) >> 7);
|
||||
*y++ = vs - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -260,42 +268,44 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr uint16_t k_magic = 0xaaaa;
|
||||
// TODO
|
||||
|
||||
const block_iq1_bn * x = (const block_iq1_bn *)vx;
|
||||
//constexpr uint16_t k_magic = 0xaaaa;
|
||||
|
||||
const float * d8 = (const float *)vy;
|
||||
const int8_t * q8 = (const int8_t *)(d8 + 4);
|
||||
int nblock = n / QK_IQ1BN;
|
||||
//const block_iq1_bn * x = (const block_iq1_bn *)vx;
|
||||
|
||||
int sumi[8] = {};
|
||||
uint32_t aux32[2];
|
||||
const int8_t * aux8 = (const int8_t *)aux32;
|
||||
//const float * d8 = (const float *)vy;
|
||||
//const int8_t * q8 = (const int8_t *)(d8 + 4);
|
||||
//int nblock = n / QK_IQ1BN;
|
||||
|
||||
for (int i = 0; i < nblock; ++i) {
|
||||
auto qh = x[i].qh;
|
||||
auto ql = x[i].ql;
|
||||
auto extra = x[i].extra;
|
||||
for (int j = 0; j < QK_IQ1BN/16; ++j) {
|
||||
uint16_t idx1 = ql[2*j+0] | ((qh[j] << 8) & 0x0f00);
|
||||
uint16_t idx2 = ql[2*j+1] | ((qh[j] << 4) & 0x0f00);
|
||||
uint16_t val1 = extra & 1 ? k_magic - iq1bn_grid_u16[idx1] : iq1bn_grid_u16[idx1];
|
||||
uint16_t val2 = extra & 2 ? k_magic - iq1bn_grid_u16[idx2] : iq1bn_grid_u16[idx2];
|
||||
extra >>= 2;
|
||||
aux32[0] = val1 | (val1 << 14);
|
||||
aux32[1] = (aux32[0] >> 4) & 0x03030303;
|
||||
aux32[0] &= 0x03030303;
|
||||
for (int k = 0; k < 8; ++k) sumi[k] += q8[k] * (aux8[k] - 1);
|
||||
q8 += 8;
|
||||
aux32[0] = val2 | (val2 << 14);
|
||||
aux32[1] = (aux32[0] >> 4) & 0x03030303;
|
||||
aux32[0] &= 0x03030303;
|
||||
for (int k = 0; k < 8; ++k) sumi[k] += q8[k] * (aux8[k] - 1);
|
||||
q8 += 8;
|
||||
}
|
||||
}
|
||||
//int sumi[8] = {};
|
||||
//uint32_t aux32[2];
|
||||
//const int8_t * aux8 = (const int8_t *)aux32;
|
||||
|
||||
*s = d8[0] * (sumi[0] + sumi[4]) + d8[1] * (sumi[1] + sumi[5]) + d8[2] * (sumi[2] + sumi[6]) + d8[3] * (sumi[3] + sumi[7]);
|
||||
//for (int i = 0; i < nblock; ++i) {
|
||||
// auto qh = x[i].qh;
|
||||
// auto ql = x[i].ql;
|
||||
// auto extra = x[i].extra;
|
||||
// for (int j = 0; j < QK_IQ1BN/16; ++j) {
|
||||
// uint16_t idx1 = ql[2*j+0] | ((qh[j] << 8) & 0x0f00);
|
||||
// uint16_t idx2 = ql[2*j+1] | ((qh[j] << 4) & 0x0f00);
|
||||
// uint16_t val1 = extra & 1 ? k_magic - iq1bn_grid_u16[idx1] : iq1bn_grid_u16[idx1];
|
||||
// uint16_t val2 = extra & 2 ? k_magic - iq1bn_grid_u16[idx2] : iq1bn_grid_u16[idx2];
|
||||
// extra >>= 2;
|
||||
// aux32[0] = val1 | (val1 << 14);
|
||||
// aux32[1] = (aux32[0] >> 4) & 0x03030303;
|
||||
// aux32[0] &= 0x03030303;
|
||||
// for (int k = 0; k < 8; ++k) sumi[k] += q8[k] * (aux8[k] - 1);
|
||||
// q8 += 8;
|
||||
// aux32[0] = val2 | (val2 << 14);
|
||||
// aux32[1] = (aux32[0] >> 4) & 0x03030303;
|
||||
// aux32[0] &= 0x03030303;
|
||||
// for (int k = 0; k < 8; ++k) sumi[k] += q8[k] * (aux8[k] - 1);
|
||||
// q8 += 8;
|
||||
// }
|
||||
//}
|
||||
|
||||
//*s = d8[0] * (sumi[0] + sumi[4]) + d8[1] * (sumi[1] + sumi[5]) + d8[2] * (sumi[2] + sumi[6]) + d8[3] * (sumi[3] + sumi[7]);
|
||||
}
|
||||
|
||||
void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
|
||||
|
||||
146
iqk_mul_mat.cpp
146
iqk_mul_mat.cpp
@@ -1342,44 +1342,31 @@ template <int nrc> struct Q8_K64 {
|
||||
|
||||
struct DequantizerIQ1BN {
|
||||
const __m256i m1_8 = _mm256_set1_epi8(1);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
const __m128i shifthh = _mm_set_epi16(5, 6, 7, 8, 9, 10, 11, 12);
|
||||
#else
|
||||
const __m128i mulhh = _mm_set_epi16(32, 64, 128, 256, 512, 1024, 2048, 4096);
|
||||
#endif
|
||||
const __m128i maskhh = _mm_set1_epi16(4096);
|
||||
const __m256i shuffles[4] = {
|
||||
_mm256_set_epi64x(0x0302030203020302, 0x0302030203020302, 0x0100010001000100, 0x0100010001000100),
|
||||
_mm256_set_epi64x(0x0706070607060706, 0x0706070607060706, 0x0504050405040504, 0x0504050405040504),
|
||||
_mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0908090809080908),
|
||||
_mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0f0e0f0e0f0e0f0e, 0x0d0c0d0c0d0c0d0c, 0x0d0c0d0c0d0c0d0c),
|
||||
static __m128i load_shuffle(int i) {
|
||||
static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12,
|
||||
3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12,
|
||||
6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12,
|
||||
9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12};
|
||||
return _mm_loadu_si128((const __m128i*)data + i);
|
||||
}
|
||||
const __m128i shuff[4] = { load_shuffle(0), load_shuffle(1), load_shuffle(2), load_shuffle(3) };
|
||||
const __m256i mult[4] = {
|
||||
_mm256_set_epi64x(0x5100010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
|
||||
_mm256_set_epi64x(0x1b00010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
|
||||
_mm256_set_epi64x(0x0900010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
|
||||
_mm256_set_epi64x(0x0300010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
|
||||
};
|
||||
const __m256i mult = _mm256_set_epi16(8, 24, 72, 216, 648, 1944, 5832, 17496, 8, 24, 72, 216, 648, 1944, 5832, 17496);
|
||||
const __m256i m3 = _mm256_set1_epi16(3);
|
||||
const __m128i shuff_l = _mm_set_epi8(-128, 8, -128, 7, -128, 6, -128, 5, -128, 4, -128, 3, -128, 2, -128, 1);
|
||||
const __m128i shuff_h = _mm_set_epi8(12, -128, 11, -128, 10, -128, 9, -128, 12, -128, 11, -128, 10, -128, 9, -128);
|
||||
const __m128i shift_h = _mm_set_epi32(4, 4, 0, 0);
|
||||
const __m128i mask_h = _mm_set1_epi16(0x0f00);
|
||||
const __m128i shuff_hh = _mm_set_epi8(-128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0, -128, 0);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
||||
#endif
|
||||
|
||||
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) {
|
||||
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) const {
|
||||
auto data = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes!
|
||||
auto aux1 = _mm_shuffle_epi8(data, shuff_l);
|
||||
auto aux2 = _mm_and_si128(_mm_srlv_epi32(_mm_shuffle_epi8(data, shuff_h), shift_h), mask_h);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto aux3 = _mm_and_si128(_mm_sllv_epi16(_mm_shuffle_epi8(data, shuff_hh), shifthh), maskhh);
|
||||
#else
|
||||
auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_shuffle_epi8(data, shuff_hh), mulhh), maskhh);
|
||||
#endif
|
||||
auto all128 = _mm_or_si128(_mm_or_si128(aux1, aux2), aux3);
|
||||
auto all = MM256_SET_M128I(all128, all128);
|
||||
auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3);
|
||||
auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3);
|
||||
auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3);
|
||||
auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3);
|
||||
auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[0])), mult[0]), m3);
|
||||
auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[1])), mult[1]), m3);
|
||||
auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[2])), mult[2]), m3);
|
||||
auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_cvtepu8_epi16(_mm_shuffle_epi8(data, shuff[3])), mult[3]), m3);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8);
|
||||
v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8);
|
||||
@@ -1389,21 +1376,6 @@ struct DequantizerIQ1BN {
|
||||
#endif
|
||||
}
|
||||
|
||||
//IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) {
|
||||
|
||||
// auto aux1 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)ql));
|
||||
// uint32_t aux32; std::memcpy(&aux32, qh, 4);
|
||||
// auto aux2 = _mm_cvtepu8_epi16(_mm_and_si128(_mm_set_epi32(aux32, aux32, aux32, aux32 << 4), mask1));
|
||||
// auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_set1_epi16(extra), mulhh), maskhh);
|
||||
// auto all128 = _mm_or_si128(_mm_slli_epi16(aux2, 4), _mm_or_si128(aux1, aux3));
|
||||
// auto all = MM256_SET_M128I(all128, all128);
|
||||
// auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3);
|
||||
// auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3);
|
||||
// auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3);
|
||||
// auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3);
|
||||
// v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8);
|
||||
// v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8);
|
||||
//}
|
||||
};
|
||||
|
||||
template <int nrc_y>
|
||||
@@ -1466,9 +1438,9 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4);
|
||||
#else
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0])),
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1])));
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1])));
|
||||
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2])),
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3])));
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3])));
|
||||
dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2));
|
||||
accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
|
||||
#endif
|
||||
@@ -4376,73 +4348,29 @@ static const uint64_t kall_signs[257] = {
|
||||
struct DequantizerIQ1BN {
|
||||
const uint8x16_t m1 = vdupq_n_u8(1);
|
||||
|
||||
static inline uint8x16_t load_shuffle_l() {
|
||||
static const uint8_t data[16] = {1, 255, 2, 255, 3, 255, 4, 255, 5, 255, 6, 255, 7, 255, 8, 255};
|
||||
return vld1q_u8(data);
|
||||
static inline uint8x16x4_t load_shuffles() {
|
||||
static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12,
|
||||
3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12,
|
||||
6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12,
|
||||
9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12};
|
||||
return vld1q_u8_x4(data);
|
||||
}
|
||||
static inline uint8x16_t load_shuffle_h() {
|
||||
static const uint8_t data[16] = {9, 255, 10, 255, 11, 255, 12, 255, 9, 255, 10, 255, 11, 255, 12, 255};
|
||||
return vld1q_u8(data);
|
||||
static inline uint8x16x4_t load_mult() {
|
||||
static const uint8_t data[64] = {81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81,
|
||||
81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 27,
|
||||
81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 9,
|
||||
81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 3};
|
||||
return vld1q_u8_x4(data);
|
||||
}
|
||||
static inline uint8x16_t load_shuffle_hh() {
|
||||
static const uint8_t data[16] = {0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255};
|
||||
return vld1q_u8(data);
|
||||
}
|
||||
static inline int16x8_t load_shift_hh() {
|
||||
static const int16_t data[8] = {12, 11, 10, 9, 8, 7, 6, 5};
|
||||
return vld1q_s16(data);
|
||||
}
|
||||
static inline uint16x8_t load_mult() {
|
||||
//static const uint16_t data[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
|
||||
static const uint16_t data[8] = {2187*8, 729*8, 243*8, 81*8, 27*8, 9*8, 3*8, 1*8};
|
||||
return vld1q_u16(data);
|
||||
}
|
||||
//static inline uint8x16x4_t load_shuffles(uint16_t s0) {
|
||||
// uint8x16x4_t r;
|
||||
// auto step = vdupq_n_u8(4);
|
||||
// r.val[0] = vreinterpretq_u8_u16(vdupq_n_u16(s0));
|
||||
// r.val[1] = vaddq_u8(r.val[0], step);
|
||||
// r.val[2] = vaddq_u8(r.val[1], step);
|
||||
// r.val[3] = vaddq_u8(r.val[2], step);
|
||||
// return r;
|
||||
//}
|
||||
|
||||
const uint8x16_t shuff_l = load_shuffle_l();
|
||||
const uint8x16_t shuff_h = load_shuffle_h();
|
||||
const int32x4_t shift_h = {8, 8, 4, 4};
|
||||
const uint16x8_t mask_h = vdupq_n_u16(0x0f00);
|
||||
const uint8x16_t shuff_hh = load_shuffle_hh();
|
||||
const uint16x8_t mask_hh = vdupq_n_u16(4096);
|
||||
const int16x8_t shift_hh = load_shift_hh();
|
||||
const uint16x8_t mult = load_mult();
|
||||
const uint8x16_t step = vdupq_n_u8(2);
|
||||
const uint8x16_t shuff0 = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
|
||||
//const uint8x16x4_t shuff1 = load_shuffles(0x0100);
|
||||
//const uint8x16x4_t shuff2 = load_shuffles(0x0302);
|
||||
//const uint16x8_t mask = vdupq_n_u16(0x1fff);
|
||||
//const uint16x8_t m3 = vdupq_n_u16(3);
|
||||
const uint8x16x4_t shuff = load_shuffles();
|
||||
const uint8x16x4_t mult = load_mult();
|
||||
|
||||
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const {
|
||||
auto data = vld1q_u8((const uint8_t *)x);
|
||||
auto aux1 = vqtbl1q_u8(data, shuff_l);
|
||||
auto aux2 = vandq_u16(vshlq_u32(vqtbl1q_u8(data, shuff_h), shift_h), mask_h);
|
||||
auto aux3 = vandq_u16(vshlq_u16(vqtbl1q_u8(data, shuff_hh), shift_hh), mask_hh);
|
||||
auto all = vorrq_u16(vorrq_u16(aux1, aux2), aux3);
|
||||
auto shuffle = shuff0;
|
||||
//auto shuffle = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
|
||||
//auto step = vdupq_n_u8(2);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step);
|
||||
auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step);
|
||||
//auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuff1.val[k]));
|
||||
//auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuff2.val[k]));
|
||||
v1 = vmulq_u16(v1, mult);
|
||||
v2 = vmulq_u16(v2, mult);
|
||||
v1 = vshrq_n_u16(vhaddq_u16(v1, vshrq_n_u16(v1, 1)), 14);
|
||||
v2 = vshrq_n_u16(vhaddq_u16(v2, vshrq_n_u16(v2, 1)), 14);
|
||||
//v1 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v1, mult), mask), m3), 13);
|
||||
//v2 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v2, mult), mask), m3), 13);
|
||||
v.val[k] = vsubq_s8(vreinterpretq_s8_u8(vcombine_u8(vmovn_u16(v1), vmovn_u16(v2))), m1);
|
||||
auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]);
|
||||
val = vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6);
|
||||
v.val[k] = vsubq_s8(vreinterpretq_s8_u8(val), m1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user