From 2bcbf9f81e074d864ace1f718ab4c1f82f68e253 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 17 Oct 2025 18:38:53 +0300 Subject: [PATCH] cuda: grouped top_k --- ggml/src/ggml-cuda/argsort.cu | 56 ++++++++++++++++++++++++----------- 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index e66b3e12..6fce5560 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -25,8 +25,8 @@ struct store { constexpr static bool has_thresh = false; }; -template -static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad, int ntop, Store s) { +template +static __global__ void k_argsort_f32_T(const float * x, dst_t * dst, const int ncols, int ncols_pad, int ntop, Store s) { // int min_experts, float thresh_experts) { // bitonic sort int col = threadIdx.x; @@ -74,15 +74,25 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n __syncthreads(); float max_val = x_row[dst_row[0]]; if (col < ntop) { - dst[row * ntop + col] = col < s.min_experts || x_row[dst_row[col]] >= s.thresh_experts*max_val ? dst_row[col] : -1; + if constexpr (std::is_same_v) { + dst[row * ntop + col] = col < s.min_experts || x_row[dst_row[col]] >= s.thresh_experts*max_val ? dst_row[col] : -1; + } else { + dst[row * ntop + col] = col < s.min_experts || x_row[dst_row[col]] >= s.thresh_experts*max_val ? x_row[dst_row[col]] : 0.f; + } } } else { if (col < ntop) { - dst[row * ntop + col] = dst_row[col]; + if constexpr (std::is_same_v) { + dst[row * ntop + col] = dst_row[col]; + } else { + dst[row * ntop + col] = x_row[dst_row[col]]; + } } } } +#if 0 +// Somehow this is not working. Someone sees the bug? template static __global__ void k_topk_sum(const float * x, float * dst, const int ncols, int ncols_pad, int n_top_k) { // bitonic sort @@ -130,6 +140,7 @@ static __global__ void k_topk_sum(const float * x, float * dst, const int ncols, float val = col < n_top_k ? x[dst_row[col]] : 0; val = warp_reduce_sum(val); if (blockDim.x > WARP_SIZE) { + __syncthreads(); auto s_sum = dst_row; const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; @@ -148,6 +159,7 @@ static __global__ void k_topk_sum(const float * x, float * dst, const int ncols, dst[row] = val; } } +#endif static __global__ void k_apply_mask(float * dst, const int * groups, const int n_top_groups, const int n_per_group, const int ncols) { @@ -167,7 +179,8 @@ static int next_power_of_2(int x) { return n; } -static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, int ntop, +template +static void argsort_f32_T_cuda(const float * x, dst_t * dst, const int ncols, const int nrows, int ntop, ggml_sort_order order, int min_experts, float thresh_experts, cudaStream_t stream) { // bitonic sort requires ncols to be power of 2 const int ncols_pad = next_power_of_2(ncols); @@ -181,17 +194,17 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co if (order == GGML_SORT_ORDER_ASC) { if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) { - k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, + k_argsort_f32_T<<>>(x, dst, ncols, ncols_pad, ntop, {min_experts, thresh_experts}); } else { - k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, ntop, {}); + k_argsort_f32_T<<>>(x, dst, ncols, ncols_pad, ntop, {}); } } else if (order == GGML_SORT_ORDER_DESC) { if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) { - k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, + k_argsort_f32_T<<>>(x, dst, ncols, ncols_pad, ntop, {min_experts, thresh_experts}); } else { - k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, ntop, {}); + k_argsort_f32_T<<>>(x, dst, ncols, ncols_pad, ntop, {}); } } else { GGML_ABORT("fatal error"); @@ -213,7 +226,7 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, ncols, order, -1, 0.f, stream); + argsort_f32_T_cuda(src0_d, (int *)dst_d, ncols, nrows, ncols, order, -1, 0.f, stream); } void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -233,9 +246,10 @@ void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * float thresh; memcpy(&thresh, dst->op_params + 1, sizeof(float)); - argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, ncols, GGML_SORT_ORDER_DESC, min_experts, thresh, stream); + argsort_f32_T_cuda(src0_d, (int *)dst_d, ncols, nrows, ncols, GGML_SORT_ORDER_DESC, min_experts, thresh, stream); } +#if 0 static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float * src, float * dst, int ncols, int nrows, int n_top_k) { GGML_ASSERT(n_top_k <= ncols); @@ -249,6 +263,7 @@ static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float * k_topk_sum<<>>(src, dst, ncols, ncols_pad, n_top_k); } +#endif void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { auto src = dst->src[0]; @@ -271,15 +286,22 @@ void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * ds GGML_ASSERT(n_top_groups < n_groups); int n_discarded_groups = n_groups - n_top_groups; - - + ggml_cuda_pool_alloc sorted_group_scores(ctx.pool(), nk*nrows*n_groups); + argsort_f32_T_cuda((const float *)src->data, sorted_group_scores.get(), n_per_group, nrows*n_groups, nk, + GGML_SORT_ORDER_DESC, -1, 0.0f, ctx.stream()); + CUDA_CHECK(cudaGetLastError()); ggml_cuda_pool_alloc group_scores(ctx.pool(), nrows*n_groups); - //ggml_cuda_op_topk_sum(ctx, (const float *)src->data, group_scores.get(), n_per_group, nrows*n_groups, nk); - sum_rows_f32_cuda((const float *)src->data, group_scores.get(), n_per_group, nrows*n_groups, ctx.stream()); + sum_rows_f32_cuda((const float *)sorted_group_scores.get(), group_scores.get(), nk, nrows*n_groups, ctx.stream()); CUDA_CHECK(cudaGetLastError()); + // This is not working for some reason, so we resort to the slightly less efficient implementation above + //ggml_cuda_pool_alloc group_scores(ctx.pool(), nrows*n_groups); + //ggml_cuda_op_topk_sum(ctx, (const float *)src->data, group_scores.get(), n_per_group, nrows*n_groups, nk); + ////sum_rows_f32_cuda((const float *)src->data, group_scores.get(), n_per_group, nrows*n_groups, ctx.stream()); + //CUDA_CHECK(cudaGetLastError()); + ggml_cuda_pool_alloc discarded_groups(ctx.pool(), nrows*n_discarded_groups); - argsort_f32_i32_cuda(group_scores.get(), discarded_groups.get(), n_groups, nrows, n_discarded_groups, GGML_SORT_ORDER_ASC, -1, 0.0f, ctx.stream()); + argsort_f32_T_cuda(group_scores.get(), discarded_groups.get(), n_groups, nrows, n_discarded_groups, GGML_SORT_ORDER_ASC, -1, 0.0f, ctx.stream()); CUDA_CHECK(cudaGetLastError()); { @@ -290,6 +312,6 @@ void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * ds CUDA_CHECK(cudaGetLastError()); } - argsort_f32_i32_cuda((const float *)src->data, (int *)dst->data, ne00, nrows, ne0, GGML_SORT_ORDER_DESC, -1, 0.0f, ctx.stream()); + argsort_f32_T_cuda((const float *)src->data, (int *)dst->data, ne00, nrows, ne0, GGML_SORT_ORDER_DESC, -1, 0.0f, ctx.stream()); }