mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
Refactor iqk: Factor out GEMM for iqk-quants (AVX2/AVX512)
This commit is contained in:
@@ -263,12 +263,14 @@ if (GGML_IQK_MUL_MAT)
|
|||||||
iqk/iqk_gemm_floats.cpp
|
iqk/iqk_gemm_floats.cpp
|
||||||
iqk/iqk_gemm_kquants.cpp
|
iqk/iqk_gemm_kquants.cpp
|
||||||
iqk/iqk_gemm_iquants.cpp
|
iqk/iqk_gemm_iquants.cpp
|
||||||
|
iqk/iqk_gemm_iqk_quants.cpp
|
||||||
iqk/iqk_gemm_legacy_quants.cpp)
|
iqk/iqk_gemm_legacy_quants.cpp)
|
||||||
set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h
|
set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h
|
||||||
iqk/iqk_flash_impl.h
|
iqk/iqk_flash_impl.h
|
||||||
iqk/iqk_gemm_floats.h
|
iqk/iqk_gemm_floats.h
|
||||||
iqk/iqk_gemm_kquants.h
|
iqk/iqk_gemm_kquants.h
|
||||||
iqk/iqk_gemm_iquants.h
|
iqk/iqk_gemm_iquants.h
|
||||||
|
iqk/iqk_gemm_iqk_quants.h
|
||||||
iqk/iqk_gemm_legacy_quants.h)
|
iqk/iqk_gemm_legacy_quants.h)
|
||||||
if (GGML_IQK_FLASH_ATTENTION)
|
if (GGML_IQK_FLASH_ATTENTION)
|
||||||
message(STATUS "Enabling IQK Flash Attention kernels")
|
message(STATUS "Enabling IQK Flash Attention kernels")
|
||||||
|
|||||||
@@ -391,6 +391,112 @@ static inline void multiply_add_avx2(const Bits& bits, const __m256i * scales, i
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef HAVE_FANCY_SIMD
|
||||||
|
|
||||||
|
struct BlockPermuter {
|
||||||
|
const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
|
||||||
|
const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Q4Bits {
|
||||||
|
inline void prepare(const uint8_t * q4) {
|
||||||
|
auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
|
||||||
|
auto tmp1 = _mm512_and_si512(q4bits, ml);
|
||||||
|
auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||||
|
values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
|
||||||
|
values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
|
||||||
|
q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
|
||||||
|
tmp1 = _mm512_and_si512(q4bits, ml);
|
||||||
|
tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||||
|
values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
|
||||||
|
values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
|
||||||
|
}
|
||||||
|
inline void prepare64(const uint8_t * q4) {
|
||||||
|
auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
|
||||||
|
values[0] = _mm512_and_si512(q4bits, ml);
|
||||||
|
values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||||
|
q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
|
||||||
|
values[2] = _mm512_and_si512(q4bits, ml);
|
||||||
|
values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||||
|
}
|
||||||
|
inline void prepare64a(const uint8_t * q4) {
|
||||||
|
for (int k = 0; k < 4; ++k) {
|
||||||
|
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + k);
|
||||||
|
values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(q4bits), _mm256_srli_epi16(q4bits, 4), 1);
|
||||||
|
values[k] = _mm512_and_si512(values[k], ml);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__m512i values[4];
|
||||||
|
const __m512i ml = _mm512_set1_epi8(0xf);
|
||||||
|
const BlockPermuter perm;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Q2Bits {
|
||||||
|
inline void prepare(const uint8_t * q2) {
|
||||||
|
|
||||||
|
auto q2bits = _mm512_loadu_si512((const __m512i*)q2);
|
||||||
|
auto tmp = _mm512_srli_epi16(q2bits, 2);
|
||||||
|
|
||||||
|
values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp);
|
||||||
|
values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp);
|
||||||
|
values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml);
|
||||||
|
values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml);
|
||||||
|
values[0] = _mm512_and_si512(values[0], ml);
|
||||||
|
values[2] = _mm512_and_si512(values[2], ml);
|
||||||
|
}
|
||||||
|
__m512i values[4];
|
||||||
|
const __m512i ml = _mm512_set1_epi8(0x03);
|
||||||
|
BlockPermuter perm;
|
||||||
|
};
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
struct Q2Bits {
|
||||||
|
inline void prepare(const uint8_t * q2, int j) {
|
||||||
|
auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j);
|
||||||
|
values[0] = _mm256_and_si256(q2bits, ml);
|
||||||
|
values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);
|
||||||
|
values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);
|
||||||
|
values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);
|
||||||
|
}
|
||||||
|
__m256i values[4];
|
||||||
|
const __m256i ml = _mm256_set1_epi8(0x03);
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Q4Bits {
|
||||||
|
inline void prepare(const uint8_t * q4, int j) {
|
||||||
|
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
|
||||||
|
values[0] = _mm256_and_si256(q4bits, ml);
|
||||||
|
values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||||
|
q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
|
||||||
|
values[2] = _mm256_and_si256(q4bits, ml);
|
||||||
|
values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||||
|
}
|
||||||
|
inline void prepare64(const uint8_t * q4, int j) {
|
||||||
|
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
|
||||||
|
values[0] = _mm256_and_si256(q4bits, ml);
|
||||||
|
values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||||
|
q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
|
||||||
|
values[1] = _mm256_and_si256(q4bits, ml);
|
||||||
|
values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||||
|
}
|
||||||
|
inline void prepare16(const uint8_t * q4, int j) {
|
||||||
|
values[0] = dequant16(q4 + 64*j + 0);
|
||||||
|
values[1] = dequant16(q4 + 64*j + 16);
|
||||||
|
values[2] = dequant16(q4 + 64*j + 32);
|
||||||
|
values[3] = dequant16(q4 + 64*j + 48);
|
||||||
|
}
|
||||||
|
inline __m256i dequant16(const uint8_t * qs) const {
|
||||||
|
const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);
|
||||||
|
const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128);
|
||||||
|
return _mm256_and_si256(ml, aux256);
|
||||||
|
}
|
||||||
|
__m256i values[4];
|
||||||
|
const __m256i ml = _mm256_set1_epi8(0xf);
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
1277
ggml/src/iqk/iqk_gemm_iqk_quants.cpp
Normal file
1277
ggml/src/iqk/iqk_gemm_iqk_quants.cpp
Normal file
File diff suppressed because it is too large
Load Diff
11
ggml/src/iqk/iqk_gemm_iqk_quants.h
Normal file
11
ggml/src/iqk/iqk_gemm_iqk_quants.h
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "iqk_common.h"
|
||||||
|
|
||||||
|
#ifdef IQK_IMPLEMENT
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels);
|
||||||
|
|
||||||
|
#endif
|
||||||
@@ -149,62 +149,6 @@ struct ScaleIQ4XS {
|
|||||||
#ifdef HAVE_FANCY_SIMD
|
#ifdef HAVE_FANCY_SIMD
|
||||||
//====================================== Zen4 ==================================================
|
//====================================== Zen4 ==================================================
|
||||||
|
|
||||||
struct BlockPermuter {
|
|
||||||
const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
|
|
||||||
const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Q4Bits {
|
|
||||||
inline void prepare(const uint8_t * q4) {
|
|
||||||
auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
|
|
||||||
auto tmp1 = _mm512_and_si512(q4bits, ml);
|
|
||||||
auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
|
||||||
values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
|
|
||||||
values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
|
|
||||||
q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
|
|
||||||
tmp1 = _mm512_and_si512(q4bits, ml);
|
|
||||||
tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
|
||||||
values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
|
|
||||||
values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
|
|
||||||
}
|
|
||||||
inline void prepare64(const uint8_t * q4) {
|
|
||||||
auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
|
|
||||||
values[0] = _mm512_and_si512(q4bits, ml);
|
|
||||||
values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
|
||||||
q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
|
|
||||||
values[2] = _mm512_and_si512(q4bits, ml);
|
|
||||||
values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
|
||||||
}
|
|
||||||
inline void prepare64a(const uint8_t * q4) {
|
|
||||||
for (int k = 0; k < 4; ++k) {
|
|
||||||
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + k);
|
|
||||||
values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(q4bits), _mm256_srli_epi16(q4bits, 4), 1);
|
|
||||||
values[k] = _mm512_and_si512(values[k], ml);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
__m512i values[4];
|
|
||||||
const __m512i ml = _mm512_set1_epi8(0xf);
|
|
||||||
const BlockPermuter perm;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Q2Bits {
|
|
||||||
inline void prepare(const uint8_t * q2) {
|
|
||||||
|
|
||||||
auto q2bits = _mm512_loadu_si512((const __m512i*)q2);
|
|
||||||
auto tmp = _mm512_srli_epi16(q2bits, 2);
|
|
||||||
|
|
||||||
values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp);
|
|
||||||
values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp);
|
|
||||||
values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml);
|
|
||||||
values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml);
|
|
||||||
values[0] = _mm512_and_si512(values[0], ml);
|
|
||||||
values[2] = _mm512_and_si512(values[2], ml);
|
|
||||||
}
|
|
||||||
__m512i values[4];
|
|
||||||
const __m512i ml = _mm512_set1_epi8(0x03);
|
|
||||||
BlockPermuter perm;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct HighBit5 {
|
struct HighBit5 {
|
||||||
inline void apply(const uint8_t * h, Q4Bits& bits) {
|
inline void apply(const uint8_t * h, Q4Bits& bits) {
|
||||||
auto hbits256 = _mm256_loadu_si256((const __m256i *)h);
|
auto hbits256 = _mm256_loadu_si256((const __m256i *)h);
|
||||||
@@ -524,50 +468,6 @@ static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const D
|
|||||||
#else
|
#else
|
||||||
//====================================== AVX2 ==================================================
|
//====================================== AVX2 ==================================================
|
||||||
|
|
||||||
struct Q2Bits {
|
|
||||||
inline void prepare(const uint8_t * q2, int j) {
|
|
||||||
auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j);
|
|
||||||
values[0] = _mm256_and_si256(q2bits, ml);
|
|
||||||
values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);
|
|
||||||
values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);
|
|
||||||
values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);
|
|
||||||
}
|
|
||||||
__m256i values[4];
|
|
||||||
const __m256i ml = _mm256_set1_epi8(0x03);
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Q4Bits {
|
|
||||||
inline void prepare(const uint8_t * q4, int j) {
|
|
||||||
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
|
|
||||||
values[0] = _mm256_and_si256(q4bits, ml);
|
|
||||||
values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
|
||||||
q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
|
|
||||||
values[2] = _mm256_and_si256(q4bits, ml);
|
|
||||||
values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
|
||||||
}
|
|
||||||
inline void prepare64(const uint8_t * q4, int j) {
|
|
||||||
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
|
|
||||||
values[0] = _mm256_and_si256(q4bits, ml);
|
|
||||||
values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
|
||||||
q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
|
|
||||||
values[1] = _mm256_and_si256(q4bits, ml);
|
|
||||||
values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
|
||||||
}
|
|
||||||
inline void prepare16(const uint8_t * q4, int j) {
|
|
||||||
values[0] = dequant16(q4 + 64*j + 0);
|
|
||||||
values[1] = dequant16(q4 + 64*j + 16);
|
|
||||||
values[2] = dequant16(q4 + 64*j + 32);
|
|
||||||
values[3] = dequant16(q4 + 64*j + 48);
|
|
||||||
}
|
|
||||||
inline __m256i dequant16(const uint8_t * qs) const {
|
|
||||||
const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);
|
|
||||||
const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128);
|
|
||||||
return _mm256_and_si256(ml, aux256);
|
|
||||||
}
|
|
||||||
__m256i values[4];
|
|
||||||
const __m256i ml = _mm256_set1_epi8(0xf);
|
|
||||||
};
|
|
||||||
|
|
||||||
struct HighBit5 {
|
struct HighBit5 {
|
||||||
inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }
|
inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }
|
||||||
inline void apply(Q4Bits& bits, bool do_shift) {
|
inline void apply(Q4Bits& bits, bool do_shift) {
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user