mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
Faster MoE inference (#112)
* multi_sdd: WIP * multi_sdd: CPU works * multi_add: CUDA * multi_add: simplify * multi_add: Metal * Metal: speed up mul_mat_id For the Granite-1B MoE model PP-512 goes from 156 t/s to 890 t/s, so nearly a 6X speedup! --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -494,6 +494,7 @@ extern "C" {
|
||||
GGML_OP_GROUP_NORM,
|
||||
GGML_OP_FUSED_RMS_NORM,
|
||||
GGML_OP_FUSED_MUL_UNARY,
|
||||
GGML_OP_MULTI_ADD,
|
||||
|
||||
GGML_OP_MUL_MAT,
|
||||
GGML_OP_MUL_MAT_ID,
|
||||
@@ -930,6 +931,11 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_multi_add(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_experts);
|
||||
|
||||
// dst = a
|
||||
// view(dst, nb1, nb2, nb3, offset) += b
|
||||
// return dst
|
||||
|
||||
@@ -2220,6 +2220,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_ADD:
|
||||
ggml_cuda_op_add(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MULTI_ADD:
|
||||
ggml_cuda_op_multi_add(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ACC:
|
||||
ggml_cuda_op_acc(ctx, dst);
|
||||
break;
|
||||
@@ -2607,6 +2610,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
||||
GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||
#endif
|
||||
}
|
||||
if (node->op == GGML_OP_MULTI_ADD && node->ne[1] > 1) {
|
||||
// disable CUDA graphs for batch size > 1 for now.
|
||||
// Changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||
use_cuda_graph = false;
|
||||
#ifndef NDEBUG
|
||||
GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_CPY) {
|
||||
// store the copy op parameter which changes with each token.
|
||||
@@ -2927,6 +2938,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_MULTI_ADD:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_RMS_NORM:
|
||||
|
||||
@@ -52,6 +52,25 @@ static __global__ void fused_mul_silu_f32(const float * x, const float * y, floa
|
||||
dst[i] = x[i] * y[i] / (1.0f + expf(-x[i]));
|
||||
}
|
||||
|
||||
static __global__ void multi_add_f32(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, const char * src0, char * dst) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
int64_t k = ne0*ne1;
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
int i1 = i / ne0;
|
||||
int i0 = i % ne0;
|
||||
float * result = (float *)(dst + i1*nb1);
|
||||
const float * s = (const float *)(src0 + i1*nb01) + i0;
|
||||
if (nused == 1) {
|
||||
result[i0] = s[0];
|
||||
} else {
|
||||
float sum = s[0] + s[ne0];
|
||||
for (int j = 2; j < nused; ++j) sum += s[j*ne0];
|
||||
result[i0] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
@@ -218,6 +237,23 @@ static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_
|
||||
sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void multi_add_f32_cuda(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, const char * src0, char * dst, cudaStream_t stream) {
|
||||
int64_t k = ne0 * ne1;
|
||||
const int num_blocks = (k + CUDA_MULTI_ADD_BLOCK_SIZE - 1) / CUDA_MULTI_ADD_BLOCK_SIZE;
|
||||
multi_add_f32<<<num_blocks, CUDA_MULTI_ADD_BLOCK_SIZE, 0, stream>>>(nused, ne0, ne1, nb1, nb01, src0, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->ne[2] == 1 && dst->ne[3] == 1);
|
||||
GGML_ASSERT(dst->nb[0] == sizeof(float));
|
||||
int nused = dst->op_params[0];
|
||||
GGML_ASSERT(nused >= 1);
|
||||
const char * src0 = (const char *)dst->src[0]->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
multi_add_f32_cuda(nused, dst->ne[0], dst->ne[1], dst->nb[1], dst->src[0]->nb[1], src0, (char *)dst->data, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#define CUDA_HARDSWISH_BLOCK_SIZE 256
|
||||
#define CUDA_SQR_BLOCK_SIZE 256
|
||||
#define CUDA_SQRT_BLOCK_SIZE 256
|
||||
#define CUDA_MULTI_ADD_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -35,3 +36,5 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -39,6 +39,8 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_ADD,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_4,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
||||
GGML_METAL_KERNEL_TYPE_MULTI_ADD,
|
||||
GGML_METAL_KERNEL_TYPE_MULTI_ADD_4,
|
||||
GGML_METAL_KERNEL_TYPE_MUL,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_4,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
||||
@@ -577,6 +579,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_4, add_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MULTI_ADD, multi_add, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MULTI_ADD_4, multi_add_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_4, mul_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
||||
@@ -932,6 +936,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_MULTI_ADD:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
@@ -1349,6 +1354,36 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_MULTI_ADD:
|
||||
{
|
||||
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dstt == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ne02 == 1 && ne03 == 1);
|
||||
GGML_ASSERT(nb0 == sizeof(float) && nb00 == sizeof(float));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
int n_expert = dst->op_params[0];
|
||||
GGML_ASSERT(n_expert >= 2);
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
int64_t n = ne0*ne1;
|
||||
if (ne0%4 == 0) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD_4].pipeline;
|
||||
n /= 4;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD].pipeline;
|
||||
}
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:2];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:3];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:4];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
||||
[encoder setBytes:&n_expert length:sizeof(n_expert) atIndex:6];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_REPEAT:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline;
|
||||
|
||||
@@ -479,6 +479,44 @@ kernel void kernel_sqr(
|
||||
dst[tpig] = src0[tpig] * src0[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_multi_add_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & nb1,
|
||||
constant int64_t & nb01,
|
||||
constant int & n_expert,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
|
||||
int64_t i0 = tpig % (ne0/4);
|
||||
int64_t i1 = tpig / (ne0/4);
|
||||
device float4 * dst_ptr = dst + i1*(nb1/16) + i0;
|
||||
device const float4 * src_ptr = src0 + i1*(nb01/16) + i0;
|
||||
float4 sum = src_ptr[0] + src_ptr[ne0/4];
|
||||
for (int i = 2; i < n_expert; ++i) sum += src_ptr[i*ne0/4];
|
||||
dst_ptr[0] = sum;
|
||||
}
|
||||
|
||||
kernel void kernel_multi_add(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & nb1,
|
||||
constant int64_t & nb01,
|
||||
constant int & n_expert,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
|
||||
int64_t i0 = tpig % ne0;
|
||||
int64_t i1 = tpig / ne0;
|
||||
device float * dst_ptr = dst + i1*nb1/4 + i0;
|
||||
device const float * src_ptr = src0 + i1*nb01/4 + i0;
|
||||
float sum = src_ptr[0] + src_ptr[ne0];
|
||||
for (int i = 2; i < n_expert; ++i) sum += src_ptr[i*ne0];
|
||||
dst_ptr[0] = sum;
|
||||
}
|
||||
|
||||
kernel void kernel_sum_rows(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
@@ -8197,6 +8235,7 @@ kernel void kernel_mul_mm_id(
|
||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 ntg3[[threads_per_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
const int32_t i02 = tgpig.z;
|
||||
@@ -8204,25 +8243,87 @@ kernel void kernel_mul_mm_id(
|
||||
|
||||
device const uchar * src0 = src0s + i02*nb02;
|
||||
|
||||
// row indices
|
||||
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
|
||||
uint ntg = ntg3.x * ntg3.y * ntg3.z;
|
||||
uint n = nei0*nei1;
|
||||
|
||||
// TODO: parallelize this loop
|
||||
int64_t _ne1 = 0;
|
||||
for (ushort ii1 = 0; ii1 < nei1; ii1++) {
|
||||
for (ushort ii0 = 0; ii0 < nei0; ii0++) {
|
||||
int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
|
||||
if (id == i02) {
|
||||
//if (tiitg == 0) {
|
||||
rowids[_ne1] = ushort2(ii0, ii1);
|
||||
//}
|
||||
_ne1++;
|
||||
}
|
||||
}
|
||||
//uint npt = (n + ntg - 1) / ntg;
|
||||
//uint first = tiitg * npt;
|
||||
//uint last = first + npt <= n ? first + npt : n;
|
||||
|
||||
//uint nhave = 0;
|
||||
//for (uint i = first; i < last; ++i) {
|
||||
// uint ii0 = i % nei0;
|
||||
// uint ii1 = i / nei0;
|
||||
// int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
|
||||
// if (id == i02) ++nhave;
|
||||
//}
|
||||
//threadgroup uint * nums = (threadgroup uint *)shared_memory;
|
||||
//nums[tiitg] = nhave;
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
//uint nprev = 0;
|
||||
//for (uint i = 0; i < tiitg; ++i) nprev += nums[i];
|
||||
//int64_t _ne1 = nprev;
|
||||
//for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i];
|
||||
|
||||
//threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
|
||||
//for (uint i = first; i < last; ++i) {
|
||||
// uint ii0 = i % nei0;
|
||||
// uint ii1 = i / nei0;
|
||||
// int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
|
||||
// if (id == i02) rowids[nprev++] = ushort2(ii0, ii1);
|
||||
//}
|
||||
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
//
|
||||
// The following is slightly faster than the commented out version above
|
||||
//
|
||||
uint nhave = 0;
|
||||
for (uint i = tiitg; i < n; i += ntg) {
|
||||
uint ii0 = i % nei0;
|
||||
uint ii1 = i / nei0;
|
||||
int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
|
||||
if (id == i02) ++nhave;
|
||||
}
|
||||
|
||||
threadgroup uint * nums = (threadgroup uint *)shared_memory;
|
||||
nums[tiitg] = nhave;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
uint nprev = 0;
|
||||
for (uint i = 0; i < tiitg; ++i) nprev += nums[i];
|
||||
int64_t _ne1 = nprev;
|
||||
for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i];
|
||||
|
||||
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
|
||||
for (uint i = tiitg; i < n; i += ntg) {
|
||||
uint ii0 = i % nei0;
|
||||
uint ii1 = i / nei0;
|
||||
int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
|
||||
if (id == i02) rowids[nprev++] = ushort2(ii0, ii1);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// This is the original version that is ridiculously slow.
|
||||
//// row indices
|
||||
//threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
|
||||
|
||||
//// TODO: parallelize this loop
|
||||
//int64_t _ne1 = 0;
|
||||
//for (ushort ii1 = 0; ii1 < nei1; ii1++) {
|
||||
// for (ushort ii0 = 0; ii0 < nei0; ii0++) {
|
||||
// int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
|
||||
// if (id == i02) {
|
||||
// //if (tiitg == 0) {
|
||||
// rowids[_ne1] = ushort2(ii0, ii1);
|
||||
// //}
|
||||
// _ne1++;
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
kernel_mul_mm_id_impl<Dequantizer>(
|
||||
src0,
|
||||
src1,
|
||||
|
||||
@@ -3338,6 +3338,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"GROUP_NORM",
|
||||
"FUSED_RMS_NORM",
|
||||
"FUSED_MUL_UNARY",
|
||||
"MULTI_ADD",
|
||||
|
||||
"MUL_MAT",
|
||||
"MUL_MAT_ID",
|
||||
@@ -3401,7 +3402,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"CROSS_ENTROPY_LOSS_BACK",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
|
||||
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -3430,6 +3431,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"group_norm(x)",
|
||||
"fused_rms_norm(x)",
|
||||
"fused_mul_unary(x)",
|
||||
"x1+x2+x3+...",
|
||||
|
||||
"X*Y",
|
||||
"X[i]*Y",
|
||||
@@ -3493,7 +3495,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"cross_entropy_loss_back(x,y)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
|
||||
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -5106,6 +5108,29 @@ struct ggml_tensor * ggml_add_inplace(
|
||||
return ggml_add_impl(ctx, a, b, true);
|
||||
}
|
||||
|
||||
// ggml_add
|
||||
|
||||
struct ggml_tensor * ggml_multi_add(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_experts) {
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (n_experts < 1) {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_MULTI_ADD;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
result->op_params[0] = n_experts;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_add_cast
|
||||
|
||||
static struct ggml_tensor * ggml_add_cast_impl(
|
||||
@@ -10425,6 +10450,59 @@ static void ggml_compute_forward_add(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_multi_add_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
struct ggml_tensor * src = dst->src[0];
|
||||
|
||||
GGML_ASSERT(dst->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(ggml_are_same_shape(src, dst));
|
||||
GGML_ASSERT(dst->ne[2] == 1 && dst->ne[3] == 1);
|
||||
|
||||
const int n_add = dst->op_params[0];
|
||||
GGML_ASSERT(n_add > 0);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nr = ggml_nrows(dst);
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
int64_t ne0 = dst->ne[0];
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; ++i1) {
|
||||
|
||||
float * dst_ptr = (float *) ((char *) dst->data + i1*dst->nb[1] );
|
||||
const float * data = (const float *) ((const char *)src->data + i1*src->nb[1]);
|
||||
memset(dst_ptr, 0, ne0*sizeof(float));
|
||||
for (int j = 0; j < n_add; ++j) {
|
||||
ggml_vec_add_f32(ne0, dst_ptr, dst_ptr, data + j*ne0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_multi_add(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
ggml_compute_forward_multi_add_f32(params, dst);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_add1
|
||||
|
||||
static void ggml_compute_forward_add1_f32(
|
||||
@@ -18202,6 +18280,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_add1(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MULTI_ADD:
|
||||
{
|
||||
ggml_compute_forward_multi_add(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_ACC:
|
||||
{
|
||||
ggml_compute_forward_acc(params, tensor);
|
||||
@@ -18947,6 +19029,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
{
|
||||
GGML_ABORT("fatal error"); // TODO: implement
|
||||
}
|
||||
case GGML_OP_MULTI_ADD:
|
||||
{
|
||||
GGML_ABORT("fatal error"); // TODO: implement
|
||||
}
|
||||
case GGML_OP_CONCAT:
|
||||
{
|
||||
GGML_ABORT("fatal error"); // TODO: implement
|
||||
@@ -19996,6 +20082,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_MULTI_ADD:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
|
||||
Reference in New Issue
Block a user