mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 15:14:10 +00:00
WIP
This commit is contained in:
@@ -8492,6 +8492,184 @@ void mul_mat_qX_1_q8_2_T(int n, const void * vx, size_t bx, const DataInfo& info
|
||||
}
|
||||
}
|
||||
|
||||
void mul_mat_q8_0_q8_x4_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
Q8<1, block_q8_0> q8(info);
|
||||
int nblock = n/QK8_0;
|
||||
auto cx = (const char *)vx;
|
||||
auto y4 = (const block_q8_0_x4 *)q8.y[0];
|
||||
ggml_half d4[4];
|
||||
__m256 dot[4];
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
auto q8x = (const block_q8_0 *)cx;
|
||||
auto acc = _mm256_setzero_ps();
|
||||
for (int ib4 = 0; ib4 < nblock/4; ++ib4) {
|
||||
auto d4y_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4[ib4].d));
|
||||
auto d4y = _mm256_set_m128(d4y_128, d4y_128);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
d4[k] = q8x[k].d;
|
||||
auto qx = _mm256_loadu_si256((const __m256i *)q8x[k].qs);
|
||||
auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)y4[ib4].qs + k), qx));
|
||||
dot[k] = _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_set1_epi16(1), p));
|
||||
}
|
||||
auto d4x_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)d4));
|
||||
auto d4x = _mm256_set_m128(d4x_128, d4x_128);
|
||||
auto d4xy = _mm256_mul_ps(d4x, d4y);
|
||||
acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xy, d4xy, 0x00), dot[0], acc);
|
||||
acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xy, d4xy, 0x55), dot[1], acc);
|
||||
acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xy, d4xy, 0xaa), dot[2], acc);
|
||||
acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xy, d4xy, 0xff), dot[3], acc);
|
||||
q8x += 4;
|
||||
}
|
||||
if (int ib0 = 4*(nblock/4); ib0 < nblock) {
|
||||
auto y = q8.y[0] + ib0;
|
||||
for (int k = 0; k < nblock - ib0; ++k) {
|
||||
auto qx = _mm256_loadu_si256((const __m256i *)q8x[k].qs);
|
||||
auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)y[k].qs), qx));
|
||||
p = _mm256_madd_epi16(_mm256_set1_epi16(1), p);
|
||||
auto dxy = GGML_FP16_TO_FP32(q8x[k].d)*GGML_FP16_TO_FP32(y[k].d);
|
||||
acc = _mm256_fmadd_ps(_mm256_set1_ps(dxy), _mm256_cvtepi32_ps(p), acc);
|
||||
}
|
||||
}
|
||||
info.store(ix, 0, hsum_float_8(acc));
|
||||
cx += bx;
|
||||
}
|
||||
}
|
||||
|
||||
void mul_mat_q8_0_q8_x4_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
Q8<2, block_q8_0> q8(info);
|
||||
int nblock = n/QK8_0;
|
||||
auto cx = (const char *)vx;
|
||||
auto y4l = (const block_q8_0_x4 *)q8.y[0];
|
||||
auto y4h = (const block_q8_0_x4 *)q8.y[1];
|
||||
ggml_half d4[4];
|
||||
__m256i dot[4];
|
||||
__m256 acc[4] = {};
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
auto q8x = (const block_q8_0 *)cx;
|
||||
for (int ib4 = 0; ib4 < nblock/4; ++ib4) {
|
||||
auto d4y_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4l[ib4].d));
|
||||
auto d4yl = _mm256_set_m128(d4y_128, d4y_128);
|
||||
d4y_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4h[ib4].d));
|
||||
auto d4yh = _mm256_set_m128(d4y_128, d4y_128);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
d4[k] = q8x[k].d;
|
||||
auto qx = _mm256_loadu_si256((const __m256i *)q8x[k].qs);
|
||||
auto ux = _mm256_sign_epi8(qx, qx);
|
||||
auto pl = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)y4l[ib4].qs + k), qx));
|
||||
auto ph = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)y4h[ib4].qs + k), qx));
|
||||
pl = _mm256_madd_epi16(_mm256_set1_epi16(1), pl);
|
||||
ph = _mm256_madd_epi16(_mm256_set1_epi16(1), ph);
|
||||
dot[k] = _mm256_add_epi32(_mm256_unpacklo_epi64(pl, ph), _mm256_unpackhi_epi64(pl, ph));
|
||||
}
|
||||
auto d4x_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)d4));
|
||||
auto d4x = _mm256_set_m128(d4x_128, d4x_128);
|
||||
auto d4xyl = _mm256_mul_ps(d4x, d4yl);
|
||||
auto d4xyh = _mm256_mul_ps(d4x, d4yh);
|
||||
acc[0] = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0x00), _mm256_cvtepi32_ps(dot[0]), acc[0]);
|
||||
acc[1] = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0x55), _mm256_cvtepi32_ps(dot[1]), acc[1]);
|
||||
acc[2] = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0xaa), _mm256_cvtepi32_ps(dot[2]), acc[2]);
|
||||
acc[3] = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0xff), _mm256_cvtepi32_ps(dot[3]), acc[3]);
|
||||
q8x += 4;
|
||||
}
|
||||
if (int ib0 = 4*(nblock/4); ib0 < nblock) {
|
||||
auto yl = q8.y[0] + ib0;
|
||||
auto yh = q8.y[1] + ib0;
|
||||
for (int k = 0; k < nblock - ib0; ++k) {
|
||||
auto qx = _mm256_loadu_si256((const __m256i *)q8x[k].qs);
|
||||
auto ux = _mm256_sign_epi8(qx, qx);
|
||||
auto pl = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)yl[k].qs), qx));
|
||||
auto ph = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)yh[k].qs), qx));
|
||||
pl = _mm256_madd_epi16(_mm256_set1_epi16(1), pl);
|
||||
ph = _mm256_madd_epi16(_mm256_set1_epi16(1), ph);
|
||||
auto p = _mm256_add_epi32(_mm256_unpacklo_epi64(pl, ph), _mm256_unpackhi_epi64(pl, ph));
|
||||
auto d = GGML_FP16_TO_FP32(q8x[k].d);
|
||||
auto dxyl = _mm256_set1_ps(d*GGML_FP16_TO_FP32(yl[k].d));
|
||||
auto dxyh = _mm256_set1_ps(d*GGML_FP16_TO_FP32(yh[k].d));
|
||||
acc[k] = _mm256_fmadd_ps(_mm256_shuffle_ps(dxyl, dxyh, 0x00), _mm256_cvtepi32_ps(p), acc[k]);
|
||||
}
|
||||
}
|
||||
acc[0] = _mm256_add_ps(acc[0], acc[1]);
|
||||
acc[2] = _mm256_add_ps(acc[2], acc[3]);
|
||||
acc[0] = _mm256_add_ps(acc[0], acc[2]);
|
||||
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[0]), _mm256_extractf128_ps(acc[0], 1));
|
||||
float daux[4]; _mm_storeu_ps(daux, sum);
|
||||
info.store(ix, 0, daux[0] + daux[1]);
|
||||
info.store(ix, 1, daux[2] + daux[3]);
|
||||
acc[0] = acc[1] = acc[2] = acc[3] = _mm256_setzero_ps();
|
||||
cx += bx;
|
||||
}
|
||||
}
|
||||
|
||||
//void mul_mat_q8_0_q8_x4_2x2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
// Q8<2, block_q8_0> q8(info);
|
||||
// int nblock = n/QK8_0;
|
||||
// auto cx = (const char *)vx;
|
||||
// auto y4l = (const block_q8_0_x4 *)q8.y[0];
|
||||
// auto y4h = (const block_q8_0_x4 *)q8.y[1];
|
||||
// ggml_half d4[4];
|
||||
// __m256i dot[4];
|
||||
// for (int ix = 0; ix < nrc_x; ix += 2) {
|
||||
// auto q8xl = (const block_q8_0 *)cx; cx += bx;
|
||||
// auto q8xh = (const block_q8_0 *)cx; cx += bx;
|
||||
// auto accl = _mm256_setzero_ps();
|
||||
// auto acch = _mm256_setzero_ps();
|
||||
// for (int ib4 = 0; ib4 < nblock/4; ++ib4) {
|
||||
// auto d4y_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4l[ib4].d));
|
||||
// auto d4yl = _mm256_set_m128(d4y_128, d4y_128);
|
||||
// d4y_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4h[ib4].d));
|
||||
// auto d4yh = _mm256_set_m128(d4y_128, d4y_128);
|
||||
// for (int k = 0; k < 4; ++k) {
|
||||
// d4[2*k+0] = q8xl[k].d;
|
||||
// d4[2*k+1] = q8xh[k].d;
|
||||
// auto qxl = _mm256_loadu_si256((const __m256i *)q8xl[k].qs);
|
||||
// auto qxh = _mm256_loadu_si256((const __m256i *)q8xh[k].qs);
|
||||
// auto uxl = _mm256_sign_epi8(qxl, qxl);
|
||||
// auto uxh = _mm256_sign_epi8(qxh, qxh);
|
||||
// auto qyl = _mm256_loadu_si256((const __m256i*)y4l[ib4].qs + k);
|
||||
// auto qyh = _mm256_loadu_si256((const __m256i*)y4h[ib4].qs + k);
|
||||
// auto pll = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(uxl, _mm256_sign_epi8(qyl, qxl)));
|
||||
// auto plh = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(uxl, _mm256_sign_epi8(qyh, qxl)));
|
||||
// auto phl = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(uxh, _mm256_sign_epi8(qyl, qxh)));
|
||||
// auto phh = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(uxh, _mm256_sign_epi8(qyh, qxh)));
|
||||
// auto pyl = _mm256_add_epi32(_mm256_unpacklo_epi32(pll, phl), _mm256_unpackhi_epi32(pll, phl));
|
||||
// auto pyh = _mm256_add_epi32(_mm256_unpacklo_epi32(plh, phh), _mm256_unpackhi_epi32(plh, phh));
|
||||
// // ll, hl, lh, hh, ll, hl, lh, hh
|
||||
// dot[k] = _mm256_add_epi32(_mm256_unpacklo_epi64(pyl, pyh), _mm256_unpackhi_epi64(pyl, pyh));
|
||||
// }
|
||||
// auto d4x_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)d4));
|
||||
// auto d4x = _mm256_set_m128(d4x_128, d4x_128);
|
||||
// auto d4xyl = _mm256_mul_ps(d4x, d4yl);
|
||||
// auto d4xyh = _mm256_mul_ps(d4x, d4yh);
|
||||
// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0x00), _mm256_cvtepi32_ps(dot[0]), acc);
|
||||
// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0x55), _mm256_cvtepi32_ps(dot[1]), acc);
|
||||
// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0xaa), _mm256_cvtepi32_ps(dot[2]), acc);
|
||||
// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0xff), _mm256_cvtepi32_ps(dot[3]), acc);
|
||||
// q8x += 4;
|
||||
// }
|
||||
// if (int ib0 = 4*(nblock/4); ib0 < nblock) {
|
||||
// auto yl = q8.y[0] + ib0;
|
||||
// auto yh = q8.y[1] + ib0;
|
||||
// for (int k = 0; k < nblock - ib0; ++k) {
|
||||
// auto qx = _mm256_loadu_si256((const __m256i *)q8x[k].qs);
|
||||
// auto ux = _mm256_sign_epi8(qx, qx);
|
||||
// auto pl = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)yl[k].qs), qx));
|
||||
// auto ph = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)yh[k].qs), qx));
|
||||
// pl = _mm256_madd_epi16(_mm256_set1_epi16(1), pl);
|
||||
// ph = _mm256_madd_epi16(_mm256_set1_epi16(1), ph);
|
||||
// auto p = _mm256_add_epi32(_mm256_unpacklo_epi64(pl, ph), _mm256_unpackhi_epi64(pl, ph));
|
||||
// auto d = GGML_FP16_TO_FP32(q8x[k].d);
|
||||
// auto dxyl = _mm256_set1_ps(d*GGML_FP16_TO_FP32(yl[k].d));
|
||||
// auto dxyh = _mm256_set1_ps(d*GGML_FP16_TO_FP32(yh[k].d));
|
||||
// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(dxyl, dxyh, 0x00), _mm256_cvtepi32_ps(p), acc);
|
||||
// }
|
||||
// }
|
||||
// auto sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
|
||||
// float daux[4]; _mm_storeu_ps(daux, sum);
|
||||
// info.store(ix, 0, daux[0] + daux[1]);
|
||||
// info.store(ix, 1, daux[2] + daux[3]);
|
||||
// }
|
||||
//}
|
||||
|
||||
struct Dequantizer4bit {
|
||||
const __m256i m4 = _mm256_set1_epi8(0xf);
|
||||
inline __m256i dequant(const uint8_t * qs) const {
|
||||
@@ -9156,6 +9334,16 @@ void mul_mat_q80_q80_T(int n, const void * vx, size_t bx, const DataInfo& info,
|
||||
}
|
||||
|
||||
template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
|
||||
//if (std::is_same_v<Dequantizer, Q8_0_Unpacker>) {
|
||||
// m.funcs[0] = mul_mat_q8_0_q8_x4_1;
|
||||
// m.funcs[1] = mul_mat_q8_0_q8_x4_2;
|
||||
// m.funcs[2] = mul_mat_qX_0_q8_0_T<Dequantizer, 3>;
|
||||
// m.funcs[3] = mul_mat_qX_0_q8_0_T<Dequantizer, 4>;
|
||||
// m.funcs[4] = mul_mat_qX_0_q8_0_T<Dequantizer, 5>;
|
||||
// m.funcs[5] = mul_mat_qX_0_q8_0_T<Dequantizer, 6>;
|
||||
// m.funcs[6] = mul_mat_qX_0_q8_0_T<Dequantizer, 7>;
|
||||
// m.funcs[7] = mul_mat_qX_0_q8_0_T<Dequantizer, 8>;
|
||||
//}
|
||||
if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> ||
|
||||
std::is_same_v<Dequantizer, Q8_0_Unpacker>) {
|
||||
m.funcs[0] = mul_mat_qX_0_q8_0_T<Dequantizer, 1>;
|
||||
@@ -16080,12 +16268,13 @@ struct FlashMS {
|
||||
}
|
||||
return F16::reduce_max<k_step>(vk);
|
||||
}
|
||||
static inline __m256 apply_mask(int l, const char * mask, __m256 val, __m256 vinf) {
|
||||
auto m128 = _mm_loadu_si128((const __m128i *)mask+l);
|
||||
m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128());
|
||||
auto m256 = _mm256_cvtepi16_epi32(m128);
|
||||
auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16)));
|
||||
return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf));
|
||||
static inline __m256 apply_mask(int l, const char * mask, __m256 val, [[maybe_unused]] __m256 vinf) {
|
||||
return _mm256_add_ps(val, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)mask+l)));
|
||||
//auto m128 = _mm_loadu_si128((const __m128i *)mask+l);
|
||||
//m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128());
|
||||
//auto m256 = _mm256_cvtepi16_epi32(m128);
|
||||
//auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16)));
|
||||
//return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf));
|
||||
}
|
||||
#ifdef __AVX512F__
|
||||
static inline __m512 apply_mask(int l, const char * mask, __m512 val, __m512 vinf) {
|
||||
@@ -16622,6 +16811,8 @@ struct FlashQKfp32 {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
MAKE_FUNCS(mul_mat_qX_1_q8_2_T<Q8_0_1_Unpacker, nq);
|
||||
#else
|
||||
if (nq == 1) return std::make_pair(mul_mat_q8_0_q8_x4_1, 1);
|
||||
if (nq == 2) return std::make_pair(mul_mat_q8_0_q8_x4_2, 2);
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user