From a0ba58e9b977ae13bf47d08b12745ee4325fe015 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 20 Jun 2025 10:47:22 +0300 Subject: [PATCH] iq2_kt and iq3_kt work with new int trellis Much slower than the fp16 based trellis. I guess, Apple doesn't have int8_t SIMD on the M2-Max GPU. --- ggml/src/ggml-metal.metal | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index c3c4f0bb..e3bd070d 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6598,16 +6598,12 @@ void kernel_mul_mv_iq2_k_f32_impl( struct Trellis3 { constexpr constant static uint32_t kmask = 0x3f3f3f3f; - constexpr constant static uint32_t ka = 89226354; - constexpr constant static uint32_t kb = 64248484; + constexpr constant static uint32_t ka = 0xCBAC1FED; constexpr constant static uint32_t ka1 = ka*ka; - constexpr constant static uint32_t kb1 = kb*ka+kb; constexpr constant static uint32_t ka2 = ka1*ka; - constexpr constant static uint32_t kb2 = kb1*ka+kb; constexpr constant static uint32_t ka3 = ka2*ka; - constexpr constant static uint32_t kb3 = kb2*ka+kb; static inline char4 gen4(uint32_t val) { - thread uint32_t aux[4] = {(ka*val + kb) & kmask, (ka1*val + kb1) & kmask, (ka2*val + kb2) & kmask, (ka3*val + kb3) & kmask}; + thread uint32_t aux[4] = {(ka*val) & kmask, (ka1*val) & kmask, (ka2*val) & kmask, (ka3*val) & kmask}; thread const int8_t * a8 = (thread const int8_t *)aux; char4 result; for (int i = 0; i < 4; ++i) result[i] = -126 + a8[4*i+0] + a8[4*i+1] + a8[4*i+2] + a8[4*i+3]; @@ -6615,14 +6611,18 @@ struct Trellis3 { } template static inline void gen8(uint32_t val, thread T4& v1, thread T4& v2) { - thread uint32_t aux[4] = {ka*val + kb, ka1*val + kb1, ka2*val + kb2, ka3*val + kb3}; + thread uint32_t aux[4] = {ka*val, ka1*val, ka2*val, ka3*val}; uint32_t aux32[2]; thread const int8_t * a8 = (thread const int8_t *)aux32; + //thread const char4 * a8 = (thread const char4 *)aux32; for (int i = 0; i < 4; ++i) { aux32[0] = aux[i] & kmask; - aux32[1] = (ka3*aux[i] + kb3) & kmask; + aux32[1] = (ka3*aux[i]) & kmask; v1[i] = -126 + a8[0] + a8[1] + a8[2] + a8[3]; v2[i] = -126 + a8[4] + a8[5] + a8[6] + a8[7]; + // Much slower: + //v1[i] = -126 + a8[0][0] + a8[0][1] + a8[0][2] + a8[0][3]; + //v2[i] = -126 + a8[1][0] + a8[1][1] + a8[1][2] + a8[1][3]; } } }; @@ -6837,7 +6837,7 @@ void kernel_mul_mv_iq3_kt_f32_impl( float drow[N_DST]; for (int row = 0; row < N_DST; ++row) { device const float * dptr = (device const float *)(cx + row*row_size); - drow[row] = dptr[0] * 31.75f * 1.01f; + drow[row] = dptr[0] * 1.01f; } device const block_iq3_kt * x = (device const block_iq3_kt *)(cx + sizeof(float)); @@ -6854,7 +6854,7 @@ void kernel_mul_mv_iq3_kt_f32_impl( const float ls = drow[row] * ((sc[(it/2)%4] >> 4*(it/8)) & 0xf); const uint8_t mask = 1 << (it/2); - Trellis::gen8(q2[2*it+0]+4096, v[0], v[1]); + Trellis3::gen8(q2[2*it+0]+4096, v[0], v[1]); for (int j = 0; j < 8; ++j) { u32[j] &= 0x7fffffff; u32[j] |= qh[j+0] & mask ? 0x80000000 : 0; @@ -6862,7 +6862,7 @@ void kernel_mul_mv_iq3_kt_f32_impl( auto sum = v[0]*y4[0] + v[1]*y4[1]; - Trellis::gen8(q2[2*it+1]+4096, v[0], v[1]); + Trellis3::gen8(q2[2*it+1]+4096, v[0], v[1]); for (int j = 0; j < 8; ++j) { u32[j] &= 0x7fffffff; u32[j] |= qh[j+8] & mask ? 0x80000000 : 0; @@ -8593,17 +8593,14 @@ template void dequantize_iq3_kt(device const block_iq3_kt * x, short il, thread type4x4 & reg) { // il is 0...15 for QK_K = 256 int ib32 = il/2; - half scale = (half)((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf) * 31.75h * 1.01h; + half scale = (half)((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf) * 1.01h; device const uint16_t * q2 = (device const uint16_t *)x->ql + 4*ib32 + 2*(il%2); device const uint8_t * qh = x->qh + 16*(il%2); const uint8_t mask = 1 << ib32; half4 v1, v2; for (int i = 0; i < 2; ++i) { - Trellis::gen8(q2[i]+4096, v1, v2); - //v1 *= scale; v2 *= scale; - //for (int j = 0; j < 4; ++j) reg[2*i+0][j] = qh[8*i+0+j] & mask ? -abs(v1[j]) : abs(v1[j]); - //for (int j = 0; j < 4; ++j) reg[2*i+1][j] = qh[8*i+4+j] & mask ? -abs(v2[j]) : abs(v2[j]); + Trellis3::gen8(q2[i]+4096, v1, v2); v1 = abs(v1)*scale; v2 = abs(v2)*scale; for (int j = 0; j < 4; ++j) reg[2*i+0][j] = qh[8*i+0+j] & mask ? -v1[j] : v1[j]; for (int j = 0; j < 4; ++j) reg[2*i+1][j] = qh[8*i+4+j] & mask ? -v2[j] : v2[j];