From f2be982fd8d0f7a2176c0437640ade4225ad5932 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 9 Jun 2025 15:04:23 +0300 Subject: [PATCH] New iq2_kt: Metal - very slow. It seems Apple Silicon cannot quickly add 4 8-bit ints. Or I don't know how to do it - but I didn't find anything in the Metal Shading Language Specification. So, performance is quite a bit worse than the original trellis. --- ggml/src/ggml-metal.metal | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index e1d47404..c3c4f0bb 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6613,6 +6613,18 @@ struct Trellis3 { for (int i = 0; i < 4; ++i) result[i] = -126 + a8[4*i+0] + a8[4*i+1] + a8[4*i+2] + a8[4*i+3]; return result; } + template + static inline void gen8(uint32_t val, thread T4& v1, thread T4& v2) { + thread uint32_t aux[4] = {ka*val + kb, ka1*val + kb1, ka2*val + kb2, ka3*val + kb3}; + uint32_t aux32[2]; + thread const int8_t * a8 = (thread const int8_t *)aux32; + for (int i = 0; i < 4; ++i) { + aux32[0] = aux[i] & kmask; + aux32[1] = (ka3*aux[i] + kb3) & kmask; + v1[i] = -126 + a8[0] + a8[1] + a8[2] + a8[3]; + v2[i] = -126 + a8[4] + a8[5] + a8[6] + a8[7]; + } + } }; struct Trellis { @@ -6710,7 +6722,7 @@ void kernel_mul_mv_iq2_kt_f32_impl( float drow[N_DST]; for (int row = 0; row < N_DST; ++row) { device const float * dptr = (device const float *)(cx + row*row_size); - drow[row] = dptr[0] * 31.75f * 1.05f; + drow[row] = dptr[0] * 1.05f; } device const block_iq2_kt * x = (device const block_iq2_kt *)(cx + sizeof(float)); @@ -6725,10 +6737,10 @@ void kernel_mul_mv_iq2_kt_f32_impl( const float ls = drow[row] * iq4k_values[(sc[(it/2)%4] >> 4*(it/8)) & 0xf]; - Trellis::gen8(q2[2*it+0]+4096, v1, v2); + Trellis3::gen8(q2[2*it+0]+4096, v1, v2); auto sum = v1*y4[0] + v2*y4[1]; - Trellis::gen8(q2[2*it+1]+4096, v1, v2); + Trellis3::gen8(q2[2*it+1]+4096, v1, v2); sum += v1*y4[2] + v2*y4[3]; sum *= ls; @@ -8561,19 +8573,18 @@ template void dequantize_iq2_kt(device const block_iq2_kt * x, short il, thread type4x4 & reg) { // il is 0...15 for QK_K = 256 int ib32 = il/2; - half scale = iq4k_values[((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf)] * 31.75h * 1.05h; + half scale = iq4k_values[((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf)] * 1.05h; device const uint16_t * q2 = (device const uint16_t *)x->ql + 4*ib32 + 2*(il%2); - half4 v1, v2; + char4 v1, v2; for (int i = 0; i < 2; ++i) { - Trellis::gen8(q2[i]+4096, v1, v2); - v1 *= scale; v2 *= scale; + Trellis3::gen8(q2[i]+4096, v1, v2); if constexpr (is_same_v) { - reg[2*i+0] = v1; - reg[2*i+1] = v2; + reg[2*i+0] = {scale*(half)v1[0], scale*(half)v1[1], scale*(half)v1[2], scale*(half)v1[3]}; + reg[2*i+1] = {scale*(half)v2[0], scale*(half)v2[1], scale*(half)v2[2], scale*(half)v2[3]}; } else { - reg[2*i+0] = {(float)v1[0], (float)v1[1], (float)v1[2], (float)v1[3]}; - reg[2*i+1] = {(float)v2[0], (float)v2[1], (float)v2[2], (float)v2[3]}; + reg[2*i+0] = {scale*(float)v1[0], scale*(float)v1[1], scale*(float)v1[2], scale*(float)v1[3]}; + reg[2*i+1] = {scale*(float)v2[0], scale*(float)v2[1], scale*(float)v2[2], scale*(float)v2[3]}; } } }