mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-05 03:20:00 +00:00
bitnet(scale in a separate tensor): CPU tweaks
A somewhat nicer iq2_bn implementation on AVX2.
This commit is contained in:
@@ -1462,43 +1462,59 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
}
|
||||
}
|
||||
|
||||
struct DequantizeIQ2BN final : public BaseDequantizer<block_iq2_bn> {
|
||||
DequantizeIQ2BN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
|
||||
inline void prepare4(int i, __m256i * val) const {
|
||||
auto q2bits_1 = _mm256_loadu_si256((const __m256i *)x[2*i].qs);
|
||||
auto q2bits_2 = _mm256_srli_epi16(q2bits_1, 2);
|
||||
make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x20), val+0);
|
||||
make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x31), val+2);
|
||||
}
|
||||
inline void make2(__m256i q2_1, __m256i * val) const {
|
||||
val[0] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
|
||||
val[1] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask3), mf_8);
|
||||
}
|
||||
inline void prepare2(int i, __m256i * val) const {
|
||||
auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
|
||||
make2(MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1), val);
|
||||
}
|
||||
const __m256i m1_8 = _mm256_set1_epi8(1);
|
||||
const __m256i mf_8 = _mm256_set1_epi8(16);
|
||||
const __m256i mask2 = _mm256_set1_epi8(0x03);
|
||||
const __m256i mask3 = _mm256_set1_epi8(0x30);
|
||||
};
|
||||
|
||||
template <int nrc_y>
|
||||
IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
const int nb = n / QK_IQ1BN;
|
||||
Q8_K64<nrc_y> q8(info);
|
||||
DequantizeIQ2BN deq(vx, bx);
|
||||
__m256i accd[nrc_y];
|
||||
__m256i val[4];
|
||||
|
||||
const auto m1_8 = _mm256_set1_epi8(1);
|
||||
const auto mask2 = _mm256_set1_epi8(3);
|
||||
#if !(defined __AVX512VNNI__ && defined __AVX512VL__)
|
||||
const auto m1_16 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx);
|
||||
deq.new_row(ix);
|
||||
|
||||
if constexpr (nrc_y == 1) {
|
||||
__m256i acc[2] = {};
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[2*i+0].qs);
|
||||
auto q2bits_2 = _mm_loadu_si128((const __m128i *)x[2*i+1].qs);
|
||||
auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1);
|
||||
auto q2_2 = MM256_SET_M128I(_mm_srli_epi16(q2bits_2, 2), q2bits_2);
|
||||
auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
|
||||
auto v3 = _mm256_sub_epi8(_mm256_and_si256(q2_2, mask2), m1_8);
|
||||
auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8);
|
||||
auto v4 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_2, 4), mask2), m1_8);
|
||||
deq.prepare4(i, val);
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), v1)),
|
||||
m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), v2));
|
||||
acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), v3)),
|
||||
m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), v4));
|
||||
acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])),
|
||||
deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]));
|
||||
acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])),
|
||||
deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]));
|
||||
#else
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), v1)),
|
||||
_mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), v2)));
|
||||
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), v3)),
|
||||
_mm256_maddubs_epi16(m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), v4)));
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])),
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1])));
|
||||
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])),
|
||||
_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3])));
|
||||
acc[0] = _mm256_add_epi32(acc[0], _mm256_madd_epi16(m1_16, dot1));
|
||||
acc[1] = _mm256_add_epi32(acc[1], _mm256_madd_epi16(m1_16, dot2));
|
||||
#endif
|
||||
@@ -1510,26 +1526,19 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256();
|
||||
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[2*i+0].qs);
|
||||
auto q2bits_2 = _mm_loadu_si128((const __m128i *)x[2*i+1].qs);
|
||||
auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1);
|
||||
auto q2_2 = MM256_SET_M128I(_mm_srli_epi16(q2bits_2, 2), q2bits_2);
|
||||
auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
|
||||
auto v3 = _mm256_sub_epi8(_mm256_and_si256(q2_2, mask2), m1_8);
|
||||
auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8);
|
||||
auto v4 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_2, 4), mask2), m1_8);
|
||||
deq.prepare4(i, val);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), v1);
|
||||
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), v2);
|
||||
auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), v3);
|
||||
auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), v4);
|
||||
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]);
|
||||
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]);
|
||||
auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]);
|
||||
auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]);
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
|
||||
accd[iy], m1_8, dot1), m1_8, dot2), m1_8, dot3), m1_8, dot4);
|
||||
accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4);
|
||||
#else
|
||||
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
|
||||
_mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)),
|
||||
_mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot3), _mm256_maddubs_epi16(m1_8, dot4))));
|
||||
_mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)),
|
||||
_mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot3), _mm256_maddubs_epi16(deq.m1_8, dot4))));
|
||||
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
|
||||
#endif
|
||||
}
|
||||
@@ -1537,17 +1546,14 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const
|
||||
}
|
||||
int i = 2*(nb/2);
|
||||
if (i < nb) {
|
||||
auto q2bits = _mm_loadu_si128((const __m128i *)x[i].qs);
|
||||
auto q2_1 = MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits);
|
||||
auto v1 = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
|
||||
auto v2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2), m1_8);
|
||||
deq.prepare2(i, val);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), v1);
|
||||
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), v2);
|
||||
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]);
|
||||
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]);
|
||||
#if defined __AVX512VNNI__ && defined __AVX512VL__
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], m1_8, dot1), m1_8, dot2);
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], deq.m1_8, dot1), deq.m1_8, dot2);
|
||||
#else
|
||||
dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)));
|
||||
dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)));
|
||||
accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
|
||||
#endif
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user