Make q8_0_r4 work with tensor row sizes that are not a multiple of 128

.., on AVX2
This commit is contained in:
Iwan Kawrakow
2025-01-28 19:59:29 +02:00
parent d3545680b9
commit 4d7dc72d41

View File

@@ -3152,9 +3152,22 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
Q8<nrc_y, block_q8_1_x4> q8(info);
auto m1 = _mm256_set1_epi16(1);
int nb = n / QK8_0;
GGML_ASSERT(nb%4 == 0);
__m256 acc[nrc_y] = {};
float d8[4*nrc_y];
__m256i qx[4], sx[4];
auto dot = [&qx, &sx, &m1] (const int8_t * qy) {
auto y128 = _mm_loadu_si128((const __m128i*)qy);
auto y = MM256_SET_M128I(y128, y128);
auto sumi1 = _mm256_add_epi32(
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]))),
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])))
);
auto sumi2 = _mm256_add_epi32(
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]))),
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])))
);
return _mm256_add_epi32(sumi1, sumi2);
};
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
@@ -3164,54 +3177,49 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
}
for (int k = 0; k < 4; ++k) {
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d));
auto q0 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0);
auto q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1);
auto q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2);
auto q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3);
auto s0 = _mm256_sign_epi8(q0, q0);
auto s1 = _mm256_sign_epi8(q1, q1);
auto s2 = _mm256_sign_epi8(q2, q2);
auto s3 = _mm256_sign_epi8(q3, q3);
for (int j = 0; j < 4; ++j) {
qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+j);
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0);
auto y = MM256_SET_M128I(y128, y128);
auto sumi1 = _mm256_add_epi32(
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q0))),
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q1)))
);
auto sumi2 = _mm256_add_epi32(
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q2))),
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q3)))
);
auto sumi = _mm256_add_epi32(sumi1, sumi2);
auto sumi = dot(q8.y[iy][ib4].qs+32*k);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
}
q0 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4);
q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+5);
q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+6);
q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+7);
s0 = _mm256_sign_epi8(q0, q0);
s1 = _mm256_sign_epi8(q1, q1);
s2 = _mm256_sign_epi8(q2, q2);
s3 = _mm256_sign_epi8(q3, q3);
for (int j = 0; j < 4; ++j) {
qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4+j);
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1);
auto y = MM256_SET_M128I(y128, y128);
auto sumi1 = _mm256_add_epi32(
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q0))),
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q1)))
);
auto sumi2 = _mm256_add_epi32(
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q2))),
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q3)))
);
auto sumi = _mm256_add_epi32(sumi1, sumi2);
auto sumi = dot(q8.y[iy][ib4].qs+32*k+16);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d));
for (int j = 0; j < 4; ++j) {
qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+j);
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto sumi = dot(qy[ib].qs);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
}
for (int j = 0; j < 4; ++j) {
qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+4+j);
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto sumi = dot(qy[ib].qs+16);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = _mm256_setzero_ps();