mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
Adding SWIGLU unary op (#65)
* Adding GGML_UNARY_OP_SWIGLU This commit implements the ggml op and CPU compute forward. I see ~3-4% speedup of PP-512 for Phi-3.5-mini. * GGML_UNARY_OP_SWIGLU: CUDA implementation I observe ~12% speedup for PP-512(Phi-3.5-mini). * GGML_UNARY_OP_SWIGLU: Metal implementation We get ~2% speedup for PP-512(Phi-3.5-mini). * GGML_UNARY_OP_SWIGLU: minor improvement on Metal * GGML_UNARY_OP_SWIGLU: cleanup --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -564,6 +564,7 @@ extern "C" {
|
|||||||
GGML_UNARY_OP_SILU,
|
GGML_UNARY_OP_SILU,
|
||||||
GGML_UNARY_OP_HARDSWISH,
|
GGML_UNARY_OP_HARDSWISH,
|
||||||
GGML_UNARY_OP_HARDSIGMOID,
|
GGML_UNARY_OP_HARDSIGMOID,
|
||||||
|
GGML_UNARY_OP_SWIGLU,
|
||||||
|
|
||||||
GGML_UNARY_OP_COUNT,
|
GGML_UNARY_OP_COUNT,
|
||||||
};
|
};
|
||||||
@@ -1127,6 +1128,10 @@ extern "C" {
|
|||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_swiglu(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
// a - x
|
// a - x
|
||||||
// b - dy
|
// b - dy
|
||||||
GGML_API struct ggml_tensor * ggml_silu_back(
|
GGML_API struct ggml_tensor * ggml_silu_back(
|
||||||
|
|||||||
@@ -2233,6 +2233,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
ggml_cuda_op_silu(ctx, dst);
|
ggml_cuda_op_silu(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_UNARY_OP_SWIGLU:
|
||||||
|
ggml_cuda_op_swiglu(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_UNARY_OP_GELU_QUICK:
|
case GGML_UNARY_OP_GELU_QUICK:
|
||||||
ggml_cuda_op_gelu_quick(ctx, dst);
|
ggml_cuda_op_gelu_quick(ctx, dst);
|
||||||
break;
|
break;
|
||||||
@@ -2773,6 +2776,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||||||
switch (ggml_get_unary_op(op)) {
|
switch (ggml_get_unary_op(op)) {
|
||||||
case GGML_UNARY_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
|
case GGML_UNARY_OP_SWIGLU:
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
case GGML_UNARY_OP_SIGMOID:
|
case GGML_UNARY_OP_SIGMOID:
|
||||||
case GGML_UNARY_OP_HARDSIGMOID:
|
case GGML_UNARY_OP_HARDSIGMOID:
|
||||||
|
|||||||
@@ -31,6 +31,18 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
|
|||||||
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __global__ void swiglu_f32(const float * x, float * dst, const int k, const int ne0, const int64_t nb1) {
|
||||||
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const int row = i/ne0;
|
||||||
|
const int idx = i%ne0;
|
||||||
|
const int j = row*nb1 + idx;
|
||||||
|
dst[i] = x[j] * x[j + ne0] / (1.0f + expf(-x[j]));
|
||||||
|
}
|
||||||
|
|
||||||
static __global__ void tanh_f32(const float * x, float * dst, int k) {
|
static __global__ void tanh_f32(const float * x, float * dst, int k) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
if (i >= k) {
|
if (i >= k) {
|
||||||
@@ -116,6 +128,11 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
|
|||||||
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void swiglu_f32_cuda(const float * x, float * dst, const int k, const int64_t ne0, const int64_t nb1, cudaStream_t stream) {
|
||||||
|
const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
|
||||||
|
swiglu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k, ne0, nb1);
|
||||||
|
}
|
||||||
|
|
||||||
static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||||
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
|
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
|
||||||
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||||
@@ -184,6 +201,21 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||||||
silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const float * src0_d = (const float *)src0->data;
|
||||||
|
float * dst_d = (float *)dst->data;
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(dst->ne[0] == src0->ne[0]/2);
|
||||||
|
|
||||||
|
swiglu_f32_cuda(src0_d, dst_d, ggml_nelements(dst), dst->ne[0], src0->nb[1]/sizeof(float), stream);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const float * src0_d = (const float *)src0->data;
|
const float * src0_d = (const float *)src0->data;
|
||||||
|
|||||||
@@ -31,3 +31,5 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|||||||
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|||||||
@@ -63,6 +63,8 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
||||||
GGML_METAL_KERNEL_TYPE_SILU,
|
GGML_METAL_KERNEL_TYPE_SILU,
|
||||||
GGML_METAL_KERNEL_TYPE_SILU_4,
|
GGML_METAL_KERNEL_TYPE_SILU_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
||||||
|
GGML_METAL_KERNEL_TYPE_SWIGLU_4,
|
||||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
||||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
||||||
@@ -583,6 +585,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_4, swiglu_4, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
|
||||||
@@ -884,6 +888,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|||||||
case GGML_UNARY_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
case GGML_UNARY_OP_GELU_QUICK:
|
case GGML_UNARY_OP_GELU_QUICK:
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
|
case GGML_UNARY_OP_SWIGLU:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
@@ -1595,6 +1600,33 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
} break;
|
||||||
|
case GGML_UNARY_OP_SWIGLU:
|
||||||
|
{
|
||||||
|
int64_t n = ggml_nelements(dst);
|
||||||
|
GGML_ASSERT(ne0 == src0->ne[0]/2);
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
|
uint32_t n_per_row = ne0;
|
||||||
|
uint32_t stride = src0->nb[1]/sizeof(float);
|
||||||
|
|
||||||
|
if (ne0 % 4 == 0) {
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_4].pipeline;
|
||||||
|
n /= 4;
|
||||||
|
n_per_row /= 4;
|
||||||
|
stride /= 4;
|
||||||
|
} else {
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
||||||
|
}
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
[encoder setBytes:&n_per_row length:sizeof(n_per_row) atIndex:2];
|
||||||
|
[encoder setBytes:&stride length:sizeof(stride) atIndex:3];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -398,6 +398,30 @@ kernel void kernel_silu_4(
|
|||||||
dst[tpig] = x / (1.0f + exp(-x));
|
dst[tpig] = x / (1.0f + exp(-x));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_swiglu(
|
||||||
|
device const float * src0,
|
||||||
|
device float * dst,
|
||||||
|
constant uint & ne0,
|
||||||
|
constant uint & stride,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
const uint row = tpig/ne0;
|
||||||
|
const uint idx = tpig%ne0;
|
||||||
|
const uint j = row*stride + idx;
|
||||||
|
dst[tpig] = src0[j] * src0[j + ne0] / (1.0f + exp(-src0[j]));
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_swiglu_4(
|
||||||
|
device const float4 * src0,
|
||||||
|
device float4 * dst,
|
||||||
|
constant uint & ne0,
|
||||||
|
constant uint & stride,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
const uint row = tpig/ne0;
|
||||||
|
const uint idx = tpig%ne0;
|
||||||
|
const uint j = row*stride + idx;
|
||||||
|
dst[tpig] = src0[j] * src0[j + ne0] / (1.0f + exp(-src0[j]));
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_sqr(
|
kernel void kernel_sqr(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
|||||||
117
ggml/src/ggml.c
117
ggml/src/ggml.c
@@ -2867,6 +2867,30 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vec_swiglu_f32(const int n, float * y, const float * x) {
|
||||||
|
int i = 0;
|
||||||
|
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||||
|
for (; i + 15 < n; i += 16) {
|
||||||
|
_mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(x + i + n)));
|
||||||
|
}
|
||||||
|
#elif defined(__AVX2__) && defined(__FMA__)
|
||||||
|
for (; i + 7 < n; i += 8) {
|
||||||
|
_mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(x + i + n)));
|
||||||
|
}
|
||||||
|
#elif defined(__SSE2__)
|
||||||
|
for (; i + 3 < n; i += 4) {
|
||||||
|
_mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(x + i + n)));
|
||||||
|
}
|
||||||
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
||||||
|
for (; i + 3 < n; i += 4) {
|
||||||
|
vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(x + i + n)));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (; i < n; ++i) {
|
||||||
|
y[i] = ggml_silu_f32(x[i]) * x[i + n];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vec_tanh_f32(const int n, float * y, const float * x) {
|
static void ggml_vec_tanh_f32(const int n, float * y, const float * x) {
|
||||||
int i = 0;
|
int i = 0;
|
||||||
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||||
@@ -3386,9 +3410,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
|||||||
"SILU",
|
"SILU",
|
||||||
"HARDSWISH",
|
"HARDSWISH",
|
||||||
"HARDSIGMOID",
|
"HARDSIGMOID",
|
||||||
|
"SWIGLU",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
|
static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
|
||||||
|
|
||||||
|
|
||||||
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
||||||
@@ -5694,6 +5719,26 @@ struct ggml_tensor * ggml_silu_inplace(
|
|||||||
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);
|
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_swiglu
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_swiglu(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(a));
|
||||||
|
|
||||||
|
int64_t ne[4] = {a->ne[0]/2, a->ne[1], a->ne[2], a->ne[3]};
|
||||||
|
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0);
|
||||||
|
|
||||||
|
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU);
|
||||||
|
|
||||||
|
result->op = GGML_OP_UNARY;
|
||||||
|
result->grad = a->grad ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
result->src[0] = a;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_silu_back
|
// ggml_silu_back
|
||||||
|
|
||||||
struct ggml_tensor * ggml_silu_back(
|
struct ggml_tensor * ggml_silu_back(
|
||||||
@@ -12243,6 +12288,67 @@ static void ggml_compute_forward_silu(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_swiglu
|
||||||
|
|
||||||
|
static void ggml_compute_forward_swiglu_f32(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
|
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
|
||||||
|
GGML_ASSERT(dst->ne[0] == src0->ne[0]/2);
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nc = dst->ne[0];
|
||||||
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
|
ggml_vec_swiglu_f32(nc,
|
||||||
|
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||||
|
(float *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
for (int k = 0; k < nc; k++) {
|
||||||
|
const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
|
||||||
|
UNUSED(x);
|
||||||
|
assert(!isnan(x));
|
||||||
|
assert(!isinf(x));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_swiglu(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_swiglu_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_leaky_relu
|
// ggml_compute_forward_leaky_relu
|
||||||
|
|
||||||
static void ggml_compute_forward_leaky_relu_f32(
|
static void ggml_compute_forward_leaky_relu_f32(
|
||||||
@@ -17289,6 +17395,10 @@ static void ggml_compute_forward_unary(
|
|||||||
{
|
{
|
||||||
ggml_compute_forward_silu(params, dst);
|
ggml_compute_forward_silu(params, dst);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_UNARY_OP_SWIGLU:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_swiglu(params, dst);
|
||||||
|
} break;
|
||||||
case GGML_UNARY_OP_HARDSWISH:
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_hardswish(params, dst);
|
ggml_compute_forward_hardswish(params, dst);
|
||||||
@@ -19155,6 +19265,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||||||
{
|
{
|
||||||
GGML_ABORT("fatal error"); // TODO: not implemented
|
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||||
}
|
}
|
||||||
|
case GGML_UNARY_OP_SWIGLU:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||||
|
}
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
{
|
{
|
||||||
// necessary for llama
|
// necessary for llama
|
||||||
@@ -19656,6 +19770,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||||||
case GGML_UNARY_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
case GGML_UNARY_OP_GELU_QUICK:
|
case GGML_UNARY_OP_GELU_QUICK:
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
|
case GGML_UNARY_OP_SWIGLU:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
|
|||||||
@@ -8111,16 +8111,8 @@ static struct ggml_tensor * llm_build_ffn(
|
|||||||
} break;
|
} break;
|
||||||
case LLM_FFN_SWIGLU:
|
case LLM_FFN_SWIGLU:
|
||||||
{
|
{
|
||||||
// Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
cur = ggml_swiglu(ctx, cur);
|
||||||
int64_t split_point = cur->ne[0] / 2;
|
cb(cur, "ffn_swiglu", il);
|
||||||
struct ggml_tensor * x0 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
|
||||||
struct ggml_tensor * x1 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
|
||||||
|
|
||||||
x0 = ggml_silu(ctx, x0);
|
|
||||||
cb(cur, "ffn_silu", il);
|
|
||||||
|
|
||||||
cur = ggml_mul(ctx, x0, x1);
|
|
||||||
cb(cur, "ffn_mul", il);
|
|
||||||
} break;
|
} break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user