diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 83bd76f9..aa7b043e 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -71,6 +71,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4, GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32, GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4, + GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32, + GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32_4, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, @@ -580,6 +582,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4, soft_cap_max_f16_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32, soft_cap_max_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4, soft_cap_max_f32_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32, soft_cap_max_u32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32_4, soft_cap_max_u32_4, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); @@ -1694,19 +1698,22 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_SOFT_CAP_MAX: { - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_I32); int nth = 32; // SIMD width id pipeline = nil; const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const bool use_u32 = (src1 && src1->type == GGML_TYPE_I32); if (ne00%4 == 0) { while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { nth *= 2; } - if (use_f16) { + if (use_u32) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32_4].pipeline; + } else if (use_f16) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4].pipeline; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4].pipeline; @@ -1715,7 +1722,9 @@ static enum ggml_status ggml_metal_graph_compute( while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { nth *= 2; } - if (use_f16) { + if (use_u32) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_U32].pipeline; + } else if (use_f16) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16].pipeline; } else { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32].pipeline; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index f9c88a37..58b3e6bc 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -661,6 +661,101 @@ kernel void kernel_soft_max_4( } } +kernel void kernel_soft_cap_max_u32( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant float & s_before, + constant float & s_after, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float * psrc0 = (device const float * ) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const uint32_t * pmask = (device const uint32_t *) src1 + i01*ne00/32; + device float * pdst = (device float * ) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + // parallel max + float lmax = -INFINITY; + + const float tot_scale = scale * s_after; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + float val = pmask[i00 >> 5] & (1u << (i00 & 31)) ? precise::tanh(s_before*psrc0[i00])*tot_scale : -INFINITY; + lmax = MAX(lmax, val); + pdst[i00] = val; + } + + // find the max value in the block + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float lsum = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + const float exp_psrc0 = exp(pdst[i00] - max_val); + lsum += exp_psrc0; + pdst[i00] = exp_psrc0; + } + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + pdst[i00] *= inv_sum; + } +} + template kernel void kernel_soft_cap_max( device const char * src0, @@ -767,6 +862,116 @@ kernel void kernel_soft_cap_max( } } +kernel void kernel_soft_cap_max_u32_4( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant float & s_before, + constant float & s_after, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const uint32_t * pmask = (device const uint32_t *) src1 + i01*ne00/32; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + + const float tot_scale = scale * s_after; + + // parallel max + float4 lmax4 = -INFINITY; + float4 vinf = lmax4; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + float4 val = precise::tanh(s_before*psrc4[i00])*tot_scale; + int idx = 4*i00; + uint8_t m = pmask[idx >> 5] >> (idx & 31); + bool4 m4 = { m & 1 ? true : false, m & 2 ? true : false, m & 4 ? true : false, m & 8 ? true : false }; + //bool4 m4 = ((pmask[idx >> 5] >> (idx & 31)) & 0xf) * 0x01010101; + val = select(vinf, val, m4); + //uint32_t m = pmask[idx >> 5] >> (idx & 31); + //val[0] = m & 1 ? val[0] : -INFINITY; + //val[1] = m & 2 ? val[1] : -INFINITY; + //val[2] = m & 4 ? val[2] : -INFINITY; + //val[3] = m & 8 ? val[3] : -INFINITY; + lmax4 = fmax(lmax4, val); + pdst4[i00] = val; + } + + const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); + + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + const float4 exp_psrc4 = exp(pdst4[i00] - max_val); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + + const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + pdst4[i00] *= inv_sum; + } +} + template kernel void kernel_soft_cap_max_4( device const char * src0, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2c740f9d..09b6c0b4 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2043,6 +2043,7 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } +#ifdef __AVX512F__ static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) { __m512 vslope = _mm512_set1_ps(slope); __m512 vmax = _mm512_set1_ps(-INFINITY); @@ -2058,7 +2059,6 @@ static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float } return max; } - static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) { __m512 vslope = _mm512_set1_ps(slope); __m512 vmax = _mm512_set1_ps(-INFINITY); @@ -2074,7 +2074,6 @@ static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y } return max; } - static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) { GGML_ASSERT(n%16 == 0); __m512 vmax = _mm512_set1_ps(-INFINITY); @@ -2087,6 +2086,29 @@ static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, fl } return _mm512_reduce_max_ps(vmax); } +#else +// TODO +static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) { + GGML_UNUSED(n); + GGML_UNUSED(x); + GGML_UNUSED(y); + GGML_UNUSED(slope); + return 0.f; +} +static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) { + GGML_UNUSED(n); + GGML_UNUSED(x); + GGML_UNUSED(y); + GGML_UNUSED(slope); + return 0.f; +} +static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) { + GGML_UNUSED(n); + GGML_UNUSED(x); + GGML_UNUSED(y); + return 0.f; +} +#endif static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { assert(nrc == 1); @@ -2903,6 +2925,7 @@ static void ggml_vec_cpy_softcap_f32(const int n, const float * x, float * y, fl } } +#ifdef __AVX512__ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) { const __mmask16 * m16 = (const __mmask16 *)mask; __m512 vinf = _mm512_set1_ps(-INFINITY); @@ -2916,6 +2939,17 @@ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * } return _mm512_reduce_max_ps(vmax); } +#else +static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) { + GGML_UNUSED(n); + GGML_UNUSED(x); + GGML_UNUSED(y); + GGML_UNUSED(mask); + GGML_UNUSED(s_before); + GGML_UNUSED(s_after); + return 0.f; +} +#endif static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s_after) { int i = 0; @@ -13788,7 +13822,7 @@ static void ggml_compute_forward_softcap( default: { GGML_ASSERT(false); - } break; + } } } @@ -13920,7 +13954,7 @@ static void ggml_compute_forward_softcap_max( default: { GGML_ASSERT(false); - } break; + } } }