softcap: Metal and NEON

About 1% speedup.
This commit is contained in:
Iwan Kawrakow
2024-08-02 07:41:52 +02:00
parent e49ce89901
commit 0e2d76bb7c
3 changed files with 56 additions and 4 deletions

View File

@@ -51,6 +51,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_REPEAT_I16,
GGML_METAL_KERNEL_TYPE_SCALE,
GGML_METAL_KERNEL_TYPE_SCALE_4,
GGML_METAL_KERNEL_TYPE_SOFTCAP,
GGML_METAL_KERNEL_TYPE_SOFTCAP_4,
GGML_METAL_KERNEL_TYPE_CLAMP,
GGML_METAL_KERNEL_TYPE_TANH,
GGML_METAL_KERNEL_TYPE_RELU,
@@ -554,6 +556,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFTCAP, softcap, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFTCAP_4, softcap_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
@@ -867,6 +871,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
case GGML_OP_SQR:
case GGML_OP_SUM_ROWS:
return true;
case GGML_OP_SOFTCAP:
return true; //ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op);
case GGML_OP_SOFT_MAX:
case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM:
@@ -1411,6 +1417,32 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SOFTCAP:
{
GGML_ASSERT(ggml_is_contiguous(src0));
float scales[2];
memcpy(scales, dst->op_params, sizeof(scales));
int64_t n = ggml_nelements(dst);
id<MTLComputePipelineState> pipeline = nil;
if (n % 4 == 0) {
n /= 4;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP_4].pipeline;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP].pipeline;
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scales[0] length:sizeof(float) atIndex:2];
[encoder setBytes:&scales[1] length:sizeof(float) atIndex:3];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_CLAMP:

View File

@@ -289,6 +289,24 @@ kernel void kernel_scale_4(
dst[tpig] = src0[tpig] * scale;
}
kernel void kernel_softcap(
device const float * src0,
device float * dst,
constant float & s_before,
constant float & s_after,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = s_after * precise::tanh(src0[tpig] * s_before);
}
kernel void kernel_softcap_4(
device const float4 * src0,
device float4 * dst,
constant float & s_before,
constant float & s_after,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = s_after * precise::tanh(src0[tpig] * s_before);
}
kernel void kernel_clamp(
device const float * src0,
device float * dst,

View File

@@ -2558,12 +2558,12 @@ inline static float32x4_t ggml_v_tanh(float32x4_t x) {
return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
}
inline static float32x4_t ggml_v_softcap(float32x4_t x, float s_before, float s_after) {
inline static float32x4_t ggml_v_softcap(float32x4_t x, float32x4_t s_before, float32x4_t s_after) {
const float32x4_t one = vdupq_n_f32(1.0f);
const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f*s_before));
const float32x4_t two_x = vmulq_f32(x, s_before);
const float32x4_t exp_two_x = ggml_v_expf(two_x);
const float32x4_t th = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
return vmulq_f32(th, vdupq_n_f32(s_after));
return vmulq_f32(th, s_after);
}
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
@@ -2834,8 +2834,10 @@ static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s
_mm_storeu_ps(x + i, ggml_v_softcap(_mm_loadu_ps(x + i), s_before, s_after));
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
float32x4_t vs_before = vdupq_n_f32(2.f*s_before);
float32x4_t vs_after = vdupq_n_f32(s_after);
for (; i + 3 < n; i += 4) {
vst1q_f32(x + i, ggml_v_softcap(vld1q_f32(x + i), s_before, s_after));
vst1q_f32(x + i, ggml_v_softcap(vld1q_f32(x + i), vs_before, vs_after));
}
#endif
for (; i < n; ++i) {