iq4_nl_x4: getting amazing

This Zen4 variant gets us to PP-512(LLaMA-3.1-8B) = 263 t/s!
This commit is contained in:
Iwan Kawrakow
2024-11-30 19:32:45 +02:00
parent 422e5768e4
commit 9982d420ef

View File

@@ -2074,190 +2074,110 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
#endif // Zen4 or vanilla AVX2
template <int nrc_y>
static void mul_mat_iq4_nl_x4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
//printf("%s: n = %d, nrc_x = %d, bx = %zu\n", __func__, n, nrc_x, bx);
Q8<nrc_y, block_q8_1_x4> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
auto values = load_iq4nl_values_256();
int nb = n / QK4_NL;
//float dequant[128], stored[4], s[4*nrc_y];
//float dequant[128], s[4*nrc_y];
GGML_ASSERT(nb%4 == 0);
__m256 acc[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_iq4_nl_x4 * iq4 = (const block_iq4_nl_x4 *)((const char *)vx + ix*bx);
//for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = _mm256_setzero_ps();
//for (int j = 0; j < 4*nrc_y; ++j) s[j] = 0;
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int k = 0; k < 4; ++k) {
//dequantize_row_iq4_nl_x4(iq4+4*ib4+k, dequant, 128);
auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[4*ib4+k].d));
auto scales = _mm256_set_m128(scales128, scales128);
auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-64.f));
//auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-128.f));
auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+0);
auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+1);
auto q1 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
auto q2 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
auto q3 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
auto q4 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
__m256i sumi;
for (int iy = 0; iy < nrc_y; ++iy) {
//float d8 = GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]);
//const int8_t * qs = q8.y[iy][ib4].qs+32*k;
////s[0] = s[1] = s[2] = s[3] = 0;
//for (int j = 0; j < 32; ++j) {
// s[4*iy+0] += d8*dequant[j+ 0]*qs[j];
// s[4*iy+1] += d8*dequant[j+32]*qs[j];
// s[4*iy+2] += d8*dequant[j+64]*qs[j];
// s[4*iy+3] += d8*dequant[j+96]*qs[j];
//}
////s[0] *= d8; s[1] *= d8; s[2] *= d8; s[3] *= d8;
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), q1, _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, q2, _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, q3, _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, q4, _mm256_shuffle_epi32(y, 0xff));
//auto d4d8 = _mm256_mul_ps(scales, _mm256_cvtph_ps(_mm_set1_epi16(q8.y[iy][ib4].d[k])));
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
//auto check = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), _mm256_setzero_ps());
//check = _mm256_fmadd_ps(scales_m, _mm256_cvtph_ps(_mm_set1_epi16(q8.y[iy][ib4].d[k+4])), check);
//auto check128 = _mm_add_ps(_mm256_castps256_ps128(check), _mm256_extractf128_ps(check, 1));
//_mm_storeu_ps(stored, check128);
//bool is_good = true;
//for (int j = 0; j < 4; ++j) {
// if (std::abs(stored[j] - s[j]) > 1e-1*(std::abs(s[j])+std::abs(stored[j]))) is_good = false;
//}
//if (!is_good) {
// static int ncount = 0;
// printf("Oops\n");
// for (int j = 0; j < 4; ++j) printf("%g vs %g, diff = %g, thresh = %g\n", s[j], stored[j], std::abs(stored[j] - s[j]), 1e-2f*std::abs(s[j]));
// _mm_storeu_ps(stored, scales128);
// printf("iq4_nl scales: %g, %g, %g, %g\n", stored[0], stored[1], stored[2], stored[3]);
// printf("d8 = %g, %g\n", d8, GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4]));
// _mm_storeu_ps(stored, _mm256_castps256_ps128(d4d8));
// printf("iq4_nl scales x q8: %g, %g, %g, %g", stored[0], stored[1], stored[2], stored[3]);
// _mm_storeu_ps(stored, _mm256_extractf128_ps(d4d8, 1));
// printf(" %g, %g, %g, %g\n", stored[0], stored[1], stored[2], stored[3]);
// check = _mm256_mul_ps(scales_m, _mm256_cvtph_ps(_mm_set1_epi16(q8.y[iy][ib4].d[k+4])));
// check128 = _mm_add_ps(_mm256_castps256_ps128(check), _mm256_extractf128_ps(check, 1));
// _mm_storeu_ps(stored, check128);
// printf("minus: %g, %g, %g, %g\n", stored[0], stored[1], stored[2], stored[3]);
// check = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), _mm256_setzero_ps());
// check128 = _mm_add_ps(_mm256_castps256_ps128(check), _mm256_extractf128_ps(check, 1));
// _mm_storeu_ps(stored, check128);
// printf("plus: %g, %g, %g, %g\n", stored[0], stored[1], stored[2], stored[3]);
// auto sumi128 = _mm_add_epi32(_mm256_castsi256_si128(sumi), _mm256_extractf128_si256(sumi, 1));
// _mm_storeu_si128((__m128i *)stored, sumi128);
// const int32_t * aux32 = (const int32_t *)stored;
// printf("sumi: %d, %d, %d, %d\n", aux32[0], aux32[1], aux32[2], aux32[3]);
// //uint8_t aux[32];
// //_mm256_storeu_si256((__m256i *)aux, q1);
// //printf("=== q1\n");
// //for (int j = 0; j < 32; ++j) printf("%d %u\n", j, aux[j]);
// //_mm256_storeu_si256((__m256i *)aux, q2);
// //printf("=== q2\n");
// //for (int j = 0; j < 32; ++j) printf("%d %u\n", j, aux[j]);
// //_mm256_storeu_si256((__m256i *)aux, q3);
// //printf("=== q3\n");
// //for (int j = 0; j < 32; ++j) printf("%d %u\n", j, aux[j]);
// //_mm256_storeu_si256((__m256i *)aux, q4);
// //printf("=== q4\n");
// //for (int j = 0; j < 32; ++j) printf("%d %u\n", j, aux[j]);
// if (++ncount > 10) GGML_ABORT("fatal error");
//}
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
//acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_cvtph_ps(_mm_set1_epi16(q8.y[iy][ib4].d[k+4])), acc[iy]);
acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[iy]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
//_mm_storeu_ps(stored, sum);
//bool is_good = true;
//for (int j = 0; j < 4; ++j) {
// if (std::abs(stored[j] - s[j]) > 1e-4f) { printf("Oops: %g vs %g\n", stored[j], s[j]); is_good = false; }
//}
//if (!is_good) GGML_ABORT("fatal error");
info.store(ix, iy, sum);
//for (int k = 0; k < 4; ++k) info.store(ix+k, iy, s[4*iy+k]);
acc[iy] = _mm256_setzero_ps();
}
}
}
//template <int nrc_y>
//static void mul_mat_iq4_nl_x4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
// GGML_ASSERT(nrc_x%8 == 0);
// //printf("%s: n = %d, nrc_x = %d, bx = %zu\n", __func__, n, nrc_x, bx);
// GGML_ASSERT(nrc_x%4 == 0);
// Q8<nrc_y, block_q8_1_x4> q8(info);
// auto m4 = _mm256_set1_epi8(0xf);
// auto values = load_iq4nl_values_256();
// int nb = n / QK4_NL;
// GGML_ASSERT(nb%4 == 0);
// __m256 acc[2*nrc_y] = {};
// __m256i qx[8];
// for (int ix = 0; ix < nrc_x; ix += 8) {
// const block_iq4_nl_x4 * iq4l = (const block_iq4_nl_x4 *)((const char *)vx + (ix+0)*bx);
// const block_iq4_nl_x4 * iq4h = (const block_iq4_nl_x4 *)((const char *)vx + (ix+4)*bx);
// //__m256 acc[nrc_y] = {};
// __m256 acc[2*nrc_y] = {};
// for (int ix = 0; ix < nrc_x; ix += 4) {
// const block_iq4_nl_x4 * iq4 = (const block_iq4_nl_x4 *)((const char *)vx + ix*bx);
// for (int ib4 = 0; ib4 < nb/4; ++ib4) {
// for (int k = 0; k < 4; ++k) {
// auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[4*ib4+k].d));
// auto scales1 = _mm256_set_m128(scales128, scales128);
// auto scales1_m = _mm256_mul_ps(scales1, _mm256_set1_ps(-64.f));
// auto bits1 = _mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0);
// auto bits2 = _mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1);
// qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
// qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
// qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
// qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
// scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[4*ib4+k].d));
// auto scales2 = _mm256_set_m128(scales128, scales128);
// auto scales2_m = _mm256_mul_ps(scales2, _mm256_set1_ps(-64.f));
// bits1 = _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0);
// bits2 = _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1);
// qx[4] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
// qx[5] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
// qx[6] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
// qx[7] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
// __m256i sumi1, sumi2;
// auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[4*ib4+k].d));
// auto scales = _mm256_set_m128(scales128, scales128);
// auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-64.f));
// auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+0);
// auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+1);
// auto q1 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4));
// auto q2 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4));
// auto q3 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4));
// auto q4 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4));
// __m256i sumi;
// for (int iy = 0; iy < nrc_y; ++iy) {
// auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
// auto sy = _mm256_shuffle_epi32(y, 0x00);
// sumi1 = _mm256_dpbusd_epi32(qx[0], sy, _mm256_setzero_si256());
// sumi2 = _mm256_dpbusd_epi32(qx[4], sy, _mm256_setzero_si256());
// sy = _mm256_shuffle_epi32(y, 0x55);
// sumi1 = _mm256_dpbusd_epi32(qx[1], sy, sumi1);
// sumi2 = _mm256_dpbusd_epi32(qx[5], sy, sumi2);
// sy = _mm256_shuffle_epi32(y, 0xaa);
// sumi1 = _mm256_dpbusd_epi32(qx[2], sy, sumi1);
// sumi2 = _mm256_dpbusd_epi32(qx[6], sy, sumi2);
// sy = _mm256_shuffle_epi32(y, 0xff);
// sumi1 = _mm256_dpbusd_epi32(qx[3], sy, sumi1);
// sumi2 = _mm256_dpbusd_epi32(qx[7], sy, sumi2);
// auto dy = _mm256_cvtph_ps(_mm_set1_epi16(q8.y[iy][ib4].d[k]));
// acc[2*iy+0] = _mm256_fmadd_ps(_mm256_mul_ps(scales1, dy), _mm256_cvtepi32_ps(sumi1), acc[2*iy+0]);
// acc[2*iy+1] = _mm256_fmadd_ps(_mm256_mul_ps(scales2, dy), _mm256_cvtepi32_ps(sumi2), acc[2*iy+1]);
// dy = _mm256_cvtph_ps(_mm_set1_epi16(q8.y[iy][ib4].d[k+4]));
// acc[2*iy+0] = _mm256_fmadd_ps(scales1_m, dy, acc[2*iy+0]);
// acc[2*iy+1] = _mm256_fmadd_ps(scales2_m, dy, acc[2*iy+1]);
// sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), q1, _mm256_shuffle_epi32(y, 0x00));
// sumi = _mm256_dpbusd_epi32(sumi, q2, _mm256_shuffle_epi32(y, 0x55));
// sumi = _mm256_dpbusd_epi32(sumi, q3, _mm256_shuffle_epi32(y, 0xaa));
// sumi = _mm256_dpbusd_epi32(sumi, q4, _mm256_shuffle_epi32(y, 0xff));
// auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
// //acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
// //acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[iy]);
// acc[2*iy+0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[2*iy+0]);
// acc[2*iy+1] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]);
// }
// }
// }
// for (int iy = 0; iy < nrc_y; ++iy) {
// auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[2*iy+0]), _mm256_extractf128_ps(acc[2*iy+0], 1));
// info.store(ix+0, iy, sum);
// sum = _mm_add_ps(_mm256_castps256_ps128(acc[2*iy+1]), _mm256_extractf128_ps(acc[2*iy+1], 1));
// info.store(ix+4, iy, sum);
// auto sum256 = _mm256_add_ps(acc[2*iy+0], acc[2*iy+1]);
// acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_ps();
// auto sum = _mm_add_ps(_mm256_castps256_ps128(sum256), _mm256_extractf128_ps(sum256, 1));
// info.store(ix, iy, sum);
// //auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
// //info.store(ix, iy, sum);
// //acc[iy] = _mm256_setzero_ps();
// }
// }
//}
template <int nrc_y>
static void mul_mat_iq4_nl_x4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_1_x4> q8(info);
auto m4 = _mm512_set1_epi8(0xf);
auto values = load_iq4nl_values_512();
int nb = n / QK4_NL;
GGML_ASSERT(nb%4 == 0);
__m512 acc[2*nrc_y] = {};
__m512i qx[4];
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_iq4_nl_x4 * iq4l = (const block_iq4_nl_x4 *)((const char *)vx + (ix+0)*bx);
const block_iq4_nl_x4 * iq4h = (const block_iq4_nl_x4 *)((const char *)vx + (ix+4)*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int k = 0; k < 4; ++k) {
auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[4*ib4+k].d));
auto scales1 = _mm256_set_m128(scales128, scales128);
scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[4*ib4+k].d));
auto scales2 = _mm256_set_m128(scales128, scales128);
auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1);
auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-64.f));
auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0)),
_mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0), 1);
auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1)),
_mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1), 1);
qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4));
qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4));
qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4));
qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k);
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
auto sumi = _mm512_setzero_si512();
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]));
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]);
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps();
auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1));
auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3));
info.store(ix+0, iy, sum1);
info.store(ix+4, iy, sum2);
}
}
}
template <typename Bits>
inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) {
if (j == 0) {