iq3_k: AVX2 iqk_mul_mat

We get PP-512 = 196 t/s for LLaMA-3.1-8B on the Ryzen-5975WX.
This commit is contained in:
Kawrakow
2024-07-30 19:01:35 +03:00
committed by Kawrakow
parent a9fa3b1563
commit 9c1eea6048

View File

@@ -1233,6 +1233,51 @@ struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
const __m128i maskl = _mm_set1_epi8(0xf);
};
struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
DequantizerIQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -64), values(load_values()) {}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
d = GGML_FP16_TO_FP32(x[i].d);
iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_h, x[i].scales_l), q8, accm, scales);
hbits = _mm256_loadu_si256((const __m256i *)x[i].qh);
}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs, j);
auto h256 = j == 0 ? hbits : _mm256_srli_epi16(hbits, 4);
bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(h256, 2), hmask));
bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(h256, 1), hmask));
bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(h256, hmask));
bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(h256, 1), hmask));
bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]);
bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]);
bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]);
bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);
}
static inline __m256i load_values() {
static const uint8_t kvalues_iq3nl[16] = {1, 24, 41, 54, 65, 77, 92, 111, 5, 28, 45, 58, 69, 81, 96, 115};
auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq3nl);
return MM256_SET_M128I(val128, val128);
}
inline __m128i make_scales(uint16_t signs, const uint8_t * scales_l) const {
uint64_t aux64; std::memcpy(&aux64, scales_l, 8);
auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf));
scl = _mm_add_epi8(_mm_slli_epi16(scl, 1), m1);
const __m128i sc_signs = _mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi16(signs), sign_mask), sign_mask);
const __m128i sch = _mm_shuffle_epi8(_mm_or_si128(sc_signs, _mm_set1_epi8(1)), hshuff);
return _mm_sign_epi8(scl, sch);
}
Q2Bits bits;
const IQXKScales iqxk;
const __m256i values;
__m256i hbits;
const __m256i hmask = _mm256_set1_epi8(4);
const __m128i m1 = _mm_set1_epi8(1);
const __m128i sign_mask = _mm_set_epi64x(0x8080404020201010, 0x0808040402020101);
const __m128i hshuff = _mm_loadu_si128((const __m128i*)k_shuff);
constexpr static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15};
};
struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -128), values(load_iq4nl_values()) {}
template <typename Q8>
@@ -3048,9 +3093,10 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
if constexpr (std::is_same_v<Dequantizer, DequantizerQ2K> ||
std::is_same_v<Dequantizer, DequantizerQ3K> ||
std::is_same_v<Dequantizer, DequantizerQ6K> ||
std::is_same_v<Dequantizer, DequantizerIQ2K>||
std::is_same_v<Dequantizer, DequantizerIQ3K>||
std::is_same_v<Dequantizer, DequantizerIQ4K>||
std::is_same_v<Dequantizer, DequantizerIQ5K>||
std::is_same_v<Dequantizer, DequantizerIQ2K>) {
std::is_same_v<Dequantizer, DequantizerIQ5K>) {
m.funcs[0] = mul_mat_qY_K_q8_K_T<Dequantizer, 1>;
m.funcs[1] = mul_mat_qY_K_q8_K_T<Dequantizer, 2>;
m.funcs[2] = mul_mat_qY_K_q8_K_T<Dequantizer, 3>;