diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index d1181b21..8981cda9 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6189,14 +6189,14 @@ void kernel_mul_mv_iq4_kss_f32_impl( device const float * yb = y + ix * QK_K + ib * 32 + il * 8; - uint32_t aux32[2]; - thread const uint8_t * q8 = (thread const uint8_t *)aux32; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; float4 qf1, qf2; device const float * dptr = (device const float *)cx; d[0] = *dptr; - device const uint32_t * qptr = (device const uint32_t *)(dptr + 1); + device const uint32_t * qptr = (device const uint32_t *)(dptr + 1) + ix*(QK_K/8) + 4*ib; dptr += row_size/4; d[1] = *dptr; @@ -6205,7 +6205,7 @@ void kernel_mul_mv_iq4_kss_f32_impl( device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - device const uint32_t * q4 = qptr + ibl*(QK_K/8) + 4*ib; + device const uint32_t * q4 = qptr; for (int row = 0; row < 2; ++row) { @@ -6213,26 +6213,26 @@ void kernel_mul_mv_iq4_kss_f32_impl( int16_t ls = (s32 | (s32 >> 15)) & 0xff; threadgroup const float * block_values = shared_values + ((ls & 1) << 4); - const int scale = (ls & 254) - 127; + const float scale = ((ls & 254) - 127); float4 acc1 = {0.f}, acc2 = {0.f}; - aux32[0] = q4[2*il+0] & 0xfffefffe; - aux32[0] ^= (aux32[0] >> 1); - aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; - aux32[0] &= 0x0f0f0f0f; + uint32_t v32 = q4[2*il+0] & 0xfffefffe; + v32 ^= (v32 >> 1); + aux32 = v32 & 0x0f0f0f0f; qf1 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]}; - qf2 = {block_values[q8[4]], block_values[q8[5]], block_values[q8[6]], block_values[q8[7]]}; acc1 += yl[0] * qf1; + aux32 = (v32 >> 4) & 0x0f0f0f0f; + qf2 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]}; acc2 += yl[1] * qf2; - aux32[0] = q4[2*il+1] & 0xfffefffe; - aux32[0] ^= (aux32[0] >> 1); - aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; - aux32[0] &= 0x0f0f0f0f; + v32 = q4[2*il+1] & 0xfffefffe; + v32 ^= (v32 >> 1); + aux32 = v32 & 0x0f0f0f0f; qf1 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]}; - qf2 = {block_values[q8[4]], block_values[q8[5]], block_values[q8[6]], block_values[q8[7]]}; acc1 += yl[2] * qf1; + aux32 = (v32 >> 4) & 0x0f0f0f0f; + qf2 = {block_values[q8[0]], block_values[q8[1]], block_values[q8[2]], block_values[q8[3]]}; acc2 += yl[3] * qf2; acc1 += acc2; @@ -6244,6 +6244,7 @@ void kernel_mul_mv_iq4_kss_f32_impl( } yb += 2 * QK_K; + qptr += 2 * (QK_K/8); } sumf = simd_sum(sumf);