New iq2_kt: Metal - very slow.

It seems Apple Silicon cannot quickly add 4 8-bit ints.
Or I don't know how to do it - but I didn't find anything
in the Metal Shading Language Specification.
So, performance is quite a bit worse than the original trellis.
This commit is contained in:
Iwan Kawrakow
2025-06-09 15:04:23 +03:00
parent d075a1c75b
commit f2be982fd8

View File

@@ -6613,6 +6613,18 @@ struct Trellis3 {
for (int i = 0; i < 4; ++i) result[i] = -126 + a8[4*i+0] + a8[4*i+1] + a8[4*i+2] + a8[4*i+3];
return result;
}
template <typename T4>
static inline void gen8(uint32_t val, thread T4& v1, thread T4& v2) {
thread uint32_t aux[4] = {ka*val + kb, ka1*val + kb1, ka2*val + kb2, ka3*val + kb3};
uint32_t aux32[2];
thread const int8_t * a8 = (thread const int8_t *)aux32;
for (int i = 0; i < 4; ++i) {
aux32[0] = aux[i] & kmask;
aux32[1] = (ka3*aux[i] + kb3) & kmask;
v1[i] = -126 + a8[0] + a8[1] + a8[2] + a8[3];
v2[i] = -126 + a8[4] + a8[5] + a8[6] + a8[7];
}
}
};
struct Trellis {
@@ -6710,7 +6722,7 @@ void kernel_mul_mv_iq2_kt_f32_impl(
float drow[N_DST];
for (int row = 0; row < N_DST; ++row) {
device const float * dptr = (device const float *)(cx + row*row_size);
drow[row] = dptr[0] * 31.75f * 1.05f;
drow[row] = dptr[0] * 1.05f;
}
device const block_iq2_kt * x = (device const block_iq2_kt *)(cx + sizeof(float));
@@ -6725,10 +6737,10 @@ void kernel_mul_mv_iq2_kt_f32_impl(
const float ls = drow[row] * iq4k_values[(sc[(it/2)%4] >> 4*(it/8)) & 0xf];
Trellis::gen8(q2[2*it+0]+4096, v1, v2);
Trellis3::gen8(q2[2*it+0]+4096, v1, v2);
auto sum = v1*y4[0] + v2*y4[1];
Trellis::gen8(q2[2*it+1]+4096, v1, v2);
Trellis3::gen8(q2[2*it+1]+4096, v1, v2);
sum += v1*y4[2] + v2*y4[3];
sum *= ls;
@@ -8561,19 +8573,18 @@ template <typename type4x4>
void dequantize_iq2_kt(device const block_iq2_kt * x, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256
int ib32 = il/2;
half scale = iq4k_values[((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf)] * 31.75h * 1.05h;
half scale = iq4k_values[((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf)] * 1.05h;
device const uint16_t * q2 = (device const uint16_t *)x->ql + 4*ib32 + 2*(il%2);
half4 v1, v2;
char4 v1, v2;
for (int i = 0; i < 2; ++i) {
Trellis::gen8(q2[i]+4096, v1, v2);
v1 *= scale; v2 *= scale;
Trellis3::gen8(q2[i]+4096, v1, v2);
if constexpr (is_same_v<type4x4, half4x4>) {
reg[2*i+0] = v1;
reg[2*i+1] = v2;
reg[2*i+0] = {scale*(half)v1[0], scale*(half)v1[1], scale*(half)v1[2], scale*(half)v1[3]};
reg[2*i+1] = {scale*(half)v2[0], scale*(half)v2[1], scale*(half)v2[2], scale*(half)v2[3]};
} else {
reg[2*i+0] = {(float)v1[0], (float)v1[1], (float)v1[2], (float)v1[3]};
reg[2*i+1] = {(float)v2[0], (float)v2[1], (float)v2[2], (float)v2[3]};
reg[2*i+0] = {scale*(float)v1[0], scale*(float)v1[1], scale*(float)v1[2], scale*(float)v1[3]};
reg[2*i+1] = {scale*(float)v2[0], scale*(float)v2[1], scale*(float)v2[2], scale*(float)v2[3]};
}
}
}