mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 23:24:13 +00:00
WIP (Zen4)
This commit is contained in:
@@ -8580,6 +8580,39 @@ void mul_mat_qX_1_q8_2_T(int n, const void * vx, size_t bx, const DataInfo& info
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Unpacker, int nrc_y, int nrc_x>
|
||||
void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) {
|
||||
static_assert(8%nrc_y == 0);
|
||||
Q8<nrc_y, block_q8_2> q8(info);
|
||||
int nb = n/Unpacker::block_size();
|
||||
Unpacker unp(vx, bx);
|
||||
typename Unpacker::Sum4T sum4;
|
||||
ScaleHelperQ8_2 scales;
|
||||
__m256 result[8];
|
||||
float val[8];
|
||||
if (nb%4 == 0) {
|
||||
for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
|
||||
for (int ix = 0; ix < 8/nrc_y; ++ix) {
|
||||
unp.set_row(ix0 + ix);
|
||||
AccumType1<nrc_y, true> accum;
|
||||
accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
|
||||
}
|
||||
_mm256_storeu_ps(val, hsum_float_8x8(result));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
|
||||
}
|
||||
} else {
|
||||
for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) {
|
||||
for (int ix = 0; ix < 8/nrc_y; ++ix) {
|
||||
unp.set_row(ix0 + ix);
|
||||
AccumType1<nrc_y, false> accum;
|
||||
accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix);
|
||||
}
|
||||
_mm256_storeu_ps(val, hsum_float_8x8(result));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Dequantizer4bit {
|
||||
const __m256i m4 = _mm256_set1_epi8(0xf);
|
||||
inline __m256i dequant(const uint8_t * qs) const {
|
||||
@@ -16710,6 +16743,9 @@ struct FlashQKfp32 {
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
|
||||
#else
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 1, k_step>, 1);
|
||||
if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 2, k_step>, 2);
|
||||
if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx<Q8_0_1_Unpacker, 4, k_step>, 4);
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q8_0_1_Unpacker, nq);
|
||||
#else
|
||||
if (nq == 1) return std::make_pair(mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 1, k_step>, 1);
|
||||
|
||||
Reference in New Issue
Block a user