q6_0: it now works on Metal

Outperforms q5_0 by a significant margin. E.g.
| model                          |       size |     params | backend    | ngl | threads |          test |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | ------------: | ---------------: |
| llama 8B Q6_0                  |   6.08 GiB |     8.03 B | Metal      | 100 |       4 |         tg128 |     44.02 ± 0.08 |
| llama 8B Q5_0                  |   5.21 GiB |     8.03 B | Metal      | 100 |       4 |         tg128 |     40.13 ± 0.12 |
| llama 8B Q6_0                  |   6.08 GiB |     8.03 B | Metal      | 100 |       4 |         pp512 |    500.55 ± 0.32 |
| llama 8B Q5_0                  |   5.21 GiB |     8.03 B | Metal      | 100 |       4 |         pp512 |    448.02 ± 0.27 |
This commit is contained in:
Iwan Kawrakow
2024-10-02 14:42:32 +03:00
parent aae268f7be
commit 0d0cd1ee68

View File

@@ -1286,7 +1286,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
}
// function for calculate inner product between half a q6_0 block and 16 floats (yl), sumy is SUM(yl[i])
// il indicates where the q6 quants begin (0 or QK6_0/2)
// il indicates where the q6 quants begin (0 or QK6_0/4)
// we assume that the yl's have been multiplied with the appropriate scale factor
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q6_0 * qb_curr, float sumy, thread float * yl, int il) {
@@ -1294,14 +1294,15 @@ inline float block_q_n_dot_y(device const block_q6_0 * qb_curr, float sumy, thre
float2 acc = 0.f;
device const uint16_t * qh = ((device const uint16_t *)qb_curr->qh);
device const uint16_t * qs = qh + 4 + il/2;
device const uint16_t * qh = (device const uint16_t *)qb_curr->qh;
device const uint16_t * qs = (device const uint16_t *)qb_curr->qs + il/2;
const int shift = 4*(il/8);
for (int i = 0; i < 8; i += 2) {
acc[0] += yl[i + 0] * ((qs[i/2] & 0x000F) | ((qh[i/2] << 4) & 0x00030))
+ yl[i + 1] * ((qs[i/2] & 0x0F00) | ((qh[i/2] << 12) & 0x03000));
acc[1] += yl[i + 8] * ((qs[i/2] & 0x00F0) | ((qh[i/2] << 6) & 0x00300))
+ yl[i + 9] * ((qs[i/2] & 0xF000) | (((uint32_t)qh[i/2] << 14) & 0x30000));
acc[0] += yl[i + 0] * ((qs[i/2] & 0x000F) | ((qh[i/2] << (4-shift)) & 0x0030))
+ yl[i + 1] * ((qs[i/2] & 0x0F00) | ((qh[i/2] << (4-shift)) & 0x3000));
acc[1] += yl[i + 8] * ((qs[i/2] & 0x00F0) | ((qh[i/2] << (6-shift)) & 0x0300))
+ yl[i + 9] * ((qs[i/2] & 0xF000) | (((uint32_t)qh[i/2] << (6-shift)) & 0x30000));
}
return d * (sumy * -32.f + acc[0] + acc[1]);
}