From 9802c771b8149c721d1ba1fddd61e47ba3dbbbef Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 14 Oct 2024 10:15:25 +0300 Subject: [PATCH] iq3_k: fix Metal dot product I was accessing the scales as 4-byte aligned, but iq3_k is not 4-byte aligned. Instead of throwing an error (as it happens on CUDA when one makes this mistake), Metal silently accepts and we get garbage. --- ggml/src/ggml-metal.metal | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index fe197309..5e12bf1c 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6463,7 +6463,6 @@ void kernel_mul_mv_iq3_k_f32_impl( uint32_t vl[2], vh[2]; uint32_t aux32[2]; thread const uint8_t * aux8 = (thread const uint8_t *)aux32; - uint16_t shift[4]; for (int ib = ix; ib < nb; ib += 4) { @@ -6479,18 +6478,14 @@ void kernel_mul_mv_iq3_k_f32_impl( device const block_iq3_k & xb = x[row*nb + ib]; device const uint16_t * ql16 = (device const uint16_t *)xb.qs + 16*iq + 4*ir; device const uint16_t * qh16 = (device const uint16_t *)xb.qh + 4*ir; - device const uint32_t * sc = (device const uint32_t *)xb.scales_l; + device const uint16_t * sc16 = (device const uint16_t *)xb.scales_l; - const uint32_t scales32 = ((sc[iq] >> 4*is) & 0x0f0f0f0f) << 1; + uint32_t scales32 = sc16[2*iq+0] | (sc16[2*iq+1] << 16); + scales32 = ((scales32 >> 4*is) & 0x0f0f0f0f) << 1; thread const int8_t * s8 = (thread const int8_t *)&scales32; - uint16_t extra = xb.extra >> (8*iq + is); + uint16_t extra = (xb.extra >> (8*iq + is)) << 3; uint16_t signs = xb.scales_h >> (8*iq + is); - shift[0] = (extra << 3) & 8; - shift[1] = (extra << 2) & 8; - shift[2] = (extra << 1) & 8; - shift[3] = (extra << 0) & 8; - vl[0] = ql16[0] | ql16[1] << 16; vl[1] = ql16[2] | ql16[3] << 16; vh[0] = ((qh16[0] | (qh16[1] << 16)) << 4*(1-iq)) >> 2; @@ -6498,12 +6493,13 @@ void kernel_mul_mv_iq3_k_f32_impl( float4 acc = {0.f}; for (int l = 0; l < 4; ++l) { - constant float * values = kvalues_iq3k_f + shift[l]; + constant float * values = kvalues_iq3k_f + (extra & 8); aux32[0] = (vl[0] & 0x03030303) | (vh[0] & 0x04040404); aux32[1] = (vl[1] & 0x03030303) | (vh[1] & 0x04040404); for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]]; vl[0] >>= 2; vl[1] >>= 2; vh[0] >>= 1; vh[1] >>= 1; + extra >>= 2; } sumf[row] += (float)xb.d * (acc[0] * (signs & 0x01 ? -s8[0] : s8[0]) + acc[1] * (signs & 0x04 ? -s8[1] : s8[1]) +