iq3_ks: Metal gemv - pathetic performance

This commit is contained in:
Iwan Kawrakow
2025-07-01 12:50:57 +02:00
parent f7b3c07c92
commit 59967f3d64

View File

@@ -7396,15 +7396,15 @@ void kernel_mul_mv_iq3_ks_f32_impl(
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
device const block_iq3_k * x = (device const block_iq3_k *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
const uint row_size = sizeof(half) + nb*sizeof(block_iq3_ks);
device const char * cx = (device const char *)src0 + (first_row + offset0)*row_size;
device const float * y = (device const float *)src1 + r1*ne10 + im*ne00*ne1;
threadgroup float * all_values = (threadgroup float *)shared_values + 16*sgitg;
{
@@ -7414,12 +7414,19 @@ void kernel_mul_mv_iq3_ks_f32_impl(
float yl[32];
float sumf[N_DST]={0.f};
float d[N_DST];
const int ix = tiisg/8; // 0...3
const int it = tiisg%8; // 0...7
const int iq = it/4; // 0 or 1
const int ir = it%4; // 0...3
const int is = (8*ir)/16;// 0 or 1
device const half * dptr = (device const half *)cx;
d[0] = (float)dptr[0];
for (int i = 1; i < N_DST; ++i) {
dptr += row_size/2;
d[i] = (float)dptr[0];
}
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
@@ -7438,16 +7445,22 @@ void kernel_mul_mv_iq3_ks_f32_impl(
for (int row = 0; row < N_DST; row++) {
device const block_iq3_k & xb = x[row*nb + ib];
device const block_iq3_ks * x = (device const block_iq3_ks *)(cx + row_size*row + sizeof(half));
device const block_iq3_ks & xb = x[ib];
device const uint16_t * ql16 = (device const uint16_t *)xb.qs + 16*iq + 4*ir;
device const uint16_t * qh16 = (device const uint16_t *)xb.qh + 4*ir;
device const uint16_t * sc16 = (device const uint16_t *)xb.scales_l;
device const uint16_t * sc16 = (device const uint16_t *)xb.scales;
uint32_t scales32 = sc16[2*iq+0] | (sc16[2*iq+1] << 16);
scales32 = ((scales32 >> 4*is) & 0x0f0f0f0f) << 1;
thread const int8_t * s8 = (thread const int8_t *)&scales32;
uint16_t extra = (xb.extra >> (8*iq + is)) << 3;
uint16_t signs = xb.scales_h >> (8*iq + is);
uint8_t extra_s = (xb.extra & 0xff) >> 4*iq;
uint8_t extra_v = xb.extra >> (8 + 4*iq);
uint32_t scales32 = sc16[0] | (sc16[1] << 16);
scales32 = (scales32 >> 4*iq) & 0x0f0f0f0f;
thread int8_t * s8 = (thread int8_t *)&scales32;
s8[0] += ((extra_s << 4) & 0x10) - 16;
s8[1] += ((extra_s << 3) & 0x10) - 16;
s8[2] += ((extra_s << 2) & 0x10) - 16;
s8[3] += ((extra_s << 1) & 0x10) - 16;
vl[0] = ql16[0] | ql16[1] << 16;
vl[1] = ql16[2] | ql16[3] << 16;
@@ -7456,18 +7469,16 @@ void kernel_mul_mv_iq3_ks_f32_impl(
float4 acc = {0.f};
for (int l = 0; l < 4; ++l) {
threadgroup const float * values = all_values + (extra & 8);
//constant float * values = kvalues_iq3k_f + (extra & 8);
threadgroup const float * values = all_values + ((extra_v & 1) << 3);
aux32[0] = (vl[0] & 0x03030303) | (vh[0] & 0x04040404);
aux32[1] = (vl[1] & 0x03030303) | (vh[1] & 0x04040404);
for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]];
vl[0] >>= 2; vl[1] >>= 2;
vh[0] >>= 1; vh[1] >>= 1;
extra >>= 2;
extra_v >>= 1;
}
sumf[row] += (float)xb.d * (acc[0] * (signs & 0x01 ? -s8[0] : s8[0]) + acc[1] * (signs & 0x04 ? -s8[1] : s8[1]) +
acc[2] * (signs & 0x10 ? -s8[2] : s8[2]) + acc[3] * (signs & 0x40 ? -s8[3] : s8[3]));
sumf[row] += d[row] * (acc[0] * s8[0] + acc[1] * s8[1] + acc[2] * s8[2] + acc[3] * s8[3]);
}