mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
Refactor iqk: Factor out GEMM for iqk-quants (AVX2/AVX512)
This commit is contained in:
@@ -263,12 +263,14 @@ if (GGML_IQK_MUL_MAT)
|
||||
iqk/iqk_gemm_floats.cpp
|
||||
iqk/iqk_gemm_kquants.cpp
|
||||
iqk/iqk_gemm_iquants.cpp
|
||||
iqk/iqk_gemm_iqk_quants.cpp
|
||||
iqk/iqk_gemm_legacy_quants.cpp)
|
||||
set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h
|
||||
iqk/iqk_flash_impl.h
|
||||
iqk/iqk_gemm_floats.h
|
||||
iqk/iqk_gemm_kquants.h
|
||||
iqk/iqk_gemm_iquants.h
|
||||
iqk/iqk_gemm_iqk_quants.h
|
||||
iqk/iqk_gemm_legacy_quants.h)
|
||||
if (GGML_IQK_FLASH_ATTENTION)
|
||||
message(STATUS "Enabling IQK Flash Attention kernels")
|
||||
|
||||
@@ -391,6 +391,112 @@ static inline void multiply_add_avx2(const Bits& bits, const __m256i * scales, i
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
|
||||
struct BlockPermuter {
|
||||
const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
|
||||
const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
|
||||
};
|
||||
|
||||
struct Q4Bits {
|
||||
inline void prepare(const uint8_t * q4) {
|
||||
auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
|
||||
auto tmp1 = _mm512_and_si512(q4bits, ml);
|
||||
auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||
values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
|
||||
values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
|
||||
q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
|
||||
tmp1 = _mm512_and_si512(q4bits, ml);
|
||||
tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||
values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
|
||||
values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
|
||||
}
|
||||
inline void prepare64(const uint8_t * q4) {
|
||||
auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
|
||||
values[0] = _mm512_and_si512(q4bits, ml);
|
||||
values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||
q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
|
||||
values[2] = _mm512_and_si512(q4bits, ml);
|
||||
values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||
}
|
||||
inline void prepare64a(const uint8_t * q4) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + k);
|
||||
values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(q4bits), _mm256_srli_epi16(q4bits, 4), 1);
|
||||
values[k] = _mm512_and_si512(values[k], ml);
|
||||
}
|
||||
}
|
||||
__m512i values[4];
|
||||
const __m512i ml = _mm512_set1_epi8(0xf);
|
||||
const BlockPermuter perm;
|
||||
};
|
||||
|
||||
struct Q2Bits {
|
||||
inline void prepare(const uint8_t * q2) {
|
||||
|
||||
auto q2bits = _mm512_loadu_si512((const __m512i*)q2);
|
||||
auto tmp = _mm512_srli_epi16(q2bits, 2);
|
||||
|
||||
values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp);
|
||||
values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp);
|
||||
values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml);
|
||||
values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml);
|
||||
values[0] = _mm512_and_si512(values[0], ml);
|
||||
values[2] = _mm512_and_si512(values[2], ml);
|
||||
}
|
||||
__m512i values[4];
|
||||
const __m512i ml = _mm512_set1_epi8(0x03);
|
||||
BlockPermuter perm;
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
struct Q2Bits {
|
||||
inline void prepare(const uint8_t * q2, int j) {
|
||||
auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j);
|
||||
values[0] = _mm256_and_si256(q2bits, ml);
|
||||
values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);
|
||||
values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);
|
||||
values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);
|
||||
}
|
||||
__m256i values[4];
|
||||
const __m256i ml = _mm256_set1_epi8(0x03);
|
||||
};
|
||||
|
||||
struct Q4Bits {
|
||||
inline void prepare(const uint8_t * q4, int j) {
|
||||
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
|
||||
values[0] = _mm256_and_si256(q4bits, ml);
|
||||
values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||
q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
|
||||
values[2] = _mm256_and_si256(q4bits, ml);
|
||||
values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||
}
|
||||
inline void prepare64(const uint8_t * q4, int j) {
|
||||
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
|
||||
values[0] = _mm256_and_si256(q4bits, ml);
|
||||
values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||
q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
|
||||
values[1] = _mm256_and_si256(q4bits, ml);
|
||||
values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||
}
|
||||
inline void prepare16(const uint8_t * q4, int j) {
|
||||
values[0] = dequant16(q4 + 64*j + 0);
|
||||
values[1] = dequant16(q4 + 64*j + 16);
|
||||
values[2] = dequant16(q4 + 64*j + 32);
|
||||
values[3] = dequant16(q4 + 64*j + 48);
|
||||
}
|
||||
inline __m256i dequant16(const uint8_t * qs) const {
|
||||
const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);
|
||||
const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128);
|
||||
return _mm256_and_si256(ml, aux256);
|
||||
}
|
||||
__m256i values[4];
|
||||
const __m256i ml = _mm256_set1_epi8(0xf);
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
1277
ggml/src/iqk/iqk_gemm_iqk_quants.cpp
Normal file
1277
ggml/src/iqk/iqk_gemm_iqk_quants.cpp
Normal file
File diff suppressed because it is too large
Load Diff
11
ggml/src/iqk/iqk_gemm_iqk_quants.h
Normal file
11
ggml/src/iqk/iqk_gemm_iqk_quants.h
Normal file
@@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "iqk_common.h"
|
||||
|
||||
#ifdef IQK_IMPLEMENT
|
||||
|
||||
#include <array>
|
||||
|
||||
bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels);
|
||||
|
||||
#endif
|
||||
@@ -149,62 +149,6 @@ struct ScaleIQ4XS {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
//====================================== Zen4 ==================================================
|
||||
|
||||
struct BlockPermuter {
|
||||
const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
|
||||
const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
|
||||
};
|
||||
|
||||
struct Q4Bits {
|
||||
inline void prepare(const uint8_t * q4) {
|
||||
auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
|
||||
auto tmp1 = _mm512_and_si512(q4bits, ml);
|
||||
auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||
values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
|
||||
values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
|
||||
q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
|
||||
tmp1 = _mm512_and_si512(q4bits, ml);
|
||||
tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||
values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
|
||||
values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
|
||||
}
|
||||
inline void prepare64(const uint8_t * q4) {
|
||||
auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
|
||||
values[0] = _mm512_and_si512(q4bits, ml);
|
||||
values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||
q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
|
||||
values[2] = _mm512_and_si512(q4bits, ml);
|
||||
values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||
}
|
||||
inline void prepare64a(const uint8_t * q4) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + k);
|
||||
values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(q4bits), _mm256_srli_epi16(q4bits, 4), 1);
|
||||
values[k] = _mm512_and_si512(values[k], ml);
|
||||
}
|
||||
}
|
||||
__m512i values[4];
|
||||
const __m512i ml = _mm512_set1_epi8(0xf);
|
||||
const BlockPermuter perm;
|
||||
};
|
||||
|
||||
struct Q2Bits {
|
||||
inline void prepare(const uint8_t * q2) {
|
||||
|
||||
auto q2bits = _mm512_loadu_si512((const __m512i*)q2);
|
||||
auto tmp = _mm512_srli_epi16(q2bits, 2);
|
||||
|
||||
values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp);
|
||||
values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp);
|
||||
values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml);
|
||||
values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml);
|
||||
values[0] = _mm512_and_si512(values[0], ml);
|
||||
values[2] = _mm512_and_si512(values[2], ml);
|
||||
}
|
||||
__m512i values[4];
|
||||
const __m512i ml = _mm512_set1_epi8(0x03);
|
||||
BlockPermuter perm;
|
||||
};
|
||||
|
||||
struct HighBit5 {
|
||||
inline void apply(const uint8_t * h, Q4Bits& bits) {
|
||||
auto hbits256 = _mm256_loadu_si256((const __m256i *)h);
|
||||
@@ -524,50 +468,6 @@ static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const D
|
||||
#else
|
||||
//====================================== AVX2 ==================================================
|
||||
|
||||
struct Q2Bits {
|
||||
inline void prepare(const uint8_t * q2, int j) {
|
||||
auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j);
|
||||
values[0] = _mm256_and_si256(q2bits, ml);
|
||||
values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);
|
||||
values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);
|
||||
values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);
|
||||
}
|
||||
__m256i values[4];
|
||||
const __m256i ml = _mm256_set1_epi8(0x03);
|
||||
};
|
||||
|
||||
struct Q4Bits {
|
||||
inline void prepare(const uint8_t * q4, int j) {
|
||||
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
|
||||
values[0] = _mm256_and_si256(q4bits, ml);
|
||||
values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||
q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
|
||||
values[2] = _mm256_and_si256(q4bits, ml);
|
||||
values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||
}
|
||||
inline void prepare64(const uint8_t * q4, int j) {
|
||||
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
|
||||
values[0] = _mm256_and_si256(q4bits, ml);
|
||||
values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||
q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
|
||||
values[1] = _mm256_and_si256(q4bits, ml);
|
||||
values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||
}
|
||||
inline void prepare16(const uint8_t * q4, int j) {
|
||||
values[0] = dequant16(q4 + 64*j + 0);
|
||||
values[1] = dequant16(q4 + 64*j + 16);
|
||||
values[2] = dequant16(q4 + 64*j + 32);
|
||||
values[3] = dequant16(q4 + 64*j + 48);
|
||||
}
|
||||
inline __m256i dequant16(const uint8_t * qs) const {
|
||||
const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);
|
||||
const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128);
|
||||
return _mm256_and_si256(ml, aux256);
|
||||
}
|
||||
__m256i values[4];
|
||||
const __m256i ml = _mm256_set1_epi8(0xf);
|
||||
};
|
||||
|
||||
struct HighBit5 {
|
||||
inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }
|
||||
inline void apply(Q4Bits& bits, bool do_shift) {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user