mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-07 04:20:03 +00:00
iqk_mul_mat: add iq1_bn (bitnet)
We get 174 t/s for PP-512 and 49 t/s for TG-128 using 16 threads.
This commit is contained in:
@@ -1313,6 +1313,92 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc> struct Q8_K64 {
|
||||
|
||||
constexpr static int nrc_y = nrc;
|
||||
|
||||
Q8_K64(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_K64 *)info.src1_row(iy); }
|
||||
|
||||
inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); }
|
||||
inline float scale(int iy, int i) const { return y[iy][i].d; }
|
||||
|
||||
const block_q8_K64 * y[nrc_y];
|
||||
};
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
const int nb = n / QK_IQ1BN;
|
||||
Q8_K64<nrc_y> q8(info);
|
||||
__m256 accd[nrc_y];
|
||||
__m256i signs[2];
|
||||
|
||||
const auto m1_8 = _mm256_set1_epi8(1);
|
||||
const auto shuff1 = _mm256_set_epi64x(0x0808080808080808, 0x0000000000000000, 0x0808080808080808, 0x0000000000000000);
|
||||
const auto shuff2 = _mm256_add_epi8(shuff1, m1_8);
|
||||
const auto shuff3 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
|
||||
const auto shuff4 = _mm256_set_epi64x(0x0707070707070707, 0x0606060606060606, 0x0505050505050505, 0x0404040404040404);
|
||||
const auto mask1 = _mm256_set1_epi64x(0x8040201008040201);
|
||||
#if !(defined __AVX512VNNI__ && defined __AVX512VL__)
|
||||
const auto m1_16 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
|
||||
//auto step = bx / sizeof(block_iq1_bn);
|
||||
const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx);
|
||||
typedef union { float f; uint32_t i; } scale_t;
|
||||
|
||||
scale_t scale;
|
||||
uint16_t u = x[0].extra & 0xff;
|
||||
scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
x = (const block_iq1_bn *)((const char *)vx + ix*bx);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
auto all_signs = _mm256_set1_epi8(x[i].extra >> 8);
|
||||
all_signs = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(all_signs, mask1), mask1), m1_8);
|
||||
signs[0] = _mm256_shuffle_epi8(all_signs, shuff3);
|
||||
signs[1] = _mm256_shuffle_epi8(all_signs, shuff4);
|
||||
|
||||
auto ql = x[i].ql;
|
||||
auto qh = x[i].qh;
|
||||
auto aux1 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[3] | ((qh[1] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[2] | ((qh[1] << 8) & 0x0f00)],
|
||||
iq1bn_grid_xxx[ql[1] | ((qh[0] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[0] | ((qh[0] << 8) & 0x0f00)]);
|
||||
auto aux2 = _mm256_set_epi64x(iq1bn_grid_xxx[ql[7] | ((qh[3] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[6] | ((qh[3] << 8) & 0x0f00)],
|
||||
iq1bn_grid_xxx[ql[5] | ((qh[2] << 4) & 0x0f00)], iq1bn_grid_xxx[ql[4] | ((qh[2] << 8) & 0x0f00)]);
|
||||
|
||||
auto v1_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff1), mask1), mask1);
|
||||
auto v1_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux1, shuff2), mask1), mask1);
|
||||
auto v2_p = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1);
|
||||
auto v2_m = _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto q8_1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), signs[0]);
|
||||
auto q8_2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), signs[1]);
|
||||
auto dot1 = _mm256_sub_epi8(_mm256_sign_epi8(q8_1, v1_m), _mm256_sign_epi8(q8_1, v1_p));
|
||||
auto dot2 = _mm256_sub_epi8(_mm256_sign_epi8(q8_2, v2_m), _mm256_sign_epi8(q8_2, v2_p));
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1);
|
||||
dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot2);
|
||||
#else
|
||||
dot1 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot1));
|
||||
dot2 = _mm256_madd_epi16(m1_16, _mm256_maddubs_epi16(m1_8, dot2));
|
||||
#endif
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_add_epi32(dot1, dot2)), accd[iy]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, scale.f * hsum_float_8(accd[iy]));
|
||||
}
|
||||
|
||||
//x += step;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dequantizer, int nrc_y>
|
||||
static void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
@@ -2504,6 +2590,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerIQ2XXS>(mm);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_BN:
|
||||
assert (ne00 % QK_IQ1BN == 0);
|
||||
mm.funcs[0] = mul_mat_iq1bn_q8_K64<1>;
|
||||
mm.funcs[1] = mul_mat_iq1bn_q8_K64<2>;
|
||||
mm.funcs[2] = mul_mat_iq1bn_q8_K64<3>;
|
||||
mm.funcs[3] = mul_mat_iq1bn_q8_K64<4>;
|
||||
mm.funcs[4] = mul_mat_iq1bn_q8_K64<5>;
|
||||
mm.funcs[5] = mul_mat_iq1bn_q8_K64<6>;
|
||||
mm.funcs[6] = mul_mat_iq1bn_q8_K64<7>;
|
||||
mm.funcs[7] = mul_mat_iq1bn_q8_K64<8>;
|
||||
expected_typeB = GGML_TYPE_Q8_K64;
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
assert (ne00 % QK4_0 == 0);
|
||||
MulMat::set_functions<Q4_0_Unpacker>(mm);
|
||||
|
||||
Reference in New Issue
Block a user