mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
iq3_ks: Metal gemv - pathetic performance
This commit is contained in:
@@ -7396,15 +7396,15 @@ void kernel_mul_mv_iq3_ks_f32_impl(
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
||||
const int ib_row = first_row * nb;
|
||||
|
||||
const uint i12 = im%ne12;
|
||||
const uint i13 = im/ne12;
|
||||
|
||||
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
|
||||
device const block_iq3_k * x = (device const block_iq3_k *) src0 + ib_row + offset0;
|
||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
const uint row_size = sizeof(half) + nb*sizeof(block_iq3_ks);
|
||||
device const char * cx = (device const char *)src0 + (first_row + offset0)*row_size;
|
||||
device const float * y = (device const float *)src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
threadgroup float * all_values = (threadgroup float *)shared_values + 16*sgitg;
|
||||
{
|
||||
@@ -7414,12 +7414,19 @@ void kernel_mul_mv_iq3_ks_f32_impl(
|
||||
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f};
|
||||
float d[N_DST];
|
||||
|
||||
const int ix = tiisg/8; // 0...3
|
||||
const int it = tiisg%8; // 0...7
|
||||
const int iq = it/4; // 0 or 1
|
||||
const int ir = it%4; // 0...3
|
||||
const int is = (8*ir)/16;// 0 or 1
|
||||
|
||||
device const half * dptr = (device const half *)cx;
|
||||
d[0] = (float)dptr[0];
|
||||
for (int i = 1; i < N_DST; ++i) {
|
||||
dptr += row_size/2;
|
||||
d[i] = (float)dptr[0];
|
||||
}
|
||||
|
||||
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
||||
|
||||
@@ -7438,16 +7445,22 @@ void kernel_mul_mv_iq3_ks_f32_impl(
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
|
||||
device const block_iq3_k & xb = x[row*nb + ib];
|
||||
device const block_iq3_ks * x = (device const block_iq3_ks *)(cx + row_size*row + sizeof(half));
|
||||
device const block_iq3_ks & xb = x[ib];
|
||||
device const uint16_t * ql16 = (device const uint16_t *)xb.qs + 16*iq + 4*ir;
|
||||
device const uint16_t * qh16 = (device const uint16_t *)xb.qh + 4*ir;
|
||||
device const uint16_t * sc16 = (device const uint16_t *)xb.scales_l;
|
||||
device const uint16_t * sc16 = (device const uint16_t *)xb.scales;
|
||||
|
||||
uint32_t scales32 = sc16[2*iq+0] | (sc16[2*iq+1] << 16);
|
||||
scales32 = ((scales32 >> 4*is) & 0x0f0f0f0f) << 1;
|
||||
thread const int8_t * s8 = (thread const int8_t *)&scales32;
|
||||
uint16_t extra = (xb.extra >> (8*iq + is)) << 3;
|
||||
uint16_t signs = xb.scales_h >> (8*iq + is);
|
||||
uint8_t extra_s = (xb.extra & 0xff) >> 4*iq;
|
||||
uint8_t extra_v = xb.extra >> (8 + 4*iq);
|
||||
|
||||
uint32_t scales32 = sc16[0] | (sc16[1] << 16);
|
||||
scales32 = (scales32 >> 4*iq) & 0x0f0f0f0f;
|
||||
thread int8_t * s8 = (thread int8_t *)&scales32;
|
||||
s8[0] += ((extra_s << 4) & 0x10) - 16;
|
||||
s8[1] += ((extra_s << 3) & 0x10) - 16;
|
||||
s8[2] += ((extra_s << 2) & 0x10) - 16;
|
||||
s8[3] += ((extra_s << 1) & 0x10) - 16;
|
||||
|
||||
vl[0] = ql16[0] | ql16[1] << 16;
|
||||
vl[1] = ql16[2] | ql16[3] << 16;
|
||||
@@ -7456,18 +7469,16 @@ void kernel_mul_mv_iq3_ks_f32_impl(
|
||||
|
||||
float4 acc = {0.f};
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
threadgroup const float * values = all_values + (extra & 8);
|
||||
//constant float * values = kvalues_iq3k_f + (extra & 8);
|
||||
threadgroup const float * values = all_values + ((extra_v & 1) << 3);
|
||||
aux32[0] = (vl[0] & 0x03030303) | (vh[0] & 0x04040404);
|
||||
aux32[1] = (vl[1] & 0x03030303) | (vh[1] & 0x04040404);
|
||||
for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]];
|
||||
vl[0] >>= 2; vl[1] >>= 2;
|
||||
vh[0] >>= 1; vh[1] >>= 1;
|
||||
extra >>= 2;
|
||||
extra_v >>= 1;
|
||||
}
|
||||
|
||||
sumf[row] += (float)xb.d * (acc[0] * (signs & 0x01 ? -s8[0] : s8[0]) + acc[1] * (signs & 0x04 ? -s8[1] : s8[1]) +
|
||||
acc[2] * (signs & 0x10 ? -s8[2] : s8[2]) + acc[3] * (signs & 0x40 ? -s8[3] : s8[3]));
|
||||
sumf[row] += d[row] * (acc[0] * s8[0] + acc[1] * s8[1] + acc[2] * s8[2] + acc[3] * s8[3]);
|
||||
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user