GGML_UNARY_OP_SWIGLU: CUDA implementation

I observe ~12% speedup for PP-512(Phi-3.5-mini).
This commit is contained in:
Iwan Kawrakow
2024-09-28 10:31:59 +03:00
parent c4886b219c
commit 79a57b1554
3 changed files with 38 additions and 0 deletions

View File

@@ -2233,6 +2233,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_SILU:
ggml_cuda_op_silu(ctx, dst);
break;
case GGML_UNARY_OP_SWIGLU:
ggml_cuda_op_swiglu(ctx, dst);
break;
case GGML_UNARY_OP_GELU_QUICK:
ggml_cuda_op_gelu_quick(ctx, dst);
break;
@@ -2773,6 +2776,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_SWIGLU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:

View File

@@ -31,6 +31,18 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] / (1.0f + expf(-x[i]));
}
static __global__ void swiglu_f32(const float * x, float * dst, const int k, const int ne0, const int64_t nb1) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
const int row = i/ne0;
const int idx = i%ne0;
const int j = row*nb1 + idx;
dst[i] = x[j] * x[j + ne0] / (1.0f + expf(-x[j]));
}
static __global__ void tanh_f32(const float * x, float * dst, int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
@@ -116,6 +128,11 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void swiglu_f32_cuda(const float * x, float * dst, const int k, const int64_t ne0, const int64_t nb1, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
swiglu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k, ne0, nb1);
}
static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -184,6 +201,21 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->ne[0] == src0->ne[0]/2);
swiglu_f32_cuda(src0_d, dst_d, ggml_nelements(dst), dst->ne[0], src0->nb[1]/sizeof(float), stream);
}
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;

View File

@@ -31,3 +31,5 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);