mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-28 18:32:04 +00:00
Refactor iqk: fix AVX2
This commit is contained in:
@@ -1045,10 +1045,12 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI
|
||||
|
||||
} // namespace
|
||||
|
||||
bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
|
||||
bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& funcs, mul_mat_t& func16) {
|
||||
|
||||
auto expected_typeB = GGML_TYPE_Q8_K128;
|
||||
|
||||
func16 = nullptr;
|
||||
|
||||
switch (typeA) {
|
||||
case GGML_TYPE_IQ1_S:
|
||||
if (ne00%QK_K != 0) return false;
|
||||
@@ -1076,7 +1078,7 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
|
||||
funcs[6] = mul_mat_iq1_s_r4_q8_1<7>;
|
||||
funcs[7] = mul_mat_iq1_s_r4_q8_1<8>;
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
mm.func16 = mul_mat_iq1_s_r4_q8_1<16>;
|
||||
func16 = mul_mat_iq1_s_r4_q8_1<16>;
|
||||
#endif
|
||||
break;
|
||||
case GGML_TYPE_IQ1_M_R4:
|
||||
@@ -1090,7 +1092,7 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
|
||||
funcs[6] = mul_mat_iq1_m_r4_q8_0<7>;
|
||||
funcs[7] = mul_mat_iq1_m_r4_q8_0<8>;
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
mm.func16 = mul_mat_iq1_m_r4_q8_0<16>;
|
||||
func16 = mul_mat_iq1_m_r4_q8_0<16>;
|
||||
#endif
|
||||
break;
|
||||
default:
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
#include <array>
|
||||
|
||||
bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels);
|
||||
bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -668,6 +668,106 @@ static void mul_mat_iqX_k_q8_K_AVX512_new(int n, const void * vx, size_t bx, con
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Q8>
|
||||
inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) {
|
||||
const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0));
|
||||
const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[1], q8.load_quants64(iy, i, 1));
|
||||
const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[2], q8.load_quants64(iy, i, 2));
|
||||
const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[3], q8.load_quants64(iy, i, 3));
|
||||
auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
|
||||
sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
|
||||
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
|
||||
}
|
||||
|
||||
template <typename Dequantizer>
|
||||
static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
const int nb = n / QK_K;
|
||||
|
||||
constexpr int k_nx = 2;
|
||||
|
||||
Q8<1> q8(info);
|
||||
|
||||
Dequantizer deq1(vx, bx);
|
||||
Dequantizer deq2(vx, bx);
|
||||
|
||||
Dequantizer * deq[k_nx];
|
||||
deq[0] = &deq1;
|
||||
deq[1] = &deq2;
|
||||
|
||||
__m512i scales[2*k_nx];
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
auto accd = _mm512_setzero_ps();
|
||||
auto accm = _mm256_setzero_ps();
|
||||
|
||||
for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_row(ix);
|
||||
|
||||
for (int i = 0; i < nb/k_nx; ++i) {
|
||||
|
||||
for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx);
|
||||
|
||||
for (int kx = 0; kx < k_nx; ++kx) {
|
||||
compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);
|
||||
}
|
||||
|
||||
}
|
||||
if (2*(nb/2) < nb) {
|
||||
int i0 = 2*(nb/2);
|
||||
deq[0]->new_block(i0, q8, &accm, scales);
|
||||
compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);
|
||||
}
|
||||
|
||||
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
|
||||
info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dequantizer, int nrc_y>
|
||||
static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
const int nb = n / QK_K;
|
||||
|
||||
Q8<nrc_y> q8(info);
|
||||
|
||||
Dequantizer deq(vx, bx);
|
||||
|
||||
__m256 accm[nrc_y];
|
||||
__m512 accd[nrc_y];
|
||||
__m512i scales[2];
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();
|
||||
|
||||
deq.new_row(ix);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
deq.new_block(i, q8, accm, scales);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));
|
||||
const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));
|
||||
const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));
|
||||
const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3));
|
||||
auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
|
||||
sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
|
||||
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
|
||||
info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
inline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) {
|
||||
@@ -1227,6 +1327,15 @@ template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX
|
||||
funcs[5] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 6>;
|
||||
funcs[6] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 7>;
|
||||
funcs[7] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 8>;
|
||||
} else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2K>) {
|
||||
funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;
|
||||
funcs[1] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 2>;
|
||||
funcs[2] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 3>;
|
||||
funcs[3] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 4>;
|
||||
funcs[4] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 5>;
|
||||
funcs[5] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 6>;
|
||||
funcs[6] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 7>;
|
||||
funcs[7] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 8>;
|
||||
} else {
|
||||
funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>;
|
||||
funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>;
|
||||
@@ -1274,27 +1383,27 @@ bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_m
|
||||
}
|
||||
|
||||
switch (typeA) {
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
set_functions<DequantizerIQ4KS>(kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ5_KS:
|
||||
set_functions<DequantizerIQ5KS>(kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
set_functions<DequantizerIQ4KSS>(kernels);
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
set_functions<DequantizerIQ2KS>(kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_K:
|
||||
set_functions<DequantizerIQ2K>(kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
set_functions<DequantizerIQ2KS>(kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_K:
|
||||
set_functions<DequantizerIQ3K>(kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
set_functions<DequantizerIQ4KSS>(kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
set_functions<DequantizerIQ4KS>(kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_K:
|
||||
set_functions<DequantizerIQ4K>(kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ5_KS:
|
||||
set_functions<DequantizerIQ5KS>(kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ5_K:
|
||||
set_functions<DequantizerIQ5K>(kernels);
|
||||
break;
|
||||
|
||||
@@ -5824,7 +5824,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_S_R4:
|
||||
case GGML_TYPE_IQ1_M_R4:
|
||||
return iqk_set_kernels_1bit(ne00, typeA, typeB, mm.funcs);
|
||||
return iqk_set_kernels_1bit(ne00, typeA, typeB, mm.funcs, mm.func16);
|
||||
|
||||
default:
|
||||
return false;
|
||||
|
||||
Reference in New Issue
Block a user