From da9fdf57d847ea9594083bf633cffacf25ac3ca7 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 5 Nov 2024 09:46:43 +0100 Subject: [PATCH] Faster iq4_k: Metal --- ggml/src/ggml-metal.metal | 98 ++++++++++++++++++--------------------- 1 file changed, 44 insertions(+), 54 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 89cd412a..fa2ec232 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6492,73 +6492,59 @@ void kernel_mul_mv_iq4_k_f32_impl( const int ix = tiisg/16; // 0 or 1 const int it = tiisg%16; // 0...15 - const int ib = it/2; - const int il = it%2; shared_values[tiisg] = kvalues_iq4k_f[tiisg]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[2]={0.f}, all_sum; + float2 sumf = 0.f; - device const float * yb = y + ix * QK_K + ib * 32 + il * 8; + device const float * yb = y + ix * QK_K + it * 4; - uint32_t aux32[2]; - thread const uint8_t * q8 = (thread const uint8_t *)aux32; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; - float4 qf1, qf2; + float4 qf; for (int ibl = ix; ibl < nb; ibl += 2) { device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - //float2 sumy; - //sumy[0] = -4.f*(yl[0][0] + yl[0][1] + yl[0][2] + yl[0][3] + yl[2][0] + yl[2][1] + yl[2][2] + yl[2][3]); - //sumy[1] = -4.f*(yl[1][0] + yl[1][1] + yl[1][2] + yl[1][3] + yl[3][0] + yl[3][1] + yl[3][2] + yl[3][3]); + yl[0] = y4[0]; yl[1] = y4[16]; yl[2] = y4[32]; yl[3] = y4[48]; for (int row = 0; row < 2; ++row) { device const block_iq4_k & xb = x[row*nb + ibl]; - device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); + device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 4*it); - uint16_t extra = xb.extra >> 2*ib; - threadgroup const float * values1 = shared_values + 16*(extra & 1); - threadgroup const float * values2 = shared_values + 8*(extra & 2); + uint16_t extra = xb.extra >> it; + threadgroup const float * values = shared_values + ((extra & 1) << 4); - float4 acc1 = {0.f}, acc2 = {0.f}; + aux32 = q4[0] & 0x0f0f0f0f; + qf = {values[q8[0]], values[q8[1]], values[q8[2]], values[q8[3]]}; + float4 acc = yl[0] * qf; + aux32 = (q4[0] >> 4) & 0x0f0f0f0f; + qf = {values[q8[0]], values[q8[1]], values[q8[2]], values[q8[3]]}; + acc += yl[1] * qf; - aux32[0] = q4[0] & 0x0f0f0f0f; - aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; - qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]}; - qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]}; - acc1 += yl[0] * qf1; - acc2 += yl[1] * qf2; + aux32 = q4[16] & 0x0f0f0f0f; + qf = {values[q8[0]], values[q8[1]], values[q8[2]], values[q8[3]]}; + acc += yl[2] * qf; + aux32 = (q4[16] >> 4) & 0x0f0f0f0f; + qf = {values[q8[0]], values[q8[1]], values[q8[2]], values[q8[3]]}; + acc += yl[3] * qf; - aux32[0] = q4[1] & 0x0f0f0f0f; - aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; - qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]}; - qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]}; - acc1 += yl[2] * qf1; - acc2 += yl[3] * qf2; - - const uint8_t h = xb.scales_h[ib/2] >> 4*(ib%2); - const int ls1 = ((xb.scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; - const int ls2 = ((xb.scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; - sumf[row] += (float)xb.d * (ls1 * (acc1[0] + acc1[1] + acc1[2] + acc1[3]) + ls2 * (acc2[0] + acc2[1] + acc2[2] + acc2[3])); - //uint16_t extra = xb.extra >> 2*ib; - //sumf[row] += (float)xb.d * (ls1 * (acc1[0] + acc1[1] + acc1[2] + acc1[3] + (extra & 1 ? sumy[0] : 0)) + - // ls2 * (acc2[0] + acc2[1] + acc2[2] + acc2[3] + (extra & 2 ? sumy[1] : 0))); + const uint8_t h = xb.scales_h[it/4] >> 2*(it%4); + const int ls = (((xb.scales_l[it/2] >> 4*(it%2)) & 0xf) | ((h << 4) & 0x30)) - 32; + sumf[row] += (float)xb.d * ls * (acc[0] + acc[1] + acc[2] + acc[3]); } yb += 2 * QK_K; } - for (int row = 0; row < 2; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } + sumf = simd_sum(sumf); + if (tiisg < 2) { + dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg]; } } @@ -7692,22 +7678,26 @@ void dequantize_iq3_k(device const block_iq3_k * xb, short il, thread type4x4 & template void dequantize_iq4_k(device const block_iq4_k * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const int ib32 = il/2; - const int l = il%2; - // l = 0 or 1. l = 0 processes the first 16 quants in a block of 32, l = 1 the second 16 - device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32; - const int ls = ((xb->scales_l[ib32] >> 4*l) & 0xf) | (((xb->scales_h[il/4] >> 2*(il%4)) & 3) << 4); - const float d = (float)xb->d * (ls - 32); + int i16 = 4*(il%4); + device const uint32_t * q4 = (device const uint32_t *)xb->qs + 16*(il/8) + i16; + const int shift = 4*((il/4)%2); + uint16_t extra = xb->extra >> i16; + const float d = xb->d; + float4 dl; + dl[0] = d * (((xb->scales_l[2*(il%4)+0] & 0xf) | ((xb->scales_h[il%4] << 4) & 0x30)) - 32); + dl[1] = d * (((xb->scales_l[2*(il%4)+0] >> 4) | ((xb->scales_h[il%4] << 2) & 0x30)) - 32); + dl[2] = d * (((xb->scales_l[2*(il%4)+1] & 0xf) | ((xb->scales_h[il%4] >> 0) & 0x30)) - 32); + dl[3] = d * (((xb->scales_l[2*(il%4)+1] >> 4) | ((xb->scales_h[il%4] >> 2) & 0x30)) - 32); uint32_t aux32; thread const uint8_t * q8 = (thread const uint8_t *)&aux32; - constant float * values = kvalues_iq4k_f + 16*((xb->extra >> il) & 1); for (int i = 0; i < 4; ++i) { - aux32 = (q4[i] >> 4*l) & 0x0f0f0f0f; - reg[i][0] = d * values[q8[0]]; - reg[i][1] = d * values[q8[1]]; - reg[i][2] = d * values[q8[2]]; - reg[i][3] = d * values[q8[3]]; + constant float * values = kvalues_iq4k_f + ((extra & 1) << 4); + aux32 = (q4[i] >> shift) & 0x0f0f0f0f; + reg[i][0] = dl[i] * values[q8[0]]; + reg[i][1] = dl[i] * values[q8[1]]; + reg[i][2] = dl[i] * values[q8[2]]; + reg[i][3] = dl[i] * values[q8[3]]; + extra >>= 1; } }