mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 15:44:10 +00:00
iq4_xxs_r4: WIP
This commit is contained in:
@@ -3218,6 +3218,11 @@ static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const Dat
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
int nbl = n / QK_K;
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
auto smask = _mm256_set1_epi64x(0x8040201008040201);
|
||||
auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
|
||||
auto m4 = _mm256_set1_epi8(4);
|
||||
#endif
|
||||
__m256 acc[nrc_y] = {};
|
||||
__m256i isum[nrc_y] = {};
|
||||
__m256i qx[4];
|
||||
@@ -3236,14 +3241,21 @@ static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const Dat
|
||||
qx[3] = _mm256_set_epi32(iq3xxs_grid[iq3[ibl].qs[32*ib+31]], iq3xxs_grid[iq3[ibl].qs[32*ib+30]], iq3xxs_grid[iq3[ibl].qs[32*ib+29]], iq3xxs_grid[iq3[ibl].qs[32*ib+28]],
|
||||
iq3xxs_grid[iq3[ibl].qs[32*ib+27]], iq3xxs_grid[iq3[ibl].qs[32*ib+26]], iq3xxs_grid[iq3[ibl].qs[32*ib+25]], iq3xxs_grid[iq3[ibl].qs[32*ib+24]]);
|
||||
auto sas = _mm_loadu_si128((const __m128i *)iq3[ibl].sas + ib);
|
||||
auto signs = _mm_and_si128(sas, _mm_set1_epi8(-2)); // 0xfe = -2 as signed. Needed to shutup compiler warning.
|
||||
signs = _mm_xor_si128(signs, _mm_srli_epi16(signs, 1));
|
||||
auto mask = (const __mmask32 *)&signs;
|
||||
auto scales = _mm_and_si128(sas, _mm_set1_epi8(1));
|
||||
auto t1 = _mm_or_si128(_mm_and_si128(scales, _mm_set1_epi32(0x00000001)), _mm_srli_epi32(_mm_and_si128(scales, _mm_set1_epi32(0x00000100)), 7));
|
||||
auto t2 = _mm_or_si128(_mm_srli_epi32(_mm_and_si128(scales, _mm_set1_epi32(0x00010000)), 14), _mm_srli_epi32(_mm_and_si128(scales, _mm_set1_epi32(0x01000000)), 21));
|
||||
scales = _mm_or_si128(_mm_slli_epi32(_mm_or_si128(t1, t2), 1), _mm_set1_epi32(1));
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
scales = _mm_dpbusd_epi32(_mm_set1_epi32(1), scales, _mm_set1_epi32(0x10080402));
|
||||
#else
|
||||
scales = _mm_maddubs_epi16(scales, _mm_set1_epi32(0x10080402));
|
||||
scales = _mm_add_epi32(_mm_madd_epi16(_mm_set1_epi16(1), scales), _mm_set1_epi32(1));
|
||||
//auto t1 = _mm_or_si128(_mm_and_si128(scales, _mm_set1_epi32(0x00000001)), _mm_srli_epi32(_mm_and_si128(scales, _mm_set1_epi32(0x00000100)), 7));
|
||||
//auto t2 = _mm_or_si128(_mm_srli_epi32(_mm_and_si128(scales, _mm_set1_epi32(0x00010000)), 14), _mm_srli_epi32(_mm_and_si128(scales, _mm_set1_epi32(0x01000000)), 21));
|
||||
//scales = _mm_or_si128(_mm_slli_epi32(_mm_or_si128(t1, t2), 1), _mm_set1_epi32(1));
|
||||
#endif
|
||||
auto scales32 = MM256_SET_M128I(scales, scales);
|
||||
auto signs128 = _mm_and_si128(sas, _mm_set1_epi8(-2)); // 0xfe = -2 as signed. Needed to shutup compiler warning.
|
||||
signs128 = _mm_xor_si128(signs128, _mm_srli_epi16(signs128, 1));
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto mask = (const __mmask32 *)&signs128;
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
|
||||
auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_mask_sub_epi8(y, mask[0], _mm256_setzero_si256(), y));
|
||||
@@ -3255,6 +3267,28 @@ static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const Dat
|
||||
auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales32, sumi));
|
||||
}
|
||||
#else
|
||||
auto signs = MM256_SET_M128I(signs128, signs128);
|
||||
auto shuffle = sign_shuffle;
|
||||
auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
|
||||
shuffle = _mm256_add_epi8(shuffle, m4);
|
||||
auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
|
||||
shuffle = _mm256_add_epi8(shuffle, m4);
|
||||
auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
|
||||
shuffle = _mm256_add_epi8(shuffle, m4);
|
||||
auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib);
|
||||
auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_sign_epi8(y, s1));
|
||||
auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], _mm256_sign_epi8(y, s2));
|
||||
auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], _mm256_sign_epi8(y, s3));
|
||||
auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], _mm256_sign_epi8(y, s4));
|
||||
auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1
|
||||
auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3
|
||||
auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3
|
||||
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales32, sumi));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
|
||||
|
||||
Reference in New Issue
Block a user