diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 3f129ce0..8980285f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -933,7 +933,8 @@ extern "C" { GGML_API struct ggml_tensor * ggml_multi_add( struct ggml_context * ctx, - struct ggml_tensor ** a); + struct ggml_tensor * a, + int n_experts); // dst = a // view(dst, nb1, nb2, nb3, offset) += b diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 72043c2e..8ffddd6d 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -62,9 +62,13 @@ static __global__ void multi_add_f32(int nused, int64_t ne0, int64_t ne1, int64_ int i0 = i % ne0; float * result = (float *)(dst + i1*nb1); const float * s = (const float *)(src0 + i1*nb01) + i0; - float sum = 0; - for (int j = 0; j < nused; ++j) sum += s[j*ne0]; - result[i0] = sum; + if (nused == 1) { + result[i0] = s[0]; + } else { + float sum = s[0] + s[ne0]; + for (int j = 2; j < nused; ++j) sum += s[j*ne0]; + result[i0] = sum; + } } static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) { @@ -243,29 +247,9 @@ void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->ne[2] == 1 && dst->ne[3] == 1); GGML_ASSERT(dst->nb[0] == sizeof(float)); - int nused = 0; - for (int i = 0; i < GGML_MAX_SRC; ++i) { - ggml_tensor * src = dst->src[i]; - if (src) { - GGML_ASSERT(src->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_are_same_shape(src, dst)); - GGML_ASSERT(src->ne[2] == 1 && src->ne[3] == 1); - GGML_ASSERT(src->nb[0] == sizeof(float)); - ++nused; - } else { - break; - } - } - GGML_ASSERT(nused >= 2); + int nused = dst->op_params[0]; + GGML_ASSERT(nused >= 1); const char * src0 = (const char *)dst->src[0]->data; - const int64_t nb01 = dst->src[0]->ne[0]*sizeof(float); - for (int i = 1; i < nused; ++i) { - GGML_ASSERT(dst->src[i]->nb[1] == dst->src[0]->nb[1]); - const char * src = (const char *)dst->src[i]->data; - GGML_ASSERT(src == src0 + i*nb01); - GGML_ASSERT(dst->src[i]->nb[1] == dst->src[0]->nb[1]); - } - //printf("%s: nused = %d\n", __func__, nused); cudaStream_t stream = ctx.stream(); multi_add_f32_cuda(nused, dst->ne[0], dst->ne[1], dst->nb[1], dst->src[0]->nb[1], src0, (char *)dst->data, stream); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5e088ca0..39218ff4 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5112,41 +5112,21 @@ struct ggml_tensor * ggml_add_inplace( struct ggml_tensor * ggml_multi_add( struct ggml_context * ctx, - struct ggml_tensor ** a) { + struct ggml_tensor * a, + int n_experts) { bool is_node = false; - struct ggml_tensor * a_used[GGML_MAX_SRC]; - int n_used = 0; - for (int i = 0; i < GGML_MAX_SRC; ++i) { - if (a[i]) { - a_used[n_used++] = a[i]; - } - } - - if (n_used < 2) { + if (n_experts < 1) { GGML_ABORT("fatal error"); } - if (n_used == 2) { - return ggml_add(ctx, a_used[0], a_used[1]); - } - for (int i = 1; i < n_used; ++i) { - if (!ggml_are_same_shape(a_used[i], a_used[0])) { - GGML_ABORT("fayal error"); - } - } - - struct ggml_tensor * result = ggml_dup_tensor(ctx, a_used[0]); + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); result->op = GGML_OP_MULTI_ADD; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - for (int i = 0; i < n_used; ++i) { - result->src[i] = a_used[i]; - } - for (int i = n_used; i < GGML_MAX_SRC; ++i) { - result->src[i] = NULL; - } + result->src[0] = a; + result->op_params[0] = n_experts; return result; } @@ -10474,13 +10454,15 @@ static void ggml_compute_forward_multi_add_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { + struct ggml_tensor * src = dst->src[0]; + GGML_ASSERT(dst->nb[0] == sizeof(float)); - for (int i = 0; i < GGML_MAX_SRC; ++i) { - if (dst->src[i]) { - GGML_ASSERT(ggml_are_same_shape(dst->src[i], dst)); - GGML_ASSERT(dst->src[i]->nb[0] == sizeof(float)); - } - } + GGML_ASSERT(src->nb[0] == sizeof(float)); + GGML_ASSERT(ggml_are_same_shape(src, dst)); + GGML_ASSERT(dst->ne[2] == 1 && dst->ne[3] == 1); + + const int n_add = dst->op_params[0]; + GGML_ASSERT(n_add > 0); const int ith = params->ith; const int nth = params->nth; @@ -10495,22 +10477,14 @@ static void ggml_compute_forward_multi_add_f32( const int ir1 = MIN(ir0 + dr, nr); int64_t ne0 = dst->ne[0]; - int64_t ne1 = dst->ne[1]; - int64_t ne2 = dst->ne[2]; - for (int ir = ir0; ir < ir1; ++ir) { - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i3 = ir/(ne2*ne1); - const int64_t i2 = (ir - i3*ne2*ne1)/ne1; - const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); + for (int i1 = ir0; i1 < ir1; ++i1) { - float * dst_ptr = (float *) ((char *) dst->data + i3*dst->nb[3] + i2*dst->nb[2] + i1*dst->nb[1] ); + float * dst_ptr = (float *) ((char *) dst->data + i1*dst->nb[1] ); + const float * data = (const float *) ((const char *)src->data + i1*src->nb[1]); memset(dst_ptr, 0, ne0*sizeof(float)); - for (int i = 0; i < GGML_MAX_SRC; ++i) { - struct ggml_tensor * src = dst->src[i]; - if (!src) continue; - const float * data = (const float *) ((const char *) src->data + i3*src->nb[3] + i2*src->nb[2] + i1*src->nb[1]); - ggml_vec_add_f32(ne0, dst_ptr, dst_ptr, data); + for (int j = 0; j < n_add; ++j) { + ggml_vec_add_f32(ne0, dst_ptr, dst_ptr, data + j*ne0); } } } diff --git a/src/llama.cpp b/src/llama.cpp index 05214d90..2b9a1b1a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8358,44 +8358,33 @@ static struct ggml_tensor * llm_build_moe_ffn( return ggml_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0), ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], experts->nb[1])); } - if (n_expert_used <= GGML_MAX_SRC) { - ggml_tensor * src[GGML_MAX_SRC]; - for (int i = 0; i < n_expert_used; ++i) { - src[i] = ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]); - } - for (int i = n_expert_used; i < GGML_MAX_SRC; ++i) src[i] = nullptr; - return ggml_multi_add(ctx, src); - } + return ggml_multi_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0), n_expert_used); - GGML_ABORT("fatal error"); + //// aggregate experts + //ggml_tensor * moe_out = nullptr; + ////ggml_tensor * first_expert = nullptr; + //for (int i = 0; i < n_expert_used; ++i) { + // ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens, + // experts->nb[2], i*experts->nb[1]); - //int nloop = (n_expert_used + GGML_MAX_SRC - 1)/GGML_MAX_SRC; + // if (i == 0) { + // moe_out = cur_expert; + // //first_expert = cur_expert; + // //printf("%s: %d: %d x %d x %d x %d | %d x %d x %d x %d\n", __func__, ggml_is_contiguous(first_expert), + // // (int)cur_expert->ne[0], (int)cur_expert->ne[1], (int)cur_expert->ne[2], (int)cur_expert->ne[3], + // // (int)cur_expert->nb[0], (int)cur_expert->nb[1], (int)cur_expert->nb[2], (int)cur_expert->nb[3]); + // } else { + // moe_out = ggml_add(ctx, moe_out, cur_expert); + // //printf("%s: %d %d\n", __func__, ggml_is_contiguous(cur_expert), ggml_are_same_shape(cur_expert, first_expert)); + // } + //} - // aggregate experts - ggml_tensor * moe_out = nullptr; - //ggml_tensor * first_expert = nullptr; - for (int i = 0; i < n_expert_used; ++i) { - ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens, - experts->nb[2], i*experts->nb[1]); + //if (n_expert_used == 1) { + // // avoid returning a non-contiguous tensor + // moe_out = ggml_cont(ctx, moe_out); + //} - if (i == 0) { - moe_out = cur_expert; - //first_expert = cur_expert; - //printf("%s: %d: %d x %d x %d x %d | %d x %d x %d x %d\n", __func__, ggml_is_contiguous(first_expert), - // (int)cur_expert->ne[0], (int)cur_expert->ne[1], (int)cur_expert->ne[2], (int)cur_expert->ne[3], - // (int)cur_expert->nb[0], (int)cur_expert->nb[1], (int)cur_expert->nb[2], (int)cur_expert->nb[3]); - } else { - moe_out = ggml_add(ctx, moe_out, cur_expert); - //printf("%s: %d %d\n", __func__, ggml_is_contiguous(cur_expert), ggml_are_same_shape(cur_expert, first_expert)); - } - } - - if (n_expert_used == 1) { - // avoid returning a non-contiguous tensor - moe_out = ggml_cont(ctx, moe_out); - } - - return moe_out; + //return moe_out; } static struct ggml_tensor * llm_build_kqv(