Adding SWIGLU unary op (#65)

* Adding GGML_UNARY_OP_SWIGLU

This commit implements the ggml op and CPU compute
forward. I see ~3-4% speedup of PP-512 for Phi-3.5-mini.

* GGML_UNARY_OP_SWIGLU: CUDA implementation

I observe ~12% speedup for PP-512(Phi-3.5-mini).

* GGML_UNARY_OP_SWIGLU: Metal implementation

We get ~2% speedup for PP-512(Phi-3.5-mini).

* GGML_UNARY_OP_SWIGLU: minor improvement on Metal

* GGML_UNARY_OP_SWIGLU: cleanup

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2024-09-28 13:37:25 +03:00
committed by GitHub
parent 1f61e91862
commit 737514fd81
8 changed files with 217 additions and 11 deletions

View File

@@ -564,6 +564,7 @@ extern "C" {
GGML_UNARY_OP_SILU,
GGML_UNARY_OP_HARDSWISH,
GGML_UNARY_OP_HARDSIGMOID,
GGML_UNARY_OP_SWIGLU,
GGML_UNARY_OP_COUNT,
};
@@ -1127,6 +1128,10 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_swiglu(
struct ggml_context * ctx,
struct ggml_tensor * a);
// a - x
// b - dy
GGML_API struct ggml_tensor * ggml_silu_back(

View File

@@ -2233,6 +2233,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_SILU:
ggml_cuda_op_silu(ctx, dst);
break;
case GGML_UNARY_OP_SWIGLU:
ggml_cuda_op_swiglu(ctx, dst);
break;
case GGML_UNARY_OP_GELU_QUICK:
ggml_cuda_op_gelu_quick(ctx, dst);
break;
@@ -2773,6 +2776,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_SWIGLU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:

View File

@@ -31,6 +31,18 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] / (1.0f + expf(-x[i]));
}
static __global__ void swiglu_f32(const float * x, float * dst, const int k, const int ne0, const int64_t nb1) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
const int row = i/ne0;
const int idx = i%ne0;
const int j = row*nb1 + idx;
dst[i] = x[j] * x[j + ne0] / (1.0f + expf(-x[j]));
}
static __global__ void tanh_f32(const float * x, float * dst, int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
@@ -116,6 +128,11 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void swiglu_f32_cuda(const float * x, float * dst, const int k, const int64_t ne0, const int64_t nb1, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
swiglu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k, ne0, nb1);
}
static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -184,6 +201,21 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->ne[0] == src0->ne[0]/2);
swiglu_f32_cuda(src0_d, dst_d, ggml_nelements(dst), dst->ne[0], src0->nb[1]/sizeof(float), stream);
}
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;

View File

@@ -31,3 +31,5 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -63,6 +63,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
GGML_METAL_KERNEL_TYPE_SILU,
GGML_METAL_KERNEL_TYPE_SILU_4,
GGML_METAL_KERNEL_TYPE_SWIGLU,
GGML_METAL_KERNEL_TYPE_SWIGLU_4,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@@ -583,6 +585,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_4, swiglu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
@@ -884,6 +888,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_SWIGLU:
return ggml_is_contiguous(op->src[0]);
default:
return false;
@@ -1595,6 +1600,33 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_SWIGLU:
{
int64_t n = ggml_nelements(dst);
GGML_ASSERT(ne0 == src0->ne[0]/2);
id<MTLComputePipelineState> pipeline = nil;
uint32_t n_per_row = ne0;
uint32_t stride = src0->nb[1]/sizeof(float);
if (ne0 % 4 == 0) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_4].pipeline;
n /= 4;
n_per_row /= 4;
stride /= 4;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&n_per_row length:sizeof(n_per_row) atIndex:2];
[encoder setBytes:&stride length:sizeof(stride) atIndex:3];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:

View File

