Slightly faster iq4_kt

PP is now almost 50% better than original, TG is ~20% better
This commit is contained in:
Iwan Kawrakow
2025-05-23 20:02:42 +03:00
parent 48c0b7d8f9
commit fb254f0c97

View File

@@ -269,40 +269,63 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
Trellis2 trellis;
union { __m256 vec; float val[8]; } s_helper;
union { __m256 vec; float val[8]; } s_helper;
union { __m256i vec; uint32_t val[8]; } o_helper; //, q_helper1, q_helper2;
__m256 accd[nrc_y];
const float * y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
float row_sum[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
y[iy] = (const float *)info.src1_row(iy);
auto sum = _mm256_setzero_ps();
for (int i = 0; i < n/8; ++i) sum = _mm256_add_ps(sum, _mm256_loadu_ps(y[iy] + 8*i));
row_sum[iy] = hsum_float_8(sum);
}
for (int ix = 0; ix < nrc_x; ++ix) {
const float * dptr = (const float *)((const char*)vx + ix*bx);
auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f);
auto row_av = _mm256_set1_ps(dptr[1]);
auto dav = dptr[1];
//auto row_av = _mm256_set1_ps(dptr[1]);
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs);
const uint32_t * shb = x[i].qs;
const uint8_t * ql = (const uint8_t *)(shb + 8);
//const uint8_t * ql = (const uint8_t *)(shb + 8);
const uint8_t * ql = (const uint8_t *)(x[i].qs + 8);
const uint8_t * qh = ql + kNumGroups;
auto iscales = _mm256_loadu_si256((const __m256i *)shb);
iscales = _mm256_srli_epi32(_mm256_and_si256(iscales, _mm256_set1_epi32(0xff)), 1);
//auto iscales = _mm256_loadu_si256((const __m256i *)shb);
auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1);
s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64))));
o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096));
//vshb = _mm256_srli_epi32(vshb, 8);
for (int ib = 0; ib < 4; ++ib) {
auto scale1 = _mm256_set1_ps(s_helper.val[ib+0]);
auto scale2 = _mm256_set1_ps(s_helper.val[ib+4]);
const uint32_t offset1 = 4096 + ((shb[ib+0] & 1) << 15);
const uint32_t offset2 = 4096 + ((shb[ib+4] & 1) << 15);
//auto vq1 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(ql + 8*ib + 0)));
//auto vq2 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(ql + 8*ib + 32)));
//auto vqh = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qh + 8*ib)));
//vq1 = _mm256_add_epi32(vq1, _mm256_and_si256(_mm256_slli_epi32(vqh, 8), _mm256_set1_epi32(0xf00)));
//vq2 = _mm256_add_epi32(vq1, _mm256_and_si256(_mm256_slli_epi32(vqh, 4), _mm256_set1_epi32(0xf00)));
//q_helper1.vec = _mm256_add_epi32(vq1, _mm256_set1_epi32(o_helper.val[ib+0]));
//q_helper2.vec = _mm256_add_epi32(vq2, _mm256_set1_epi32(o_helper.val[ib+4]));
//const uint32_t offset1 = 4096 + ((shb[ib+0] & 1) << 15);
//const uint32_t offset2 = 4096 + ((shb[ib+4] & 1) << 15);
for (int j = 0; j < 4; ++j) {
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
// j/4 -> (32*ib+8*j)/4 = 8*ib + 2*j
uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1;
uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2;
uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1;
uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2;
uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
//uint32_t val1 = q_helper1.val[2*j+0] + ((sh1 & 7) << 12);
//uint32_t val2 = q_helper2.val[2*j+0] + ((sh2 & 7) << 12);
//uint32_t val3 = q_helper1.val[2*j+1] + ((sh1 & 56) << 9);
//uint32_t val4 = q_helper2.val[2*j+1] + ((sh2 & 56) << 9);
auto x_val1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(val1, val3)));
auto x_val2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(val2, val4)));
for (int iy = 0; iy < nrc_y; ++iy) {
@@ -310,7 +333,7 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
auto y2 = _mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j+128);
accd[iy] = _mm256_fmadd_ps(y1, x_val1, accd[iy]);
accd[iy] = _mm256_fmadd_ps(y2, x_val2, accd[iy]);
accd[iy] = _mm256_fmadd_ps(row_av, _mm256_add_ps(y1, y2), accd[iy]);
//accd[iy] = _mm256_fmadd_ps(row_av, _mm256_add_ps(y1, y2), accd[iy]);
}
}
}
@@ -351,7 +374,7 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
info.store(ix, iy, hsum_float_8(accd[iy]) + dav*row_sum[iy]);
}
}
}