diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 3510cbaf..fffb3ab2 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1510,18 +1510,13 @@ IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const Da deq.new_row(ix); - for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { __m256i sumi[nrc_y]; - //deq.new_block(i, q8, sumi); deq.new_block(i); deq.prepare(i, 0); for (int iy = 0; iy < nrc_y; ++iy) { - //sumi[iy] = _mm256_sub_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 0)), - // _mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 1))), sumi[iy]); sumi[iy] = _mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[0], q8.load_quants(iy, i, 0)), _mm256_maddubs_epi16(deq.bits.values[1], q8.load_quants(iy, i, 1))); sumi[iy] = _mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(deq.bits.values[2], q8.load_quants(iy, i, 2)), @@ -1535,8 +1530,14 @@ IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const Da _mm256_maddubs_epi16(deq.bits.values[3], q8.load_quants(iy, i, 7))), sumi[iy]); sumi[iy] = _mm256_sub_epi16(sumi[iy], q8.load_bsums(iy, i)); } - for (int iy = 0; iy < nrc_y; ++iy) { - accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy])), accd[iy]); + if (i > 0) { + for (int iy = 0; iy < nrc_y; ++iy) { + accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy])), accd[iy]); + } + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + accd[iy] = _mm256_mul_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(_mm256_madd_epi16(m1, sumi[iy]))); + } } } @@ -2040,7 +2041,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16( _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)), _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot3), _mm256_maddubs_epi16(deq.m1_8, dot4)))); - accd[iy] = _mm256_add_epi32(dot, accd[iy]); + accd[iy] = i > 0 ? _mm256_add_epi32(dot, accd[iy]) : dot; #endif } } @@ -3275,7 +3276,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[2] = mul_mat_iq2tn_q8_K<3>; mm.funcs[3] = mul_mat_iq2tn_q8_K<4>; mm.funcs[4] = mul_mat_iq2tn_q8_K<5>; - mm.funcs[5] = mul_mat_iq2tn_q8_K<6>; + //mm.funcs[5] = mul_mat_iq2tn_q8_K<6>; //mm.funcs[6] = mul_mat_iq2tn_q8_K<7>; //mm.funcs[7] = mul_mat_iq2tn_q8_K<8>; #endif