diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 8b8cae14..fc954b54 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -100,33 +100,24 @@ struct Trellis2 { template struct Trellis3 { - constexpr static uint32_t ka = 89226354; - constexpr static uint32_t kb = 64248484; + constexpr static uint32_t ka = 0xCBAC1FED; constexpr static uint32_t ka1 = ka*ka; - constexpr static uint32_t kb1 = kb*ka+kb; constexpr static uint32_t ka2 = ka1*ka; - constexpr static uint32_t kb2 = kb1*ka+kb; constexpr static uint32_t ka3 = ka2*ka; - constexpr static uint32_t kb3 = kb2*ka+kb; constexpr static uint32_t ka4 = ka3*ka; - constexpr static uint32_t kb4 = kb3*ka+kb; constexpr static uint32_t ka5 = ka4*ka; - constexpr static uint32_t kb5 = kb4*ka+kb; constexpr static uint32_t ka6 = ka5*ka; - constexpr static uint32_t kb6 = kb5*ka+kb; constexpr static uint32_t ka7 = ka6*ka; - constexpr static uint32_t kb7 = kb6*ka+kb; const __m256i mka = is_8 ? _mm256_setr_epi32(ka, ka1, ka2, ka3, ka4, ka5, ka6, ka7) : _mm256_setr_epi32(ka, ka1, ka2, ka3, ka, ka1, ka2, ka3); - const __m256i mkb = is_8 ? _mm256_setr_epi32(kb, kb1, kb2, kb3, kb4, kb5, kb6, kb7) : _mm256_setr_epi32(kb, kb1, kb2, kb3, kb, kb1, kb2, kb3); const __m256i shuffle = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); inline __m256i next8(uint32_t val1, uint32_t val2) const { __m256i mval = MM256_SET_M128I(_mm_set1_epi32(val2), _mm_set1_epi32(val1)); - return _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb); + return _mm256_mullo_epi32(mval, mka); } inline __m256i next8(uint32_t val) const { __m256i mval = _mm256_set1_epi32(val); - return _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb); + return _mm256_mullo_epi32(mval, mka); } inline __m256 gen8(uint32_t val1, uint32_t val2) const { auto v8 = _mm256_and_si256(next8(val1, val2), _mm256_set1_epi32(0x3f3f3f3f)); @@ -189,11 +180,11 @@ struct Trellis3 { template inline void next64(const uint32_t * val, __m256i * result) const { const __m256i offset = is_unsigned ? _mm256_setzero_si256() : _mm256_set1_epi32(-126); - auto vka3 = _mm256_set1_epi32(ka3), vkb3 = _mm256_set1_epi32(kb3); + auto vka3 = _mm256_set1_epi32(ka3); __m256i aux[8]; for (int i = 0; i < 4; ++i) { auto i8_1 = next8(val[2*i+0], val[2*i+1]); - auto i8_2 = _mm256_add_epi32(_mm256_mullo_epi32(i8_1, vka3), vkb3); + auto i8_2 = _mm256_mullo_epi32(i8_1, vka3); i8_1 = _mm256_and_si256(i8_1, _mm256_set1_epi32(0x3f3f3f3f)); i8_2 = _mm256_and_si256(i8_2, _mm256_set1_epi32(0x3f3f3f3f)); #ifdef HAVE_FANCY_SIMD @@ -1419,22 +1410,17 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf } struct Trellis3 { - constexpr static uint32_t ka = 89226354; - constexpr static uint32_t kb = 64248484; + constexpr static uint32_t ka = ;0xCBAC1FED; constexpr static uint32_t ka1 = ka*ka; - constexpr static uint32_t kb1 = kb*ka+kb; constexpr static uint32_t ka2 = ka1*ka; - constexpr static uint32_t kb2 = kb1*ka+kb; constexpr static uint32_t ka3 = ka2*ka; - constexpr static uint32_t kb3 = kb2*ka+kb; const uint32x4_t mka = uint32x4_t{ka, ka1, ka2, ka3}; - const uint32x4_t mkb = uint32x4_t{kb, kb1, kb2, kb3}; const uint8x16_t shuffle = load_shuffle(); inline uint32x4x2_t next8(uint32_t val1, uint32_t val2) const { uint32x4x2_t result{vdupq_n_u32(val1), vdupq_n_u32(val2)}; - result.val[0] = vmlaq_u32(mkb, mka, result.val[0]); - result.val[1] = vmlaq_u32(mkb, mka, result.val[1]); + result.val[0] = vmulq_u32(mka, result.val[0]); + result.val[1] = vmulq_u32(mka, result.val[1]); return result; } inline int8x16x2_t next32(const uint32_t * val) const { @@ -1457,12 +1443,12 @@ struct Trellis3 { int8x16x2_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126)}; int8x16x2_t i8; for (int i = 0; i < 2; ++i) { - i8.val[0] = vmlaq_u32(mkb, mka, vdupq_n_u32(val[2*i+0]+v0)); + i8.val[0] = vmulq_u32(mka, vdupq_n_u32(val[2*i+0]+v0)); i8.val[1] = vmlaq_u32(vkb3, vka3, i8.val[0]); i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f)); auto s1 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1])); - i8.val[0] = vmlaq_u32(mkb, mka, vdupq_n_u32(val[2*i+1]+v0)); + i8.val[0] = vmulq_u32(mka, vdupq_n_u32(val[2*i+1]+v0)); i8.val[1] = vmlaq_u32(vkb3, vka3, i8.val[0]); i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f)); i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f));