mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 15:44:10 +00:00
iq5_k_r5: WIP
This commit is contained in:
@@ -3960,7 +3960,8 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
|
||||
#endif
|
||||
|
||||
template <int nrc_y>
|
||||
IQK_ALWAYS_INLINE void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
|
||||
//IQK_ALWAYS_INLINE void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
|
||||
inline void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
|
||||
__m256i * isum, int16_t min) {
|
||||
auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
|
||||
auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
|
||||
@@ -4008,6 +4009,46 @@ IQK_ALWAYS_INLINE void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
inline void iq2345_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
|
||||
__m256i extra, __m256i * isum, int8_t min, int8_t delta) {
|
||||
auto mask = _mm256_set_epi64x(0x0808080808080808, 0x0404040404040404, 0x0202020202020202, 0x0101010101010101);
|
||||
auto vdelta = _mm256_set1_epi8(delta);
|
||||
auto vmin = _mm256_set1_epi8(min);
|
||||
auto min1 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(extra, mask), mask)));
|
||||
auto min2 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_srli_epi16(extra, 4), mask), mask)));
|
||||
auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
|
||||
auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
|
||||
auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
|
||||
auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
|
||||
auto m1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
|
||||
auto m2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
|
||||
auto m3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
|
||||
auto m4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
|
||||
auto s1 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 0), _mm256_extracti128_si256(m1, 0)),
|
||||
MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0))); // blocks 0, 1, 8, 9
|
||||
auto s2 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 1), _mm256_extracti128_si256(m1, 1)),
|
||||
MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1))); // blocks 2, 3, 10, 11
|
||||
auto s3 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 0), _mm256_extracti128_si256(m2, 0)),
|
||||
MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0))); // blocks 4, 5, 12, 13
|
||||
auto s4 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 1), _mm256_extracti128_si256(m2, 1)),
|
||||
MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1))); // blocks 6, 7, 14, 15
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto bsums = q8.load_bsums(iy, ibl);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s1, _mm256_shuffle_epi32(bsums, 0x00));
|
||||
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s2, _mm256_shuffle_epi32(bsums, 0x55));
|
||||
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s3, _mm256_shuffle_epi32(bsums, 0xaa));
|
||||
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s4, _mm256_shuffle_epi32(bsums, 0xff));
|
||||
#else
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
@@ -4285,8 +4326,8 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI
|
||||
values[0] = MM256_SET_M128I(val1, val1);
|
||||
values[1] = MM256_SET_M128I(val2, val2);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
values[0] = _mm256_add_epi8(values[0], _mm256_set1_epi8(127));
|
||||
values[1] = _mm256_add_epi8(values[1], _mm256_set1_epi8(127));
|
||||
values[0] = _mm256_sub_epi8(values[0], _mm256_set1_epi8(-128));
|
||||
values[1] = _mm256_sub_epi8(values[1], _mm256_set1_epi8(-128));
|
||||
#endif
|
||||
}
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
@@ -4295,6 +4336,7 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI
|
||||
#else
|
||||
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
|
||||
#endif
|
||||
auto m128 = _mm256_set1_epi8(-128);
|
||||
int nbl = n / QK_K;
|
||||
__m256 acc[nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
@@ -4316,7 +4358,11 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI
|
||||
_mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
|
||||
__m256i isum[nrc_y] = {};
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -127);
|
||||
if constexpr (nrc_y == 1) {
|
||||
iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -128);
|
||||
} else {
|
||||
iq2345_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, extra, isum, -128, 2);
|
||||
}
|
||||
#endif
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
@@ -4328,8 +4374,6 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI
|
||||
auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
|
||||
auto hbits = _mm_loadu_si128((const __m128i *)iq5[ibl].qh+ib);
|
||||
auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 2), hbits);
|
||||
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1);
|
||||
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
|
||||
qx[0] = _mm256_and_si256(lbits1, m4);
|
||||
qx[1] = _mm256_and_si256(lbits2, m4);
|
||||
qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4);
|
||||
@@ -4352,31 +4396,41 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI
|
||||
// qx[2] = _mm256_add_epi8(qx[2], shift);
|
||||
// qx[3] = _mm256_add_epi8(qx[3], shift);
|
||||
//#else
|
||||
auto qh = _mm256_and_si256(_mm256_slli_epi16(hb, 7), _mm256_set1_epi8(-128));
|
||||
auto qh = _mm256_and_si256(_mm256_slli_epi16(hb, 7), m128);
|
||||
auto q5vl = _mm256_or_si256(qx[0], qh);
|
||||
auto q5vh = _mm256_or_si256(qx[0], _mm256_xor_si256(qh, _mm256_set1_epi8(-128)));
|
||||
auto q5vh = _mm256_or_si256(qx[0], _mm256_xor_si256(qh, m128));
|
||||
qx[0] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
|
||||
qx[0] = _mm256_add_epi8(qx[0], shift);
|
||||
|
||||
qh = _mm256_and_si256(_mm256_slli_epi16(hb, 3), _mm256_set1_epi8(-128));
|
||||
qh = _mm256_and_si256(_mm256_slli_epi16(hb, 3), m128);
|
||||
q5vl = _mm256_or_si256(qx[1], qh);
|
||||
q5vh = _mm256_or_si256(qx[1], _mm256_xor_si256(qh, _mm256_set1_epi8(-128)));
|
||||
q5vh = _mm256_or_si256(qx[1], _mm256_xor_si256(qh, m128));
|
||||
qx[1] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
|
||||
qx[1] = _mm256_add_epi8(qx[1], shift);
|
||||
|
||||
qh = _mm256_and_si256(_mm256_slli_epi16(hb, 6), _mm256_set1_epi8(-128));
|
||||
qh = _mm256_and_si256(_mm256_slli_epi16(hb, 6), m128);
|
||||
q5vl = _mm256_or_si256(qx[2], qh);
|
||||
q5vh = _mm256_or_si256(qx[2], _mm256_xor_si256(qh, _mm256_set1_epi8(-128)));
|
||||
q5vh = _mm256_or_si256(qx[2], _mm256_xor_si256(qh, m128));
|
||||
qx[2] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
|
||||
qx[2] = _mm256_add_epi8(qx[2], shift);
|
||||
|
||||
qh = _mm256_and_si256(_mm256_slli_epi16(hb, 2), _mm256_set1_epi8(-128));
|
||||
qh = _mm256_and_si256(_mm256_slli_epi16(hb, 2), m128);
|
||||
q5vl = _mm256_or_si256(qx[3], qh);
|
||||
q5vh = _mm256_or_si256(qx[3], _mm256_xor_si256(qh, _mm256_set1_epi8(-128)));
|
||||
q5vh = _mm256_or_si256(qx[3], _mm256_xor_si256(qh, m128));
|
||||
qx[3] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
|
||||
qx[3] = _mm256_add_epi8(qx[3], shift);
|
||||
if constexpr (nrc_y == 1) {
|
||||
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1);
|
||||
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
|
||||
qx[0] = _mm256_add_epi8(qx[0], shift);
|
||||
qx[1] = _mm256_add_epi8(qx[1], shift);
|
||||
qx[2] = _mm256_add_epi8(qx[2], shift);
|
||||
qx[3] = _mm256_add_epi8(qx[3], shift);
|
||||
}
|
||||
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1);
|
||||
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
|
||||
qx[0] = _mm256_add_epi8(qx[0], shift);
|
||||
qx[1] = _mm256_add_epi8(qx[1], shift);
|
||||
qx[2] = _mm256_add_epi8(qx[2], shift);
|
||||
qx[3] = _mm256_add_epi8(qx[3], shift);
|
||||
auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
|
||||
auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
|
||||
auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
|
||||
|
||||
Reference in New Issue
Block a user