mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-04 19:10:03 +00:00
iqk_mul_mat: AVX2 implementation for iq3_xxs
We get 2.3X for PP-512 (87 t/s). But for TG, we need to use the original implementation in llama.cpp because the template is not able to match the performance of the special-purpose implementation. Also, 87 t/s is significantly lower than the 111 t/s I have in iquants.
This commit is contained in:
130
iqk_mul_mat.cpp
130
iqk_mul_mat.cpp
@@ -127,6 +127,41 @@ inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) {
|
||||
aux32[0] = a0 & 0x3f3f3f3f;
|
||||
}
|
||||
|
||||
const uint64_t keven_signs[128] = {
|
||||
0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff,
|
||||
0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff,
|
||||
0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff,
|
||||
0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff,
|
||||
0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff,
|
||||
0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff,
|
||||
0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff,
|
||||
0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff,
|
||||
0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff,
|
||||
0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff,
|
||||
0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff,
|
||||
0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff,
|
||||
0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff,
|
||||
0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff,
|
||||
0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff,
|
||||
0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff,
|
||||
0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff,
|
||||
0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff,
|
||||
0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff,
|
||||
0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff,
|
||||
0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff,
|
||||
0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff,
|
||||
0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff,
|
||||
0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff,
|
||||
0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff,
|
||||
0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff,
|
||||
0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff,
|
||||
0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff,
|
||||
0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff,
|
||||
0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff,
|
||||
0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff,
|
||||
0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff,
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, const void * B,
|
||||
@@ -1087,11 +1122,17 @@ struct SimpleBits {
|
||||
};
|
||||
|
||||
struct SignHelper {
|
||||
inline __m256i make_signs(const uint16_t * sign_bits) const {
|
||||
auto aux256 = _mm256_set1_epi32(sign_bits[0] | (sign_bits[1] << 16));
|
||||
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
|
||||
inline __m256i make_signs(uint32_t sign_bits) const {
|
||||
auto aux256 = _mm256_set1_epi32(sign_bits);
|
||||
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256, mask1), mask2);
|
||||
return _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone);
|
||||
}
|
||||
inline __m256i make_signs(const uint16_t * sign_bits) const {
|
||||
return make_signs(sign_bits[0] | (sign_bits[1] << 16));
|
||||
//auto aux256 = _mm256_set1_epi32(sign_bits[0] | (sign_bits[1] << 16));
|
||||
//aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256, mask1), mask2);
|
||||
//return _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone);
|
||||
}
|
||||
const __m256i mask1 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
|
||||
const __m256i mask2 = _mm256_set1_epi64x(0x8040201008040201ull);
|
||||
const __m256i mone = _mm256_set1_epi8(1);
|
||||
@@ -1153,6 +1194,46 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
|
||||
const __m256i idx_mask = _mm256_set1_epi32(256);
|
||||
const __m256i min_value = _mm256_set1_epi8(minv);
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
|
||||
DequantizerIQ3XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
|
||||
template <typename Q8>
|
||||
inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
|
||||
d = 0.25f * GGML_FP16_TO_FP32(x[i].d);
|
||||
auto tmp = _mm256_loadu_si256((const __m256i *)(x[i].qs + QK_K/4));
|
||||
auto scales32 = _mm256_srli_epi32(tmp, 28);
|
||||
scales32 = _mm256_or_si256(_mm256_slli_epi32(scales32, 1), _mm256_set1_epi32(1));
|
||||
auto scales16 = _mm_packs_epi32(_mm256_castsi256_si128(scales32), _mm256_extractf128_si256(scales32, 1));
|
||||
scb.accum_mins(scales16, q8, i, -minv*d, accd);
|
||||
return MM256_SET_M128I(scales16, scales16);
|
||||
}
|
||||
|
||||
inline static __m256i make1(const uint8_t * qs, const uint16_t * sidx, const __m256i& min_value) {
|
||||
auto val = _mm256_set_epi32(iq3xxs_grid[qs[7]], iq3xxs_grid[qs[6]], iq3xxs_grid[qs[5]], iq3xxs_grid[qs[4]],
|
||||
iq3xxs_grid[qs[3]], iq3xxs_grid[qs[2]], iq3xxs_grid[qs[1]], iq3xxs_grid[qs[0]]);
|
||||
uint32_t aux32 = sidx[0] | (sidx[1] << 16);
|
||||
auto s = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127],
|
||||
keven_signs[(aux32 >> 7) & 127], keven_signs[aux32 & 127]);
|
||||
return _mm256_add_epi8(_mm256_sign_epi8(val, s), min_value);
|
||||
}
|
||||
|
||||
inline void prepare(int i, int j) {
|
||||
auto qs = x[i].qs + 32*j;
|
||||
const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/4) + 8*j;
|
||||
bits.values[0] = make1(qs+ 0, signs+0, min_value);
|
||||
bits.values[1] = make1(qs+ 8, signs+2, min_value);
|
||||
bits.values[2] = make1(qs+16, signs+4, min_value);
|
||||
bits.values[3] = make1(qs+24, signs+6, min_value);
|
||||
}
|
||||
|
||||
constexpr static int minv = 64;
|
||||
|
||||
SimpleBits bits;
|
||||
Scales8KBase scb;
|
||||
const __m256i min_value = _mm256_set1_epi8(minv);
|
||||
|
||||
};
|
||||
//
|
||||
// ============================== Legacy quants
|
||||
@@ -1528,7 +1609,7 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
|
||||
m.funcs[6] = mul_mat_qX_1_q8_1_T<Dequantizer, 7>;
|
||||
m.funcs[7] = mul_mat_qX_1_q8_1_T<Dequantizer, 8>;
|
||||
}
|
||||
else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ3S>) {
|
||||
else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ3S> || std::is_same_v<Dequantizer, DequantizerIQ3XXS>) {
|
||||
m.funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
|
||||
m.funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
|
||||
m.funcs[2] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 3>;
|
||||
@@ -1576,7 +1657,7 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
|
||||
|
||||
bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int Ny) {
|
||||
|
||||
if (Ny == 1 && typeA == GGML_TYPE_IQ3_S) {
|
||||
if (Ny == 1 && (typeA == GGML_TYPE_IQ3_S || typeA == GGML_TYPE_IQ3_XXS)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1611,6 +1692,10 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerIQ3S>(mm);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerIQ3XXS>(mm);
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
assert (ne00 % QK4_0 == 0);
|
||||
MulMat::set_functions<Q4_0_Unpacker>(mm);
|
||||
@@ -2177,41 +2262,6 @@ struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
|
||||
float d;
|
||||
};
|
||||
|
||||
const uint64_t keven_signs[128] = {
|
||||
0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff,
|
||||
0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff,
|
||||
0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff,
|
||||
0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff,
|
||||
0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff,
|
||||
0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff,
|
||||
0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff,
|
||||
0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff,
|
||||
0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff,
|
||||
0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff,
|
||||
0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff,
|
||||
0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff,
|
||||
0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff,
|
||||
0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff,
|
||||
0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff,
|
||||
0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff,
|
||||
0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff,
|
||||
0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff,
|
||||
0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff,
|
||||
0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff,
|
||||
0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff,
|
||||
0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff,
|
||||
0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff,
|
||||
0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff,
|
||||
0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff,
|
||||
0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff,
|
||||
0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff,
|
||||
0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff,
|
||||
0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff,
|
||||
0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff,
|
||||
0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff,
|
||||
0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff,
|
||||
};
|
||||
|
||||
struct SimpleBits {
|
||||
uint8x16x4_t b1;
|
||||
uint8x16x4_t b2;
|
||||
|
||||
Reference in New Issue
Block a user