diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 0e4fe49c..d1181b21 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6185,6 +6185,7 @@ void kernel_mul_mv_iq4_kss_f32_impl( float4 yl[4]; float2 sumf = 0.f; + float d[2]; device const float * yb = y + ix * QK_K + ib * 32 + il * 8; @@ -6193,19 +6194,21 @@ void kernel_mul_mv_iq4_kss_f32_impl( float4 qf1, qf2; + device const float * dptr = (device const float *)cx; + d[0] = *dptr; + device const uint32_t * qptr = (device const uint32_t *)(dptr + 1); + dptr += row_size/4; + d[1] = *dptr; + for (int ibl = ix; ibl < nb; ibl += 2) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - device const float * dptr = (device const float *)cx; + device const uint32_t * q4 = qptr + ibl*(QK_K/8) + 4*ib; for (int row = 0; row < 2; ++row) { - const float d = *dptr; - device const block_iq4_kss * x = (device const block_iq4_kss *)(dptr + 1); - device const uint32_t * q4 = (device const uint32_t *)x[ibl].qs + 4*ib; - uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6); int16_t ls = (s32 | (s32 >> 15)) & 0xff; @@ -6234,9 +6237,9 @@ void kernel_mul_mv_iq4_kss_f32_impl( acc1 += acc2; - sumf[row] += d * scale * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + sumf[row] += d[row] * scale * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - dptr += row_size/4; + q4 += row_size/4; }