mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 19:01:47 +00:00
softcap: Metal and NEON
About 1% speedup.
This commit is contained in:
@@ -51,6 +51,8 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_REPEAT_I16,
|
GGML_METAL_KERNEL_TYPE_REPEAT_I16,
|
||||||
GGML_METAL_KERNEL_TYPE_SCALE,
|
GGML_METAL_KERNEL_TYPE_SCALE,
|
||||||
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_SOFTCAP,
|
||||||
|
GGML_METAL_KERNEL_TYPE_SOFTCAP_4,
|
||||||
GGML_METAL_KERNEL_TYPE_CLAMP,
|
GGML_METAL_KERNEL_TYPE_CLAMP,
|
||||||
GGML_METAL_KERNEL_TYPE_TANH,
|
GGML_METAL_KERNEL_TYPE_TANH,
|
||||||
GGML_METAL_KERNEL_TYPE_RELU,
|
GGML_METAL_KERNEL_TYPE_RELU,
|
||||||
@@ -554,6 +556,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFTCAP, softcap, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFTCAP_4, softcap_4, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
||||||
@@ -867,6 +871,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|||||||
case GGML_OP_SQR:
|
case GGML_OP_SQR:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
return true;
|
return true;
|
||||||
|
case GGML_OP_SOFTCAP:
|
||||||
|
return true; //ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op);
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
@@ -1411,6 +1417,32 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
} break;
|
||||||
|
case GGML_OP_SOFTCAP:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
|
float scales[2];
|
||||||
|
memcpy(scales, dst->op_params, sizeof(scales));
|
||||||
|
|
||||||
|
int64_t n = ggml_nelements(dst);
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
|
if (n % 4 == 0) {
|
||||||
|
n /= 4;
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP_4].pipeline;
|
||||||
|
} else {
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP].pipeline;
|
||||||
|
}
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
[encoder setBytes:&scales[0] length:sizeof(float) atIndex:2];
|
||||||
|
[encoder setBytes:&scales[1] length:sizeof(float) atIndex:3];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
|
|||||||
@@ -289,6 +289,24 @@ kernel void kernel_scale_4(
|
|||||||
dst[tpig] = src0[tpig] * scale;
|
dst[tpig] = src0[tpig] * scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_softcap(
|
||||||
|
device const float * src0,
|
||||||
|
device float * dst,
|
||||||
|
constant float & s_before,
|
||||||
|
constant float & s_after,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
dst[tpig] = s_after * precise::tanh(src0[tpig] * s_before);
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_softcap_4(
|
||||||
|
device const float4 * src0,
|
||||||
|
device float4 * dst,
|
||||||
|
constant float & s_before,
|
||||||
|
constant float & s_after,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
dst[tpig] = s_after * precise::tanh(src0[tpig] * s_before);
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_clamp(
|
kernel void kernel_clamp(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
|||||||
@@ -2558,12 +2558,12 @@ inline static float32x4_t ggml_v_tanh(float32x4_t x) {
|
|||||||
return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
|
return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline static float32x4_t ggml_v_softcap(float32x4_t x, float s_before, float s_after) {
|
inline static float32x4_t ggml_v_softcap(float32x4_t x, float32x4_t s_before, float32x4_t s_after) {
|
||||||
const float32x4_t one = vdupq_n_f32(1.0f);
|
const float32x4_t one = vdupq_n_f32(1.0f);
|
||||||
const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f*s_before));
|
const float32x4_t two_x = vmulq_f32(x, s_before);
|
||||||
const float32x4_t exp_two_x = ggml_v_expf(two_x);
|
const float32x4_t exp_two_x = ggml_v_expf(two_x);
|
||||||
const float32x4_t th = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
|
const float32x4_t th = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
|
||||||
return vmulq_f32(th, vdupq_n_f32(s_after));
|
return vmulq_f32(th, s_after);
|
||||||
}
|
}
|
||||||
|
|
||||||
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
|
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||||
@@ -2834,8 +2834,10 @@ static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s
|
|||||||
_mm_storeu_ps(x + i, ggml_v_softcap(_mm_loadu_ps(x + i), s_before, s_after));
|
_mm_storeu_ps(x + i, ggml_v_softcap(_mm_loadu_ps(x + i), s_before, s_after));
|
||||||
}
|
}
|
||||||
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
||||||
|
float32x4_t vs_before = vdupq_n_f32(2.f*s_before);
|
||||||
|
float32x4_t vs_after = vdupq_n_f32(s_after);
|
||||||
for (; i + 3 < n; i += 4) {
|
for (; i + 3 < n; i += 4) {
|
||||||
vst1q_f32(x + i, ggml_v_softcap(vld1q_f32(x + i), s_before, s_after));
|
vst1q_f32(x + i, ggml_v_softcap(vld1q_f32(x + i), vs_before, vs_after));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
for (; i < n; ++i) {
|
for (; i < n; ++i) {
|
||||||
|
|||||||
Reference in New Issue
Block a user