diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 4c2d8b32..161355ef 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2521,6 +2521,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor 0, src0_2->ne[1], 1, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); + bool fuse_down = false; if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && ggml_backend_buffer_is_cuda(next->src[0]->buffer) && !ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) && @@ -2562,14 +2563,15 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor 0, next->src[0]->ne[1], 1, dst_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); - return true; + fuse_down = true; } else { CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream)); ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst), (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data); CUDA_CHECK(cudaGetLastError()); - return false; } + CUDA_CHECK(cudaStreamSynchronize(stream)); + return fuse_down; } }