WIP (Zen4)

This commit is contained in:
Iwan Kawrakow
2025-04-21 11:27:33 +03:00
parent 26eb64c4f9
commit a7cd27f7e0

View File

@@ -8580,6 +8580,39 @@ void mul_mat_qX_1_q8_2_T(int n, const void * vx, size_t bx, const DataInfo& info
}
}
template <typename Unpacker, int nrc_y, int nrc_x>
void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) {
static_assert(8%nrc_y == 0);
Q8<nrc_y, block_q8_2> q8(info);
int nb = n/Unpacker::block_size();
Unpacker unp(vx, bx);
typename Unpacker::Sum4T sum4;
ScaleHelperQ8_2 scales;
__m256 result[8];
float val[8];
if (nb%4 == 0) {
for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
for (int ix = 0; ix < 8/nrc_y; ++ix) {
unp.set_row(ix0 + ix);
AccumType1<nrc_y, true> accum;
accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
}
_mm256_storeu_ps(val, hsum_float_8x8(result));
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
}
} else {
for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
for (int ix = 0; ix < 8/nrc_y; ++ix) {
unp.set_row(ix0 + ix);
AccumType1<nrc_y, false> accum;
accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
}
_mm256_storeu_ps(val, hsum_float_8x8(result));
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
}
}
}
struct Dequantizer4bit {
const __m256i m4 = _mm256_set1_epi8(0xf);
inline __m256i dequant(const uint8_t * qs) const {
@@ -16710,6 +16743,9 @@ struct FlashQKfp32 {
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
#else
#ifdef HAVE_FANCY_SIMD
if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 1, k_step>, 1);
if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 2, k_step>, 2);
if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 4, k_step>, 4);
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q8_0_1_Unpacker, nq);
#else
if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 1, k_step>, 1);