diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 5b9ba8a4..1203ebfa 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1286,7 +1286,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre } // function for calculate inner product between half a q6_0 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q6 quants begin (0 or QK6_0/2) +// il indicates where the q6 quants begin (0 or QK6_0/4) // we assume that the yl's have been multiplied with the appropriate scale factor // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) inline float block_q_n_dot_y(device const block_q6_0 * qb_curr, float sumy, thread float * yl, int il) { @@ -1294,14 +1294,15 @@ inline float block_q_n_dot_y(device const block_q6_0 * qb_curr, float sumy, thre float2 acc = 0.f; - device const uint16_t * qh = ((device const uint16_t *)qb_curr->qh); - device const uint16_t * qs = qh + 4 + il/2; + device const uint16_t * qh = (device const uint16_t *)qb_curr->qh; + device const uint16_t * qs = (device const uint16_t *)qb_curr->qs + il/2; + const int shift = 4*(il/8); for (int i = 0; i < 8; i += 2) { - acc[0] += yl[i + 0] * ((qs[i/2] & 0x000F) | ((qh[i/2] << 4) & 0x00030)) - + yl[i + 1] * ((qs[i/2] & 0x0F00) | ((qh[i/2] << 12) & 0x03000)); - acc[1] += yl[i + 8] * ((qs[i/2] & 0x00F0) | ((qh[i/2] << 6) & 0x00300)) - + yl[i + 9] * ((qs[i/2] & 0xF000) | (((uint32_t)qh[i/2] << 14) & 0x30000)); + acc[0] += yl[i + 0] * ((qs[i/2] & 0x000F) | ((qh[i/2] << (4-shift)) & 0x0030)) + + yl[i + 1] * ((qs[i/2] & 0x0F00) | ((qh[i/2] << (4-shift)) & 0x3000)); + acc[1] += yl[i + 8] * ((qs[i/2] & 0x00F0) | ((qh[i/2] << (6-shift)) & 0x0300)) + + yl[i + 9] * ((qs[i/2] & 0xF000) | (((uint32_t)qh[i/2] << (6-shift)) & 0x30000)); } return d * (sumy * -32.f + acc[0] + acc[1]); }