diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index ef493104..d6350f6e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1396,6 +1396,16 @@ extern "C" { struct ggml_tensor * ids, enum ggml_unary_op op); + GGML_API struct ggml_tensor * ggml_moe_up_gate_ext( + struct ggml_context * ctx, + struct ggml_tensor * a_up, + struct ggml_tensor * a_gate, + struct ggml_tensor * b, + struct ggml_tensor * ids, + struct ggml_tensor * a_up_b, + struct ggml_tensor * a_gate_b, + enum ggml_unary_op op); + // A: m columns, n rows, // B: p columns, n rows, // result is m columns, p rows diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index d2ea68f0..d3868f2e 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2221,6 +2221,24 @@ static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_origin } } +//static __global__ void k_quick_add(uint32_t n, uint32_t n_per_row, const float * src1, const float * src2, float * dst) { +// +// for (uint32_t j = threadIdx.x; j < n; j += blockDim.x) { +// dst[j] = src1[j] + src2[j % n_per_row]; +// } +//} + +static __global__ void k_quick_add(uint32_t n_per_row, const float * src1, const float * src2, float * dst) { + + uint32_t row = blockIdx.x; + const float * src1_row = src1 + row*n_per_row; + float * dst_row = dst + row*n_per_row; + + for (uint32_t j = threadIdx.x; j < n_per_row; j += blockDim.x) { + dst_row[j] = src1_row[j] + src2[j]; + } +} + static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids, const ggml_tensor * ids, std::vector& moe_counts, std::vector& cum_moe_counts, ggml_cuda_pool_alloc& dev_row_mapping) { @@ -2271,7 +2289,7 @@ static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n return is_ser; } -static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * ids = dst->src[2]; @@ -2320,7 +2338,25 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * 0, src0->ne[1], 1, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); - return; + if (next && next->op == GGML_OP_MUL_MAT_ID && next->src[0]->type == src0->type && src1 == next->src[1] && + ggml_are_same_shape(src0, next->src[0]) && + ggml_backend_buffer_is_cuda(next->src[0]->buffer) && + ggml_backend_buffer_is_cuda(next->buffer) && + !ggml_backend_buffer_is_cuda_split(next->src[0]->buffer)) { + ggml_backend_cuda_buffer_context * next_src0_ctx = (ggml_backend_cuda_buffer_context *) next->src[0]->buffer->context; + ggml_backend_cuda_buffer_context * next_dst_ctx = (ggml_backend_cuda_buffer_context *) next->buffer->context; + if (next_src0_ctx->device == device_id && + next_dst_ctx->device == device_id) { + local_dst.data = next->data; + ggml_cuda_op_mul_mat_vec_q_id(ctx, next->src[0], &local_src1, ids, &local_dst, + (const char *)next->src[0]->data, nullptr, src1_quantized.get(), (float *)next->data, + 0, src0->ne[1], 1, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + return true; + } + } + + return false; } } @@ -2443,6 +2479,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } } } + return false; } static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) { @@ -2471,6 +2508,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor src0_2_ctx->device == device_id && src1_ctx->device == device_id && dst_ctx->device == device_id) { + //printf("%s(%s, %s): %ld x %ld x %ld, %ld x %ld x %ld, %ld x %ld x %ld\n", __func__, src0_1->name, src0_2->name, + // src0->ne[0], src0->ne[1], src0->ne[2], src1->ne[0], src1->ne[1], src1->ne[2], ids->ne[0], ids->ne[1], ids->ne[2]); // Fast TG path const int64_t n_ids = ids->ne[0]; auto stream = ctx.stream(device_id, 0); @@ -2506,12 +2545,26 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor 0, src0_1->ne[1], 1, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); + if (dst->src[4]) { + ggml_cuda_add_id((const float *)local_dst.data, (const float *)dst->src[4]->data, + (const int32_t *)ids->data, (float *)local_dst.data, + local_dst.ne[0], local_dst.ne[2], local_dst.ne[1], local_dst.ne[0], local_dst.ne[2], + local_dst.nb[1], local_dst.nb[2], dst->src[4]->nb[1], ids->nb[2], stream); + } + local_dst.data = dst_gate_contiguous.get(); ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_2, &local_src1, ids, &local_dst, (const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_gate_contiguous.get(), 0, src0_2->ne[1], 1, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); + if (dst->src[5]) { + ggml_cuda_add_id((const float *)local_dst.data, (const float *)dst->src[5]->data, + (const int32_t *)ids->data, (float *)local_dst.data, + local_dst.ne[0], local_dst.ne[2], local_dst.ne[1], local_dst.ne[0], local_dst.ne[2], + local_dst.nb[1], local_dst.nb[2], dst->src[5]->nb[1], ids->nb[2], stream); + } + if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && ggml_backend_buffer_is_cuda(next->src[0]->buffer) && !ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) && @@ -2519,8 +2572,15 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor ggml_backend_buffer_is_cuda(next->buffer) && ((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id) { - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst->ne[0]*n_ids, - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); + auto unary_op = (ggml_unary_op)dst->op_params[0]; + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get(), dst->ne[0]*n_ids, dst->ne[0], dst->ne[0], dst->ne[0], 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, unary_op, dst->ne[0]*n_ids, + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get()); + } CUDA_CHECK(cudaGetLastError()); const int64_t dst_padded_col_size = GGML_PAD(dst->ne[0], MATRIX_ROW_PADDING); @@ -2556,8 +2616,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor return true; } else { CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream)); - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst), - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data); + auto unary_op = (ggml_unary_op)dst->op_params[0]; + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst->data, dst->ne[0]*n_ids, dst->ne[0], dst->ne[0], dst->ne[0], 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, unary_op, ggml_nelements(dst), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data); + } CUDA_CHECK(cudaGetLastError()); return false; } @@ -2625,7 +2691,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor final_src.nb[3] = final_src.nb[2]; } - if (ne12 == 1) { + if (false && ne12 == 1) { ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); if (fuse_down) { @@ -2762,6 +2828,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor } CUDA_CHECK(cudaGetLastError()); + if (dst->src[4]) { + dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); + dim3 grid_dims(num_src1_rows); + k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, + (const float *)((const char *)dst->src[4]->data + i02*dst->src[4]->nb[1]), (float *)dst_row.data); + CUDA_CHECK(cudaGetLastError()); + } + dst_row.data = dst_gate_contiguous.get(); if (use_quantized_src1) { ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, @@ -2771,8 +2845,24 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor } CUDA_CHECK(cudaGetLastError()); - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); + if (dst->src[5]) { + dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); + dim3 grid_dims(num_src1_rows); + k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, + (const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data); + CUDA_CHECK(cudaGetLastError()); + } + + auto unary_op = (ggml_unary_op)dst->op_params[0]; + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0], + 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get()); + } CUDA_CHECK(cudaGetLastError()); if (fuse_down) { @@ -2945,7 +3035,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg } break; case GGML_OP_MUL_MAT_ID: - ggml_cuda_mul_mat_id(ctx, dst); + skip_next = ggml_cuda_mul_mat_id(ctx, dst, next); break; case GGML_OP_MOE_FUSED_UP_GATE: skip_next = ggml_cuda_up_gate_unary(ctx, dst, next); diff --git a/ggml/src/ggml-cuda/add-id.cu b/ggml/src/ggml-cuda/add-id.cu index 8bed62ac..34e66954 100644 --- a/ggml/src/ggml-cuda/add-id.cu +++ b/ggml/src/ggml-cuda/add-id.cu @@ -56,3 +56,17 @@ void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { nb21 ); } + +void ggml_cuda_add_id(const float * src0, const float * src1, const int32_t * src2, float * dst, + int64_t ne00, int64_t ne01, int64_t ne02, + int64_t ne0, int64_t ne1, size_t nb01, size_t nb02, size_t nb11, size_t nb21, cudaStream_t stream) { + int threads = std::min((int)ne00, 768); // cols + dim3 blocks(ne01, ne02); // n_experts_used, n_tokens + add_id_kernel<<>>( + src0, src1, src2, dst, + ne0, ne1, + nb01, nb02, + nb11, + nb21 + ); +} diff --git a/ggml/src/ggml-cuda/add-id.cuh b/ggml/src/ggml-cuda/add-id.cuh index 30b1721a..175d6800 100644 --- a/ggml/src/ggml-cuda/add-id.cuh +++ b/ggml/src/ggml-cuda/add-id.cuh @@ -1,3 +1,8 @@ #include "common.cuh" void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_add_id(const float * src0, const float * src1, const int32_t * src2, float * dst, + int64_t ne00, int64_t ne01, int64_t ne02, + int64_t ne0, int64_t ne1, size_t nb01, size_t nb02, size_t nb11, size_t nb21, cudaStream_t stream); + diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index ffcaf219..dfc31c77 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -524,7 +524,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // Hence, we use it only for DeepSeek with MLA enabled, where head sizes are 576, 512, // so no other implementation works. // - if (new_mma_available(cc) && Q->ne[0] == 576) { + if (new_mma_available(cc) && (Q->ne[0] == 576 || (Q->ne[0] == 64) && Q->ne[1] >= 128)) { ggml_cuda_flash_attn_ext_mma_new(ctx, dst); return; } diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 4abd6d51..a6742b98 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -546,3 +546,7 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream); } +void ggml_swiglu_oai_cuda_f32(const float * x, const float * g, float * dst, const int64_t k, const int64_t n, + const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) { + swiglu_oai_cuda(x, g, dst, k, n, o0, o1, alpha, limit, stream); +} diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 21f39510..9da1f8ca 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -49,3 +49,7 @@ void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op, void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_swiglu_oai_cuda_f32(const float * x, const float * g, float * dst, const int64_t k, const int64_t n, + const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream); + diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5084a6fe..4d1ffcf7 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7085,6 +7085,66 @@ struct ggml_tensor * ggml_moe_up_gate( result->src[1] = as_gate; result->src[2] = b; result->src[3] = ids; + result->src[4] = NULL; + result->src[5] = NULL; + + ggml_set_op_params_i32(result, 0, (int32_t) op); + + return result; +} + +struct ggml_tensor * ggml_moe_up_gate_ext( + struct ggml_context * ctx, + struct ggml_tensor * as_up, + struct ggml_tensor * as_gate, + struct ggml_tensor * b, + struct ggml_tensor * ids, + struct ggml_tensor * as_up_b, + struct ggml_tensor * as_gate_b, + enum ggml_unary_op op) { + + if (!as_up_b && !as_gate_b) { + return ggml_moe_up_gate(ctx, as_up, as_gate, b, ids, op); + } + + if (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate)) { + struct ggml_tensor * result_up = ggml_mul_mat_id(ctx, as_up, b, ids); + if (as_up_b) { + result_up = ggml_add_id(ctx, result_up, as_up_b, ids); + } + struct ggml_tensor * result_gate = ggml_mul_mat_id(ctx, as_gate, b, ids); + if (as_gate_b) { + result_gate = ggml_add_id(ctx, result_gate, as_gate_b, ids); + } + return ggml_fused_mul_unary(ctx, result_gate, result_up, op); + } + + GGML_ASSERT(!ggml_is_transposed(as_up)); + GGML_ASSERT(!ggml_is_transposed(as_gate)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + GGML_ASSERT(as_up->ne[3] == 1); // as is 3d (one matrix per expert) + GGML_ASSERT(b->ne[3] == 1); // b is 3d + GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d + GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row + GGML_ASSERT(as_up->ne[0] == b->ne[0]); // can_mul_mat + GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast + + GGML_ASSERT(as_up->ne[1] == as_up_b->ne[0]); + GGML_ASSERT(as_gate->ne[1] == as_gate_b->ne[0]); + bool is_node = false; + + const int64_t ne[4] = { as_up->ne[1], ids->ne[0], b->ne[2], 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_MOE_FUSED_UP_GATE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = as_up; + result->src[1] = as_gate; + result->src[2] = b; + result->src[3] = ids; + result->src[4] = as_up_b; + result->src[5] = as_gate_b; ggml_set_op_params_i32(result, 0, (int32_t) op); diff --git a/src/llama.cpp b/src/llama.cpp index 19e1e914..03955003 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -10512,7 +10512,7 @@ static ggml_tensor * llm_build_moe_ffn( float w_scale, llm_expert_gating_func_type gating_op, const llm_build_cb & cb, - int il) { + int il, struct ggml_cgraph * graph = nullptr) { int64_t n_embd = cur->ne[0]; int64_t n_tokens = cur->ne[1]; bool weight_before_ffn = lctx.model.arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN @@ -10606,23 +10606,38 @@ llm_expert_gating_func_type gating_op, // For now we don't modify the fused up/gate op to include biases. // Hence, if we have biases, we cannot use fmoe. // - bool can_use_fmoe = !up_exps_b && !gate_exps_b && (type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU); + //bool can_use_fmoe = !up_exps_b && !gate_exps_b && (type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU); + bool can_use_fmoe = type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU || type_op == LLM_FFN_SWIGLU_OAI_MOE; ggml_tensor * par; if (can_use_fmoe && lctx.cparams.fused_moe_up_gate && up_exps->type == gate_exps->type) { - par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); + if (up_exps_b || gate_exps_b) { + par = ggml_moe_up_gate_ext(ctx, up_exps, gate_exps, cur, selected_experts, up_exps_b, gate_exps_b, + type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : + type_op == LLM_FFN_GELU ? GGML_UNARY_OP_GELU : GGML_UNARY_OP_SWIGLU_OAI); + } else { + GGML_ASSERT(type_op != LLM_FFN_SWIGLU_OAI_MOE); + par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts, + type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); + } } else { ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] cb(up, "ffn_moe_up", il); + ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(gate, "ffn_moe_gate", il); + + if (graph) { + // So we can potentially fuse the up and gate mul_mat_id + ggml_build_forward_expand(graph, up); + ggml_build_forward_expand(graph, gate); + } + if (up_exps_b) { up = ggml_add_id(ctx, up, up_exps_b, selected_experts); cb(up, "ffn_moe_up_biased", il); } - ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(gate, "ffn_moe_gate", il); - if (gate_exps_b) { gate = ggml_add_id(ctx, gate, gate_exps_b, selected_experts); cb(gate, "ffn_moe_gate_biased", il); @@ -10683,7 +10698,7 @@ static ggml_tensor * llm_build_moe_ffn( float w_scale, llm_expert_gating_func_type gating_op, const llm_build_cb & cb, - int il) { + int il, struct ggml_cgraph * graph = nullptr) { return llm_build_moe_ffn(ctx, lctx, cur, gate_inp, nullptr, up_exps, nullptr, @@ -10692,7 +10707,7 @@ llm_expert_gating_func_type gating_op, exp_probs_b, n_expert, n_expert_used, type_op, norm_w, scale_w, w_scale, - gating_op, cb, il); + gating_op, cb, il, graph); } static struct ggml_tensor * llm_build_kqv( @@ -18242,7 +18257,7 @@ struct llm_build_context { LLM_FFN_SWIGLU_OAI_MOE, false, false, 0.0, LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT, - cb, il); + cb, il, gf); cb(cur, "ffn_moe_out", il); cur = ggml_add(ctx0, cur, ffn_inp);