diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 29972071..e5bf4967 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -632,9 +632,7 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn const float row_av = dptr[1]; const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); - for (int iy = 0; iy < nrc_y * 2; ++iy) { - accd[iy] = vdupq_n_f32(0.0f); - } + for (int iy = 0; iy < nrc_y * 2; ++iy) accd[iy] = vdupq_n_f32(0.0f); for (int i = 0; i < nb; ++i) { const uint32_t * shb = x[i].qs; @@ -652,17 +650,12 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn uint32_t sh1 = shb[ib+0] >> 8; uint32_t sh2 = shb[ib+4] >> 8; - for (int jj = 0; jj < 4; ++jj) { - //int j = 32*ib + 8*jj; - // -> (j/8)%4 = (4*ib+jj)%4 = jj%4; - // j/4 = 8*ib + 2*jj; - //const uint32_t sh1 = shb[j/32+0] >> (8 + 6*((j/8)%4)); - //const uint32_t sh2 = shb[j/32+4] >> (8 + 6*((j/8)%4)); + for (int j = 0; j < 4; ++j) { - uint32_t val1 = ql[8*ib+2*jj+ 0] + ((qh[8*ib+2*jj+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; - uint32_t val2 = ql[8*ib+2*jj+32] + ((qh[8*ib+2*jj+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; - uint32_t val3 = ql[8*ib+2*jj+ 1] + ((qh[8*ib+2*jj+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; - uint32_t val4 = ql[8*ib+2*jj+33] + ((qh[8*ib+2*jj+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; + uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; + uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; + uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; + uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; sh1 >>= 6; sh2 >>= 6; @@ -675,8 +668,8 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn x2.val[1] = vmulq_f32(scale2, x2.val[1]); for (int iy = 0; iy < nrc_y; ++iy) { - auto y1 = vld1q_f32_x2(y[iy] + i*QK_K + 32*ib + 8*jj); - auto y2 = vld1q_f32_x2(y[iy] + i*QK_K + 32*ib + 8*jj + 128); + auto y1 = vld1q_f32_x2(y[iy] + i*QK_K + 32*ib + 8*j); + auto y2 = vld1q_f32_x2(y[iy] + i*QK_K + 32*ib + 8*j + 128); accd[iy*2 + 0] = vfmaq_f32(accd[iy*2 + 0], y1.val[0], x1.val[0]); accd[iy*2 + 1] = vfmaq_f32(accd[iy*2 + 1], y1.val[1], x1.val[1]); @@ -689,10 +682,7 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn } for (int iy = 0; iy < nrc_y; ++iy) { - // Sum the two accumulators for this y row float32x4_t sum1 = vaddq_f32(accd[iy*2], accd[iy*2 + 1]); - - // Compute final result float result = d*vaddvq_f32(sum1) + row_av*row_sum[iy]; info.store(ix, iy, result); }