@@ -398,6 +398,30 @@ kernel void kernel_silu_4(
dst[tpig] = x / (1.0f + exp(-x));
}
kernel void kernel_swiglu(
device const float * src0,
device float * dst,
constant uint & ne0,
constant uint & stride,
uint tpig[[thread_position_in_grid]]) {
const uint row = tpig/ne0;
const uint idx = tpig%ne0;
const uint j = row*stride + idx;
dst[tpig] = src0[j] * src0[j + ne0] / (1.0f + exp(-src0[j]));
}
kernel void kernel_swiglu_4(
device const float4 * src0,
device float4 * dst,
constant uint & ne0,
constant uint & stride,
uint tpig[[thread_position_in_grid]]) {
const uint row = tpig/ne0;
const uint idx = tpig%ne0;
const uint j = row*stride + idx;
dst[tpig] = src0[j] * src0[j + ne0] / (1.0f + exp(-src0[j]));
}
kernel void kernel_sqr(
device const float * src0,
device float * dst,

View File

@@ -2867,6 +2867,30 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
}
}
static void ggml_vec_swiglu_f32(const int n, float * y, const float * x) {
int i = 0;
#if defined(__AVX512F__) && defined(__AVX512DQ__)
for (; i + 15 < n; i += 16) {
_mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(x + i + n)));
}
#elif defined(__AVX2__) && defined(__FMA__)
for (; i + 7 < n; i += 8) {
_mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(x + i + n)));
}
#elif defined(__SSE2__)
for (; i + 3 < n; i += 4) {
_mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(x + i + n)));
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) {
vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(x + i + n)));
}
#endif
for (; i < n; ++i) {
y[i] = ggml_silu_f32(x[i]) * x[i + n];
}
}
static void ggml_vec_tanh_f32(const int n, float * y, const float * x) {
int i = 0;
#if defined(__AVX512F__) && defined(__AVX512DQ__)
@@ -3386,9 +3410,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"SILU",
"HARDSWISH",
"HARDSIGMOID",
"SWIGLU",
};
static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -5694,6 +5719,26 @@ struct ggml_tensor * ggml_silu_inplace(
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);
}
// ggml_swiglu
struct ggml_tensor * ggml_swiglu(
struct ggml_context * ctx,
struct ggml_tensor * a) {
GGML_ASSERT(ggml_is_contiguous_1(a));
int64_t ne[4] = {a->ne[0]/2, a->ne[1], a->ne[2], a->ne[3]};
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0);
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU);
result->op = GGML_OP_UNARY;
result->grad = a->grad ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_silu_back
struct ggml_tensor * ggml_silu_back(
@@ -12243,6 +12288,67 @@ static void ggml_compute_forward_silu(
}
}
}
// ggml_compute_forward_swiglu
static void ggml_compute_forward_swiglu_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
GGML_ASSERT(dst->ne[0] == src0->ne[0]/2);
const int ith = params->ith;
const int nth = params->nth;
const int nc = dst->ne[0];
const int nr = ggml_nrows(src0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_swiglu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}
static void ggml_compute_forward_swiglu(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_swiglu_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_leaky_relu
static void ggml_compute_forward_leaky_relu_f32(
@@ -17289,6 +17395,10 @@ static void ggml_compute_forward_unary(
{
ggml_compute_forward_silu(params, dst);
} break;
case GGML_UNARY_OP_SWIGLU:
{
ggml_compute_forward_swiglu(params, dst);
} break;
case GGML_UNARY_OP_HARDSWISH:
{
ggml_compute_forward_hardswish(params, dst);
@@ -19155,6 +19265,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
case GGML_UNARY_OP_SWIGLU:
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
case GGML_UNARY_OP_SILU:
{
// necessary for llama
@@ -19656,6 +19770,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_SWIGLU:
{
n_tasks = n_threads;
} break;

View File

@@ -8111,16 +8111,8 @@ static struct ggml_tensor * llm_build_ffn(
} break;
case LLM_FFN_SWIGLU:
{
// Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
int64_t split_point = cur->ne[0] / 2;
struct ggml_tensor * x0 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], 0));
struct ggml_tensor * x1 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
x0 = ggml_silu(ctx, x0);
cb(cur, "ffn_silu", il);
cur = ggml_mul(ctx, x0, x1);
cb(cur, "ffn_mul", il);
cur = ggml_swiglu(ctx, cur);
cb(cur, "ffn_swiglu", il);
} break;
}