mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 23:24:13 +00:00
iq3_k: fix Metal dot product
I was accessing the scales as 4-byte aligned, but iq3_k is not 4-byte aligned. Instead of throwing an error (as it happens on CUDA when one makes this mistake), Metal silently accepts and we get garbage.
This commit is contained in:
@@ -6463,7 +6463,6 @@ void kernel_mul_mv_iq3_k_f32_impl(
|
||||
uint32_t vl[2], vh[2];
|
||||
uint32_t aux32[2];
|
||||
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
|
||||
uint16_t shift[4];
|
||||
|
||||
for (int ib = ix; ib < nb; ib += 4) {
|
||||
|
||||
@@ -6479,18 +6478,14 @@ void kernel_mul_mv_iq3_k_f32_impl(
|
||||
device const block_iq3_k & xb = x[row*nb + ib];
|
||||
device const uint16_t * ql16 = (device const uint16_t *)xb.qs + 16*iq + 4*ir;
|
||||
device const uint16_t * qh16 = (device const uint16_t *)xb.qh + 4*ir;
|
||||
device const uint32_t * sc = (device const uint32_t *)xb.scales_l;
|
||||
device const uint16_t * sc16 = (device const uint16_t *)xb.scales_l;
|
||||
|
||||
const uint32_t scales32 = ((sc[iq] >> 4*is) & 0x0f0f0f0f) << 1;
|
||||
uint32_t scales32 = sc16[2*iq+0] | (sc16[2*iq+1] << 16);
|
||||
scales32 = ((scales32 >> 4*is) & 0x0f0f0f0f) << 1;
|
||||
thread const int8_t * s8 = (thread const int8_t *)&scales32;
|
||||
uint16_t extra = xb.extra >> (8*iq + is);
|
||||
uint16_t extra = (xb.extra >> (8*iq + is)) << 3;
|
||||
uint16_t signs = xb.scales_h >> (8*iq + is);
|
||||
|
||||
shift[0] = (extra << 3) & 8;
|
||||
shift[1] = (extra << 2) & 8;
|
||||
shift[2] = (extra << 1) & 8;
|
||||
shift[3] = (extra << 0) & 8;
|
||||
|
||||
vl[0] = ql16[0] | ql16[1] << 16;
|
||||
vl[1] = ql16[2] | ql16[3] << 16;
|
||||
vh[0] = ((qh16[0] | (qh16[1] << 16)) << 4*(1-iq)) >> 2;
|
||||
@@ -6498,12 +6493,13 @@ void kernel_mul_mv_iq3_k_f32_impl(
|
||||
|
||||
float4 acc = {0.f};
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
constant float * values = kvalues_iq3k_f + shift[l];
|
||||
constant float * values = kvalues_iq3k_f + (extra & 8);
|
||||
aux32[0] = (vl[0] & 0x03030303) | (vh[0] & 0x04040404);
|
||||
aux32[1] = (vl[1] & 0x03030303) | (vh[1] & 0x04040404);
|
||||
for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]];
|
||||
vl[0] >>= 2; vl[1] >>= 2;
|
||||
vh[0] >>= 1; vh[1] >>= 1;
|
||||
extra >>= 2;
|
||||
}
|
||||
|
||||
sumf[row] += (float)xb.d * (acc[0] * (signs & 0x01 ? -s8[0] : s8[0]) + acc[1] * (signs & 0x04 ? -s8[1] : s8[1]) +
|
||||
|
||||
Reference in New Issue
Block a user