iq6_k: slightly better Zen4 iqk_mul_mat

We now arrive at pp-512 = 147 t/s for LLaMA-3.1-8B.
TG-128 is 9.5 t/s. This is better than last commit,
but still kind of slow compared to Q6_K.

My last commit message is wrong: also iq3_k needs a fix
for overflow.
This commit is contained in:
Iwan Kawrakow
2024-08-08 14:01:08 +03:00
committed by Kawrakow
parent 849476acc7
commit 595d2ae32d

View File

@@ -967,30 +967,19 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
prepare(x[i].qs, x[i].qh);
auto scales8 = _mm_loadu_si128((const __m128i*)x[i].scales);
auto scales16 = _mm256_cvtepi8_epi16(scales8);
scales16 = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, x[i].extra, min, shift));
auto bs = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, x[i].extra, min, shift));
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
auto prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i));
auto prod = _mm256_madd_epi16(bs, q8.load_bsums(iy, i));
accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
}
scales16 = MM256_SET_M128I(scales8, scales8);
scales[0] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle1));
scales[1] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle2));
}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512 * scales) {
d = GGML_FP16_TO_FP32(x[i].d);
prepare(x[i].qs, x[i].qh);
auto scales8 = _mm_loadu_si128((const __m128i*)x[i].scales);
auto scales16 = _mm256_cvtepi8_epi16(scales8);
scales16 = _mm256_mullo_epi16(scales16, _mm256_mask_add_epi16(min, x[i].extra, min, shift));
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
auto prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i));
accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
}
auto vd = Q8::nrc_y == 1 ? _mm512_set1_ps(d*q8.scale(0, i)) : _mm512_set1_ps(d);
for (int k = 0; k < 4; ++k) {
scales[k] = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(_mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, shuffles[k])))));
}
auto aux_1 = MM256_SET_M128I(_mm256_castsi256_si128(scales16), _mm256_castsi256_si128(scales16));
auto aux_2 = MM256_SET_M128I(_mm256_extracti128_si256(scales16, 1), _mm256_extracti128_si256(scales16, 1));
auto scales16_1 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_1), aux_1, 1);
auto scales16_2 = _mm512_inserti32x8(_mm512_castsi256_si512(aux_2), aux_2, 1);
scales[0] = _mm512_shuffle_epi8(scales16_1, shuffles[0]);
scales[1] = _mm512_shuffle_epi8(scales16_1, shuffles[1]);
scales[2] = _mm512_shuffle_epi8(scales16_2, shuffles[0]);
scales[3] = _mm512_shuffle_epi8(scales16_2, shuffles[1]);
}
inline __m512i make_one(__m512i l, __m512i h) const {
auto p = _mm512_shuffle_epi8(values[0], l);
@@ -1040,9 +1029,11 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
__m512i masks[3] = { _mm512_set1_epi8(0x01), _mm512_set1_epi8(0x02), _mm512_set1_epi8(0x03) };
const __m256i min = _mm256_set1_epi16(-128);
const __m256i shift = _mm256_set1_epi16(1);
const __m128i shuffles[4] = {
_mm_set_epi64x(0x0303030302020202, 0x0101010100000000), _mm_set_epi64x(0x0707070706060606, 0x0505050504040404),
_mm_set_epi64x(0x0b0b0b0b0a0a0a0a, 0x0909090908080808), _mm_set_epi64x(0x0f0f0f0f0e0e0e0e, 0x0d0d0d0d0c0c0c0c),
const __m512i shuffles[2] = {
_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(),
_mm_set1_epi16(0x0100), 0), _mm_set1_epi16(0x0302), 1), _mm_set1_epi16(0x0504), 2), _mm_set1_epi16(0x0706), 3),
_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_inserti32x4(_mm512_setzero_si512(),
_mm_set1_epi16(0x0908), 0), _mm_set1_epi16(0x0b0a), 1), _mm_set1_epi16(0x0d0c), 2), _mm_set1_epi16(0x0f0e), 3)
};
const __m256i shuffle1 = _mm256_set_epi64x(0x0707070703030303, 0x0606060602020202, 0x0505050501010101, 0x0404040400000000);
const __m256i shuffle2 = _mm256_set_epi64x(0x0f0f0f0f0b0b0b0b, 0x0e0e0e0e0a0a0a0a, 0x0d0d0d0d09090909, 0x0c0c0c0c08080808);
@@ -1135,7 +1126,7 @@ static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const D
__m256 accm[nrc_y];
__m512 accd[nrc_y];
__m512 scales[4];
__m512i scales[4];
for (int ix = 0; ix < nrc_x; ++ix) {
@@ -1149,22 +1140,13 @@ static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const D
deq.new_block(i, q8, accm, scales);
for (int iy = 0; iy < nrc_y; ++iy) {
const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));
const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));
const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));
const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3));
if constexpr (nrc_y == 1) {
accd[iy] = _mm512_fmadd_ps(scales[0], _mm512_cvtepi32_ps(p1), accd[iy]);
accd[iy] = _mm512_fmadd_ps(scales[1], _mm512_cvtepi32_ps(p2), accd[iy]);
accd[iy] = _mm512_fmadd_ps(scales[2], _mm512_cvtepi32_ps(p3), accd[iy]);
accd[iy] = _mm512_fmadd_ps(scales[3], _mm512_cvtepi32_ps(p4), accd[iy]);
} else {
auto d8 = _mm512_set1_ps(q8.scale(iy, i));
accd[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d8, scales[0]), _mm512_cvtepi32_ps(p1), accd[iy]);
accd[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d8, scales[1]), _mm512_cvtepi32_ps(p2), accd[iy]);
accd[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d8, scales[2]), _mm512_cvtepi32_ps(p3), accd[iy]);
accd[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d8, scales[3]), _mm512_cvtepi32_ps(p4), accd[iy]);
}
const __m512i p1 = _mm512_maddubs_epi16(deq.bits.values[0], q8.load_quants64(iy, i, 0));
const __m512i p2 = _mm512_maddubs_epi16(deq.bits.values[1], q8.load_quants64(iy, i, 1));
const __m512i p3 = _mm512_maddubs_epi16(deq.bits.values[2], q8.load_quants64(iy, i, 2));
const __m512i p4 = _mm512_maddubs_epi16(deq.bits.values[3], q8.load_quants64(iy, i, 3));
auto sumi = _mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_setzero_si512(),
p1, scales[0]), p2, scales[1]), p3, scales[2]), p4, scales[3]);
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
}
}