mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-02 18:10:02 +00:00
iq4_nl_x4: getting amazing
This Zen4 variant gets us to PP-512(LLaMA-3.1-8B) = 263 t/s!
This commit is contained in:
@@ -2074,190 +2074,110 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
|
||||
|
||||
#endif // Zen4 or vanilla AVX2
|
||||
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq4_nl_x4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
//printf("%s: n = %d, nrc_x = %d, bx = %zu\n", __func__, n, nrc_x, bx);
|
||||
Q8<nrc_y, block_q8_1_x4> q8(info);
|
||||
auto m4 = _mm256_set1_epi8(0xf);
|
||||
auto values = load_iq4nl_values_256();
|
||||
int nb = n / QK4_NL;
|
||||
//float dequant[128], stored[4], s[4*nrc_y];
|
||||
//float dequant[128], s[4*nrc_y];
|
||||
GGML_ASSERT(nb%4 == 0);
|
||||
__m256 acc[nrc_y] = {};
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_iq4_nl_x4 * iq4 = (const block_iq4_nl_x4 *)((const char *)vx + ix*bx);
|
||||
//for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = _mm256_setzero_ps();
|
||||
//for (int j = 0; j < 4*nrc_y; ++j) s[j] = 0;
|
||||
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
//dequantize_row_iq4_nl_x4(iq4+4*ib4+k, dequant, 128);
|
||||
auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[4*ib4+k].d));
|
||||
auto scales = _mm256_set_m128(scales128, scales128);
|
||||
auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-64.f));
|
||||
//auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-128.f));
|
||||
auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+0);
|
||||
auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+1);
|
||||
auto q1 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
|
||||
auto q2 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
|
||||
auto q3 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
|
||||
auto q4 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
|
||||
__m256i sumi;
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
//float d8 = GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]);
|
||||
//const int8_t * qs = q8.y[iy][ib4].qs+32*k;
|
||||
////s[0] = s[1] = s[2] = s[3] = 0;
|
||||
//for (int j = 0; j < 32; ++j) {
|
||||
// s[4*iy+0] += d8*dequant[j+ 0]*qs[j];
|
||||
// s[4*iy+1] += d8*dequant[j+32]*qs[j];
|
||||
// s[4*iy+2] += d8*dequant[j+64]*qs[j];
|
||||
// s[4*iy+3] += d8*dequant[j+96]*qs[j];
|
||||
//}
|
||||
////s[0] *= d8; s[1] *= d8; s[2] *= d8; s[3] *= d8;
|
||||
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
|
||||
sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), q1, _mm256_shuffle_epi32(y, 0x00));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, q2, _mm256_shuffle_epi32(y, 0x55));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, q3, _mm256_shuffle_epi32(y, 0xaa));
|
||||
sumi = _mm256_dpbusd_epi32(sumi, q4, _mm256_shuffle_epi32(y, 0xff));
|
||||
//auto d4d8 = _mm256_mul_ps(scales, _mm256_cvtph_ps(_mm_set1_epi16(q8.y[iy][ib4].d[k])));
|
||||
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
|
||||
//auto check = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), _mm256_setzero_ps());
|
||||
//check = _mm256_fmadd_ps(scales_m, _mm256_cvtph_ps(_mm_set1_epi16(q8.y[iy][ib4].d[k+4])), check);
|
||||
//auto check128 = _mm_add_ps(_mm256_castps256_ps128(check), _mm256_extractf128_ps(check, 1));
|
||||
//_mm_storeu_ps(stored, check128);
|
||||
//bool is_good = true;
|
||||
//for (int j = 0; j < 4; ++j) {
|
||||
// if (std::abs(stored[j] - s[j]) > 1e-1*(std::abs(s[j])+std::abs(stored[j]))) is_good = false;
|
||||
//}
|
||||
//if (!is_good) {
|
||||
// static int ncount = 0;
|
||||
// printf("Oops\n");
|
||||
// for (int j = 0; j < 4; ++j) printf("%g vs %g, diff = %g, thresh = %g\n", s[j], stored[j], std::abs(stored[j] - s[j]), 1e-2f*std::abs(s[j]));
|
||||
// _mm_storeu_ps(stored, scales128);
|
||||
// printf("iq4_nl scales: %g, %g, %g, %g\n", stored[0], stored[1], stored[2], stored[3]);
|
||||
// printf("d8 = %g, %g\n", d8, GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4]));
|
||||
// _mm_storeu_ps(stored, _mm256_castps256_ps128(d4d8));
|
||||
// printf("iq4_nl scales x q8: %g, %g, %g, %g", stored[0], stored[1], stored[2], stored[3]);
|
||||
// _mm_storeu_ps(stored, _mm256_extractf128_ps(d4d8, 1));
|
||||
// printf(" %g, %g, %g, %g\n", stored[0], stored[1], stored[2], stored[3]);
|
||||
// check = _mm256_mul_ps(scales_m, _mm256_cvtph_ps(_mm_set1_epi16(q8.y[iy][ib4].d[k+4])));
|
||||
// check128 = _mm_add_ps(_mm256_castps256_ps128(check), _mm256_extractf128_ps(check, 1));
|
||||
// _mm_storeu_ps(stored, check128);
|
||||
// printf("minus: %g, %g, %g, %g\n", stored[0], stored[1], stored[2], stored[3]);
|
||||
// check = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), _mm256_setzero_ps());
|
||||
// check128 = _mm_add_ps(_mm256_castps256_ps128(check), _mm256_extractf128_ps(check, 1));
|
||||
// _mm_storeu_ps(stored, check128);
|
||||
// printf("plus: %g, %g, %g, %g\n", stored[0], stored[1], stored[2], stored[3]);
|
||||
// auto sumi128 = _mm_add_epi32(_mm256_castsi256_si128(sumi), _mm256_extractf128_si256(sumi, 1));
|
||||
// _mm_storeu_si128((__m128i *)stored, sumi128);
|
||||
// const int32_t * aux32 = (const int32_t *)stored;
|
||||
// printf("sumi: %d, %d, %d, %d\n", aux32[0], aux32[1], aux32[2], aux32[3]);
|
||||
// //uint8_t aux[32];
|
||||
// //_mm256_storeu_si256((__m256i *)aux, q1);
|
||||
// //printf("=== q1\n");
|
||||
// //for (int j = 0; j < 32; ++j) printf("%d %u\n", j, aux[j]);
|
||||
// //_mm256_storeu_si256((__m256i *)aux, q2);
|
||||
// //printf("=== q2\n");
|
||||
// //for (int j = 0; j < 32; ++j) printf("%d %u\n", j, aux[j]);
|
||||
// //_mm256_storeu_si256((__m256i *)aux, q3);
|
||||
// //printf("=== q3\n");
|
||||
// //for (int j = 0; j < 32; ++j) printf("%d %u\n", j, aux[j]);
|
||||
// //_mm256_storeu_si256((__m256i *)aux, q4);
|
||||
// //printf("=== q4\n");
|
||||
// //for (int j = 0; j < 32; ++j) printf("%d %u\n", j, aux[j]);
|
||||
// if (++ncount > 10) GGML_ABORT("fatal error");
|
||||
//}
|
||||
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
//acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_cvtph_ps(_mm_set1_epi16(q8.y[iy][ib4].d[k+4])), acc[iy]);
|
||||
acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[iy]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
|
||||
//_mm_storeu_ps(stored, sum);
|
||||
//bool is_good = true;
|
||||
//for (int j = 0; j < 4; ++j) {
|
||||
// if (std::abs(stored[j] - s[j]) > 1e-4f) { printf("Oops: %g vs %g\n", stored[j], s[j]); is_good = false; }
|
||||
//}
|
||||
//if (!is_good) GGML_ABORT("fatal error");
|
||||
info.store(ix, iy, sum);
|
||||
//for (int k = 0; k < 4; ++k) info.store(ix+k, iy, s[4*iy+k]);
|
||||
acc[iy] = _mm256_setzero_ps();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//template <int nrc_y>
|
||||
//static void mul_mat_iq4_nl_x4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
// GGML_ASSERT(nrc_x%8 == 0);
|
||||
// //printf("%s: n = %d, nrc_x = %d, bx = %zu\n", __func__, n, nrc_x, bx);
|
||||
// GGML_ASSERT(nrc_x%4 == 0);
|
||||
// Q8<nrc_y, block_q8_1_x4> q8(info);
|
||||
// auto m4 = _mm256_set1_epi8(0xf);
|
||||
// auto values = load_iq4nl_values_256();
|
||||
// int nb = n / QK4_NL;
|
||||
// GGML_ASSERT(nb%4 == 0);
|
||||
// __m256 acc[2*nrc_y] = {};
|
||||
// __m256i qx[8];
|
||||
// for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
// const block_iq4_nl_x4 * iq4l = (const block_iq4_nl_x4 *)((const char *)vx + (ix+0)*bx);
|
||||
// const block_iq4_nl_x4 * iq4h = (const block_iq4_nl_x4 *)((const char *)vx + (ix+4)*bx);
|
||||
// //__m256 acc[nrc_y] = {};
|
||||
// __m256 acc[2*nrc_y] = {};
|
||||
// for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
// const block_iq4_nl_x4 * iq4 = (const block_iq4_nl_x4 *)((const char *)vx + ix*bx);
|
||||
// for (int ib4 = 0; ib4 < nb/4; ++ib4) {
|
||||
// for (int k = 0; k < 4; ++k) {
|
||||
// auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[4*ib4+k].d));
|
||||
// auto scales1 = _mm256_set_m128(scales128, scales128);
|
||||
// auto scales1_m = _mm256_mul_ps(scales1, _mm256_set1_ps(-64.f));
|
||||
// auto bits1 = _mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0);
|
||||
// auto bits2 = _mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1);
|
||||
// qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
|
||||
// qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
|
||||
// qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
|
||||
// qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
|
||||
// scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[4*ib4+k].d));
|
||||
// auto scales2 = _mm256_set_m128(scales128, scales128);
|
||||
// auto scales2_m = _mm256_mul_ps(scales2, _mm256_set1_ps(-64.f));
|
||||
// bits1 = _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0);
|
||||
// bits2 = _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1);
|
||||
// qx[4] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
|
||||
// qx[5] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
|
||||
// qx[6] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
|
||||
// qx[7] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
|
||||
// __m256i sumi1, sumi2;
|
||||
// auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[4*ib4+k].d));
|
||||
// auto scales = _mm256_set_m128(scales128, scales128);
|
||||
// auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-64.f));
|
||||
// auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+0);
|
||||
// auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+1);
|
||||
// auto q1 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
|
||||
// auto q2 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
|
||||
// auto q3 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
|
||||
// auto q4 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
|
||||
// __m256i sumi;
|
||||
// for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
// auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
|
||||
// auto sy = _mm256_shuffle_epi32(y, 0x00);
|
||||
// sumi1 = _mm256_dpbusd_epi32(qx[0], sy, _mm256_setzero_si256());
|
||||
// sumi2 = _mm256_dpbusd_epi32(qx[4], sy, _mm256_setzero_si256());
|
||||
// sy = _mm256_shuffle_epi32(y, 0x55);
|
||||
// sumi1 = _mm256_dpbusd_epi32(qx[1], sy, sumi1);
|
||||
// sumi2 = _mm256_dpbusd_epi32(qx[5], sy, sumi2);
|
||||
// sy = _mm256_shuffle_epi32(y, 0xaa);
|
||||
// sumi1 = _mm256_dpbusd_epi32(qx[2], sy, sumi1);
|
||||
// sumi2 = _mm256_dpbusd_epi32(qx[6], sy, sumi2);
|
||||
// sy = _mm256_shuffle_epi32(y, 0xff);
|
||||
// sumi1 = _mm256_dpbusd_epi32(qx[3], sy, sumi1);
|
||||
// sumi2 = _mm256_dpbusd_epi32(qx[7], sy, sumi2);
|
||||
// auto dy = _mm256_cvtph_ps(_mm_set1_epi16(q8.y[iy][ib4].d[k]));
|
||||
// acc[2*iy+0] = _mm256_fmadd_ps(_mm256_mul_ps(scales1, dy), _mm256_cvtepi32_ps(sumi1), acc[2*iy+0]);
|
||||
// acc[2*iy+1] = _mm256_fmadd_ps(_mm256_mul_ps(scales2, dy), _mm256_cvtepi32_ps(sumi2), acc[2*iy+1]);
|
||||
// dy = _mm256_cvtph_ps(_mm_set1_epi16(q8.y[iy][ib4].d[k+4]));
|
||||
// acc[2*iy+0] = _mm256_fmadd_ps(scales1_m, dy, acc[2*iy+0]);
|
||||
// acc[2*iy+1] = _mm256_fmadd_ps(scales2_m, dy, acc[2*iy+1]);
|
||||
// sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), q1, _mm256_shuffle_epi32(y, 0x00));
|
||||
// sumi = _mm256_dpbusd_epi32(sumi, q2, _mm256_shuffle_epi32(y, 0x55));
|
||||
// sumi = _mm256_dpbusd_epi32(sumi, q3, _mm256_shuffle_epi32(y, 0xaa));
|
||||
// sumi = _mm256_dpbusd_epi32(sumi, q4, _mm256_shuffle_epi32(y, 0xff));
|
||||
// auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
|
||||
// //acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
// //acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[iy]);
|
||||
// acc[2*iy+0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[2*iy+0]);
|
||||
// acc[2*iy+1] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
// auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[2*iy+0]), _mm256_extractf128_ps(acc[2*iy+0], 1));
|
||||
// info.store(ix+0, iy, sum);
|
||||
// sum = _mm_add_ps(_mm256_castps256_ps128(acc[2*iy+1]), _mm256_extractf128_ps(acc[2*iy+1], 1));
|
||||
// info.store(ix+4, iy, sum);
|
||||
// auto sum256 = _mm256_add_ps(acc[2*iy+0], acc[2*iy+1]);
|
||||
// acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_ps();
|
||||
// auto sum = _mm_add_ps(_mm256_castps256_ps128(sum256), _mm256_extractf128_ps(sum256, 1));
|
||||
// info.store(ix, iy, sum);
|
||||
// //auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
|
||||
// //info.store(ix, iy, sum);
|
||||
// //acc[iy] = _mm256_setzero_ps();
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq4_nl_x4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
Q8<nrc_y, block_q8_1_x4> q8(info);
|
||||
auto m4 = _mm512_set1_epi8(0xf);
|
||||
auto values = load_iq4nl_values_512();
|
||||
int nb = n / QK4_NL;
|
||||
GGML_ASSERT(nb%4 == 0);
|
||||
__m512 acc[2*nrc_y] = {};
|
||||
__m512i qx[4];
|
||||
for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
const block_iq4_nl_x4 * iq4l = (const block_iq4_nl_x4 *)((const char *)vx + (ix+0)*bx);
|
||||
const block_iq4_nl_x4 * iq4h = (const block_iq4_nl_x4 *)((const char *)vx + (ix+4)*bx);
|
||||
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[4*ib4+k].d));
|
||||
auto scales1 = _mm256_set_m128(scales128, scales128);
|
||||
scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[4*ib4+k].d));
|
||||
auto scales2 = _mm256_set_m128(scales128, scales128);
|
||||
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
|
||||
auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-64.f));
|
||||
auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0)),
|
||||
_mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0), 1);
|
||||
auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1)),
|
||||
_mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1), 1);
|
||||
qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4));
|
||||
qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4));
|
||||
qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4));
|
||||
qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
|
||||
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
|
||||
auto sumi = _mm512_setzero_si512();
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
|
||||
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
|
||||
auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]));
|
||||
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
|
||||
acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]);
|
||||
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
|
||||
auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
|
||||
auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
|
||||
info.store(ix+0, iy, sum1);
|
||||
info.store(ix+4, iy, sum2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Bits>
|
||||
inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) {
|
||||
if (j == 0) {
|
||||
|
||||
Reference in New Issue
Block a user