Some performance tweaks

This commit is contained in:
Iwan Kawrakow
2025-05-21 13:39:45 +03:00
parent 60a948bf44
commit 38692aaab4

View File

@@ -3492,6 +3492,8 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
}
}
#endif // Zen4 or vanilla AVX2
static inline uint32_t trellis_next(uint32_t& val) {
constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
@@ -3533,22 +3535,81 @@ static inline float trellis_gen(uint32_t& val, uint32_t* s) {
return GGML_FP16_TO_FP32(h[0]) + GGML_FP16_TO_FP32(h[1]);
}
struct Trellis1 {
constexpr static uint32_t kmask = 0x8fff8fff;
constexpr static uint32_t km32 = 0x3b603b60;
constexpr static uint32_t ka = 89226354;
constexpr static uint32_t kb = 64248484;
constexpr static uint32_t ka1 = ka*ka;
constexpr static uint32_t kb1 = kb*ka+kb;
constexpr static uint32_t ka2 = ka1*ka;
constexpr static uint32_t kb2 = kb1*ka+kb;
constexpr static uint32_t ka3 = ka2*ka;
constexpr static uint32_t kb3 = kb2*ka+kb;
constexpr static uint32_t ka4 = ka3*ka;
constexpr static uint32_t kb4 = kb3*ka+kb;
constexpr static uint32_t ka5 = ka4*ka;
constexpr static uint32_t kb5 = kb4*ka+kb;
constexpr static uint32_t ka6 = ka5*ka;
constexpr static uint32_t kb6 = kb5*ka+kb;
constexpr static uint32_t ka7 = ka6*ka;
constexpr static uint32_t kb7 = kb6*ka+kb;
const __m256i mka = _mm256_setr_epi32(ka, ka1, ka2, ka3, ka4, ka5, ka6, ka7);
const __m256i mkb = _mm256_setr_epi32(kb, kb1, kb2, kb3, kb4, kb5, kb6, kb7);
const __m256i mask1 = _mm256_set1_epi32(kmask);
const __m256i mask2 = _mm256_set1_epi32(km32);
inline __m256i next8(uint32_t val) const {
auto mval = _mm256_set1_epi32(val);
auto mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb);
return _mm256_and_si256(mres, mask1) ^ mask2;
}
};
//static inline __m256 trellis_gen8(uint32_t val) {
// __m256i i8 = trellis_next8(val);
// // split upper and lower bits of each 32-bit lane into two 8xfloat16 `hlo`, `hhi`
// __m256i low_16_bits_mask = _mm256_set1_epi32(0x0000FFFF);
// __m256i lower_halves_lanes32 = _mm256_and_si256(i8, low_16_bits_mask);
// __m256i upper_halves_lanes32 = _mm256_srli_epi32(i8, 16);
// __m128i lo0123 = _mm256_extracti128_si256(lower_halves_lanes32, 0); // Extracts [00L0, 00L1, 00L2, 00L3]
// __m128i lo4567 = _mm256_extracti128_si256(lower_halves_lanes32, 1); // Extracts [00L4, 00L5, 00L6, 00L7]
// __m128i hlo = _mm_packus_epi32(lo0123, lo4567);
// __m128i hi0123 = _mm256_extracti128_si256(upper_halves_lanes32, 0); // Extracts [00H0, 00H1, 00H2, 00H3]
// __m128i hi4567 = _mm256_extracti128_si256(upper_halves_lanes32, 1); // Extracts [00H4, 00H5, 00H6, 00H7]
// __m128i hhi = _mm_packus_epi32(hi0123, hi4567);
// // widen both to 8xfloat32 and sum
// __m256 f1 = _mm256_cvtph_ps(hlo);
// __m256 f2 = _mm256_cvtph_ps(hhi);
// return _mm256_add_ps(f1, f2);
//}
static inline __m256 trellis_gen8(__m256i i8) {
// split upper and lower bits of each 32-bit lane into two 8xfloat16 `hlo`, `hhi`
__m256i low_16_bits_mask = _mm256_set1_epi32(0x0000FFFF);
__m256i lower_halves_lanes32 = _mm256_and_si256(i8, low_16_bits_mask);
__m256i upper_halves_lanes32 = _mm256_srli_epi32(i8, 16);
// 00L0, 00L1, 00L2, 00L3, 00H0, 00H1, 00H2, 00H3, 00L4, 00L5, 00L6, 00L7, 00H4, 00H5, 00H6, 00H7
auto iv = _mm256_packus_epi32(lower_halves_lanes32, upper_halves_lanes32);
// 00L0, 00L1, 00L2, 00L3, 00L4, 00L5, 00L6, 00L7, 00H0, 00H1, 00H2, 00H3, 00H4, 00H5, 00H6, 00H7
iv = _mm256_permute4x64_epi64(iv, 0xd8);
auto fv1 = _mm256_cvtph_ps(_mm256_extracti128_si256(iv, 0));
auto fv2 = _mm256_cvtph_ps(_mm256_extracti128_si256(iv, 1));
return _mm256_add_ps(fv1, fv2);
}
static inline __m256 trellis_gen8(uint32_t val) {
__m256i i8 = trellis_next8(val);
// split upper and lower bits of each 32-bit lane into two 8xfloat16 `hlo`, `hhi`
__m256i low_16_bits_mask = _mm256_set1_epi32(0x0000FFFF);
__m256i lower_halves_lanes32 = _mm256_and_si256(i8, low_16_bits_mask);
__m256i upper_halves_lanes32 = _mm256_srli_epi32(i8, 16);
__m128i lo0123 = _mm256_extracti128_si256(lower_halves_lanes32, 0); // Extracts [00L0, 00L1, 00L2, 00L3]
__m128i lo4567 = _mm256_extracti128_si256(lower_halves_lanes32, 1); // Extracts [00L4, 00L5, 00L6, 00L7]
__m128i hlo = _mm_packus_epi32(lo0123, lo4567);
__m128i hi0123 = _mm256_extracti128_si256(upper_halves_lanes32, 0); // Extracts [00H0, 00H1, 00H2, 00H3]
__m128i hi4567 = _mm256_extracti128_si256(upper_halves_lanes32, 1); // Extracts [00H4, 00H5, 00H6, 00H7]
__m128i hhi = _mm_packus_epi32(hi0123, hi4567);
// widen both to 8xfloat32 and sum
__m256 f1 = _mm256_cvtph_ps(hlo);
__m256 f2 = _mm256_cvtph_ps(hhi);
return _mm256_add_ps(f1, f2);
// 00L0, 00L1, 00L2, 00L3, 00H0, 00H1, 00H2, 00H3, 00L4, 00L5, 00L6, 00L7, 00H4, 00H5, 00H6, 00H7
auto iv = _mm256_packus_epi32(lower_halves_lanes32, upper_halves_lanes32);
// 00L0, 00L1, 00L2, 00L3, 00L4, 00L5, 00L6, 00L7, 00H0, 00H1, 00H2, 00H3, 00H4, 00H5, 00H6, 00H7
iv = _mm256_permute4x64_epi64(iv, 0xd8);
auto fv1 = _mm256_cvtph_ps(_mm256_extracti128_si256(iv, 0));
auto fv2 = _mm256_cvtph_ps(_mm256_extracti128_si256(iv, 1));
return _mm256_add_ps(fv1, fv2);
}
static inline __m256i trellis_next8(uint32_t val1, uint32_t val2) {
@@ -3592,7 +3653,11 @@ static void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
assert(n%QK_K == 0);
const int nb = n/QK_K;
__m256 accd[nrc_y];
Trellis1 trellis;
constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;
__m256 accd[k_acc];
const float * y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
@@ -3601,35 +3666,61 @@ static void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
const float d = *dptr * 31.75f * 1.05f;
const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1);
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const uint16_t * ql = (const uint16_t *)x[i].ql;
for (int j = 0; j < 128; j+=8) {
uint32_t val1 = ql[j/8] + 4096;
uint32_t val2 = ql[j/8+16] + 4096;
const float x_scale1 = iq4k_values[x[i].scales[j/32] & 0xf];
const float x_scale2 = iq4k_values[x[i].scales[j/32] >> 4];
const __m256 x_val1 = trellis_gen8(val1);
const __m256 x_val2 = trellis_gen8(val2);
for (int iy = 0; iy < nrc_y; ++iy) {
accd[iy] = _mm256_fmadd_ps(
_mm256_load_ps(y[iy] + i*QK_K+j),
_mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1),
accd[iy]
);
accd[iy] = _mm256_fmadd_ps(
_mm256_load_ps(y[iy] + i*QK_K+j+128),
_mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2),
accd[iy]
);
for (int ib = 0; ib < QK_K/64; ++ib) {
auto scale1 = _mm256_set1_ps(iq4k_values[x[i].scales[ib] & 0xf]);
auto scale2 = _mm256_set1_ps(iq4k_values[x[i].scales[ib] >> 4]);
for (int j = 0; j < 4; ++j) {
uint32_t val1 = ql[4*ib+j+ 0] + 4096;
uint32_t val2 = ql[4*ib+j+16] + 4096;
//const __m256 x_val1 = _mm256_mul_ps(scale1, trellis_gen8(val1));
//const __m256 x_val2 = _mm256_mul_ps(scale2, trellis_gen8(val2));
const __m256 x_val1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(val1)));
const __m256 x_val2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(val2)));
if constexpr (nrc_y == 1) {
accd[0] = _mm256_fmadd_ps(_mm256_load_ps(y[0] + i*QK_K + 32*ib + 8*j ), x_val1, accd[0]);
accd[1] = _mm256_fmadd_ps(_mm256_load_ps(y[0] + i*QK_K + 32*ib + 8*j + 128), x_val2, accd[1]);
} else {
for (int iy = 0; iy < nrc_y; ++iy) {
accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K + 32*ib + 8*j ), x_val1, accd[iy]);
accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K + 32*ib + 8*j + 128), x_val2, accd[iy]);
}
}
}
}
//for (int j = 0; j < 128; j+=8) {
// uint32_t val1 = ql[j/8] + 4096;
// uint32_t val2 = ql[j/8+16] + 4096;
// const float x_scale1 = iq4k_values[x[i].scales[j/32] & 0xf];
// const float x_scale2 = iq4k_values[x[i].scales[j/32] >> 4];
// const __m256 x_val1 = trellis_gen8(val1);
// const __m256 x_val2 = trellis_gen8(val2);
// for (int iy = 0; iy < nrc_y; ++iy) {
// accd[iy] = _mm256_fmadd_ps(
// _mm256_load_ps(y[iy] + i*QK_K+j),
// _mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1),
// accd[iy]
// );
// accd[iy] = _mm256_fmadd_ps(
// _mm256_load_ps(y[iy] + i*QK_K+j+128),
// _mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2),
// accd[iy]
// );
// }
//}
}
for (int iy = 0; iy < nrc_y; ++iy) {
__m256 res = _mm256_mul_ps(_mm256_set1_ps(d), accd[iy]);
info.store(ix, iy, hsum_float_8(res));
if constexpr (nrc_y == 1) {
__m256 res = _mm256_mul_ps(_mm256_set1_ps(d), _mm256_add_ps(accd[0], accd[1]));
info.store(ix, 0, hsum_float_8(res));
} else {
for (int iy = 0; iy < nrc_y; ++iy) {
__m256 res = _mm256_mul_ps(_mm256_set1_ps(d), accd[iy]);
info.store(ix, iy, hsum_float_8(res));
}
}
}
}
@@ -3787,8 +3878,6 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
}
}
#endif // Zen4 or vanilla AVX2
template <int nrc_y>
static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
if (nrc_x%4) {