iq3_k: fix Metal dot product

I was accessing the scales as 4-byte aligned, but iq3_k is
not 4-byte aligned. Instead of throwing an error (as it happens
on CUDA when one makes this mistake), Metal silently accepts
and we get garbage.
This commit is contained in:
Iwan Kawrakow
2024-10-14 10:15:25 +03:00
parent baab1d9a1e
commit 9802c771b8

View File

@@ -6463,7 +6463,6 @@ void kernel_mul_mv_iq3_k_f32_impl(
uint32_t vl[2], vh[2];
uint32_t aux32[2];
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
uint16_t shift[4];
for (int ib = ix; ib < nb; ib += 4) {
@@ -6479,18 +6478,14 @@ void kernel_mul_mv_iq3_k_f32_impl(
device const block_iq3_k & xb = x[row*nb + ib];
device const uint16_t * ql16 = (device const uint16_t *)xb.qs + 16*iq + 4*ir;
device const uint16_t * qh16 = (device const uint16_t *)xb.qh + 4*ir;
device const uint32_t * sc = (device const uint32_t *)xb.scales_l;
device const uint16_t * sc16 = (device const uint16_t *)xb.scales_l;
const uint32_t scales32 = ((sc[iq] >> 4*is) & 0x0f0f0f0f) << 1;
uint32_t scales32 = sc16[2*iq+0] | (sc16[2*iq+1] << 16);
scales32 = ((scales32 >> 4*is) & 0x0f0f0f0f) << 1;
thread const int8_t * s8 = (thread const int8_t *)&scales32;
uint16_t extra = xb.extra >> (8*iq + is);
uint16_t extra = (xb.extra >> (8*iq + is)) << 3;
uint16_t signs = xb.scales_h >> (8*iq + is);
shift[0] = (extra << 3) & 8;
shift[1] = (extra << 2) & 8;
shift[2] = (extra << 1) & 8;
shift[3] = (extra << 0) & 8;
vl[0] = ql16[0] | ql16[1] << 16;
vl[1] = ql16[2] | ql16[3] << 16;
vh[0] = ((qh16[0] | (qh16[1] << 16)) << 4*(1-iq)) >> 2;
@@ -6498,12 +6493,13 @@ void kernel_mul_mv_iq3_k_f32_impl(
float4 acc = {0.f};
for (int l = 0; l < 4; ++l) {
constant float * values = kvalues_iq3k_f + shift[l];
constant float * values = kvalues_iq3k_f + (extra & 8);
aux32[0] = (vl[0] & 0x03030303) | (vh[0] & 0x04040404);
aux32[1] = (vl[1] & 0x03030303) | (vh[1] & 0x04040404);
for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]];
vl[0] >>= 2; vl[1] >>= 2;
vh[0] >>= 1; vh[1] >>= 1;
extra >>= 2;
}
sumf[row] += (float)xb.d * (acc[0] * (signs & 0x01 ? -s8[0] : s8[0]) + acc[1] * (signs & 0x04 ? -s8[1] : s8[1]) +