diff --git a/ggml/src/ggml-cuda/reduce.cu b/ggml/src/ggml-cuda/reduce.cu index 3ad4fb1f..dda70252 100644 --- a/ggml/src/ggml-cuda/reduce.cu +++ b/ggml/src/ggml-cuda/reduce.cu @@ -54,6 +54,10 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(nhave >=2 && nhave <= nreduce); + if (dst->op_params[3] == 1) { + // The dst tensor is just a container for the sources and the reduce op is turned off + return; + } auto & info = ggml_cuda_info(); #ifdef GGML_USE_NCCL diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index d9f12ec3..c515d6e9 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -701,6 +701,13 @@ ggml_tensor * llm_build_context::llm_build_ffn( if (cur->ne[1] >= 32) { cur = ggml_cast(ctx, cur, GGML_TYPE_F16); } + if (add_extra && add_extra->op == GGML_OP_REDUCE && add_extra->op_params[3] == 1) { + // When the reduce op is turned off via op_params[3] == 1, we need to add each src + // rtaher than add the reduced add_extra result to the ffn reduced ffn result. + GGML_ASSERT(add_extra->src[id]); // TODO: fix this! It can be null if the splits of the attention and ffn tensors are different + cur = ggml_add(ctx, cur, add_extra->src[id]); + cb(cur, "ffn_with_extra", il_cb); + } if (graph) { ggml_build_forward_expand(graph, cur); } @@ -712,7 +719,7 @@ ggml_tensor * llm_build_context::llm_build_ffn( ffn[id_last] = ggml_add(ctx, ffn[id_last], input); cb(ffn[id_last], "ffn_with_inp", il); } - if (add_extra) { + if (add_extra && !(add_extra->op == GGML_OP_REDUCE && add_extra->op_params[3] == 1)) { ffn[id_last] = ggml_add(ctx, ffn[id_last], add_extra); cb(ffn[id_last], "ffn_with_inp", il); } @@ -7288,6 +7295,8 @@ ggml_cgraph * llm_build_context::build_cohere2() { inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } + attn_out->op_params[3] = 1; // i.e., turn off the reduce operation as it is not required + // feed-forward network cur = llm_build_ffn(ctx0, lctx, model.layers[il].attn_norm, inpL, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,