mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-04 11:00:00 +00:00
Slighty faster iq2_kt
This commit is contained in:
@@ -3655,8 +3655,12 @@ static void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
|
||||
|
||||
Trellis1 trellis;
|
||||
|
||||
constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;
|
||||
auto shifts = _mm_set_epi32(0, 0, 4, 0);
|
||||
auto values = _mm_loadu_si128((const __m128i *)iq4k_values);
|
||||
|
||||
union { __m256 vec; float val[8]; } s_helper;
|
||||
|
||||
constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;
|
||||
__m256 accd[k_acc];
|
||||
const float * y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
|
||||
@@ -3670,47 +3674,28 @@ static void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const uint16_t * ql = (const uint16_t *)x[i].ql;
|
||||
auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales);
|
||||
s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf));
|
||||
s8 = _mm_shuffle_epi8(values, s8);
|
||||
auto s32 = _mm256_cvtepi8_epi32(s8);
|
||||
s_helper.vec = _mm256_cvtepi32_ps(s32);
|
||||
for (int ib = 0; ib < QK_K/64; ++ib) {
|
||||
auto scale1 = _mm256_set1_ps(iq4k_values[x[i].scales[ib] & 0xf]);
|
||||
auto scale2 = _mm256_set1_ps(iq4k_values[x[i].scales[ib] >> 4]);
|
||||
auto scale1 = _mm256_set1_ps(s_helper.val[2*ib+0]);
|
||||
auto scale2 = _mm256_set1_ps(s_helper.val[2*ib+1]);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
uint32_t val1 = ql[4*ib+j+ 0] + 4096;
|
||||
uint32_t val2 = ql[4*ib+j+16] + 4096;
|
||||
//const __m256 x_val1 = _mm256_mul_ps(scale1, trellis_gen8(val1));
|
||||
//const __m256 x_val2 = _mm256_mul_ps(scale2, trellis_gen8(val2));
|
||||
const __m256 x_val1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(val1)));
|
||||
const __m256 x_val2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(val2)));
|
||||
auto xval1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(ql[8*ib+j+0]+4096)));
|
||||
auto xval2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(ql[8*ib+j+4]+4096)));
|
||||
if constexpr (nrc_y == 1) {
|
||||
accd[0] = _mm256_fmadd_ps(_mm256_load_ps(y[0] + i*QK_K + 32*ib + 8*j ), x_val1, accd[0]);
|
||||
accd[1] = _mm256_fmadd_ps(_mm256_load_ps(y[0] + i*QK_K + 32*ib + 8*j + 128), x_val2, accd[1]);
|
||||
accd[0] = _mm256_fmadd_ps(_mm256_load_ps(y[0] + i*QK_K + 64*ib + 8*j + 0), xval1, accd[0]);
|
||||
accd[1] = _mm256_fmadd_ps(_mm256_load_ps(y[0] + i*QK_K + 64*ib + 8*j + 32), xval2, accd[1]);
|
||||
} else {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K + 32*ib + 8*j ), x_val1, accd[iy]);
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K + 32*ib + 8*j + 128), x_val2, accd[iy]);
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K + 64*ib + 8*j + 0), xval1, accd[iy]);
|
||||
accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K + 64*ib + 8*j + 32), xval2, accd[iy]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
//for (int j = 0; j < 128; j+=8) {
|
||||
// uint32_t val1 = ql[j/8] + 4096;
|
||||
// uint32_t val2 = ql[j/8+16] + 4096;
|
||||
// const float x_scale1 = iq4k_values[x[i].scales[j/32] & 0xf];
|
||||
// const float x_scale2 = iq4k_values[x[i].scales[j/32] >> 4];
|
||||
// const __m256 x_val1 = trellis_gen8(val1);
|
||||
// const __m256 x_val2 = trellis_gen8(val2);
|
||||
// for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
// accd[iy] = _mm256_fmadd_ps(
|
||||
// _mm256_load_ps(y[iy] + i*QK_K+j),
|
||||
// _mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1),
|
||||
// accd[iy]
|
||||
// );
|
||||
// accd[iy] = _mm256_fmadd_ps(
|
||||
// _mm256_load_ps(y[iy] + i*QK_K+j+128),
|
||||
// _mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2),
|
||||
// accd[iy]
|
||||
// );
|
||||
// }
|
||||
//}
|
||||
}
|
||||
|
||||
if constexpr (nrc_y == 1) {
|
||||
@@ -3727,7 +3712,7 @@ static void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
|
||||
|
||||
static inline __m256 abs_ps(__m256 vals) {
|
||||
// Clear sign-bit of all the 32-bit floats in vals
|
||||
__m256 sign_bit = _mm256_set1_ps(-0.0f);
|
||||
__m256 sign_bit = _mm256_set1_ps(-0.0f);
|
||||
return _mm256_andnot_ps(sign_bit, vals);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user