This commit is contained in:
Iwan Kawrakow
2025-06-14 06:19:18 +03:00
parent de4e6c797f
commit 6d38e43f1d

View File

@@ -100,33 +100,24 @@ struct Trellis2 {
template <bool is_8 = false>
struct Trellis3 {
constexpr static uint32_t ka = 89226354;
constexpr static uint32_t kb = 64248484;
constexpr static uint32_t ka = 0xCBAC1FED;
constexpr static uint32_t ka1 = ka*ka;
constexpr static uint32_t kb1 = kb*ka+kb;
constexpr static uint32_t ka2 = ka1*ka;
constexpr static uint32_t kb2 = kb1*ka+kb;
constexpr static uint32_t ka3 = ka2*ka;
constexpr static uint32_t kb3 = kb2*ka+kb;
constexpr static uint32_t ka4 = ka3*ka;
constexpr static uint32_t kb4 = kb3*ka+kb;
constexpr static uint32_t ka5 = ka4*ka;
constexpr static uint32_t kb5 = kb4*ka+kb;
constexpr static uint32_t ka6 = ka5*ka;
constexpr static uint32_t kb6 = kb5*ka+kb;
constexpr static uint32_t ka7 = ka6*ka;
constexpr static uint32_t kb7 = kb6*ka+kb;
const __m256i mka = is_8 ? _mm256_setr_epi32(ka, ka1, ka2, ka3, ka4, ka5, ka6, ka7) : _mm256_setr_epi32(ka, ka1, ka2, ka3, ka, ka1, ka2, ka3);
const __m256i mkb = is_8 ? _mm256_setr_epi32(kb, kb1, kb2, kb3, kb4, kb5, kb6, kb7) : _mm256_setr_epi32(kb, kb1, kb2, kb3, kb, kb1, kb2, kb3);
const __m256i shuffle = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
inline __m256i next8(uint32_t val1, uint32_t val2) const {
__m256i mval = MM256_SET_M128I(_mm_set1_epi32(val2), _mm_set1_epi32(val1));
return _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb);
return _mm256_mullo_epi32(mval, mka);
}
inline __m256i next8(uint32_t val) const {
__m256i mval = _mm256_set1_epi32(val);
return _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb);
return _mm256_mullo_epi32(mval, mka);
}
inline __m256 gen8(uint32_t val1, uint32_t val2) const {
auto v8 = _mm256_and_si256(next8(val1, val2), _mm256_set1_epi32(0x3f3f3f3f));
@@ -189,11 +180,11 @@ struct Trellis3 {
template <bool is_unsigned = false>
inline void next64(const uint32_t * val, __m256i * result) const {
const __m256i offset = is_unsigned ? _mm256_setzero_si256() : _mm256_set1_epi32(-126);
auto vka3 = _mm256_set1_epi32(ka3), vkb3 = _mm256_set1_epi32(kb3);
auto vka3 = _mm256_set1_epi32(ka3);
__m256i aux[8];
for (int i = 0; i < 4; ++i) {
auto i8_1 = next8(val[2*i+0], val[2*i+1]);
auto i8_2 = _mm256_add_epi32(_mm256_mullo_epi32(i8_1, vka3), vkb3);
auto i8_2 = _mm256_mullo_epi32(i8_1, vka3);
i8_1 = _mm256_and_si256(i8_1, _mm256_set1_epi32(0x3f3f3f3f));
i8_2 = _mm256_and_si256(i8_2, _mm256_set1_epi32(0x3f3f3f3f));
#ifdef HAVE_FANCY_SIMD
@@ -1419,22 +1410,17 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
}
struct Trellis3 {
constexpr static uint32_t ka = 89226354;
constexpr static uint32_t kb = 64248484;
constexpr static uint32_t ka = ;0xCBAC1FED;
constexpr static uint32_t ka1 = ka*ka;
constexpr static uint32_t kb1 = kb*ka+kb;
constexpr static uint32_t ka2 = ka1*ka;
constexpr static uint32_t kb2 = kb1*ka+kb;
constexpr static uint32_t ka3 = ka2*ka;
constexpr static uint32_t kb3 = kb2*ka+kb;
const uint32x4_t mka = uint32x4_t{ka, ka1, ka2, ka3};
const uint32x4_t mkb = uint32x4_t{kb, kb1, kb2, kb3};
const uint8x16_t shuffle = load_shuffle();
inline uint32x4x2_t next8(uint32_t val1, uint32_t val2) const {
uint32x4x2_t result{vdupq_n_u32(val1), vdupq_n_u32(val2)};
result.val[0] = vmlaq_u32(mkb, mka, result.val[0]);
result.val[1] = vmlaq_u32(mkb, mka, result.val[1]);
result.val[0] = vmulq_u32(mka, result.val[0]);
result.val[1] = vmulq_u32(mka, result.val[1]);
return result;
}
inline int8x16x2_t next32(const uint32_t * val) const {
@@ -1457,12 +1443,12 @@ struct Trellis3 {
int8x16x2_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126)};
int8x16x2_t i8;
for (int i = 0; i < 2; ++i) {
i8.val[0] = vmlaq_u32(mkb, mka, vdupq_n_u32(val[2*i+0]+v0));
i8.val[0] = vmulq_u32(mka, vdupq_n_u32(val[2*i+0]+v0));
i8.val[1] = vmlaq_u32(vkb3, vka3, i8.val[0]);
i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f));
i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f));
auto s1 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1]));
i8.val[0] = vmlaq_u32(mkb, mka, vdupq_n_u32(val[2*i+1]+v0));
i8.val[0] = vmulq_u32(mka, vdupq_n_u32(val[2*i+1]+v0));
i8.val[1] = vmlaq_u32(vkb3, vka3, i8.val[0]);
i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f));
i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f));