mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
Refactor iqk: fix AVX2
This commit is contained in:
@@ -291,6 +291,69 @@ struct BaseDequantizer {
|
||||
float d;
|
||||
};
|
||||
|
||||
template <typename Q8, typename Bits>
|
||||
static inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
|
||||
if (j == 0) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));
|
||||
}
|
||||
#else
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));
|
||||
const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));
|
||||
const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));
|
||||
const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));
|
||||
sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4));
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));
|
||||
}
|
||||
#else
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));
|
||||
const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));
|
||||
const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));
|
||||
const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));
|
||||
sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));
|
||||
sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Q8, typename Bits>
|
||||
static inline void multiply_add_avx2(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
|
||||
__m256i p[4];
|
||||
if (j == 0) {
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]);
|
||||
p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, k), bits.values[k])));
|
||||
}
|
||||
sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p[0], p[1]), _mm256_add_epi32(p[2], p[3]));
|
||||
}
|
||||
} else {
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]);
|
||||
p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, 4+k), bits.values[k])));
|
||||
}
|
||||
sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[0], p[2]));
|
||||
sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[1], p[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1438,69 +1438,6 @@ inline void set_scales_16(const __m256i& all_scales, __m256i * scales) {
|
||||
scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3));
|
||||
}
|
||||
|
||||
template <typename Q8, typename Bits>
|
||||
inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
|
||||
if (j == 0) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));
|
||||
}
|
||||
#else
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));
|
||||
const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));
|
||||
const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));
|
||||
const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));
|
||||
sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4));
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));
|
||||
sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));
|
||||
}
|
||||
#else
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));
|
||||
const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));
|
||||
const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));
|
||||
const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));
|
||||
sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));
|
||||
sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Q8, typename Bits>
|
||||
inline void multiply_add_avx2(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
|
||||
__m256i p[4];
|
||||
if (j == 0) {
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]);
|
||||
p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, k), bits.values[k])));
|
||||
}
|
||||
sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p[0], p[1]), _mm256_add_epi32(p[2], p[3]));
|
||||
}
|
||||
} else {
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]);
|
||||
p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, 4+k), bits.values[k])));
|
||||
}
|
||||
sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[0], p[2]));
|
||||
sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[1], p[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct SignHelper {
|
||||
inline __m256i make_signs(uint32_t sign_bits) const {
|
||||
auto aux256 = _mm256_set1_epi32(sign_bits);
|
||||
@@ -8209,10 +8146,7 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
|
||||
m.funcs[7] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 8>;
|
||||
}
|
||||
#else
|
||||
if constexpr (std::is_same_v<Dequantizer, DequantizerQ2K> ||
|
||||
std::is_same_v<Dequantizer, DequantizerQ3K> ||
|
||||
std::is_same_v<Dequantizer, DequantizerQ6K> ||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ2K>||
|
||||
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2K>||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ3K>||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ4K>||
|
||||
std::is_same_v<Dequantizer, DequantizerIQ5K>||
|
||||
|
||||
Reference in New Issue
Block a user