mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-22 15:39:23 +00:00
iq3_kt: Metal GEMV
Performance is not as good as iq2_kt: 40 t/s on my M2-Max for LlaMA-3.1-8B. Flipping signs is a costly affair.
This commit is contained in:
@@ -6759,7 +6759,6 @@ kernel void kernel_mul_mv_iq2_kt_f32(
|
||||
kernel_mul_mv_iq2_kt_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
// TODO
|
||||
void kernel_mul_mv_iq3_kt_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
@@ -6784,7 +6783,7 @@ void kernel_mul_mv_iq3_kt_f32_impl(
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
||||
const uint row_size = sizeof(float) + nb*sizeof(block_iq2_kt);
|
||||
const uint row_size = sizeof(float) + nb*sizeof(block_iq3_kt);
|
||||
|
||||
const uint i12 = im%ne12;
|
||||
const uint i13 = im/ne12;
|
||||
@@ -6801,7 +6800,8 @@ void kernel_mul_mv_iq3_kt_f32_impl(
|
||||
|
||||
device const float4 * y4 = (device const float4 *)y + ix * (QK_K/4) + 4 * it;
|
||||
|
||||
float4 v1, v2;
|
||||
float4 v[2];
|
||||
thread uint32_t * u32 = (thread uint32_t *)v;
|
||||
|
||||
float drow[N_DST];
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
@@ -6809,7 +6809,7 @@ void kernel_mul_mv_iq3_kt_f32_impl(
|
||||
drow[row] = dptr[0] * 31.75f * 1.05f;
|
||||
}
|
||||
|
||||
device const block_iq2_kt * x = (device const block_iq2_kt *)(cx + sizeof(float));
|
||||
device const block_iq3_kt * x = (device const block_iq3_kt *)(cx + sizeof(float));
|
||||
|
||||
for (int ib = ix; ib < nb; ib += 2) {
|
||||
|
||||
@@ -6818,14 +6818,26 @@ void kernel_mul_mv_iq3_kt_f32_impl(
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
|
||||
device const uint16_t * q2 = (device const uint16_t *)(sc + 4);
|
||||
device const uint8_t * qh = (device const uint8_t *)(q2 + QK_K/8) + 16*(it%2);
|
||||
|
||||
const float ls = drow[row] * iq4k_values[(sc[(it/2)%4] >> 4*(it/8)) & 0xf];
|
||||
const float ls = drow[row] * ((sc[(it/2)%4] >> 4*(it/8)) & 0xf);
|
||||
const uint8_t mask = 1 << (it/2);
|
||||
|
||||
Trellis::gen8(q2[2*it+0]+4096, v1, v2);
|
||||
auto sum = v1*y4[0] + v2*y4[1];
|
||||
Trellis::gen8(q2[2*it+0]+4096, v[0], v[1]);
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
u32[j] &= 0x7fffffff;
|
||||
u32[j] |= qh[j+0] & mask ? 0x80000000 : 0;
|
||||
}
|
||||
|
||||
Trellis::gen8(q2[2*it+1]+4096, v1, v2);
|
||||
sum += v1*y4[2] + v2*y4[3];
|
||||
auto sum = v[0]*y4[0] + v[1]*y4[1];
|
||||
|
||||
Trellis::gen8(q2[2*it+1]+4096, v[0], v[1]);
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
u32[j] &= 0x7fffffff;
|
||||
u32[j] |= qh[j+8] & mask ? 0x80000000 : 0;
|
||||
}
|
||||
|
||||
sum += v[0]*y4[2] + v[1]*y4[3];
|
||||
|
||||
sum *= ls;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user