mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
Playing with horizontal sums
This commit is contained in:
@@ -260,6 +260,19 @@ inline float hmax_float_8(__m256 x) {
|
||||
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4));
|
||||
return _mm_cvtss_f32(max4);
|
||||
}
|
||||
IQK_ALWAYS_INLINE __m256 hsum_float_8x8(__m256 * accm) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)),
|
||||
_mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1)));
|
||||
}
|
||||
for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
|
||||
return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
|
||||
}
|
||||
IQK_ALWAYS_INLINE void store_8(int ix, __m256 * accm, const DataInfo& info) {
|
||||
union { __m256 vec; float val[8]; } h;
|
||||
h.vec = hsum_float_8x8(accm);
|
||||
for (int iy = 0; iy < 8; ++iy) info.store(ix, iy, h.val[iy]);
|
||||
}
|
||||
|
||||
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
||||
|
||||
@@ -1128,9 +1141,17 @@ static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const Da
|
||||
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
|
||||
info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
|
||||
if constexpr (nrc_y == 8) {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
accm[iy] = _mm256_add_ps(accm[iy], _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)));
|
||||
}
|
||||
store_8(ix, accm, info);
|
||||
}
|
||||
else {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
|
||||
info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1230,9 +1251,18 @@ static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const D
|
||||
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
|
||||
info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
|
||||
if constexpr (nrc_y == 8) {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
|
||||
accm[iy] = _mm256_add_ps(accm[iy], sum256);
|
||||
}
|
||||
store_8(ix, accm, info);
|
||||
}
|
||||
else {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
|
||||
info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1833,8 +1863,12 @@ IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const Da
|
||||
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, hsum_float_8(accd[iy]));
|
||||
if constexpr (nrc_y == 8) {
|
||||
store_8(ix, accd, info);
|
||||
} else {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, hsum_float_8(accd[iy]));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1877,10 +1911,13 @@ static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
|
||||
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, hsum_float_8(accd[iy]));
|
||||
if constexpr (nrc_y == 8) {
|
||||
store_8(ix, accd, info);
|
||||
} else {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, hsum_float_8(accd[iy]));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1926,8 +1963,12 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
|
||||
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, hsum_float_8(accd[iy]));
|
||||
if constexpr (nrc_y == 8) {
|
||||
store_8(ix, accd, info);
|
||||
} else {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, hsum_float_8(accd[iy]));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -2094,8 +2135,12 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, hsum_float_8(accd[iy]));
|
||||
if constexpr (nrc_y == 8) {
|
||||
store_8(ix, accd, info);
|
||||
} else {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, hsum_float_8(accd[iy]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2999,10 +3044,17 @@ struct ScaleHelperQ_1 {
|
||||
}
|
||||
};
|
||||
|
||||
struct MinusType0 {
|
||||
template <int nrc_y> struct MinusType0 {
|
||||
inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); }
|
||||
inline float compute(float d, int) const { return d; }
|
||||
inline float result(__m256 acc, int) const { return hsum_float_8(acc); }
|
||||
//inline void store(int ix, __m256 * acc, const DataInfo& info) {
|
||||
// if constexpr (nrc_y == 8) {
|
||||
// store_8(ix, acc, info);
|
||||
// } else {
|
||||
// for (int iy = 0; iy < nrc_y; ++iy) info.store(ix, iy, hsum_float_8(acc[iy]));
|
||||
// }
|
||||
//}
|
||||
};
|
||||
|
||||
template <int nrc_y> struct MinusType1 {
|
||||
@@ -3022,6 +3074,23 @@ template <int nrc_y> struct MinusType1 {
|
||||
const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
|
||||
return hsum_float_4(_mm_add_ps(sum, accm[iy]));
|
||||
}
|
||||
//inline void store(int ix, const __m256 * acc, const DataInfo& info) {
|
||||
// for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
// accm[iy] = _mm_add_ps(accm[iy], _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)));
|
||||
// }
|
||||
// if constexpr (nrc_y >= 4) {
|
||||
// union { __m128 vec; float val[4]; } h;
|
||||
// for (int i = 0; i < nrc_y/4; ++i) {
|
||||
// accm[4*i+0] = _mm_add_ps(_mm_unpacklo_ps(accm[4*i+0], accm[4*i+2]), _mm_unpackhi_ps(accm[4*i+0], accm[4*i+2]));
|
||||
// accm[4*i+1] = _mm_add_ps(_mm_unpacklo_ps(accm[4*i+1], accm[4*i+3]), _mm_unpackhi_ps(accm[4*i+1], accm[4*i+3]));
|
||||
// h.vec = _mm_add_ps(_mm_unpacklo_ps(accm[4*i+0], accm[4*i+1]), _mm_unpackhi_ps(accm[4*i+0], accm[4*i+1]));
|
||||
// for (int j = 0; j < 4; ++j) info.store(ix, 4*i+j, h.val[j]);
|
||||
// }
|
||||
// for (int iy = 4*(nrc_y/4); iy < nrc_y; ++iy) info.store(ix, iy, hsum_float_4(accm[iy]));
|
||||
// } else {
|
||||
// for (int iy = 0; iy < nrc_y; ++iy) info.store(ix, iy, hsum_float_4(accm[iy]));
|
||||
// }
|
||||
//}
|
||||
};
|
||||
|
||||
template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {
|
||||
@@ -3054,6 +3123,7 @@ template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {
|
||||
}
|
||||
}
|
||||
}
|
||||
//accm.store(ix, acc, info);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, accm.result(acc[iy], iy));
|
||||
}
|
||||
@@ -3061,7 +3131,7 @@ template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {
|
||||
};
|
||||
|
||||
template <int nrc_y, bool is_multiple_of_4>
|
||||
using AccumType0 = AccumT<MinusType0, nrc_y, is_multiple_of_4>;
|
||||
using AccumType0 = AccumT<MinusType0<nrc_y>, nrc_y, is_multiple_of_4>;
|
||||
|
||||
template <int nrc_y, bool is_multiple_of_4>
|
||||
using AccumType1 = AccumT<MinusType1<nrc_y>, nrc_y, is_multiple_of_4>;
|
||||
|
||||
Reference in New Issue
Block a user