This commit is contained in:
Iwan Kawrakow
2024-12-09 16:26:24 +02:00
parent 9b475867f1
commit a515388119

View File

@@ -3085,7 +3085,7 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
auto m1 = _mm256_set1_epi16(1);
#endif
int nbl = n / QK_K;
union { __m256i vec; uint32_t val[8]; } hd, hm;
union { __m256i vec; uint32_t val[8]; } hd;
__m256 acc[nrc_y] = {};
__m256i qx[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
@@ -3093,7 +3093,7 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d));
auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl));
auto m4 = _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1));
auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)));
if constexpr (nrc_y == 1) {
d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
}
@@ -3101,26 +3101,24 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
auto hbits128 = _mm_loadu_si128((const __m128i *)iq4[ibl].scales_h);
auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m3));
hm.vec = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m3));
//if constexpr (nrc_y > 2) {
m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), m4);
auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(hm.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(hm.val[0])))));
auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(hm.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(hm.val[1])))));
auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(hm.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(hm.val[2])))));
auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(hm.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(hm.val[3])))));
for (int iy = 0; iy < nrc_y; ++iy) {
auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]);
}
//} else {
// m4 = _mm256_mul_ps(_mm256_set1_ps(-0.5f), m4);
//}
auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m3));
auto shuffle = _mm256_set1_epi64x(0x0000000400000000);
auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1));
auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle)))));
for (int iy = 0; iy < nrc_y; ++iy) {
auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]);
}
for (int ib = 0; ib < QK_K/32; ++ib) {
auto scales_d = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]))));
//auto scales_m = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib]))));
auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
qx[0] = _mm256_and_si256(bits1, mf);
@@ -3148,10 +3146,6 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D
float d8 = q8.scale(iy, ibl);
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales_d, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]);
}
//if constexpr (nrc_y <= 2) {
// float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
// acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
//}
}
}
}