mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 15:14:10 +00:00
Slightly faster iq4_kt
This commit is contained in:
@@ -115,7 +115,8 @@ struct Trellis2 {
|
||||
const __m256i mask2 = _mm256_set1_epi32(km32);
|
||||
|
||||
inline __m256i next8(uint32_t val1, uint32_t val2) {
|
||||
__m256i mval = _mm256_setr_epi32(val1, val1, val1, val1, val2, val2, val2, val2);
|
||||
__m256i mval = MM256_SET_M128I(_mm_set1_epi32(val2), _mm_set1_epi32(val1));
|
||||
//__m256i mval = _mm256_setr_epi32(val1, val1, val1, val1, val2, val2, val2, val2);
|
||||
__m256i mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb);
|
||||
return _mm256_and_si256(mres, _mm256_set1_epi32(kmask)) ^ _mm256_set1_epi32(km32);
|
||||
}
|
||||
@@ -251,6 +252,15 @@ static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
|
||||
}
|
||||
}
|
||||
|
||||
// QuantizerIQKT<block_size = 32, group_size = 4, num_bits = 15>;
|
||||
// constexpr static int kSuperBlockSize = QK_K;
|
||||
// constexpr static int kBlockSize = block_size; -> 32
|
||||
// constexpr static int kGroupSize = group_size; -> 4
|
||||
// constexpr static int kNg = kBlockSize/kGroupSize; -> 8
|
||||
// constexpr static int kNblock = kSuperBlockSize/kBlockSize; -> 8
|
||||
// constexpr static int kNumVal = 1 << num_bits; -> 32768
|
||||
// constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize -> 64
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
@@ -259,66 +269,89 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
|
||||
|
||||
Trellis2 trellis;
|
||||
|
||||
union { __m256 vec; float val[8]; } s_helper;
|
||||
|
||||
__m256 accd[nrc_y];
|
||||
__m256 accd2[nrc_y];
|
||||
const float * y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||
const float d = dptr[0] * 31.75f * 1.01f;
|
||||
const float row_av = dptr[1];
|
||||
auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f);
|
||||
auto row_av = _mm256_set1_ps(dptr[1]);
|
||||
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
accd[iy] = _mm256_setzero_ps();
|
||||
accd2[iy] = _mm256_setzero_ps();
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const uint32_t * shb = x[i].qs;
|
||||
const uint8_t * ql = (const uint8_t *)(shb + 8);
|
||||
const uint8_t * qh = ql + kNumGroups;
|
||||
for (int j = 0; j < 128; j+=8) {
|
||||
const uint32_t offset1 = 4096 + ((shb[j/32+0] & 1) << 15);
|
||||
const uint32_t offset2 = 4096 + ((shb[j/32+4] & 1) << 15);
|
||||
const float x_scale1 = (int)((shb[j/32+0] & 0xff) >> 1) - 64;
|
||||
const float x_scale2 = (int)((shb[j/32+4] & 0xff) >> 1) - 64;
|
||||
const uint32_t sh1 = shb[j/32+0] >> (8 + 6*((j/8)%4));
|
||||
const uint32_t sh2 = shb[j/32+4] >> (8 + 6*((j/8)%4));
|
||||
uint32_t val1 = ql[j/4+ 0] + ((qh[j/4+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1;
|
||||
uint32_t val2 = ql[j/4+32] + ((qh[j/4+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2;
|
||||
uint32_t val3 = ql[j/4+ 1] + ((qh[j/4+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1;
|
||||
uint32_t val4 = ql[j/4+33] + ((qh[j/4+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2;
|
||||
const __m256 x_val1 = trellis_gen8(trellis.next8(val1, val3));
|
||||
const __m256 x_val2 = trellis_gen8(trellis.next8(val2, val4));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
accd[iy] = _mm256_fmadd_ps(
|
||||
_mm256_load_ps(y[iy] + i*QK_K+j),
|
||||
_mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1),
|
||||
accd[iy]
|
||||
);
|
||||
accd[iy] = _mm256_fmadd_ps(
|
||||
_mm256_load_ps(y[iy] + i*QK_K+j+128),
|
||||
_mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2),
|
||||
accd[iy]
|
||||
);
|
||||
accd2[iy] = _mm256_add_ps(
|
||||
_mm256_load_ps(y[iy] + i*QK_K+j),
|
||||
accd2[iy]
|
||||
);
|
||||
accd2[iy] = _mm256_add_ps(
|
||||
_mm256_load_ps(y[iy] + i*QK_K+j+128),
|
||||
accd2[iy]
|
||||
);
|
||||
auto iscales = _mm256_loadu_si256((const __m256i *)shb);
|
||||
iscales = _mm256_srli_epi32(_mm256_and_si256(iscales, _mm256_set1_epi32(0xff)), 1);
|
||||
s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64))));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto scale1 = _mm256_set1_ps(s_helper.val[ib+0]);
|
||||
auto scale2 = _mm256_set1_ps(s_helper.val[ib+4]);
|
||||
const uint32_t offset1 = 4096 + ((shb[ib+0] & 1) << 15);
|
||||
const uint32_t offset2 = 4096 + ((shb[ib+4] & 1) << 15);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
|
||||
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
|
||||
// j/4 -> (32*ib+8*j)/4 = 8*ib + 2*j
|
||||
uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1;
|
||||
uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2;
|
||||
uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1;
|
||||
uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2;
|
||||
auto x_val1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(val1, val3)));
|
||||
auto x_val2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(val2, val4)));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y1 = _mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j+ 0);
|
||||
auto y2 = _mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j+128);
|
||||
accd[iy] = _mm256_fmadd_ps(y1, x_val1, accd[iy]);
|
||||
accd[iy] = _mm256_fmadd_ps(y2, x_val2, accd[iy]);
|
||||
accd[iy] = _mm256_fmadd_ps(row_av, _mm256_add_ps(y1, y2), accd[iy]);
|
||||
}
|
||||
}
|
||||
}
|
||||
//for (int j = 0; j < 128; j+=8) {
|
||||
// const uint32_t offset1 = 4096 + ((shb[j/32+0] & 1) << 15);
|
||||
// const uint32_t offset2 = 4096 + ((shb[j/32+4] & 1) << 15);
|
||||
// const float x_scale1 = (int)((shb[j/32+0] & 0xff) >> 1) - 64;
|
||||
// const float x_scale2 = (int)((shb[j/32+4] & 0xff) >> 1) - 64;
|
||||
// const uint32_t sh1 = shb[j/32+0] >> (8 + 6*((j/8)%4));
|
||||
// const uint32_t sh2 = shb[j/32+4] >> (8 + 6*((j/8)%4));
|
||||
// uint32_t val1 = ql[j/4+ 0] + ((qh[j/4+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1;
|
||||
// uint32_t val2 = ql[j/4+32] + ((qh[j/4+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2;
|
||||
// uint32_t val3 = ql[j/4+ 1] + ((qh[j/4+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1;
|
||||
// uint32_t val4 = ql[j/4+33] + ((qh[j/4+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2;
|
||||
// const __m256 x_val1 = trellis_gen8(trellis.next8(val1, val3));
|
||||
// const __m256 x_val2 = trellis_gen8(trellis.next8(val2, val4));
|
||||
// for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
// accd[iy] = _mm256_fmadd_ps(
|
||||
// _mm256_load_ps(y[iy] + i*QK_K+j),
|
||||
// _mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1),
|
||||
// accd[iy]
|
||||
// );
|
||||
// accd[iy] = _mm256_fmadd_ps(
|
||||
// _mm256_load_ps(y[iy] + i*QK_K+j+128),
|
||||
// _mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2),
|
||||
// accd[iy]
|
||||
// );
|
||||
// accd2[iy] = _mm256_add_ps(
|
||||
// _mm256_load_ps(y[iy] + i*QK_K+j),
|
||||
// accd2[iy]
|
||||
// );
|
||||
// accd2[iy] = _mm256_add_ps(
|
||||
// _mm256_load_ps(y[iy] + i*QK_K+j+128),
|
||||
// accd2[iy]
|
||||
// );
|
||||
// }
|
||||
//}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
__m256 res = _mm256_mul_ps(_mm256_set1_ps(d), accd[iy]);
|
||||
__m256 res2 = _mm256_mul_ps(_mm256_set1_ps(row_av), accd2[iy]);
|
||||
info.store(ix, iy, hsum_float_8(res) + hsum_float_8(res2));
|
||||
info.store(ix, iy, hsum_float_8(accd[iy]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user