iq5_k_r5: WIP

This commit is contained in:
Iwan Kawrakow
2024-12-18 10:26:47 +02:00
parent 5eac4edc90
commit 27e987d959

View File

@@ -3960,7 +3960,8 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI
#endif
template <int nrc_y>
IQK_ALWAYS_INLINE void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
//IQK_ALWAYS_INLINE void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
inline void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
__m256i * isum, int16_t min) {
auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
@@ -4008,6 +4009,46 @@ IQK_ALWAYS_INLINE void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8
}
}
template <int nrc_y>
inline void iq2345_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8<nrc_y, block_q8_K>& q8, __m256i shuff,
__m256i extra, __m256i * isum, int8_t min, int8_t delta) {
auto mask = _mm256_set_epi64x(0x0808080808080808, 0x0404040404040404, 0x0202020202020202, 0x0101010101010101);
auto vdelta = _mm256_set1_epi8(delta);
auto vmin = _mm256_set1_epi8(min);
auto min1 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(extra, mask), mask)));
auto min2 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_srli_epi16(extra, 4), mask), mask)));
auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
auto m1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
auto m2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
auto m3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
auto m4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
auto s1 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 0), _mm256_extracti128_si256(m1, 0)),
MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0))); // blocks 0, 1, 8, 9
auto s2 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 1), _mm256_extracti128_si256(m1, 1)),
MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1))); // blocks 2, 3, 10, 11
auto s3 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 0), _mm256_extracti128_si256(m2, 0)),
MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0))); // blocks 4, 5, 12, 13
auto s4 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 1), _mm256_extracti128_si256(m2, 1)),
MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1))); // blocks 6, 7, 14, 15
for (int iy = 0; iy < nrc_y; ++iy) {
auto bsums = q8.load_bsums(iy, ibl);
#ifdef HAVE_FANCY_SIMD
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s1, _mm256_shuffle_epi32(bsums, 0x00));
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s2, _mm256_shuffle_epi32(bsums, 0x55));
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s3, _mm256_shuffle_epi32(bsums, 0xaa));
isum[iy] = _mm256_dpwssd_epi32(isum[iy], s4, _mm256_shuffle_epi32(bsums, 0xff));
#else
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
#endif
}
}
template <int nrc_y>
static void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
@@ -4285,8 +4326,8 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI
values[0] = MM256_SET_M128I(val1, val1);
values[1] = MM256_SET_M128I(val2, val2);
#ifdef HAVE_FANCY_SIMD
values[0] = _mm256_add_epi8(values[0], _mm256_set1_epi8(127));
values[1] = _mm256_add_epi8(values[1], _mm256_set1_epi8(127));
values[0] = _mm256_sub_epi8(values[0], _mm256_set1_epi8(-128));
values[1] = _mm256_sub_epi8(values[1], _mm256_set1_epi8(-128));
#endif
}
#ifdef HAVE_FANCY_SIMD
@@ -4295,6 +4336,7 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI
#else
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
#endif
auto m128 = _mm256_set1_epi8(-128);
int nbl = n / QK_K;
__m256 acc[nrc_y] = {};
__m256i qx[4];
@@ -4316,7 +4358,11 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI
_mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2);
__m256i isum[nrc_y] = {};
#ifdef HAVE_FANCY_SIMD
iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -127);
if constexpr (nrc_y == 1) {
iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -128);
} else {
iq2345_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, extra, isum, -128, 2);
}
#endif
for (int ib = 0; ib < QK_K/32; ++ib) {
#ifdef HAVE_FANCY_SIMD
@@ -4328,8 +4374,6 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI
auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
auto hbits = _mm_loadu_si128((const __m128i *)iq5[ibl].qh+ib);
auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 2), hbits);
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1);
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
qx[0] = _mm256_and_si256(lbits1, m4);
qx[1] = _mm256_and_si256(lbits2, m4);
qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4);
@@ -4352,31 +4396,41 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI
// qx[2] = _mm256_add_epi8(qx[2], shift);
// qx[3] = _mm256_add_epi8(qx[3], shift);
//#else
auto qh = _mm256_and_si256(_mm256_slli_epi16(hb, 7), _mm256_set1_epi8(-128));
auto qh = _mm256_and_si256(_mm256_slli_epi16(hb, 7), m128);
auto q5vl = _mm256_or_si256(qx[0], qh);
auto q5vh = _mm256_or_si256(qx[0], _mm256_xor_si256(qh, _mm256_set1_epi8(-128)));
auto q5vh = _mm256_or_si256(qx[0], _mm256_xor_si256(qh, m128));
qx[0] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
qx[0] = _mm256_add_epi8(qx[0], shift);
qh = _mm256_and_si256(_mm256_slli_epi16(hb, 3), _mm256_set1_epi8(-128));
qh = _mm256_and_si256(_mm256_slli_epi16(hb, 3), m128);
q5vl = _mm256_or_si256(qx[1], qh);
q5vh = _mm256_or_si256(qx[1], _mm256_xor_si256(qh, _mm256_set1_epi8(-128)));
q5vh = _mm256_or_si256(qx[1], _mm256_xor_si256(qh, m128));
qx[1] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
qx[1] = _mm256_add_epi8(qx[1], shift);
qh = _mm256_and_si256(_mm256_slli_epi16(hb, 6), _mm256_set1_epi8(-128));
qh = _mm256_and_si256(_mm256_slli_epi16(hb, 6), m128);
q5vl = _mm256_or_si256(qx[2], qh);
q5vh = _mm256_or_si256(qx[2], _mm256_xor_si256(qh, _mm256_set1_epi8(-128)));
q5vh = _mm256_or_si256(qx[2], _mm256_xor_si256(qh, m128));
qx[2] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
qx[2] = _mm256_add_epi8(qx[2], shift);
qh = _mm256_and_si256(_mm256_slli_epi16(hb, 2), _mm256_set1_epi8(-128));
qh = _mm256_and_si256(_mm256_slli_epi16(hb, 2), m128);
q5vl = _mm256_or_si256(qx[3], qh);
q5vh = _mm256_or_si256(qx[3], _mm256_xor_si256(qh, _mm256_set1_epi8(-128)));
q5vh = _mm256_or_si256(qx[3], _mm256_xor_si256(qh, m128));
qx[3] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
qx[3] = _mm256_add_epi8(qx[3], shift);
if constexpr (nrc_y == 1) {
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1);
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
qx[0] = _mm256_add_epi8(qx[0], shift);
qx[1] = _mm256_add_epi8(qx[1], shift);
qx[2] = _mm256_add_epi8(qx[2], shift);
qx[3] = _mm256_add_epi8(qx[3], shift);
}
#ifndef HAVE_FANCY_SIMD
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1);
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
qx[0] = _mm256_add_epi8(qx[0], shift);
qx[1] = _mm256_add_epi8(qx[1], shift);
qx[2] = _mm256_add_epi8(qx[2], shift);
qx[3] = _mm256_add_epi8(qx[3], shift);
auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
auto s3 = _mm256_sign_epi8(qx[2], qx[2]);