From f9f15c27b694b5df0bd6aa29b6366b502221555e Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 13 Oct 2024 12:23:14 +0300 Subject: [PATCH] iq2_ks: faster Metal MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit LLaMA-3.1-8B: PP-512 = 475.22 ± 0.37 t/s TG-128 = 45.32 ± 0.03 t/s --- ggml/src/ggml-metal.m | 7 +++++- ggml/src/ggml-metal.metal | 53 ++++++++++++++++++++++++++++----------- 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 533f2fa5..d5e8d6ae 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -2289,10 +2289,15 @@ static enum ggml_status ggml_metal_graph_compute( if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q6_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| - src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K|| src0t == GGML_TYPE_IQ2_KS || + src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K|| src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } + else if (src0t == GGML_TYPE_IQ2_KS) { + const int mem_size = 64*sizeof(float); + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 31d38df9..5ed424d3 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -3685,6 +3685,7 @@ constexpr constant static float kvalues_iq6k_f[128] = { }; constexpr constant static float kvalues_iq2k_f[8] = { -31.f, -13.f, 1.f, 17.f, -26.f, -8.f, 6.f, 22.f }; +constexpr constant static half kvalues_iq2k_h[8] = { -31.h, -13.h, 1.h, 17.h, -26.h, -8.h, 6.h, 22.h }; constexpr constant static float kvalues_iq3k_f[16] = { -63.f, -40.f, -23.f, -10.f, 1.f, 13.f, 28.f, 47.f, -59.f, -36.f, -19.f, -6.f, 5.f, 17.f, 32.f, 51.f }; constexpr constant static half kvalues_iq3k_h[16] = { -63.h, -40.h, -23.h, -10.h, 1.h, 13.h, 28.h, 47.h, -59.h, -36.h, -19.h, -6.h, 5.h, 17.h, 32.h, 51.h }; @@ -6304,6 +6305,24 @@ void kernel_mul_mv_iq2_ks_f32_impl( device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; + threadgroup float * all_values = (threadgroup float *)shared_values + 32*sgitg; + { + //int row = tiisg%N_DST; + //device const half * dptr = (device const half *)(cx + row*row_size); + //const float d = *dptr; + //all_values[8*row + tiisg/N_DST] = d*iq2nl_values[tiisg/N_DST]; + //threadgroup_barrier(mem_flags::mem_threadgroup); + int row = tiisg/8; + int pos = tiisg%8; + device const half * dptr = (device const half *)(cx + row*row_size); + const float d = *dptr; + all_values[8*row + pos] = d*kvalues_iq2k_f[pos]; + simdgroup_barrier(mem_flags::mem_none); + //threadgroup_barrier(mem_flags::mem_threadgroup); + } + + cx += sizeof(half); + uint32_t q32[2]; uint32_t aux32[2]; thread const uint8_t * aux8 = (thread const uint8_t *)aux32; @@ -6317,15 +6336,14 @@ void kernel_mul_mv_iq2_ks_f32_impl( yl[i+24] = y4[i+96]; } + device const block_iq2_ks * x = (device const block_iq2_ks *)cx + ib; + device const uint16_t * q16 = (device const uint16_t *)x->qs + 16*iq + 4*ir; + device const uint16_t * sc = (device const uint16_t *)x->scales; + device const uint16_t * ex = (device const uint16_t *)&x->extra; + for (int row = 0; row < N_DST; row++) { - device const half * dptr = (device const half *)(cx + row*row_size); - const float d = *dptr; - - device const block_iq2_ks * x = (device const block_iq2_ks *)(dptr + 1); - device const block_iq2_ks & xb = x[ib]; - device const uint16_t * q16 = (device const uint16_t *)xb.qs + 16*iq + 4*ir; - device const uint16_t * sc = (device const uint16_t *)xb.scales; + threadgroup const float * row_values = all_values + 8*row; uint32_t sc32 = (sc[iq] | (sc[iq] << 12)) & 0x0f0f0f0f; thread const int8_t * s8 = (thread const int8_t *)&sc32; @@ -6333,18 +6351,22 @@ void kernel_mul_mv_iq2_ks_f32_impl( q32[0] = q16[0] | (q16[1] << 16); q32[1] = q16[2] | (q16[3] << 16); - uint8_t extra = xb.extra << 4*(1-iq); + uint8_t extra = ex[0] << 4*(1-iq); float4 acc = {0.f}; for (int l = 0; l < 4; ++l) { - constant float * values = kvalues_iq2k_f + ((extra >> (2 + l)) & 4); + threadgroup const float * values = row_values + ((extra >> (2 + l)) & 4); aux32[0] = (q32[0] >> 2*l) & 0x03030303; aux32[1] = (q32[1] >> 2*l) & 0x03030303; for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]]; } - extra = xb.extra >> (8 + 4*iq); - sumf[row] += d * (acc[0] * (s8[0] - (extra & 1 ? 0 : 16)) + acc[1] * (s8[2] - (extra & 2 ? 0 : 16)) - + acc[2] * (s8[1] - (extra & 4 ? 0 : 16)) + acc[3] * (s8[3] - (extra & 8 ? 0 : 16))); + extra = ex[0] >> (8 + 4*iq); + sumf[row] += acc[0] * (s8[0] - (extra & 1 ? 0 : 16)) + acc[1] * (s8[2] - (extra & 2 ? 0 : 16)) + + acc[2] * (s8[1] - (extra & 4 ? 0 : 16)) + acc[3] * (s8[3] - (extra & 8 ? 0 : 16)); + + q16 += row_size/2; + sc += row_size/2; + ex += row_size/2; } @@ -6381,11 +6403,12 @@ kernel void kernel_mul_mv_iq2_ks_f32( constant int64_t & ne1, constant uint & r2, constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_ks_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_ks_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq3_k_f32_impl( @@ -7703,8 +7726,8 @@ void dequantize_iq2_ks(device const block_iq2_ks * xb, short il, thread type4x4 const short ib32 = il/2; half d = (((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf) - ((xb->extra >> (8 + ib32)) & 1 ? 0 : 16)); - constant int8_t * int_values = iq2nl_values + 4*((xb->extra >> ib32) & 1); - half4 values = { d * int_values[0], d * int_values[1], d * int_values[2], d * int_values[3] }; + constant half4 * half_values = (constant half4 *)kvalues_iq2k_h; + half4 values = half_values[(xb->extra >> ib32) & 1] * d; const int shift = 2*((il%8)/2); thread uint16_t aux16[2];