From 0e12b2919c8a3c47ecf30591c4a22dfdea9ba901 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 9 Oct 2024 11:06:45 +0300 Subject: [PATCH] iq4_xxs: slightly faster TG on Metal --- ggml/src/ggml-metal.metal | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index ade6cc1c..681af7ef 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6077,7 +6077,7 @@ void kernel_mul_mv_iq4_xxs_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[2]={0.f}, all_sum; + float2 sumf = 0.f; device const float * yb = y + ix * QK_K + ib * 32 + il * 8; @@ -6091,9 +6091,11 @@ void kernel_mul_mv_iq4_xxs_f32_impl( device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + device const float * dptr = (device const float *)cx; + for (int row = 0; row < 2; ++row) { - device const float * dptr = (device const float *)(cx + row*row_size); + //device const float * dptr = (device const float *)(cx + row*row_size); const float d = *dptr; device const block_iq4_xxs * x = (device const block_iq4_xxs *)(dptr + 1); device const block_iq4_xxs & xb = x[ibl]; @@ -6122,16 +6124,16 @@ void kernel_mul_mv_iq4_xxs_f32_impl( const int ls = (xb.scales[ib] & 254) - 127; sumf[row] += d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + dptr += row_size/4; + } yb += 2 * QK_K; } - for (int row = 0; row < 2; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } + sumf = simd_sum(sumf); + if (tiisg < 2) { + dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg]; } }