mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
Make q8_0_r4 work with tensor row sizes that are not a multiple of 128
.., on AVX2
This commit is contained in:
@@ -3152,9 +3152,22 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
|
||||
Q8<nrc_y, block_q8_1_x4> q8(info);
|
||||
auto m1 = _mm256_set1_epi16(1);
|
||||
int nb = n / QK8_0;
|
||||
GGML_ASSERT(nb%4 == 0);
|
||||
__m256 acc[nrc_y] = {};
|
||||
float d8[4*nrc_y];
|
||||
__m256i qx[4], sx[4];
|
||||
auto dot = [&qx, &sx, &m1] (const int8_t * qy) {
|
||||
auto y128 = _mm_loadu_si128((const __m128i*)qy);
|
||||
auto y = MM256_SET_M128I(y128, y128);
|
||||
auto sumi1 = _mm256_add_epi32(
|
||||
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]))),
|
||||
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])))
|
||||
);
|
||||
auto sumi2 = _mm256_add_epi32(
|
||||
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]))),
|
||||
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])))
|
||||
);
|
||||
return _mm256_add_epi32(sumi1, sumi2);
|
||||
};
|
||||
for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx);
|
||||
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
|
||||
@@ -3164,54 +3177,49 @@ static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn
|
||||
}
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d));
|
||||
auto q0 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0);
|
||||
auto q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1);
|
||||
auto q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2);
|
||||
auto q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3);
|
||||
auto s0 = _mm256_sign_epi8(q0, q0);
|
||||
auto s1 = _mm256_sign_epi8(q1, q1);
|
||||
auto s2 = _mm256_sign_epi8(q2, q2);
|
||||
auto s3 = _mm256_sign_epi8(q3, q3);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+j);
|
||||
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+0);
|
||||
auto y = MM256_SET_M128I(y128, y128);
|
||||
auto sumi1 = _mm256_add_epi32(
|
||||
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q0))),
|
||||
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q1)))
|
||||
);
|
||||
auto sumi2 = _mm256_add_epi32(
|
||||
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q2))),
|
||||
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q3)))
|
||||
);
|
||||
auto sumi = _mm256_add_epi32(sumi1, sumi2);
|
||||
auto sumi = dot(q8.y[iy][ib4].qs+32*k);
|
||||
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
|
||||
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
}
|
||||
q0 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4);
|
||||
q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+5);
|
||||
q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+6);
|
||||
q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+7);
|
||||
s0 = _mm256_sign_epi8(q0, q0);
|
||||
s1 = _mm256_sign_epi8(q1, q1);
|
||||
s2 = _mm256_sign_epi8(q2, q2);
|
||||
s3 = _mm256_sign_epi8(q3, q3);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4+j);
|
||||
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ib4].qs+2*k+1);
|
||||
auto y = MM256_SET_M128I(y128, y128);
|
||||
auto sumi1 = _mm256_add_epi32(
|
||||
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q0))),
|
||||
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q1)))
|
||||
);
|
||||
auto sumi2 = _mm256_add_epi32(
|
||||
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q2))),
|
||||
_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q3)))
|
||||
);
|
||||
auto sumi = _mm256_add_epi32(sumi1, sumi2);
|
||||
auto sumi = dot(q8.y[iy][ib4].qs+32*k+16);
|
||||
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k]));
|
||||
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int ib = 4*(nb/4); ib < nb; ++ib) {
|
||||
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d));
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+j);
|
||||
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto qy = (const block_q8_1 *)q8.y[iy];
|
||||
auto sumi = dot(qy[ib].qs);
|
||||
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
|
||||
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
}
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+4+j);
|
||||
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto qy = (const block_q8_1 *)q8.y[iy];
|
||||
auto sumi = dot(qy[ib].qs+16);
|
||||
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)));
|
||||
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, acc[iy]);
|
||||
acc[iy] = _mm256_setzero_ps();
|
||||
|
||||
Reference in New Issue
Block a user