mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-08 23:40:10 +00:00
iqk_mul_mat: small improvement for iq3_s
The same as in llamafile. We get
PP-512 = 96.6 t/s
TG-128 = 7.77 t/s @ 4 threads
14.4 t/s @ 8 threads
16.3 t/s @ 16 threads
This commit is contained in:
@@ -1145,6 +1145,21 @@ struct SignHelper {
|
||||
//aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256, mask1), mask2);
|
||||
//return _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone);
|
||||
}
|
||||
inline void sign_4_values(const uint16_t * sign_bits, __m256i * values) const {
|
||||
auto s128 = _mm_loadu_si128((const __m128i *)sign_bits);
|
||||
auto s256 = MM256_SET_M128I(s128, s128);
|
||||
__m256i aux256;
|
||||
auto shuffle = mask1;
|
||||
auto step = _mm256_set1_epi8(4);
|
||||
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step);
|
||||
values[0] = _mm256_sign_epi8(values[0], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone));
|
||||
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step);
|
||||
values[1] = _mm256_sign_epi8(values[1], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone));
|
||||
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step);
|
||||
values[2] = _mm256_sign_epi8(values[2], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone));
|
||||
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step);
|
||||
values[3] = _mm256_sign_epi8(values[3], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone));
|
||||
}
|
||||
const __m256i mask1 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
|
||||
const __m256i mask2 = _mm256_set1_epi64x(0x8040201008040201ull);
|
||||
const __m256i mone = _mm256_set1_epi8(1);
|
||||
@@ -1181,65 +1196,37 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
|
||||
uint32_t val[8];
|
||||
};
|
||||
|
||||
struct SignSelf {
|
||||
SignSelf(const SignHelper& sh, const __m256i& min_value, __m256i * values, const uint16_t * sidx) :
|
||||
sh(sh), min_value(min_value), values(values), sidx(sidx) {}
|
||||
inline void apply(int k) {
|
||||
values[k] = _mm256_add_epi8(_mm256_sign_epi8(values[k], sh.make_signs(sidx+2*k)), min_value);
|
||||
}
|
||||
const SignHelper& sh;
|
||||
const __m256i& min_value;
|
||||
__m256i * values;
|
||||
const uint16_t * sidx;
|
||||
};
|
||||
template <typename Q8>
|
||||
struct SignQ8 {
|
||||
SignQ8(const Q8& q8, const SignHelper& sh, __m256i * values, const uint16_t * sidx, int i, int j) :
|
||||
q8(q8), sh(sh), values(values), sidx(sidx), i(i), j(j) {}
|
||||
inline void apply(int k) {
|
||||
values[k] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+k), sh.make_signs(sidx+2*k));
|
||||
}
|
||||
const Q8& q8;
|
||||
const SignHelper& sh;
|
||||
__m256i * values;
|
||||
const uint16_t * sidx;
|
||||
int i;
|
||||
int j;
|
||||
};
|
||||
|
||||
template <typename ApplySignes>
|
||||
inline static void make1(int k, const __m128i& idx_l, uint8_t qh, __m256i * values, const __m256i& idx_shift, const __m256i& idx_mask,
|
||||
ApplySignes& as) {
|
||||
inline static void make2(const uint8_t * qs, const uint8_t * qh, __m256i * values, const __m256i& idx_shift, const __m256i& idx_mask) {
|
||||
index_t idx;
|
||||
idx.vec = _mm256_set1_epi32(qh);
|
||||
idx.vec = _mm256_and_si256(_mm256_sllv_epi32(idx.vec, idx_shift), idx_mask);
|
||||
idx.vec = _mm256_or_si256(idx.vec, _mm256_cvtepi16_epi32(idx_l));
|
||||
values[k] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]],
|
||||
auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs));
|
||||
auto idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[0]), idx_shift), idx_mask);
|
||||
idx.vec = _mm256_or_si256(idx_h, idx_l);
|
||||
values[0] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]],
|
||||
iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]);
|
||||
idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qs + 8)));
|
||||
idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[1]), idx_shift), idx_mask);
|
||||
idx.vec = _mm256_or_si256(idx_h, idx_l);
|
||||
values[1] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]],
|
||||
iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]);
|
||||
as.apply(k);
|
||||
}
|
||||
template <typename ApplySignes>
|
||||
inline static void make2(int k, const uint8_t * qs, const uint8_t * qh,
|
||||
__m256i * values, const __m256i& idx_shift, const __m256i& idx_mask, ApplySignes& as) {
|
||||
auto idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs));
|
||||
make1(k+0, _mm256_castsi256_si128 (idx_l ), qh[0], values, idx_shift, idx_mask, as);
|
||||
make1(k+1, _mm256_extractf128_si256(idx_l, 1), qh[1], values, idx_shift, idx_mask, as);
|
||||
}
|
||||
|
||||
inline void prepare(int i, int j) {
|
||||
auto qs = x[i].qs + 32*j;
|
||||
auto qh = x[i].qh + 4*j;
|
||||
SignSelf ss(sh, min_value, bits.values, (const uint16_t *)x[i].signs + 8*j);
|
||||
make2(0, qs+ 0, qh+0, bits.values, idx_shift, idx_mask, ss);
|
||||
make2(2, qs+16, qh+2, bits.values, idx_shift, idx_mask, ss);
|
||||
prepare_unsigned(i, j);
|
||||
sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, bits.values);
|
||||
for (int k = 0; k < 4; ++k) bits.values[k] = _mm256_add_epi8(bits.values[k], min_value);
|
||||
}
|
||||
template <typename Q8>
|
||||
inline void prepare(int i, int j, const Q8& q8, __m256i * q8_quants) {
|
||||
prepare_unsigned(i, j);
|
||||
for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
|
||||
sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, q8_quants);
|
||||
}
|
||||
|
||||
inline void prepare_unsigned(int i, int j) {
|
||||
auto qs = x[i].qs + 32*j;
|
||||
auto qh = x[i].qh + 4*j;
|
||||
SignQ8 sq8(q8, sh, q8_quants, (const uint16_t *)x[i].signs + 8*j, i, j);
|
||||
make2(0, qs+ 0, qh+0, bits.values, idx_shift, idx_mask, sq8);
|
||||
make2(2, qs+16, qh+2, bits.values, idx_shift, idx_mask, sq8);
|
||||
make2(qs+ 0, qh+0, bits.values+0, idx_shift, idx_mask);
|
||||
make2(qs+16, qh+2, bits.values+2, idx_shift, idx_mask);
|
||||
}
|
||||
|
||||
constexpr static int minv = 16;
|
||||
|
||||
Reference in New Issue
Block a user