diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index da5df81a..da43d4cd 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -45,6 +45,7 @@ #include "ggml-cuda/conv2d.cuh" #include "ggml-cuda/conv2d-dw.cuh" #include "ggml-cuda/set-rows.cuh" +#include "ggml-cuda/argmax.cuh" #include #include @@ -3106,6 +3107,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg auto next = i < cgraph->n_nodes - 1 ? cgraph->nodes[i+1] : nullptr; switch (dst->op) { + case GGML_OP_ARGMAX: + ggml_cuda_argmax(ctx, dst); + break; case GGML_OP_REPEAT: ggml_cuda_op_repeat(ctx, dst); break; @@ -4272,6 +4276,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons } return false; } break; + case GGML_OP_ARGMAX: + return true; case GGML_OP_DUP: case GGML_OP_REPEAT: case GGML_OP_CONCAT: diff --git a/ggml/src/ggml-cuda/argmax.cu b/ggml/src/ggml-cuda/argmax.cu new file mode 100644 index 00000000..754e133d --- /dev/null +++ b/ggml/src/ggml-cuda/argmax.cu @@ -0,0 +1,90 @@ +#include +#include + +#include "argmax.cuh" +#include "common.cuh" + +static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) { + const int64_t row = blockIdx.x; + + float maxval = -FLT_MAX; + int argmax = -1; + const float * rowx = x + row * ncols; + + for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) { + const float val = rowx[col]; + if (val > maxval) { + maxval = val; + argmax = col; + } + } + +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); + const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); + if (val > maxval) { + maxval = val; + argmax = col; + } + } + + const int n_warps = blockDim.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + if (n_warps > 1) { + constexpr int max_warps = 1024 / WARP_SIZE; + __shared__ float shared_maxval[max_warps]; + __shared__ int shared_argmax[max_warps]; + if (lane_id == 0) { + shared_maxval[warp_id] = maxval; + shared_argmax[warp_id] = argmax; + } + + __syncthreads(); + + if (warp_id == 0) { + if (lane_id < n_warps) { + maxval = shared_maxval[lane_id]; + argmax = shared_argmax[lane_id]; + } +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); + const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); + if (val > maxval) { + maxval = val; + argmax = col; + } + } + } + } + + if (warp_id == 0 && lane_id == 0) { + dst[row] = argmax; + } +} + +void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + const float * src0_d = (const float *) src0->data; + int32_t * dst_d = (int32_t *) dst->data; + + cudaStream_t stream = ctx.stream(); + + const int64_t num_blocks = nrows; + const int64_t num_threads = std::min(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE); + const dim3 blocks_dim(num_threads, 1, 1); + const dim3 blocks_num(num_blocks, 1, 1); + + argmax_f32<<>>(src0_d, dst_d, ne00); +} diff --git a/ggml/src/ggml-cuda/argmax.cuh b/ggml/src/ggml-cuda/argmax.cuh new file mode 100644 index 00000000..5b7223ad --- /dev/null +++ b/ggml/src/ggml-cuda/argmax.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index df214082..0442dc90 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -13,9 +13,20 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) { b = tmp; } -template -static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad, - int min_experts, float thresh_experts) { +struct store_ser { + constexpr static bool has_thresh = true; + int min_experts; + float thresh_experts; + store_ser(int min, float thresh) : min_experts(min), thresh_experts(thresh) {} +}; + +struct store { + constexpr static bool has_thresh = false; +}; + +template +static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad, Store s) { +// int min_experts, float thresh_experts) { // bitonic sort int col = threadIdx.x; int row = blockIdx.y; @@ -58,19 +69,30 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n } } - if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) { + if constexpr (Store::has_thresh) { __syncthreads(); float max_val = x_row[dst_row[0]]; if (col < ncols) { - dst[row * ncols + col] = col < min_experts || x_row[dst_row[col]] >= thresh_experts*max_val ? dst_row[col] : -1; + dst[row * ncols + col] = col < s.min_experts || x_row[dst_row[col]] >= s.thresh_experts*max_val ? dst_row[col] : -1; } - } - else { - // copy the result to dst without the padding + } else { if (col < ncols) { dst[row * ncols + col] = dst_row[col]; } } + //if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) { + // __syncthreads(); + // float max_val = x_row[dst_row[0]]; + // if (col < ncols) { + // dst[row * ncols + col] = col < min_experts || x_row[dst_row[col]] >= thresh_experts*max_val ? dst_row[col] : -1; + // } + //} + //else { + // // copy the result to dst without the padding + // if (col < ncols) { + // dst[row * ncols + col] = dst_row[col]; + // } + //} } static int next_power_of_2(int x) { @@ -94,9 +116,21 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); if (order == GGML_SORT_ORDER_ASC) { - k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts); + if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) { + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, + {min_experts, thresh_experts}); + } else { + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, {}); + } + //k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts); } else if (order == GGML_SORT_ORDER_DESC) { - k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts); + if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) { + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, + {min_experts, thresh_experts}); + } else { + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, {}); + } + //k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts); } else { GGML_ABORT("fatal error"); } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 04aa3e26..68dec960 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -422,7 +422,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY { - CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + if (src0->type == GGML_TYPE_F32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else { + CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + } } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); @@ -473,6 +477,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) { // This is needed for MLA with mla=2 when using q8_0 cache. transpose_q8_0(ctx, src0, src1); @@ -498,7 +506,13 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { - return nullptr; + // Prioritize CUDA graph compatibility over direct memory copy optimization. + // Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs. + if (src0->type == GGML_TYPE_F32) { + return (void*) cpy_flt>; + } else { + return nullptr; + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { @@ -545,6 +559,10 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { + return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_flt>; } else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) { return (void *)transpose_q8_0; } else { diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 02e06b07..e1817bae 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -820,18 +820,18 @@ llm_expert_gating_func_type gating_op, selection_probs = logits; } - if (lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) { + if (false && lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) { auto& hparams = lctx.model.hparams; const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups; // organize experts into n_expert_groups ggml_tensor * selection_groups = ggml_view_2d(ctx, ggml_cont(ctx, ggml_transpose(ctx, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups, n_tokens * n_exp_per_group * sizeof(float), 0); // [n_tokens * n_exp_per_group, n_expert_groups] #if 0 - ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups] - group_scores = ggml_get_rows(ctx0, ggml_reshape_3d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 2, n_expert_groups] + ggml_tensor * group_scores = ggml_top_k(ctx, selection_groups, 2); // [2, n_expert_groups] + group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 2, n_expert_groups] // get top n_group_used expert groups - group_scores = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]))); // [n_expert_groups, 1] + group_scores = ggml_transpose(ctx, ggml_sum_rows(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2]))); // [n_expert_groups, 1] #else // Replace top_k(2) with argmax due to backend limitations, ideally we should use something like argmax2 instead ggml_tensor * group_scores = ggml_reshape_2d(ctx, ggml_argmax(ctx, selection_groups), 1, selection_groups->ne[1]); // [1, n_expert_groups]