mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-02 12:39:54 +00:00
iq1bn: attempt without a lookup table
This commit is contained in:
@@ -118,17 +118,29 @@ uint16_t IQ1BNQuantizer::quantize_one_block_1bn(const IQ1BNData& iq1bn, const fl
|
||||
|
||||
void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix) {
|
||||
|
||||
static const int k_nb[8] = {1, 3, 9, 27, 81, 243, 729, 2187};
|
||||
(void)imatrix;
|
||||
|
||||
const int nblock = n_per_row/QK_IQ1BN;
|
||||
|
||||
const auto& iq1bn = get_iq1bn_data();
|
||||
|
||||
for (int ib = 0; ib < nblock; ++ib) {
|
||||
std::memset(&y[ib], 0, sizeof(block_iq1_bn));
|
||||
auto xb = src + QK_IQ1BN*ib;
|
||||
y[ib].extra = quantize_one_block_1bn(iq1bn, xb, L, y[ib].ql, y[ib].qh);
|
||||
auto xb = src + ib*QK_IQ1BN;
|
||||
for (int i = 0; i < QK_IQ1BN/8; ++i) {
|
||||
int idx = 0;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
float v = xb[8*i + j];
|
||||
int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2;
|
||||
idx += k_nb[j]*q;
|
||||
}
|
||||
idx = (8192*idx + 6560)/6561;
|
||||
y[ib].ql[i] = idx & 255;
|
||||
y[ib].qh[i%4] |= ((idx >> 8) & 0xf) << 4*(i/4);
|
||||
y[ib].extra |= (idx >> 12) << i;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, int n_per_row, const float * imatrix) {
|
||||
@@ -182,21 +194,18 @@ void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) {
|
||||
assert(k%QK_IQ1BN == 0);
|
||||
int nblock = k / QK_IQ1BN;
|
||||
|
||||
uint32_t aux32[2];
|
||||
const int8_t * aux8 = (const int8_t *)aux32;
|
||||
static const int k_mult[8] = {17496, 5832, 1944, 648, 216, 72, 24, 8};
|
||||
|
||||
for (int i = 0; i < nblock; ++i) {
|
||||
uint8_t extra = x[i].extra;
|
||||
auto qh = x[i].qh;
|
||||
auto ql = x[i].ql;
|
||||
for (int k = 0; k < QK_IQ1BN/8; ++k) {
|
||||
uint16_t idx = ql[k] | ((qh[k/2] << (8 - 4*(k%2))) & 0x0f00);
|
||||
uint16_t val = extra & 1 ? 0xaaaa - iq1bn_grid_u16[idx] : iq1bn_grid_u16[idx];
|
||||
aux32[0] = val | (val << 14);
|
||||
aux32[1] = (aux32[0] >> 4) & 0x03030303;
|
||||
aux32[0] &= 0x03030303;
|
||||
for (int j = 0; j < 8; ++j) y[j] = aux8[j] - 1;
|
||||
y += 8;
|
||||
extra >>= 1;
|
||||
uint16_t idx = ql[k] | ((qh[k%4] << (8 - 4*(k/4))) & 0x0f00) | ((extra << (12 - k)) & 4096);
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
int v = (idx*k_mult[j] & 0xffff)*3 >> 16;
|
||||
*y++ = v - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1342,40 +1342,32 @@ template <int nrc> struct Q8_K64 {
|
||||
|
||||
struct DequantizerIQ1BN {
|
||||
const __m256i m1_8 = _mm256_set1_epi8(1);
|
||||
const __m256i shuff1 = _mm256_set_epi64x(0x0908090809080908, 0x0100010001000100, 0x0908090809080908, 0x0100010001000100);
|
||||
#if defined __AVX512F__ && defined __AVX512VL__
|
||||
const __m256i minus1 = _mm256_set1_epi64x(0xaaaa);
|
||||
const __m256i shifts = _mm256_set1_epi64x(0x0006000400020000);
|
||||
#else
|
||||
const __m256i shuff3 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
|
||||
const __m256i shuff4 = _mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404);
|
||||
const __m256i mask1 = _mm256_set1_epi64x(0x8040201008040201);
|
||||
#endif
|
||||
const __m256i qmask = _mm256_set1_epi8(0x03);
|
||||
const __m128i mask1 = _mm_set1_epi8(0xf0);
|
||||
const __m128i mulhh = _mm_set_epi16(32, 64, 128, 256, 512, 1024, 2048, 4096);
|
||||
const __m128i maskhh = _mm_set1_epi16(4096);
|
||||
const __m256i shuffles[4] = {
|
||||
_mm256_set_epi64x(0x0302030203020302, 0x0302030203020302, 0x0100010001000100, 0x0100010001000100),
|
||||
_mm256_set_epi64x(0x0706070607060706, 0x0706070607060706, 0x0504050405040504, 0x0504050405040504),
|
||||
_mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0908090809080908),
|
||||
_mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0f0e0f0e0f0e0f0e, 0x0d0c0d0c0d0c0d0c, 0x0d0c0d0c0d0c0d0c),
|
||||
};
|
||||
const __m256i mult = _mm256_set_epi16(8, 24, 72, 216, 648, 1944, 5832, 17496, 8, 24, 72, 216, 648, 1944, 5832, 17496);
|
||||
const __m256i m3 = _mm256_set1_epi16(3);
|
||||
|
||||
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, __m256i& v1, __m256i& v2) {
|
||||
|
||||
auto aux1 = _mm256_set_epi64x(iq1bn_grid_u16[ql[3] | ((qh[1] << 4) & 0x0f00)], iq1bn_grid_u16[ql[2] | ((qh[1] << 8) & 0x0f00)],
|
||||
iq1bn_grid_u16[ql[1] | ((qh[0] << 4) & 0x0f00)], iq1bn_grid_u16[ql[0] | ((qh[0] << 8) & 0x0f00)]);
|
||||
auto aux2 = _mm256_set_epi64x(iq1bn_grid_u16[ql[7] | ((qh[3] << 4) & 0x0f00)], iq1bn_grid_u16[ql[6] | ((qh[3] << 8) & 0x0f00)],
|
||||
iq1bn_grid_u16[ql[5] | ((qh[2] << 4) & 0x0f00)], iq1bn_grid_u16[ql[4] | ((qh[2] << 8) & 0x0f00)]);
|
||||
#if defined __AVX512F__ && defined __AVX512VL__
|
||||
aux1 = _mm256_mask_sub_epi64(aux1, extra & 0xf, minus1, aux1);
|
||||
aux2 = _mm256_mask_sub_epi64(aux2, extra >> 4, minus1, aux2);
|
||||
v1 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srlv_epi16(_mm256_shuffle_epi8(aux1, shuff1), shifts), qmask), m1_8);
|
||||
v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srlv_epi16(_mm256_shuffle_epi8(aux2, shuff1), shifts), qmask), m1_8);
|
||||
#else
|
||||
aux1 = _mm256_or_si256(aux1, _mm256_slli_epi64(aux1, 14));
|
||||
aux2 = _mm256_or_si256(aux2, _mm256_slli_epi64(aux2, 14));
|
||||
aux1 = _mm256_or_si256(aux1, _mm256_slli_epi64(aux1, 28));
|
||||
aux2 = _mm256_or_si256(aux2, _mm256_slli_epi64(aux2, 28));
|
||||
v1 = _mm256_sub_epi8(_mm256_and_si256(aux1, qmask), m1_8);
|
||||
v2 = _mm256_sub_epi8(_mm256_and_si256(aux2, qmask), m1_8);
|
||||
auto all_signs = _mm256_set1_epi8(extra);
|
||||
all_signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(all_signs, mask1), mask1), m1_8);
|
||||
v1 = _mm256_sign_epi8(v1, _mm256_shuffle_epi8(all_signs, shuff3));
|
||||
v2 = _mm256_sign_epi8(v2, _mm256_shuffle_epi8(all_signs, shuff4));
|
||||
#endif
|
||||
auto aux1 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)ql));
|
||||
uint32_t aux32; std::memcpy(&aux32, qh, 4);
|
||||
auto aux2 = _mm_cvtepu8_epi16(_mm_and_si128(_mm_set_epi32(aux32, aux32, aux32, aux32 << 4), mask1));
|
||||
auto aux3 = _mm_and_si128(_mm_mullo_epi16(_mm_set1_epi16(extra), mulhh), maskhh);
|
||||
auto all128 = _mm_or_si128(_mm_slli_epi16(aux2, 4), _mm_or_si128(aux1, aux3));
|
||||
auto all = MM256_SET_M128I(all128, all128);
|
||||
auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[0]), mult), m3);
|
||||
auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[1]), mult), m3);
|
||||
auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[2]), mult), m3);
|
||||
auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(all, shuffles[3]), mult), m3);
|
||||
v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8);
|
||||
v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user