diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 9f0e3763..7cc524c2 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -286,7 +286,6 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn const float * dptr = (const float *)((const char*)vx + ix*bx); auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f); auto dav = dptr[1]; - //auto row_av = _mm256_set1_ps(dptr[1]); const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); @@ -294,38 +293,21 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn for (int i = 0; i < nb; ++i) { auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs); const uint32_t * shb = x[i].qs; - //const uint8_t * ql = (const uint8_t *)(shb + 8); - const uint8_t * ql = (const uint8_t *)(x[i].qs + 8); + const uint8_t * ql = (const uint8_t *)(shb + 8); const uint8_t * qh = ql + kNumGroups; - //auto iscales = _mm256_loadu_si256((const __m256i *)shb); auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1); s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64)))); o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096)); - //vshb = _mm256_srli_epi32(vshb, 8); for (int ib = 0; ib < 4; ++ib) { auto scale1 = _mm256_set1_ps(s_helper.val[ib+0]); auto scale2 = _mm256_set1_ps(s_helper.val[ib+4]); - //auto vq1 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(ql + 8*ib + 0))); - //auto vq2 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(ql + 8*ib + 32))); - //auto vqh = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qh + 8*ib))); - //vq1 = _mm256_add_epi32(vq1, _mm256_and_si256(_mm256_slli_epi32(vqh, 8), _mm256_set1_epi32(0xf00))); - //vq2 = _mm256_add_epi32(vq1, _mm256_and_si256(_mm256_slli_epi32(vqh, 4), _mm256_set1_epi32(0xf00))); - //q_helper1.vec = _mm256_add_epi32(vq1, _mm256_set1_epi32(o_helper.val[ib+0])); - //q_helper2.vec = _mm256_add_epi32(vq2, _mm256_set1_epi32(o_helper.val[ib+4])); - //const uint32_t offset1 = 4096 + ((shb[ib+0] & 1) << 15); - //const uint32_t offset2 = 4096 + ((shb[ib+4] & 1) << 15); for (int j = 0; j < 4; ++j) { const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); - // j/4 -> (32*ib+8*j)/4 = 8*ib + 2*j uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; - //uint32_t val1 = q_helper1.val[2*j+0] + ((sh1 & 7) << 12); - //uint32_t val2 = q_helper2.val[2*j+0] + ((sh2 & 7) << 12); - //uint32_t val3 = q_helper1.val[2*j+1] + ((sh1 & 56) << 9); - //uint32_t val4 = q_helper2.val[2*j+1] + ((sh2 & 56) << 9); auto x_val1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(val1, val3))); auto x_val2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(val2, val4))); for (int iy = 0; iy < nrc_y; ++iy) { @@ -333,44 +315,9 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn auto y2 = _mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j+128); accd[iy] = _mm256_fmadd_ps(y1, x_val1, accd[iy]); accd[iy] = _mm256_fmadd_ps(y2, x_val2, accd[iy]); - //accd[iy] = _mm256_fmadd_ps(row_av, _mm256_add_ps(y1, y2), accd[iy]); } } } - //for (int j = 0; j < 128; j+=8) { - // const uint32_t offset1 = 4096 + ((shb[j/32+0] & 1) << 15); - // const uint32_t offset2 = 4096 + ((shb[j/32+4] & 1) << 15); - // const float x_scale1 = (int)((shb[j/32+0] & 0xff) >> 1) - 64; - // const float x_scale2 = (int)((shb[j/32+4] & 0xff) >> 1) - 64; - // const uint32_t sh1 = shb[j/32+0] >> (8 + 6*((j/8)%4)); - // const uint32_t sh2 = shb[j/32+4] >> (8 + 6*((j/8)%4)); - // uint32_t val1 = ql[j/4+ 0] + ((qh[j/4+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; - // uint32_t val2 = ql[j/4+32] + ((qh[j/4+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; - // uint32_t val3 = ql[j/4+ 1] + ((qh[j/4+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; - // uint32_t val4 = ql[j/4+33] + ((qh[j/4+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; - // const __m256 x_val1 = trellis_gen8(trellis.next8(val1, val3)); - // const __m256 x_val2 = trellis_gen8(trellis.next8(val2, val4)); - // for (int iy = 0; iy < nrc_y; ++iy) { - // accd[iy] = _mm256_fmadd_ps( - // _mm256_load_ps(y[iy] + i*QK_K+j), - // _mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1), - // accd[iy] - // ); - // accd[iy] = _mm256_fmadd_ps( - // _mm256_load_ps(y[iy] + i*QK_K+j+128), - // _mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2), - // accd[iy] - // ); - // accd2[iy] = _mm256_add_ps( - // _mm256_load_ps(y[iy] + i*QK_K+j), - // accd2[iy] - // ); - // accd2[iy] = _mm256_add_ps( - // _mm256_load_ps(y[iy] + i*QK_K+j+128), - // accd2[iy] - // ); - // } - //} } for (int iy = 0; iy < nrc_y; ++iy) {