mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 15:14:10 +00:00
Some performance tweaks
This commit is contained in:
@@ -3492,6 +3492,8 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
|
||||
}
|
||||
}
|
||||
|
||||
#endif // Zen4 or vanilla AVX2
|
||||
|
||||
static inline uint32_t trellis_next(uint32_t& val) {
|
||||
constexpr uint32_t ka = 89226354;
|
||||
constexpr uint32_t kb = 64248484;
|
||||
@@ -3533,22 +3535,81 @@ static inline float trellis_gen(uint32_t& val, uint32_t* s) {
|
||||
return GGML_FP16_TO_FP32(h[0]) + GGML_FP16_TO_FP32(h[1]);
|
||||
}
|
||||
|
||||
struct Trellis1 {
|
||||
constexpr static uint32_t kmask = 0x8fff8fff;
|
||||
constexpr static uint32_t km32 = 0x3b603b60;
|
||||
constexpr static uint32_t ka = 89226354;
|
||||
constexpr static uint32_t kb = 64248484;
|
||||
constexpr static uint32_t ka1 = ka*ka;
|
||||
constexpr static uint32_t kb1 = kb*ka+kb;
|
||||
constexpr static uint32_t ka2 = ka1*ka;
|
||||
constexpr static uint32_t kb2 = kb1*ka+kb;
|
||||
constexpr static uint32_t ka3 = ka2*ka;
|
||||
constexpr static uint32_t kb3 = kb2*ka+kb;
|
||||
constexpr static uint32_t ka4 = ka3*ka;
|
||||
constexpr static uint32_t kb4 = kb3*ka+kb;
|
||||
constexpr static uint32_t ka5 = ka4*ka;
|
||||
constexpr static uint32_t kb5 = kb4*ka+kb;
|
||||
constexpr static uint32_t ka6 = ka5*ka;
|
||||
constexpr static uint32_t kb6 = kb5*ka+kb;
|
||||
constexpr static uint32_t ka7 = ka6*ka;
|
||||
constexpr static uint32_t kb7 = kb6*ka+kb;
|
||||
const __m256i mka = _mm256_setr_epi32(ka, ka1, ka2, ka3, ka4, ka5, ka6, ka7);
|
||||
const __m256i mkb = _mm256_setr_epi32(kb, kb1, kb2, kb3, kb4, kb5, kb6, kb7);
|
||||
const __m256i mask1 = _mm256_set1_epi32(kmask);
|
||||
const __m256i mask2 = _mm256_set1_epi32(km32);
|
||||
|
||||
inline __m256i next8(uint32_t val) const {
|
||||
auto mval = _mm256_set1_epi32(val);
|
||||
auto mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb);
|
||||
return _mm256_and_si256(mres, mask1) ^ mask2;
|
||||
}
|
||||
};
|
||||
|
||||
//static inline __m256 trellis_gen8(uint32_t val) {
|
||||
// __m256i i8 = trellis_next8(val);
|
||||
// // split upper and lower bits of each 32-bit lane into two 8xfloat16 `hlo`, `hhi`
|
||||
// __m256i low_16_bits_mask = _mm256_set1_epi32(0x0000FFFF);
|
||||
// __m256i lower_halves_lanes32 = _mm256_and_si256(i8, low_16_bits_mask);
|
||||
// __m256i upper_halves_lanes32 = _mm256_srli_epi32(i8, 16);
|
||||
// __m128i lo0123 = _mm256_extracti128_si256(lower_halves_lanes32, 0); // Extracts [00L0, 00L1, 00L2, 00L3]
|
||||
// __m128i lo4567 = _mm256_extracti128_si256(lower_halves_lanes32, 1); // Extracts [00L4, 00L5, 00L6, 00L7]
|
||||
// __m128i hlo = _mm_packus_epi32(lo0123, lo4567);
|
||||
// __m128i hi0123 = _mm256_extracti128_si256(upper_halves_lanes32, 0); // Extracts [00H0, 00H1, 00H2, 00H3]
|
||||
// __m128i hi4567 = _mm256_extracti128_si256(upper_halves_lanes32, 1); // Extracts [00H4, 00H5, 00H6, 00H7]
|
||||
// __m128i hhi = _mm_packus_epi32(hi0123, hi4567);
|
||||
// // widen both to 8xfloat32 and sum
|
||||
// __m256 f1 = _mm256_cvtph_ps(hlo);
|
||||
// __m256 f2 = _mm256_cvtph_ps(hhi);
|
||||
// return _mm256_add_ps(f1, f2);
|
||||
//}
|
||||
|
||||
static inline __m256 trellis_gen8(__m256i i8) {
|
||||
// split upper and lower bits of each 32-bit lane into two 8xfloat16 `hlo`, `hhi`
|
||||
__m256i low_16_bits_mask = _mm256_set1_epi32(0x0000FFFF);
|
||||
__m256i lower_halves_lanes32 = _mm256_and_si256(i8, low_16_bits_mask);
|
||||
__m256i upper_halves_lanes32 = _mm256_srli_epi32(i8, 16);
|
||||
// 00L0, 00L1, 00L2, 00L3, 00H0, 00H1, 00H2, 00H3, 00L4, 00L5, 00L6, 00L7, 00H4, 00H5, 00H6, 00H7
|
||||
auto iv = _mm256_packus_epi32(lower_halves_lanes32, upper_halves_lanes32);
|
||||
// 00L0, 00L1, 00L2, 00L3, 00L4, 00L5, 00L6, 00L7, 00H0, 00H1, 00H2, 00H3, 00H4, 00H5, 00H6, 00H7
|
||||
iv = _mm256_permute4x64_epi64(iv, 0xd8);
|
||||
auto fv1 = _mm256_cvtph_ps(_mm256_extracti128_si256(iv, 0));
|
||||
auto fv2 = _mm256_cvtph_ps(_mm256_extracti128_si256(iv, 1));
|
||||
return _mm256_add_ps(fv1, fv2);
|
||||
}
|
||||
static inline __m256 trellis_gen8(uint32_t val) {
|
||||
__m256i i8 = trellis_next8(val);
|
||||
// split upper and lower bits of each 32-bit lane into two 8xfloat16 `hlo`, `hhi`
|
||||
__m256i low_16_bits_mask = _mm256_set1_epi32(0x0000FFFF);
|
||||
__m256i lower_halves_lanes32 = _mm256_and_si256(i8, low_16_bits_mask);
|
||||
__m256i upper_halves_lanes32 = _mm256_srli_epi32(i8, 16);
|
||||
__m128i lo0123 = _mm256_extracti128_si256(lower_halves_lanes32, 0); // Extracts [00L0, 00L1, 00L2, 00L3]
|
||||
__m128i lo4567 = _mm256_extracti128_si256(lower_halves_lanes32, 1); // Extracts [00L4, 00L5, 00L6, 00L7]
|
||||
__m128i hlo = _mm_packus_epi32(lo0123, lo4567);
|
||||
__m128i hi0123 = _mm256_extracti128_si256(upper_halves_lanes32, 0); // Extracts [00H0, 00H1, 00H2, 00H3]
|
||||
__m128i hi4567 = _mm256_extracti128_si256(upper_halves_lanes32, 1); // Extracts [00H4, 00H5, 00H6, 00H7]
|
||||
__m128i hhi = _mm_packus_epi32(hi0123, hi4567);
|
||||
// widen both to 8xfloat32 and sum
|
||||
__m256 f1 = _mm256_cvtph_ps(hlo);
|
||||
__m256 f2 = _mm256_cvtph_ps(hhi);
|
||||
return _mm256_add_ps(f1, f2);
|
||||
// 00L0, 00L1, 00L2, 00L3, 00H0, 00H1, 00H2, 00H3, 00L4, 00L5, 00L6, 00L7, 00H4, 00H5, 00H6, 00H7
|
||||
auto iv = _mm256_packus_epi32(lower_halves_lanes32, upper_halves_lanes32);
|
||||
// 00L0, 00L1, 00L2, 00L3, 00L4, 00L5, 00L6, 00L7, 00H0, 00H1, 00H2, 00H3, 00H4, 00H5, 00H6, 00H7
|
||||
iv = _mm256_permute4x64_epi64(iv, 0xd8);
|
||||
auto fv1 = _mm256_cvtph_ps(_mm256_extracti128_si256(iv, 0));
|
||||
auto fv2 = _mm256_cvtph_ps(_mm256_extracti128_si256(iv, 1));
|
||||
return _mm256_add_ps(fv1, fv2);
|
||||
}
|
||||
|
||||
static inline __m256i trellis_next8(uint32_t val1, uint32_t val2) {
|
||||
@@ -3592,7 +3653,11 @@ static void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
|
||||
assert(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
|
||||
__m256 accd[nrc_y];
|
||||
Trellis1 trellis;
|
||||
|
||||
constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;
|
||||
|
||||
__m256 accd[k_acc];
|
||||
const float * y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
|
||||
|
||||
@@ -3601,35 +3666,61 @@ static void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
|
||||
const float d = *dptr * 31.75f * 1.05f;
|
||||
const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
|
||||
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const uint16_t * ql = (const uint16_t *)x[i].ql;
|
||||
for (int j = 0; j < 128; j+=8) {
|
||||
uint32_t val1 = ql[j/8] + 4096;
|
||||
uint32_t val2 = ql[j/8+16] + 4096;
|
||||
const float x_scale1 = iq4k_values[x[i].scales[j/32] & 0xf];
|
||||
const float x_scale2 = iq4k_values[x[i].scales[j/32] >> 4];
|
||||
const __m256 x_val1 = trellis_gen8(val1);
|
||||
const __m256 x_val2 = trellis_gen8(val2);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
accd[iy] = _mm256_fmadd_ps(
|
||||
_mm256_load_ps(y[iy] + i*QK_K+j),
|
||||
_mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1),
|
||||
accd[iy]
|
||||
);
|
||||
accd[iy] = _mm256_fmadd_ps(
|
||||
_mm256_load_ps(y[iy] + i*QK_K+j+128),
|
||||
_mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2),
|
||||
accd[iy]
|
||||
);
|
||||
for (int ib = 0; ib < QK_K/64; ++ib) {
|
||||
auto scale1 = _mm256_set1_ps(iq4k_values[x[i].scales[ib] & 0xf]);
|
||||
auto scale2 = _mm256_set1_ps(iq4k_values[x[i].scales[ib] >> 4]);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
uint32_t val1 = ql[4*ib+j+ 0] + 4096;
|
||||
uint32_t val2 = ql[4*ib+j+16] + 4096;
|
||||
//const __m256 x_val1 = _mm256_mul_ps(scale1, trellis_gen8(val1));
|
||||
//const __m256 x_val2 = _mm256_mul_ps(scale2, trellis_gen8(val2));
|
||||
const __m256 x_val1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(val1)));
|
||||
const __m256 x_val2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(val2)));
|
||||
if constexpr (nrc_y == 1) {
|
||||
accd[0] = _mm256_fmadd_ps(_mm256_load_ps(y[0] + i*QK_K + 32*ib + 8*j ), x_val1, accd[0]);
|
||||
accd[1] = _mm256_fmadd_ps(_mm256_load_ps(y[0] + i*QK_K + 32*ib + 8*j + 128), x_val2, accd[1]);
|
||||
} else {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K + 32*ib + 8*j ), x_val1, accd[iy]);
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K + 32*ib + 8*j + 128), x_val2, accd[iy]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
//for (int j = 0; j < 128; j+=8) {
|
||||
// uint32_t val1 = ql[j/8] + 4096;
|
||||
// uint32_t val2 = ql[j/8+16] + 4096;
|
||||
// const float x_scale1 = iq4k_values[x[i].scales[j/32] & 0xf];
|
||||
// const float x_scale2 = iq4k_values[x[i].scales[j/32] >> 4];
|
||||
// const __m256 x_val1 = trellis_gen8(val1);
|
||||
// const __m256 x_val2 = trellis_gen8(val2);
|
||||
// for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
// accd[iy] = _mm256_fmadd_ps(
|
||||
// _mm256_load_ps(y[iy] + i*QK_K+j),
|
||||
// _mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1),
|
||||
// accd[iy]
|
||||
// );
|
||||
// accd[iy] = _mm256_fmadd_ps(
|
||||
// _mm256_load_ps(y[iy] + i*QK_K+j+128),
|
||||
// _mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2),
|
||||
// accd[iy]
|
||||
// );
|
||||
// }
|
||||
//}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
__m256 res = _mm256_mul_ps(_mm256_set1_ps(d), accd[iy]);
|
||||
info.store(ix, iy, hsum_float_8(res));
|
||||
if constexpr (nrc_y == 1) {
|
||||
__m256 res = _mm256_mul_ps(_mm256_set1_ps(d), _mm256_add_ps(accd[0], accd[1]));
|
||||
info.store(ix, 0, hsum_float_8(res));
|
||||
} else {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
__m256 res = _mm256_mul_ps(_mm256_set1_ps(d), accd[iy]);
|
||||
info.store(ix, iy, hsum_float_8(res));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3787,8 +3878,6 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
|
||||
}
|
||||
}
|
||||
|
||||
#endif // Zen4 or vanilla AVX2
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
if (nrc_x%4) {
|
||||
|
||||
Reference in New Issue
Block a user