mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-11 14:30:02 +00:00
Refactor iqk: Factor out GEMM for k-quants (AVX2/AVX512)
This commit is contained in:
@@ -261,10 +261,12 @@ if (GGML_IQK_MUL_MAT)
|
||||
set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp
|
||||
iqk/iqk_flash_attn.cpp
|
||||
iqk/iqk_gemm_floats.cpp
|
||||
iqk/iqk_gemm_kquants.cpp
|
||||
iqk/iqk_gemm_legacy_quants.cpp)
|
||||
set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h
|
||||
iqk/iqk_flash_impl.h
|
||||
iqk/iqk_gemm_floats.h
|
||||
iqk/iqk_gemm_kquants.h
|
||||
iqk/iqk_gemm_legacy_quants.h)
|
||||
if (GGML_IQK_FLASH_ATTENTION)
|
||||
message(STATUS "Enabling IQK Flash Attention kernels")
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-quants.h"
|
||||
@@ -139,6 +140,17 @@ typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& inf
|
||||
|
||||
// ==================================================================================================
|
||||
|
||||
static inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) {
|
||||
const uint16_t * scales = (const uint16_t *)scales8;
|
||||
const uint32_t a0 = scales[0] | (scales[1] << 16);
|
||||
const uint32_t a1 = scales[2] | (scales[3] << 16);
|
||||
const uint32_t a2 = scales[4] | (scales[5] << 16);
|
||||
aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030);
|
||||
aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030);
|
||||
aux32[2] = a1 & 0x3f3f3f3f;
|
||||
aux32[0] = a0 & 0x3f3f3f3f;
|
||||
}
|
||||
|
||||
#ifdef __AVX2__
|
||||
|
||||
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
||||
@@ -236,6 +248,49 @@ template <int nrc> struct Q8_16 {
|
||||
const int8_t * y[nrc_y];
|
||||
};
|
||||
|
||||
struct Scales8KBase {
|
||||
template <typename Q8>
|
||||
inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {
|
||||
const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0]));
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
const __m256i q8s = q8.load_bsums(iy, i);
|
||||
const __m256i prod = _mm256_madd_epi16(mins, q8s);
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
|
||||
}
|
||||
}
|
||||
inline __m256i shuffle(__m128i mins) const {
|
||||
return MM256_SET_M128I(_mm_shuffle_epi8(mins, shuffles[1]), _mm_shuffle_epi8(mins, shuffles[0]));
|
||||
}
|
||||
const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100),
|
||||
_mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)};
|
||||
};
|
||||
|
||||
template <typename Block, bool per_row_scale = false, bool is_f16 = false>
|
||||
struct BaseDequantizer {
|
||||
BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {}
|
||||
inline void new_row(int ix) {
|
||||
if constexpr (per_row_scale) {
|
||||
if constexpr (is_f16) {
|
||||
const ggml_half * dptr = (const ggml_half *)((const char *)vx + bx*ix);
|
||||
d = GGML_FP16_TO_FP32(*dptr);
|
||||
x = (const Block *)(dptr + 1);
|
||||
} else {
|
||||
const float * dptr = (const float *)((const char *)vx + bx*ix);
|
||||
d = *dptr;
|
||||
x = (const Block *)(dptr + 1);
|
||||
}
|
||||
} else {
|
||||
x = (const Block *)((const char *)vx + bx*ix);
|
||||
}
|
||||
}
|
||||
|
||||
const void * vx;
|
||||
const size_t bx;
|
||||
const Block * x;
|
||||
|
||||
float d;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
763
ggml/src/iqk/iqk_gemm_kquants.cpp
Normal file
763
ggml/src/iqk/iqk_gemm_kquants.cpp
Normal file
@@ -0,0 +1,763 @@
|
||||
#include "iqk_gemm_kquants.h"
|
||||
|
||||
#ifdef IQK_IMPLEMENT
|
||||
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#define GGML_COMMON_IMPL_C
|
||||
#include "ggml-common.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// Handles q4_K and q5_K scales/mins
|
||||
struct Scales8K {
|
||||
template <typename Q8>
|
||||
inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {
|
||||
make_q4_scales(data, utmp);
|
||||
const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
|
||||
const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1);
|
||||
accum_mins(mins128, q8, i, c, accd);
|
||||
const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
|
||||
return MM256_SET_M128I(sc128, sc128);
|
||||
}
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
template <typename Q8>
|
||||
inline __m512i process_mins_and_scales_64(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {
|
||||
auto scales = process_mins_and_scales(data, c, i, q8, accd);
|
||||
return _mm512_inserti32x8(_mm512_castsi256_si512(scales), scales, 1);
|
||||
}
|
||||
#endif
|
||||
template <typename Q8>
|
||||
inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {
|
||||
base.accum_mins(mins128, q8, i, c, accd);
|
||||
}
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
const __m512i shuffles512[2] = {
|
||||
_mm512_set_epi64(0x0706070607060706, 0x0302030203020302, 0x0706070607060706, 0x0302030203020302,
|
||||
0x0504050405040504, 0x0100010001000100, 0x0504050405040504, 0x0100010001000100),
|
||||
_mm512_set_epi64(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a,
|
||||
0x0d0c0d0c0d0c0d0c, 0x0908090809080908, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)
|
||||
};
|
||||
#endif
|
||||
Scales8KBase base;
|
||||
|
||||
uint32_t utmp[4];
|
||||
};
|
||||
|
||||
template <typename Q8>
|
||||
inline void process_mins_16(const __m256i& all_scales, const Q8& q8, int i, float d, __m256 * accm) {
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
const __m256i prod = _mm256_madd_epi16(all_scales, q8.load_bsums(iy, i));
|
||||
accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
|
||||
}
|
||||
}
|
||||
inline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) {
|
||||
const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
|
||||
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
|
||||
scales[0] = MM256_SET_M128I(l_scales, l_scales);
|
||||
scales[1] = MM256_SET_M128I(h_scales, h_scales);
|
||||
}
|
||||
|
||||
// Handles q3_K scales
|
||||
struct ScaleQ3 {
|
||||
inline __m128i make_scales(const uint16_t * s8) const {
|
||||
const uint16_t * scales16 = (const uint16_t *)s8;
|
||||
uint32_t aux0 = scales16[0] | (scales16[1] << 16);
|
||||
uint32_t aux1 = scales16[2] | (scales16[3] << 16);
|
||||
uint32_t aux2 = scales16[4] | (scales16[5] << 16);
|
||||
__m128i scales128 = _mm_set_epi32(
|
||||
((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030),
|
||||
((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030),
|
||||
(aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030),
|
||||
(aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030));
|
||||
return _mm_add_epi8(scales128, m32);
|
||||
}
|
||||
const __m128i m32 = _mm_set1_epi8(-32);
|
||||
};
|
||||
|
||||
struct Scale16 {
|
||||
inline void make_scales(const __m128i& scales8, __m512i * scales) const {
|
||||
auto all_scales8 = MM256_SET_M128I(scales8, scales8);
|
||||
auto scales1 = _mm256_shuffle_epi8(all_scales8, shuffle1);
|
||||
auto scales2 = _mm256_shuffle_epi8(all_scales8, shuffle2);
|
||||
scales[0] = _mm512_cvtepi8_epi16(scales1);
|
||||
scales[1] = _mm512_cvtepi8_epi16(scales2);
|
||||
}
|
||||
template <typename Q8>
|
||||
inline void process_mins_and_scales(int i, float c, const __m128i& mins8, const __m128i& scales8,
|
||||
const Q8& q8, __m256 * accm, __m512i * scales) const {
|
||||
process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, c, accm);
|
||||
make_scales(scales8, scales);
|
||||
}
|
||||
const __m256i shuffle1 = _mm256_set_epi32(0x07070707, 0x03030303, 0x06060606, 0x02020202,
|
||||
0x05050505, 0x01010101, 0x04040404, 0x00000000);
|
||||
const __m256i shuffle2 = _mm256_set_epi32(0x0f0f0f0f, 0x0b0b0b0b, 0x0e0e0e0e, 0x0a0a0a0a,
|
||||
0x0d0d0d0d, 0x09090909, 0x0c0c0c0c, 0x08080808);
|
||||
};
|
||||
|
||||
template <typename Q8>
|
||||
inline void process_mins_and_scales_16(const __m128i& scales128, const Q8& q8, int i, float d,
|
||||
__m256 * accm, __m256i * scales) {
|
||||
const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
|
||||
process_mins_16(all_scales, q8, i, d, accm);
|
||||
prepare_scales_16(all_scales, scales);
|
||||
}
|
||||
|
||||
inline __m256i get_scale_shuffle_8(int i) {
|
||||
return _mm256_set1_epi16((2*i) | ((2*i+1) << 8));
|
||||
}
|
||||
|
||||
inline void set_scales_8(const __m256i& all_scales, int j, __m256i * scales) {
|
||||
scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+0));
|
||||
scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+1));
|
||||
scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+2));
|
||||
scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3));
|
||||
}
|
||||
|
||||
inline __m256i get_scale_shuffle_16(int i) {
|
||||
static const uint8_t k_shuffle[128] = {
|
||||
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
|
||||
4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
|
||||
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
|
||||
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
|
||||
};
|
||||
return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
|
||||
}
|
||||
|
||||
inline void set_scales_16(const __m256i& all_scales, __m256i * scales) {
|
||||
scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0));
|
||||
scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1));
|
||||
scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2));
|
||||
scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3));
|
||||
}
|
||||
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
//====================================== Zen4 ==================================================
|
||||
|
||||
struct BlockPermuter {
|
||||
const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
|
||||
const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
|
||||
};
|
||||
|
||||
struct Q4Bits {
|
||||
inline void prepare(const uint8_t * q4) {
|
||||
auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
|
||||
auto tmp1 = _mm512_and_si512(q4bits, ml);
|
||||
auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||
values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
|
||||
values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
|
||||
q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
|
||||
tmp1 = _mm512_and_si512(q4bits, ml);
|
||||
tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||
values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
|
||||
values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
|
||||
}
|
||||
inline void prepare64(const uint8_t * q4) {
|
||||
auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
|
||||
values[0] = _mm512_and_si512(q4bits, ml);
|
||||
values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||
q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
|
||||
values[2] = _mm512_and_si512(q4bits, ml);
|
||||
values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
|
||||
}
|
||||
inline void prepare64a(const uint8_t * q4) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + k);
|
||||
values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(q4bits), _mm256_srli_epi16(q4bits, 4), 1);
|
||||
values[k] = _mm512_and_si512(values[k], ml);
|
||||
}
|
||||
}
|
||||
__m512i values[4];
|
||||
const __m512i ml = _mm512_set1_epi8(0xf);
|
||||
const BlockPermuter perm;
|
||||
};
|
||||
|
||||
struct Q2Bits {
|
||||
inline void prepare(const uint8_t * q2) {
|
||||
|
||||
auto q2bits = _mm512_loadu_si512((const __m512i*)q2);
|
||||
auto tmp = _mm512_srli_epi16(q2bits, 2);
|
||||
|
||||
values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp);
|
||||
values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp);
|
||||
values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml);
|
||||
values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml);
|
||||
values[0] = _mm512_and_si512(values[0], ml);
|
||||
values[2] = _mm512_and_si512(values[2], ml);
|
||||
}
|
||||
__m512i values[4];
|
||||
const __m512i ml = _mm512_set1_epi8(0x03);
|
||||
BlockPermuter perm;
|
||||
};
|
||||
|
||||
struct HighBit5 {
|
||||
inline void apply(const uint8_t * h, Q4Bits& bits) {
|
||||
auto hbits256 = _mm256_loadu_si256((const __m256i *)h);
|
||||
auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1);
|
||||
bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh));
|
||||
bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh));
|
||||
bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(hbits, mh));
|
||||
bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh));
|
||||
}
|
||||
const __m512i mh = _mm512_set1_epi8(0x10);
|
||||
};
|
||||
|
||||
struct HighBit3 {
|
||||
inline void apply(const uint8_t * h, Q2Bits& bits) {
|
||||
auto hbits256 = _mm256_loadu_si256((const __m256i *)h);
|
||||
auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1);
|
||||
bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh));
|
||||
bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(hbits, mh));
|
||||
bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh));
|
||||
bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 4), mh));
|
||||
}
|
||||
const __m512i mh = _mm512_set1_epi8(0x04);
|
||||
};
|
||||
|
||||
|
||||
template <typename Q8>
|
||||
inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) {
|
||||
const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0));
|
||||
const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[1], q8.load_quants64(iy, i, 1));
|
||||
const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[2], q8.load_quants64(iy, i, 2));
|
||||
const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[3], q8.load_quants64(iy, i, 3));
|
||||
auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
|
||||
sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
|
||||
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
|
||||
}
|
||||
|
||||
struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
|
||||
DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
bits.prepare(x[i].qs);
|
||||
const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
|
||||
const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
|
||||
const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
|
||||
sc16.process_mins_and_scales(i, -GGML_FP16_TO_FP32(x[i].dmin), mins8, scales8, q8, accm, scales);
|
||||
}
|
||||
|
||||
Q2Bits bits;
|
||||
Scale16 sc16;
|
||||
const __m128i m4 = _mm_set1_epi8(0xf);
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
|
||||
DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
bits.prepare(x[i].qs);
|
||||
hbits.apply(x[i].hmask, bits);
|
||||
auto scales128 = sc3.make_scales((const uint16_t *)x[i].scales);
|
||||
sc16.process_mins_and_scales(i, -4.f*d, scales128, scales128, q8, accm, scales);
|
||||
}
|
||||
|
||||
Q2Bits bits;
|
||||
HighBit3 hbits;
|
||||
ScaleQ3 sc3;
|
||||
Scale16 sc16;
|
||||
const __m128i m4 = _mm_set1_epi8(0xf);
|
||||
const __m128i m32 = _mm_set1_epi8(-32);
|
||||
};
|
||||
|
||||
struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
|
||||
DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
bits.prepare(x[i].qs);
|
||||
auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
|
||||
scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
|
||||
scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
|
||||
}
|
||||
|
||||
Q4Bits bits;
|
||||
Scales8K s8k;
|
||||
};
|
||||
|
||||
struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
|
||||
DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
bits.prepare(x[i].qs);
|
||||
hbits.apply(x[i].qh, bits);
|
||||
auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
|
||||
scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
|
||||
scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
|
||||
}
|
||||
|
||||
Q4Bits bits;
|
||||
HighBit5 hbits;
|
||||
Scales8K s8k;
|
||||
};
|
||||
|
||||
struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
|
||||
DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
bits.prepare64(x[i].ql);
|
||||
add_high_bits(x[i].qh, bits);
|
||||
auto scales128 = _mm_loadu_si128((const __m128i *)x[i].scales);
|
||||
sc16.process_mins_and_scales(i, -32.f*d, scales128, scales128, q8, accm, scales);
|
||||
}
|
||||
|
||||
inline void add_high_bits(const uint8_t * qh, Q4Bits& bits) const {
|
||||
auto hbits = _mm512_loadu_si512((const __m512i *)qh);
|
||||
auto tmp1 = _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh);
|
||||
auto tmp2 = _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh);
|
||||
bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));
|
||||
bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));
|
||||
tmp1 = _mm512_and_si512(hbits, mh);
|
||||
tmp2 = _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh);
|
||||
bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));
|
||||
bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));
|
||||
}
|
||||
|
||||
Q4Bits bits;
|
||||
HighBit3 hbits;
|
||||
Scale16 sc16;
|
||||
|
||||
const __m512i mh = _mm512_set1_epi8(0x30);
|
||||
|
||||
};
|
||||
|
||||
template <typename Dequantizer>
|
||||
static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
const int nb = n / QK_K;
|
||||
|
||||
constexpr int k_nx = 2;
|
||||
|
||||
Q8<1> q8(info);
|
||||
|
||||
Dequantizer deq1(vx, bx);
|
||||
Dequantizer deq2(vx, bx);
|
||||
|
||||
Dequantizer * deq[k_nx];
|
||||
deq[0] = &deq1;
|
||||
deq[1] = &deq2;
|
||||
|
||||
__m512i scales[2*k_nx];
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
auto accd = _mm512_setzero_ps();
|
||||
auto accm = _mm256_setzero_ps();
|
||||
|
||||
for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_row(ix);
|
||||
|
||||
for (int i = 0; i < nb/k_nx; ++i) {
|
||||
|
||||
for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx);
|
||||
|
||||
for (int kx = 0; kx < k_nx; ++kx) {
|
||||
compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);
|
||||
}
|
||||
|
||||
}
|
||||
if (2*(nb/2) < nb) {
|
||||
int i0 = 2*(nb/2);
|
||||
deq[0]->new_block(i0, q8, &accm, scales);
|
||||
compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);
|
||||
}
|
||||
|
||||
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
|
||||
info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dequantizer, int nrc_y>
|
||||
static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
const int nb = n / QK_K;
|
||||
|
||||
Q8<nrc_y> q8(info);
|
||||
|
||||
Dequantizer deq(vx, bx);
|
||||
|
||||
__m256 accm[nrc_y];
|
||||
__m512 accd[nrc_y];
|
||||
__m512i scales[2];
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();
|
||||
|
||||
deq.new_row(ix);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
deq.new_block(i, q8, accm, scales);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));
|
||||
const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));
|
||||
const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));
|
||||
const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3));
|
||||
auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
|
||||
sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
|
||||
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
|
||||
info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
//====================================== AVX2 ==================================================
|
||||
|
||||
struct Q2Bits {
|
||||
inline void prepare(const uint8_t * q2, int j) {
|
||||
auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j);
|
||||
values[0] = _mm256_and_si256(q2bits, ml);
|
||||
values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);
|
||||
values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);
|
||||
values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);
|
||||
}
|
||||
__m256i values[4];
|
||||
const __m256i ml = _mm256_set1_epi8(0x03);
|
||||
};
|
||||
|
||||
struct Q4Bits {
|
||||
inline void prepare(const uint8_t * q4, int j) {
|
||||
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
|
||||
values[0] = _mm256_and_si256(q4bits, ml);
|
||||
values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||
q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
|
||||
values[2] = _mm256_and_si256(q4bits, ml);
|
||||
values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||
}
|
||||
inline void prepare64(const uint8_t * q4, int j) {
|
||||
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
|
||||
values[0] = _mm256_and_si256(q4bits, ml);
|
||||
values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||
q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
|
||||
values[1] = _mm256_and_si256(q4bits, ml);
|
||||
values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
|
||||
}
|
||||
inline void prepare16(const uint8_t * q4, int j) {
|
||||
values[0] = dequant16(q4 + 64*j + 0);
|
||||
values[1] = dequant16(q4 + 64*j + 16);
|
||||
values[2] = dequant16(q4 + 64*j + 32);
|
||||
values[3] = dequant16(q4 + 64*j + 48);
|
||||
}
|
||||
inline __m256i dequant16(const uint8_t * qs) const {
|
||||
const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);
|
||||
const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128);
|
||||
return _mm256_and_si256(ml, aux256);
|
||||
}
|
||||
__m256i values[4];
|
||||
const __m256i ml = _mm256_set1_epi8(0xf);
|
||||
};
|
||||
|
||||
struct HighBit5 {
|
||||
inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }
|
||||
inline void apply(Q4Bits& bits, bool do_shift) {
|
||||
bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
|
||||
bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 3), mh));
|
||||
bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
|
||||
bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
|
||||
if (do_shift) {
|
||||
hbits = _mm256_srli_epi16(hbits, 4);
|
||||
}
|
||||
}
|
||||
const __m256i mh = _mm256_set1_epi8(0x10);
|
||||
__m256i hbits;
|
||||
};
|
||||
|
||||
struct HighBit3 {
|
||||
inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }
|
||||
inline void apply(Q2Bits& bits, bool do_shift) {
|
||||
bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
|
||||
bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
|
||||
bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));
|
||||
bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh));
|
||||
if (do_shift) {
|
||||
hbits = _mm256_srli_epi16(hbits, 4);
|
||||
}
|
||||
}
|
||||
const __m256i mh = _mm256_set1_epi8(0x04);
|
||||
__m256i hbits;
|
||||
};
|
||||
|
||||
template <typename Q8>
|
||||
inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) {
|
||||
const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0));
|
||||
const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[1], q8.load_quants64(iy, i, 1));
|
||||
const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[2], q8.load_quants64(iy, i, 2));
|
||||
const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[3], q8.load_quants64(iy, i, 3));
|
||||
auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
|
||||
sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
|
||||
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
|
||||
}
|
||||
|
||||
struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
|
||||
DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
|
||||
const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
|
||||
const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
|
||||
process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, -GGML_FP16_TO_FP32(x[i].dmin), accm);
|
||||
prepare_scales_16(_mm256_cvtepi8_epi16(scales8), scales);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs, j);
|
||||
}
|
||||
|
||||
Q2Bits bits;
|
||||
|
||||
const __m128i m4 = _mm_set1_epi8(0xf);
|
||||
};
|
||||
|
||||
struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
|
||||
DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
hbits.load(x[i].hmask);
|
||||
process_mins_and_scales_16(sc3.make_scales((const uint16_t *)x[i].scales), q8, i, -4.f*d, accm, scales);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs, j);
|
||||
hbits.apply(bits, j == 0);
|
||||
}
|
||||
|
||||
Q2Bits bits;
|
||||
HighBit3 hbits;
|
||||
ScaleQ3 sc3;
|
||||
|
||||
const __m128i m32 = _mm_set1_epi8(-32);
|
||||
};
|
||||
|
||||
struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
|
||||
DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
template <typename Q8>
|
||||
inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs, j);
|
||||
}
|
||||
|
||||
Q4Bits bits;
|
||||
Scales8K s8k;
|
||||
};
|
||||
|
||||
struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
|
||||
DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
template <typename Q8>
|
||||
inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
hbits.load(x[i].qh);
|
||||
return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs, j);
|
||||
hbits.apply(bits, j == 0);
|
||||
}
|
||||
|
||||
Q4Bits bits;
|
||||
HighBit5 hbits;
|
||||
Scales8K s8k;
|
||||
};
|
||||
|
||||
struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
|
||||
DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
process_mins_and_scales_16(_mm_loadu_si128((const __m128i *)x[i].scales), q8, i, -32.f*d, accm, scales);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare64(x[i].ql, j);
|
||||
auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j);
|
||||
bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
|
||||
bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
|
||||
bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));
|
||||
bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh));
|
||||
}
|
||||
|
||||
Q4Bits bits;
|
||||
const __m256i mh = _mm256_set1_epi8(0x30);
|
||||
};
|
||||
|
||||
template <typename Dequantizer, int nrc_y>
|
||||
static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
const int nb = n / QK_K;
|
||||
|
||||
Q8<nrc_y> q8(info);
|
||||
|
||||
Dequantizer deq(vx, bx);
|
||||
|
||||
__m256 accd[nrc_y];
|
||||
__m256i scales[4];
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
|
||||
|
||||
deq.new_row(ix);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
auto all_scales = deq.new_block(i, q8, accd);
|
||||
|
||||
__m256i sumi[nrc_y];
|
||||
|
||||
for (int j = 0; j < QK_K/128; ++j) {
|
||||
|
||||
deq.prepare(i, j);
|
||||
|
||||
set_scales_8(all_scales, j, scales);
|
||||
|
||||
multiply_add(deq.bits, scales, j, i, q8, sumi);
|
||||
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
|
||||
accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, hsum_float_8(accd[iy]));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dequantizer, int nrc_y>
|
||||
static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
|
||||
Q8<nrc_y> q8(info);
|
||||
|
||||
__m256i all_scales[2];
|
||||
__m256i scales[4];
|
||||
__m256 accd[nrc_y];
|
||||
|
||||
Dequantizer deq(vx, bx);
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
deq.new_row(ix);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
deq.new_block(i, q8, accd, all_scales);
|
||||
|
||||
__m256i sumi[nrc_y];
|
||||
|
||||
for (int j = 0; j < QK_K/128; ++j) {
|
||||
deq.prepare(i, j);
|
||||
set_scales_16(all_scales[j], scales);
|
||||
multiply_add(deq.bits, scales, j, i, q8, sumi);
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, hsum_float_8(accd[iy]));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;
|
||||
funcs[1] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 2>;
|
||||
funcs[2] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 3>;
|
||||
funcs[3] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 4>;
|
||||
funcs[4] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 5>;
|
||||
funcs[5] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 6>;
|
||||
funcs[6] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 7>;
|
||||
funcs[7] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 8>;
|
||||
#else
|
||||
if constexpr (std::is_same_v<Dequantizer, DequantizerQ2K> ||
|
||||
std::is_same_v<Dequantizer, DequantizerQ3K> ||
|
||||
std::is_same_v<Dequantizer, DequantizerQ6K>) {
|
||||
funcs[0] = mul_mat_qY_K_q8_K_T<Dequantizer, 1>;
|
||||
funcs[1] = mul_mat_qY_K_q8_K_T<Dequantizer, 2>;
|
||||
funcs[2] = mul_mat_qY_K_q8_K_T<Dequantizer, 3>;
|
||||
funcs[3] = mul_mat_qY_K_q8_K_T<Dequantizer, 4>;
|
||||
funcs[4] = mul_mat_qY_K_q8_K_T<Dequantizer, 5>;
|
||||
funcs[5] = mul_mat_qY_K_q8_K_T<Dequantizer, 6>;
|
||||
funcs[6] = mul_mat_qY_K_q8_K_T<Dequantizer, 7>;
|
||||
funcs[7] = mul_mat_qY_K_q8_K_T<Dequantizer, 8>;
|
||||
} else {
|
||||
funcs[0] = mul_mat_qX_K_q8_K_T<Dequantizer, 1>;
|
||||
funcs[1] = mul_mat_qX_K_q8_K_T<Dequantizer, 2>;
|
||||
funcs[2] = mul_mat_qX_K_q8_K_T<Dequantizer, 3>;
|
||||
funcs[3] = mul_mat_qX_K_q8_K_T<Dequantizer, 4>;
|
||||
funcs[4] = mul_mat_qX_K_q8_K_T<Dequantizer, 5>;
|
||||
funcs[5] = mul_mat_qX_K_q8_K_T<Dequantizer, 6>;
|
||||
funcs[6] = mul_mat_qX_K_q8_K_T<Dequantizer, 7>;
|
||||
funcs[7] = mul_mat_qX_K_q8_K_T<Dequantizer, 8>;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels) {
|
||||
|
||||
if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (typeA) {
|
||||
case GGML_TYPE_Q2_K:
|
||||
set_functions<DequantizerQ2K>(kernels);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
set_functions<DequantizerQ3K>(kernels);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
set_functions<DequantizerQ4K>(kernels);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
set_functions<DequantizerQ5K>(kernels);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
set_functions<DequantizerQ6K>(kernels);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
11
ggml/src/iqk/iqk_gemm_kquants.h
Normal file
11
ggml/src/iqk/iqk_gemm_kquants.h
Normal file
@@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "iqk_common.h"
|
||||
|
||||
#ifdef IQK_IMPLEMENT
|
||||
|
||||
#include <array>
|
||||
|
||||
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels);
|
||||
|
||||
#endif
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "iqk_quantize.h"
|
||||
#include "iqk_flash_impl.h"
|
||||
#include "iqk_gemm_floats.h"
|
||||
#include "iqk_gemm_kquants.h"
|
||||
#include "iqk_gemm_legacy_quants.h"
|
||||
|
||||
#define GGML_COMMON_IMPL_C
|
||||
@@ -484,17 +485,6 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
|
||||
|
||||
namespace {
|
||||
|
||||
inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) {
|
||||
const uint16_t * scales = (const uint16_t *)scales8;
|
||||
const uint32_t a0 = scales[0] | (scales[1] << 16);
|
||||
const uint32_t a1 = scales[2] | (scales[3] << 16);
|
||||
const uint32_t a2 = scales[4] | (scales[5] << 16);
|
||||
aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030);
|
||||
aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030);
|
||||
aux32[2] = a1 & 0x3f3f3f3f;
|
||||
aux32[0] = a0 & 0x3f3f3f3f;
|
||||
}
|
||||
|
||||
#ifdef __AVX2__
|
||||
static const uint64_t iq1s_grid_us[2048] = {
|
||||
0x0000000000000000, 0x0000000000000002, 0x0000000000000101, 0x0000000000000200,
|
||||
@@ -1314,23 +1304,6 @@ const uint64_t keven_signs[128] = {
|
||||
|
||||
namespace {
|
||||
|
||||
struct Scales8KBase {
|
||||
template <typename Q8>
|
||||
inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {
|
||||
const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0]));
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
const __m256i q8s = q8.load_bsums(iy, i);
|
||||
const __m256i prod = _mm256_madd_epi16(mins, q8s);
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
|
||||
}
|
||||
}
|
||||
inline __m256i shuffle(__m128i mins) const {
|
||||
return MM256_SET_M128I(_mm_shuffle_epi8(mins, shuffles[1]), _mm_shuffle_epi8(mins, shuffles[0]));
|
||||
}
|
||||
const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100),
|
||||
_mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)};
|
||||
};
|
||||
|
||||
// Handles q4_K and q5_K scales/mins
|
||||
struct Scales8K {
|
||||
template <typename Q8>
|
||||
@@ -1639,21 +1612,6 @@ struct Q2Bits {
|
||||
BlockPermuter perm;
|
||||
};
|
||||
|
||||
struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
|
||||
DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
bits.prepare(x[i].qs);
|
||||
auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
|
||||
scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
|
||||
scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
|
||||
}
|
||||
|
||||
Q4Bits bits;
|
||||
Scales8K s8k;
|
||||
};
|
||||
|
||||
__m512i inline load_iq4nl_values_512() {
|
||||
auto val256 = load_iq4nl_values_256();
|
||||
return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
|
||||
@@ -1726,23 +1684,6 @@ struct HighBit3 {
|
||||
const __m512i mh = _mm512_set1_epi8(0x04);
|
||||
};
|
||||
|
||||
struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
|
||||
DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
bits.prepare(x[i].qs);
|
||||
hbits.apply(x[i].qh, bits);
|
||||
auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
|
||||
scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
|
||||
scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
|
||||
}
|
||||
|
||||
Q4Bits bits;
|
||||
HighBit5 hbits;
|
||||
Scales8K s8k;
|
||||
};
|
||||
|
||||
struct Scale16 {
|
||||
inline void make_scales(const __m128i& scales8, __m512i * scales) const {
|
||||
auto all_scales8 = MM256_SET_M128I(scales8, scales8);
|
||||
@@ -1763,74 +1704,6 @@ struct Scale16 {
|
||||
0x0d0d0d0d, 0x09090909, 0x0c0c0c0c, 0x08080808);
|
||||
};
|
||||
|
||||
struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
|
||||
DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
bits.prepare(x[i].qs);
|
||||
const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
|
||||
const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
|
||||
const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
|
||||
sc16.process_mins_and_scales(i, -GGML_FP16_TO_FP32(x[i].dmin), mins8, scales8, q8, accm, scales);
|
||||
}
|
||||
|
||||
Q2Bits bits;
|
||||
Scale16 sc16;
|
||||
const __m128i m4 = _mm_set1_epi8(0xf);
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
|
||||
DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
bits.prepare(x[i].qs);
|
||||
hbits.apply(x[i].hmask, bits);
|
||||
auto scales128 = sc3.make_scales((const uint16_t *)x[i].scales);
|
||||
sc16.process_mins_and_scales(i, -4.f*d, scales128, scales128, q8, accm, scales);
|
||||
}
|
||||
|
||||
Q2Bits bits;
|
||||
HighBit3 hbits;
|
||||
ScaleQ3 sc3;
|
||||
Scale16 sc16;
|
||||
const __m128i m4 = _mm_set1_epi8(0xf);
|
||||
const __m128i m32 = _mm_set1_epi8(-32);
|
||||
};
|
||||
|
||||
struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
|
||||
DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
bits.prepare64(x[i].ql);
|
||||
add_high_bits(x[i].qh, bits);
|
||||
auto scales128 = _mm_loadu_si128((const __m128i *)x[i].scales);
|
||||
sc16.process_mins_and_scales(i, -32.f*d, scales128, scales128, q8, accm, scales);
|
||||
}
|
||||
|
||||
inline void add_high_bits(const uint8_t * qh, Q4Bits& bits) const {
|
||||
auto hbits = _mm512_loadu_si512((const __m512i *)qh);
|
||||
auto tmp1 = _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh);
|
||||
auto tmp2 = _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh);
|
||||
bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));
|
||||
bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));
|
||||
tmp1 = _mm512_and_si512(hbits, mh);
|
||||
tmp2 = _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh);
|
||||
bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));
|
||||
bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));
|
||||
}
|
||||
|
||||
Q4Bits bits;
|
||||
HighBit3 hbits;
|
||||
Scale16 sc16;
|
||||
|
||||
const __m512i mh = _mm512_set1_epi8(0x30);
|
||||
|
||||
};
|
||||
|
||||
struct IQXKScales {
|
||||
IQXKScales(uint8_t shift, int8_t min_val) : eshift(_mm256_set1_epi16(shift)), min(_mm256_set1_epi16(min_val)) {}
|
||||
template <typename Q8>
|
||||
@@ -2670,21 +2543,6 @@ struct HighBit3 {
|
||||
__m256i hbits;
|
||||
};
|
||||
|
||||
struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
|
||||
DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
template <typename Q8>
|
||||
inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs, j);
|
||||
}
|
||||
|
||||
Q4Bits bits;
|
||||
Scales8K s8k;
|
||||
};
|
||||
|
||||
struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
|
||||
DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_256()) {}
|
||||
template <typename Q8>
|
||||
@@ -3111,24 +2969,6 @@ struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true>
|
||||
const __m128i shift = _mm_set_epi32(0, 0, 4, 0);
|
||||
};
|
||||
|
||||
struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
|
||||
DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
template <typename Q8>
|
||||
inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
hbits.load(x[i].qh);
|
||||
return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs, j);
|
||||
hbits.apply(bits, j == 0);
|
||||
}
|
||||
|
||||
Q4Bits bits;
|
||||
HighBit5 hbits;
|
||||
Scales8K s8k;
|
||||
};
|
||||
|
||||
template <typename Q8>
|
||||
inline void process_mins_and_scales_16(const __m128i& scales128, const Q8& q8, int i, float d,
|
||||
__m256 * accm, __m256i * scales) {
|
||||
@@ -3137,68 +2977,6 @@ inline void process_mins_and_scales_16(const __m128i& scales128, const Q8& q8, i
|
||||
prepare_scales_16(all_scales, scales);
|
||||
}
|
||||
|
||||
struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
|
||||
DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
hbits.load(x[i].hmask);
|
||||
process_mins_and_scales_16(sc3.make_scales((const uint16_t *)x[i].scales), q8, i, -4.f*d, accm, scales);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs, j);
|
||||
hbits.apply(bits, j == 0);
|
||||
}
|
||||
|
||||
Q2Bits bits;
|
||||
HighBit3 hbits;
|
||||
ScaleQ3 sc3;
|
||||
|
||||
const __m128i m32 = _mm_set1_epi8(-32);
|
||||
};
|
||||
|
||||
struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
|
||||
DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
|
||||
const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
|
||||
const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
|
||||
process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, -GGML_FP16_TO_FP32(x[i].dmin), accm);
|
||||
prepare_scales_16(_mm256_cvtepi8_epi16(scales8), scales);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs, j);
|
||||
}
|
||||
|
||||
Q2Bits bits;
|
||||
|
||||
const __m128i m4 = _mm_set1_epi8(0xf);
|
||||
};
|
||||
|
||||
struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
|
||||
DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
process_mins_and_scales_16(_mm_loadu_si128((const __m128i *)x[i].scales), q8, i, -32.f*d, accm, scales);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare64(x[i].ql, j);
|
||||
auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j);
|
||||
bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
|
||||
bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
|
||||
bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));
|
||||
bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh));
|
||||
}
|
||||
|
||||
Q4Bits bits;
|
||||
const __m256i mh = _mm256_set1_epi8(0x30);
|
||||
};
|
||||
|
||||
template <typename Dequantizer, int nrc_y>
|
||||
static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
@@ -8473,25 +8251,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
|
||||
|
||||
switch (typeA) {
|
||||
case GGML_TYPE_Q2_K:
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerQ2K>(mm);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerQ3K>(mm);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerQ4K>(mm);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerQ5K>(mm);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerQ6K>(mm);
|
||||
break;
|
||||
return ggml_type(typeB) == GGML_TYPE_Q8_K ? iqk_set_kernels_kquants(ne00, typeA, typeB, mm.funcs) : false;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerIQ4XS>(mm);
|
||||
|
||||
Reference in New Issue
Block a user