diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 327961b7..b47b7dd5 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2078,9 +2078,43 @@ static int ggml_cuda_mul_mat_q(ggml_backend_cuda_context & ctx, const ggml_tenso src0->type, stream); CUDA_CHECK(cudaGetLastError()); - ggml_cuda_op_mul_mat_vec_q(ctx, src0, src1, dst, (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data, - 0, src0->ne[1], src1->ne[1], ne10_padded, stream); - CUDA_CHECK(cudaGetLastError()); + // The code below handles the case when Q, K, V have a bias applied after the resepctive matrix multiplication. + // In that case the graph contains mul_mat(Q) -> mul_mat(K) -> mul_mat(V) -> add(Q) -> add(K) -> add(V) + if (cgraph && node_n + 5 < cgraph->n_nodes && + cgraph->nodes[node_n+1]->op == GGML_OP_MUL_MAT && + cgraph->nodes[node_n+2]->op == GGML_OP_MUL_MAT && + ggml_is_quantized(cgraph->nodes[node_n+1]->src[0]->type) && + ggml_is_quantized(cgraph->nodes[node_n+2]->src[0]->type) && + cgraph->nodes[node_n+3]->op == GGML_OP_ADD && + cgraph->nodes[node_n+4]->op == GGML_OP_ADD && + cgraph->nodes[node_n+5]->op == GGML_OP_ADD && + cgraph->nodes[node_n+0] == cgraph->nodes[node_n+3]->src[0] && + cgraph->nodes[node_n+1] == cgraph->nodes[node_n+4]->src[0] && + cgraph->nodes[node_n+2] == cgraph->nodes[node_n+5]->src[0]) { + for (int i = 0; i < 3; ++i) { + auto src0_i = cgraph->nodes[node_n+i]->src[0]; + ggml_cuda_op_mul_mat_vec_q_biased(ctx, src0_i, src1, cgraph->nodes[node_n+i], cgraph->nodes[node_n+i+3]->src[1], + (const char *)src0_i->data, nullptr, src1_quantized.get(), (float *)cgraph->nodes[node_n+i]->data, + 0, src0_i->ne[1], src1->ne[1], ne10_padded, stream); + CUDA_CHECK(cudaGetLastError()); + } + node_n += 5; + } else if (cgraph && node_n + 1 < cgraph->n_nodes && + cgraph->nodes[node_n+1]->op == GGML_OP_ADD && + dst == cgraph->nodes[node_n+1]->src[0] && + dst->ne[0] == cgraph->nodes[node_n+1]->src[1]->ne[0] && + cgraph->nodes[node_n+1]->src[1]->type == GGML_TYPE_F32 && + ggml_nrows(cgraph->nodes[node_n+1]->src[1]) == 1) { + // We have a bias applied after the matrix multiplication and we can fuse it + ggml_cuda_op_mul_mat_vec_q_biased(ctx, dst->src[0], src1, cgraph->nodes[node_n+1], cgraph->nodes[node_n+1]->src[1], + (const char *)dst->src[0]->data, nullptr, src1_quantized.get(), (float *)cgraph->nodes[node_n+1]->data, + 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream); + ++node_n; + } else { + ggml_cuda_op_mul_mat_vec_q(ctx, src0, src1, dst, (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data, + 0, src0->ne[1], src1->ne[1], ne10_padded, stream); + CUDA_CHECK(cudaGetLastError()); + } } else { quantize_mmq_q8_1_cuda((const float *)src1->data, src1_quantized.get(), src1->ne[0], src1->ne[1], 1, ne10_padded, src0->type, stream); CUDA_CHECK(cudaGetLastError()); @@ -2101,8 +2135,21 @@ static int ggml_cuda_mul_mat_q(ggml_backend_cuda_context & ctx, const ggml_tenso if (dst->op != GGML_OP_MUL_MAT || dst->src[1] != src1 || !ggml_is_quantized(dst->src[0]->type)) break; if (!is_gemv && mmq_get_q8_1_ds_layout(src0->type) != mmq_get_q8_1_ds_layout(dst->src[0]->type)) break; if (is_gemv) { - ggml_cuda_op_mul_mat_vec_q(ctx, dst->src[0], src1, dst, (const char *)dst->src[0]->data, nullptr, src1_quantized.get(), - (float *)dst->data, 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream); + if (node_n + 1 < cgraph->n_nodes && + cgraph->nodes[node_n+1]->op == GGML_OP_ADD && + dst == cgraph->nodes[node_n+1]->src[0] && + dst->ne[0] == cgraph->nodes[node_n+1]->src[1]->ne[0] && + cgraph->nodes[node_n+1]->src[1]->type == GGML_TYPE_F32 && + ggml_nrows(cgraph->nodes[node_n+1]->src[1]) == 1) { + // We have a bias applied after the matrix multiplication and we can fuse it + ggml_cuda_op_mul_mat_vec_q_biased(ctx, dst->src[0], src1, cgraph->nodes[node_n+1], cgraph->nodes[node_n+1]->src[1], + (const char *)dst->src[0]->data, nullptr, src1_quantized.get(), (float *)cgraph->nodes[node_n+1]->data, + 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream); + ++node_n; + } else { + ggml_cuda_op_mul_mat_vec_q(ctx, dst->src[0], src1, dst, (const char *)dst->src[0]->data, nullptr, src1_quantized.get(), + (float *)dst->data, 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream); + } } else { ggml_cuda_op_mul_mat_q(ctx, dst->src[0], src1, dst, (const char *)dst->src[0]->data, nullptr, src1_quantized.get(), (float *)dst->data, 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream); diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 701b0f80..7ab7fc26 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -313,7 +313,25 @@ void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_bin_bcast>(&aux_dst, &aux_src, &aux_dst, nullptr, dst->src[0]->data, dst->data, ctx.stream()); } +static __global__ void k_fast_add(int64_t ne0, int64_t nelem, const float * x, const float * y, float * z) { + int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + if (i >= nelem) { + return; + } + z[i] = x[i] + y[i % ne0]; +} + void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + if (ggml_nrows(dst->src[1]) == 1 && dst->src[0]->ne[0] == dst->src[1]->ne[0] && + dst->type == GGML_TYPE_F32 && dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32 && + ggml_are_same_shape(dst, dst->src[0]) && ggml_is_contiguous(dst)) { + constexpr int kBlockSize = 256; + auto nelem = ggml_nelements(dst); + int nblocks = (nelem + kBlockSize - 1)/kBlockSize; + k_fast_add<<>>(dst->ne[0], nelem, + (const float *)dst->src[0]->data, (const float *)dst->src[1]->data, (float *)dst->data); + return; + } ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 68dec960..6734acb3 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -38,6 +38,25 @@ static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne cpy_1(cx + x_offset, cdst + dst_offset); } +template +static __global__ void cpy_flt_contiguous(const char * cx, char * cdst_direct, const int ne, + char ** cdst_indirect, int graph_cpynode_index) { + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= ne) { + return; + } + + auto dst = (cdst_indirect != nullptr) ? (dst_t *)cdst_indirect[graph_cpynode_index] : (dst_t *)cdst_direct; + auto src = (const src_t *)cx; + + if constexpr (std::is_same_v) { + dst[i] = __float2bfloat16(src[i]); + } else { + dst[i] = (dst_t)src[i]; + } +} + static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { float * cdstf = (float *)(cdsti); @@ -163,6 +182,16 @@ static void ggml_cpy_flt_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } +template +static void ggml_cpy_flt_contiguous_cuda( + const char * cx, char * cdst, const int ne, + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_flt_contiguous<<>> + (cx, cdst, ne, cdst_indirect, graph_cpynode_index++); +} + static void ggml_cpy_f32_q8_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -404,6 +433,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char * src0_ddc = (char *) src0->data; char * src1_ddc = (char *) src1->data; + bool fast_cpy = ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_are_same_shape(src0, src1); + char ** dest_ptrs_d = nullptr; int graph_cpynode_index = -1; #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) @@ -429,11 +460,23 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + if (fast_cpy) { + ggml_cpy_flt_contiguous_cuda(src0_ddc, src1_ddc, ne, main_stream, dest_ptrs_d, graph_cpynode_index); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + if (fast_cpy) { + ggml_cpy_flt_contiguous_cuda(src0_ddc, src1_ddc, ne, main_stream, dest_ptrs_d, graph_cpynode_index); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + if (fast_cpy) { + ggml_cpy_flt_contiguous_cuda(src0_ddc, src1_ddc, ne, main_stream, dest_ptrs_d, graph_cpynode_index); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { @@ -505,6 +548,7 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { + bool fast_cpy = ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_are_same_shape(src0, src1); if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { // Prioritize CUDA graph compatibility over direct memory copy optimization. // Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs. @@ -514,11 +558,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return nullptr; } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_flt>; + return fast_cpy ? (void *)cpy_flt_contiguous : (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { - return (void*) cpy_flt>; + return fast_cpy ? (void *)cpy_flt_contiguous : (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_flt>; + return fast_cpy ? (void *)cpy_flt_contiguous : (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 47a7fd51..f70b60ab 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -168,9 +168,10 @@ void ggml_cuda_op_mul_mat_vec_q_3D( GGML_UNUSED(src1_ddf_i); } -void ggml_cuda_op_mul_mat_vec_q( +void ggml_cuda_op_mul_mat_vec_q_biased( ggml_backend_cuda_context & ctx, - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_tensor * bias, + const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream) { @@ -180,14 +181,37 @@ void ggml_cuda_op_mul_mat_vec_q( const int64_t ne0 = dst->ne[0]; + if (bias) { + if (bias->ne[0] != ne0) { + printf("Oops: bias %s is %ld x %ld x %ld x %ld, dst %s is %ld x %ld x %ld x %ld\n", + bias->name, bias->ne[0], bias->ne[1], bias->ne[2], bias->ne[3], + dst->name, dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]); + } + GGML_ASSERT(bias->ne[0] == ne0); + GGML_ASSERT(bias->type == GGML_TYPE_F32); + if (ggml_nrows(bias) != 1) { + printf("Oops: bias %s is %ld x %ld x %ld x %ld\n", bias->name, bias->ne[0], bias->ne[1], bias->ne[2], bias->ne[3]); + } + GGML_ASSERT(ggml_nrows(bias) == 1); + } + ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, ne00, ne0, 1, 0, 0, 0, 0, 0, - src0_dd_i, nullptr, src1_ddq_i, dst_dd_i, nullptr, nullptr, nullptr, + src0_dd_i, nullptr, src1_ddq_i, dst_dd_i, nullptr, bias ? bias->data : nullptr, nullptr, row_low, row_high, src1_ncols, src1_padded_row_size, GGML_UNARY_OP_COUNT, stream); GGML_UNUSED(src1_ddf_i); } +void ggml_cuda_op_mul_mat_vec_q( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + ggml_cuda_op_mul_mat_vec_q_biased(ctx, src0, src1, dst, nullptr, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, row_low, row_high, src1_ncols, + src1_padded_row_size, stream); +} void ggml_cuda_op_mul_mat_vec_q_id( ggml_backend_cuda_context & ctx, diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index da424ab3..5e92dc2a 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -9,12 +9,20 @@ #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. +void ggml_cuda_op_mul_mat_vec_q_biased(ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_tensor * bias, + const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream); + void ggml_cuda_op_mul_mat_vec_q(ggml_backend_cuda_context & ctx, - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); bool ggml_cuda_mmvq_type_supported(ggml_type src0_type); + void ggml_cuda_op_mul_mat_vec_q_3D(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 3f8772e1..917ef027 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -1240,14 +1240,17 @@ std::tuple llm_build_context::llm_buil if (bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); + ggml_build_forward_expand(gf, Qcur); } if (bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); + ggml_build_forward_expand(gf, Kcur); } if (bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); + ggml_build_forward_expand(gf, Vcur); } return {Qcur, Kcur, Vcur}; }