iq3_kt: Metal GEMV

Performance is not as good as iq2_kt: 40 t/s on my M2-Max for LlaMA-3.1-8B.
Flipping signs is a costly affair.
This commit is contained in:
Iwan Kawrakow
2025-05-30 13:21:25 +03:00
parent 2396cc3f88
commit ad52554a5e

View File

@@ -6759,7 +6759,6 @@ kernel void kernel_mul_mv_iq2_kt_f32(
kernel_mul_mv_iq2_kt_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
// TODO
void kernel_mul_mv_iq3_kt_f32_impl(
device const void * src0,
device const float * src1,
@@ -6784,7 +6783,7 @@ void kernel_mul_mv_iq3_kt_f32_impl(
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const uint row_size = sizeof(float) + nb*sizeof(block_iq2_kt);
const uint row_size = sizeof(float) + nb*sizeof(block_iq3_kt);
const uint i12 = im%ne12;
const uint i13 = im/ne12;
@@ -6801,7 +6800,8 @@ void kernel_mul_mv_iq3_kt_f32_impl(
device const float4 * y4 = (device const float4 *)y + ix * (QK_K/4) + 4 * it;
float4 v1, v2;
float4 v[2];
thread uint32_t * u32 = (thread uint32_t *)v;
float drow[N_DST];
for (int row = 0; row < N_DST; ++row) {
@@ -6809,7 +6809,7 @@ void kernel_mul_mv_iq3_kt_f32_impl(
drow[row] = dptr[0] * 31.75f * 1.05f;
}
device const block_iq2_kt * x = (device const block_iq2_kt *)(cx + sizeof(float));
device const block_iq3_kt * x = (device const block_iq3_kt *)(cx + sizeof(float));
for (int ib = ix; ib < nb; ib += 2) {
@@ -6818,14 +6818,26 @@ void kernel_mul_mv_iq3_kt_f32_impl(
for (int row = 0; row < N_DST; row++) {
device const uint16_t * q2 = (device const uint16_t *)(sc + 4);
device const uint8_t * qh = (device const uint8_t *)(q2 + QK_K/8) + 16*(it%2);
const float ls = drow[row] * iq4k_values[(sc[(it/2)%4] >> 4*(it/8)) & 0xf];
const float ls = drow[row] * ((sc[(it/2)%4] >> 4*(it/8)) & 0xf);
const uint8_t mask = 1 << (it/2);
Trellis::gen8(q2[2*it+0]+4096, v1, v2);
auto sum = v1*y4[0] + v2*y4[1];
Trellis::gen8(q2[2*it+0]+4096, v[0], v[1]);
for (int j = 0; j < 8; ++j) {
u32[j] &= 0x7fffffff;
u32[j] |= qh[j+0] & mask ? 0x80000000 : 0;
}
Trellis::gen8(q2[2*it+1]+4096, v1, v2);
sum += v1*y4[2] + v2*y4[3];
auto sum = v[0]*y4[0] + v[1]*y4[1];
Trellis::gen8(q2[2*it+1]+4096, v[0], v[1]);
for (int j = 0; j < 8; ++j) {
u32[j] &= 0x7fffffff;
u32[j] |= qh[j+8] & mask ? 0x80000000 : 0;
}
sum += v[0]*y4[2] + v[1]*y4[3];
sum *= ls;