iq4_kss: somewhat faster Metal dot product

45.75 t/s -> 48.75 t/s.
Still 22% slower than q4_0
This commit is contained in:
Iwan Kawrakow
2024-10-16 08:47:14 +03:00
parent e01045b02e
commit 7cbe979ee0

View File

@@ -6185,6 +6185,7 @@ void kernel_mul_mv_iq4_kss_f32_impl(
float4 yl[4];
float2 sumf = 0.f;
float d[2];
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
@@ -6193,19 +6194,21 @@ void kernel_mul_mv_iq4_kss_f32_impl(
float4 qf1, qf2;
device const float * dptr = (device const float *)cx;
d[0] = *dptr;
device const uint32_t * qptr = (device const uint32_t *)(dptr + 1);
dptr += row_size/4;
d[1] = *dptr;
for (int ibl = ix; ibl < nb; ibl += 2) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
device const float * dptr = (device const float *)cx;
device const uint32_t * q4 = qptr + ibl*(QK_K/8) + 4*ib;
for (int row = 0; row < 2; ++row) {
const float d = *dptr;
device const block_iq4_kss * x = (device const block_iq4_kss *)(dptr + 1);
device const uint32_t * q4 = (device const uint32_t *)x[ibl].qs + 4*ib;
uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6);
int16_t ls = (s32 | (s32 >> 15)) & 0xff;
@@ -6234,9 +6237,9 @@ void kernel_mul_mv_iq4_kss_f32_impl(
acc1 += acc2;
sumf[row] += d * scale * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
sumf[row] += d[row] * scale * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
dptr += row_size/4;
q4 += row_size/4;
}