This commit is contained in:
Iwan Kawrakow
2025-04-18 17:00:06 +03:00
parent b498633203
commit fae18dd0bc

View File

@@ -8492,6 +8492,184 @@ void mul_mat_qX_1_q8_2_T(int n, const void * vx, size_t bx, const DataInfo& info
}
}
void mul_mat_q8_0_q8_x4_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
Q8<1, block_q8_0> q8(info);
int nblock = n/QK8_0;
auto cx = (const char *)vx;
auto y4 = (const block_q8_0_x4 *)q8.y[0];
ggml_half d4[4];
__m256 dot[4];
for (int ix = 0; ix < nrc_x; ++ix) {
auto q8x = (const block_q8_0 *)cx;
auto acc = _mm256_setzero_ps();
for (int ib4 = 0; ib4 < nblock/4; ++ib4) {
auto d4y_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4[ib4].d));
auto d4y = _mm256_set_m128(d4y_128, d4y_128);
for (int k = 0; k < 4; ++k) {
d4[k] = q8x[k].d;
auto qx = _mm256_loadu_si256((const __m256i *)q8x[k].qs);
auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)y4[ib4].qs + k), qx));
dot[k] = _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_set1_epi16(1), p));
}
auto d4x_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)d4));
auto d4x = _mm256_set_m128(d4x_128, d4x_128);
auto d4xy = _mm256_mul_ps(d4x, d4y);
acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xy, d4xy, 0x00), dot[0], acc);
acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xy, d4xy, 0x55), dot[1], acc);
acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xy, d4xy, 0xaa), dot[2], acc);
acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xy, d4xy, 0xff), dot[3], acc);
q8x += 4;
}
if (int ib0 = 4*(nblock/4); ib0 < nblock) {
auto y = q8.y[0] + ib0;
for (int k = 0; k < nblock - ib0; ++k) {
auto qx = _mm256_loadu_si256((const __m256i *)q8x[k].qs);
auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)y[k].qs), qx));
p = _mm256_madd_epi16(_mm256_set1_epi16(1), p);
auto dxy = GGML_FP16_TO_FP32(q8x[k].d)*GGML_FP16_TO_FP32(y[k].d);
acc = _mm256_fmadd_ps(_mm256_set1_ps(dxy), _mm256_cvtepi32_ps(p), acc);
}
}
info.store(ix, 0, hsum_float_8(acc));
cx += bx;
}
}
void mul_mat_q8_0_q8_x4_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
Q8<2, block_q8_0> q8(info);
int nblock = n/QK8_0;
auto cx = (const char *)vx;
auto y4l = (const block_q8_0_x4 *)q8.y[0];
auto y4h = (const block_q8_0_x4 *)q8.y[1];
ggml_half d4[4];
__m256i dot[4];
__m256 acc[4] = {};
for (int ix = 0; ix < nrc_x; ++ix) {
auto q8x = (const block_q8_0 *)cx;
for (int ib4 = 0; ib4 < nblock/4; ++ib4) {
auto d4y_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4l[ib4].d));
auto d4yl = _mm256_set_m128(d4y_128, d4y_128);
d4y_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4h[ib4].d));
auto d4yh = _mm256_set_m128(d4y_128, d4y_128);
for (int k = 0; k < 4; ++k) {
d4[k] = q8x[k].d;
auto qx = _mm256_loadu_si256((const __m256i *)q8x[k].qs);
auto ux = _mm256_sign_epi8(qx, qx);
auto pl = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)y4l[ib4].qs + k), qx));
auto ph = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)y4h[ib4].qs + k), qx));
pl = _mm256_madd_epi16(_mm256_set1_epi16(1), pl);
ph = _mm256_madd_epi16(_mm256_set1_epi16(1), ph);
dot[k] = _mm256_add_epi32(_mm256_unpacklo_epi64(pl, ph), _mm256_unpackhi_epi64(pl, ph));
}
auto d4x_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)d4));
auto d4x = _mm256_set_m128(d4x_128, d4x_128);
auto d4xyl = _mm256_mul_ps(d4x, d4yl);
auto d4xyh = _mm256_mul_ps(d4x, d4yh);
acc[0] = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0x00), _mm256_cvtepi32_ps(dot[0]), acc[0]);
acc[1] = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0x55), _mm256_cvtepi32_ps(dot[1]), acc[1]);
acc[2] = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0xaa), _mm256_cvtepi32_ps(dot[2]), acc[2]);
acc[3] = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0xff), _mm256_cvtepi32_ps(dot[3]), acc[3]);
q8x += 4;
}
if (int ib0 = 4*(nblock/4); ib0 < nblock) {
auto yl = q8.y[0] + ib0;
auto yh = q8.y[1] + ib0;
for (int k = 0; k < nblock - ib0; ++k) {
auto qx = _mm256_loadu_si256((const __m256i *)q8x[k].qs);
auto ux = _mm256_sign_epi8(qx, qx);
auto pl = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)yl[k].qs), qx));
auto ph = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)yh[k].qs), qx));
pl = _mm256_madd_epi16(_mm256_set1_epi16(1), pl);
ph = _mm256_madd_epi16(_mm256_set1_epi16(1), ph);
auto p = _mm256_add_epi32(_mm256_unpacklo_epi64(pl, ph), _mm256_unpackhi_epi64(pl, ph));
auto d = GGML_FP16_TO_FP32(q8x[k].d);
auto dxyl = _mm256_set1_ps(d*GGML_FP16_TO_FP32(yl[k].d));
auto dxyh = _mm256_set1_ps(d*GGML_FP16_TO_FP32(yh[k].d));
acc[k] = _mm256_fmadd_ps(_mm256_shuffle_ps(dxyl, dxyh, 0x00), _mm256_cvtepi32_ps(p), acc[k]);
}
}
acc[0] = _mm256_add_ps(acc[0], acc[1]);
acc[2] = _mm256_add_ps(acc[2], acc[3]);
acc[0] = _mm256_add_ps(acc[0], acc[2]);
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[0]), _mm256_extractf128_ps(acc[0], 1));
float daux[4]; _mm_storeu_ps(daux, sum);
info.store(ix, 0, daux[0] + daux[1]);
info.store(ix, 1, daux[2] + daux[3]);
acc[0] = acc[1] = acc[2] = acc[3] = _mm256_setzero_ps();
cx += bx;
}
}
//void mul_mat_q8_0_q8_x4_2x2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
// Q8<2, block_q8_0> q8(info);
// int nblock = n/QK8_0;
// auto cx = (const char *)vx;
// auto y4l = (const block_q8_0_x4 *)q8.y[0];
// auto y4h = (const block_q8_0_x4 *)q8.y[1];
// ggml_half d4[4];
// __m256i dot[4];
// for (int ix = 0; ix < nrc_x; ix += 2) {
// auto q8xl = (const block_q8_0 *)cx; cx += bx;
// auto q8xh = (const block_q8_0 *)cx; cx += bx;
// auto accl = _mm256_setzero_ps();
// auto acch = _mm256_setzero_ps();
// for (int ib4 = 0; ib4 < nblock/4; ++ib4) {
// auto d4y_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4l[ib4].d));
// auto d4yl = _mm256_set_m128(d4y_128, d4y_128);
// d4y_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4h[ib4].d));
// auto d4yh = _mm256_set_m128(d4y_128, d4y_128);
// for (int k = 0; k < 4; ++k) {
// d4[2*k+0] = q8xl[k].d;
// d4[2*k+1] = q8xh[k].d;
// auto qxl = _mm256_loadu_si256((const __m256i *)q8xl[k].qs);
// auto qxh = _mm256_loadu_si256((const __m256i *)q8xh[k].qs);
// auto uxl = _mm256_sign_epi8(qxl, qxl);
// auto uxh = _mm256_sign_epi8(qxh, qxh);
// auto qyl = _mm256_loadu_si256((const __m256i*)y4l[ib4].qs + k);
// auto qyh = _mm256_loadu_si256((const __m256i*)y4h[ib4].qs + k);
// auto pll = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(uxl, _mm256_sign_epi8(qyl, qxl)));
// auto plh = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(uxl, _mm256_sign_epi8(qyh, qxl)));
// auto phl = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(uxh, _mm256_sign_epi8(qyl, qxh)));
// auto phh = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(uxh, _mm256_sign_epi8(qyh, qxh)));
// auto pyl = _mm256_add_epi32(_mm256_unpacklo_epi32(pll, phl), _mm256_unpackhi_epi32(pll, phl));
// auto pyh = _mm256_add_epi32(_mm256_unpacklo_epi32(plh, phh), _mm256_unpackhi_epi32(plh, phh));
// // ll, hl, lh, hh, ll, hl, lh, hh
// dot[k] = _mm256_add_epi32(_mm256_unpacklo_epi64(pyl, pyh), _mm256_unpackhi_epi64(pyl, pyh));
// }
// auto d4x_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)d4));
// auto d4x = _mm256_set_m128(d4x_128, d4x_128);
// auto d4xyl = _mm256_mul_ps(d4x, d4yl);
// auto d4xyh = _mm256_mul_ps(d4x, d4yh);
// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0x00), _mm256_cvtepi32_ps(dot[0]), acc);
// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0x55), _mm256_cvtepi32_ps(dot[1]), acc);
// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0xaa), _mm256_cvtepi32_ps(dot[2]), acc);
// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0xff), _mm256_cvtepi32_ps(dot[3]), acc);
// q8x += 4;
// }
// if (int ib0 = 4*(nblock/4); ib0 < nblock) {
// auto yl = q8.y[0] + ib0;
// auto yh = q8.y[1] + ib0;
// for (int k = 0; k < nblock - ib0; ++k) {
// auto qx = _mm256_loadu_si256((const __m256i *)q8x[k].qs);
// auto ux = _mm256_sign_epi8(qx, qx);
// auto pl = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)yl[k].qs), qx));
// auto ph = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)yh[k].qs), qx));
// pl = _mm256_madd_epi16(_mm256_set1_epi16(1), pl);
// ph = _mm256_madd_epi16(_mm256_set1_epi16(1), ph);
// auto p = _mm256_add_epi32(_mm256_unpacklo_epi64(pl, ph), _mm256_unpackhi_epi64(pl, ph));
// auto d = GGML_FP16_TO_FP32(q8x[k].d);
// auto dxyl = _mm256_set1_ps(d*GGML_FP16_TO_FP32(yl[k].d));
// auto dxyh = _mm256_set1_ps(d*GGML_FP16_TO_FP32(yh[k].d));
// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(dxyl, dxyh, 0x00), _mm256_cvtepi32_ps(p), acc);
// }
// }
// auto sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
// float daux[4]; _mm_storeu_ps(daux, sum);
// info.store(ix, 0, daux[0] + daux[1]);
// info.store(ix, 1, daux[2] + daux[3]);
// }
//}
struct Dequantizer4bit {
const __m256i m4 = _mm256_set1_epi8(0xf);
inline __m256i dequant(const uint8_t * qs) const {
@@ -9156,6 +9334,16 @@ void mul_mat_q80_q80_T(int n, const void * vx, size_t bx, const DataInfo& info,
}
template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
//if (std::is_same_v<Dequantizer, Q8_0_Unpacker>) {
// m.funcs[0] = mul_mat_q8_0_q8_x4_1;
// m.funcs[1] = mul_mat_q8_0_q8_x4_2;
// m.funcs[2] = mul_mat_qX_0_q8_0_T<Dequantizer, 3>;
// m.funcs[3] = mul_mat_qX_0_q8_0_T<Dequantizer, 4>;
// m.funcs[4] = mul_mat_qX_0_q8_0_T<Dequantizer, 5>;
// m.funcs[5] = mul_mat_qX_0_q8_0_T<Dequantizer, 6>;
// m.funcs[6] = mul_mat_qX_0_q8_0_T<Dequantizer, 7>;
// m.funcs[7] = mul_mat_qX_0_q8_0_T<Dequantizer, 8>;
//}
if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> ||
std::is_same_v<Dequantizer, Q8_0_Unpacker>) {
m.funcs[0] = mul_mat_qX_0_q8_0_T<Dequantizer, 1>;
@@ -16080,12 +16268,13 @@ struct FlashMS {
}
return F16::reduce_max<k_step>(vk);
}
static inline __m256 apply_mask(int l, const char * mask, __m256 val, __m256 vinf) {
auto m128 = _mm_loadu_si128((const __m128i *)mask+l);
m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128());
auto m256 = _mm256_cvtepi16_epi32(m128);
auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16)));
return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf));
static inline __m256 apply_mask(int l, const char * mask, __m256 val, [[maybe_unused]] __m256 vinf) {
return _mm256_add_ps(val, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)mask+l)));
//auto m128 = _mm_loadu_si128((const __m128i *)mask+l);
//m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128());
//auto m256 = _mm256_cvtepi16_epi32(m128);
//auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16)));
//return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf));
}
#ifdef __AVX512F__
static inline __m512 apply_mask(int l, const char * mask, __m512 val, __m512 vinf) {
@@ -16622,6 +16811,8 @@ struct FlashQKfp32 {
#ifdef HAVE_FANCY_SIMD
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q8_0_1_Unpacker, nq);
#else
if (nq == 1) return std::make_pair(mul_mat_q8_0_q8_x4_1, 1);
if (nq == 2) return std::make_pair(mul_mat_q8_0_q8_x4_2, 2);
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
#endif
#endif