From 9083a50eae4be5b3e61e4f3925df1983b78b4f9c Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 19 Sep 2024 18:39:34 +0300 Subject: [PATCH] POC per row scale: iq2_tn on Metal --- ggml/src/ggml-metal.metal | 66 ++++++++++++++++++++++++++------------- 1 file changed, 45 insertions(+), 21 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 259fa609..5d65f264 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -3778,21 +3778,21 @@ void kernel_mul_mv_iq2_tn_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; + const int row_size = nb*sizeof(block_iq2_tn) + 4; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = ((i12/r2)*ne01 + (i13/r3)*ne01*ne02)*row_size; - device const block_iq2_tn * x = (device const block_iq2_tn *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const char * cx = (device const char *) src0 + first_row*row_size + offset0; + device const float * y = (device const float*) src1 + r1*ne10 + im*ne00*ne1; float yl[32]; float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_iq2_tn) * nb / 2; + float drow[N_DST]; const int ix = tiisg/8; // 0...3 const int it = tiisg%8; // 0...7 @@ -3802,6 +3802,8 @@ void kernel_mul_mv_iq2_tn_f32_impl( device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; + for (int row = 0; row < N_DST; row++) drow[row] = *((device const float *)(cx + row*row_size)); + for (int ib = ix; ib < nb; ib += 4) { float sumy = 0.f; @@ -3812,7 +3814,7 @@ void kernel_mul_mv_iq2_tn_f32_impl( yl[i+24] = y4[i+96]; sumy += yl[i+24]; } - device const half * dh = &x[ib].d; + device const block_iq2_tn * x = (device const block_iq2_tn *)(cx + 4); device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; for (int row = 0; row < N_DST; row++) { @@ -3829,14 +3831,12 @@ void kernel_mul_mv_iq2_tn_f32_impl( acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); } - float dall = dh[0]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * 1.f/64.f - sumy); + sumf[row] += (acc1[0] + 1.f/256.f * acc2[0]) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * 1.f/64.f - sumy; - qs += step; - dh += step; + qs += row_size/2; } y4 += 4 * QK_K; @@ -3845,7 +3845,7 @@ void kernel_mul_mv_iq2_tn_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = drow[row]*all_sum; } } } @@ -6850,16 +6850,14 @@ void dequantize_q2_K(device const block_q2_K * xb, short il, thread type4x4 & re template void dequantize_iq2_tn(device const block_iq2_tn * xb, short il, thread type4x4 & reg) { - const half d = xb->d; device const uint8_t * q = (device const uint8_t *)xb->qs + 32*(il/8) + 16*(il&1); il = (il/2)%4; half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - const half dl = d * coef; for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - d; + reg[i/4][i%4] = coef * (q[i] & mask) - 1; } } @@ -7472,6 +7470,32 @@ struct DequantizerIQ1TN { half d; }; +template +struct DequantizerIQ2TN { + using type4x4 = T4x4; + using Block = block_iq2_tn; + constexpr constant static int nl = 16; + DequantizerIQ2TN(device const char * cx, short il = 0) : il(il) { + d = *(device const float *)cx; + x = (device const Block *)(cx + sizeof(float)); + } + inline void convert(thread T4x4& t) const { + dequantize_iq2_tn(x, il, t); + t *= d; + } + inline void convert(int64_t ind, thread T4x4& t) { + dequantize_iq2_tn(x + ind/nl, ind%nl, t); + t *= d; + } + inline void next() { + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + } + device const Block * x; + short il; + float d; +}; + // each block_q contains 16*nl weights template kernel void kernel_mul_mm(device const uchar * src0, @@ -7821,7 +7845,6 @@ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_iq2_tn")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q; @@ -7843,6 +7866,7 @@ template [[host_name("kernel_get_rows_iq6_k")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_iq2_bn")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_iq1_tn")]] kernel get_rows_q_t kernel_get_rows_q2>; +template [[host_name("kernel_get_rows_iq2_tn")]] kernel get_rows_q_t kernel_get_rows_q2>; // // matrix-matrix multiplication @@ -7862,7 +7886,6 @@ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm>; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm>; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq2_tn_f32")]] kernel mat_mm_t kernel_mul_mm>; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm>; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm>; template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm>; @@ -7884,6 +7907,7 @@ template [[host_name("kernel_mul_mm_iq6_k_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm>; template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm>; template [[host_name("kernel_mul_mm_iq1_tn_f32")]] kernel mat_mm_t kernel_mul_mm>; +template [[host_name("kernel_mul_mm_iq2_tn_f32")]] kernel mat_mm_t kernel_mul_mm>; // // indirect matrix-matrix multiplication @@ -7900,7 +7924,6 @@ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; -template [[host_name("kernel_mul_mm_id_iq2_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; @@ -7922,6 +7945,7 @@ template [[host_name("kernel_mul_mm_id_iq4_k_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_iq5_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; template [[host_name("kernel_mul_mm_id_iq6_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; template [[host_name("kernel_mul_mm_id_iq1_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; +template [[host_name("kernel_mul_mm_id_iq2_tn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id>; // // matrix-vector multiplication