diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 292f9ac7..1e940c5b 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -51,6 +51,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_REPEAT_I16, GGML_METAL_KERNEL_TYPE_SCALE, GGML_METAL_KERNEL_TYPE_SCALE_4, + GGML_METAL_KERNEL_TYPE_SOFTCAP, + GGML_METAL_KERNEL_TYPE_SOFTCAP_4, GGML_METAL_KERNEL_TYPE_CLAMP, GGML_METAL_KERNEL_TYPE_TANH, GGML_METAL_KERNEL_TYPE_RELU, @@ -554,6 +556,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFTCAP, softcap, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFTCAP_4, softcap_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); @@ -867,6 +871,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx case GGML_OP_SQR: case GGML_OP_SUM_ROWS: return true; + case GGML_OP_SOFTCAP: + return true; //ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op); case GGML_OP_SOFT_MAX: case GGML_OP_RMS_NORM: case GGML_OP_GROUP_NORM: @@ -1411,6 +1417,32 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SOFTCAP: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + float scales[2]; + memcpy(scales, dst->op_params, sizeof(scales)); + + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + n /= 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scales[0] length:sizeof(float) atIndex:2]; + [encoder setBytes:&scales[1] length:sizeof(float) atIndex:3]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_CLAMP: diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 904639a5..2a0e84a6 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -289,6 +289,24 @@ kernel void kernel_scale_4( dst[tpig] = src0[tpig] * scale; } +kernel void kernel_softcap( + device const float * src0, + device float * dst, + constant float & s_before, + constant float & s_after, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = s_after * precise::tanh(src0[tpig] * s_before); +} + +kernel void kernel_softcap_4( + device const float4 * src0, + device float4 * dst, + constant float & s_before, + constant float & s_after, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = s_after * precise::tanh(src0[tpig] * s_before); +} + kernel void kernel_clamp( device const float * src0, device float * dst, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 94c3eb3a..41a8e712 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2558,12 +2558,12 @@ inline static float32x4_t ggml_v_tanh(float32x4_t x) { return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); } -inline static float32x4_t ggml_v_softcap(float32x4_t x, float s_before, float s_after) { +inline static float32x4_t ggml_v_softcap(float32x4_t x, float32x4_t s_before, float32x4_t s_after) { const float32x4_t one = vdupq_n_f32(1.0f); - const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f*s_before)); + const float32x4_t two_x = vmulq_f32(x, s_before); const float32x4_t exp_two_x = ggml_v_expf(two_x); const float32x4_t th = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); - return vmulq_f32(th, vdupq_n_f32(s_after)); + return vmulq_f32(th, s_after); } #elif defined(__AVX512F__) && defined(__AVX512DQ__) @@ -2834,8 +2834,10 @@ static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s _mm_storeu_ps(x + i, ggml_v_softcap(_mm_loadu_ps(x + i), s_before, s_after)); } #elif defined(__ARM_NEON) && defined(__aarch64__) + float32x4_t vs_before = vdupq_n_f32(2.f*s_before); + float32x4_t vs_after = vdupq_n_f32(s_after); for (; i + 3 < n; i += 4) { - vst1q_f32(x + i, ggml_v_softcap(vld1q_f32(x + i), s_before, s_after)); + vst1q_f32(x + i, ggml_v_softcap(vld1q_f32(x + i), vs_before, vs_after)); } #endif for (; i < n; ++i) {