diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index d5e8d6ae..ca08b0f5 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -2286,15 +2286,15 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q6_0 || + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| - src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K|| + src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_Q6_0 || src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src0t == GGML_TYPE_IQ2_KS) { - const int mem_size = 64*sizeof(float); + else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K) { + const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : 16*sizeof(float); [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } @@ -2693,13 +2693,18 @@ static enum ggml_status ggml_metal_graph_compute( const int64_t _ne1 = 1; const int tgz = dst_rows; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q6_0 || + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| - src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K|| src0t == GGML_TYPE_IQ2_KS || + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_Q6_0 || + src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K|| src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } + else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K) { + const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : 16*sizeof(float); + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 4af428ef..fe197309 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6173,7 +6173,7 @@ void kernel_mul_mv_iq2_k_f32_impl( device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; float yl[32]; - float sumf[N_DST]={0.f}, all_sum; + float sumf[N_DST]={0.f}; const int ix = tiisg/8; // 0...3 const int it = tiisg%8; // 0...7 @@ -6183,9 +6183,14 @@ void kernel_mul_mv_iq2_k_f32_impl( device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; + threadgroup float * all_values = (threadgroup float *)shared_values + 8*sgitg; + { + if (tiisg < 8) all_values[tiisg] = kvalues_iq2k_f[tiisg]; + simdgroup_barrier(mem_flags::mem_none); + } + uint32_t aux32[2]; thread const uint8_t * aux8 = (thread const uint8_t *)aux32; - uint16_t shift[4]; for (int ib = ix; ib < nb; ib += 4) { @@ -6204,31 +6209,29 @@ void kernel_mul_mv_iq2_k_f32_impl( const uint32_t scales32 = (sc[iq] >> 4*is) & 0x0f0f0f0f; thread const int8_t * s8 = (thread const int8_t *)&scales32; - uint16_t extra = xb.extra >> (8*iq + is); - - shift[0] = (extra << 2) & 4; - shift[1] = (extra << 0) & 4; - shift[2] = (extra >> 2) & 4; - shift[3] = (extra >> 4) & 4; + uint16_t extra = (xb.extra >> (8*iq + is)) << 2; float4 acc = {0.f}; for (int l = 0; l < 4; ++l) { - constant float * values = kvalues_iq2k_f + shift[l]; + threadgroup const float * values = all_values + (extra & 4); aux32[0] = (q32[0] >> 2*l) & 0x03030303; aux32[1] = (q32[1] >> 2*l) & 0x03030303; for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]]; + extra >>= 2; } - sumf[row] += (float)xb.d * (acc[0] * (s8[0] - 8) + acc[1] * (s8[1] - 8) + acc[2] * (s8[2] - 8) + acc[3] * (s8[3] - 8)); + + sumf[row] += (float)xb.d * (acc[0] * s8[0] + acc[1] * s8[1] + acc[2] * s8[2] + acc[3] * s8[3] - 8.f*(acc[0] + acc[1] + acc[2] + acc[3])); } y4 += 4 * QK_K; } - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + for (int row = 0; row < N_DST; row += 2) { + float2 tmp{sumf[row], sumf[row+1]}; + tmp = simd_sum(tmp); + if (tiisg < 2) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = tmp[tiisg]; } } } @@ -6259,7 +6262,7 @@ kernel void kernel_mul_mv_iq2_k_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_k_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_k_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq2_ks_f32_impl(