From fa6fb271f31a9a9e5e99b2769cc9fc20bf518fb5 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 17 Oct 2025 18:07:03 +0300 Subject: [PATCH] WIP --- ggml/src/ggml-cuda.cu | 4 + ggml/src/ggml-cuda/argsort.cu | 174 ++++++++++++++++++++++++++++----- ggml/src/ggml-cuda/argsort.cuh | 2 + ggml/src/ggml-cuda/sumrows.cu | 2 +- ggml/src/ggml-cuda/sumrows.cuh | 2 + 5 files changed, 156 insertions(+), 28 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index da43d4cd..fd5a5cac 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3323,6 +3323,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ARGSORT_THRESH: ggml_cuda_op_argsort_thresh(ctx, dst); break; + case GGML_OP_GROUPED_TOPK: + ggml_cuda_op_grouped_topk(ctx, dst); + break; case GGML_OP_FLASH_ATTN_EXT: ggml_cuda_flash_attn_ext(ctx, dst); break; @@ -4332,6 +4335,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ARGSORT_THRESH: + case GGML_OP_GROUPED_TOPK: case GGML_OP_ACC: case GGML_OP_GROUP_NORM: case GGML_OP_UPSCALE: diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 0442dc90..e66b3e12 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -5,6 +5,7 @@ // SPDX-License-Identifier: MIT // #include "argsort.cuh" +#include "sumrows.cuh" template static inline __device__ void ggml_cuda_swap(T & a, T & b) { @@ -25,7 +26,7 @@ struct store { }; template -static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad, Store s) { +static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad, int ntop, Store s) { // int min_experts, float thresh_experts) { // bitonic sort int col = threadIdx.x; @@ -72,27 +73,90 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n if constexpr (Store::has_thresh) { __syncthreads(); float max_val = x_row[dst_row[0]]; - if (col < ncols) { - dst[row * ncols + col] = col < s.min_experts || x_row[dst_row[col]] >= s.thresh_experts*max_val ? dst_row[col] : -1; + if (col < ntop) { + dst[row * ntop + col] = col < s.min_experts || x_row[dst_row[col]] >= s.thresh_experts*max_val ? dst_row[col] : -1; } } else { - if (col < ncols) { - dst[row * ncols + col] = dst_row[col]; + if (col < ntop) { + dst[row * ntop + col] = dst_row[col]; } } - //if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) { - // __syncthreads(); - // float max_val = x_row[dst_row[0]]; - // if (col < ncols) { - // dst[row * ncols + col] = col < min_experts || x_row[dst_row[col]] >= thresh_experts*max_val ? dst_row[col] : -1; - // } - //} - //else { - // // copy the result to dst without the padding - // if (col < ncols) { - // dst[row * ncols + col] = dst_row[col]; - // } - //} +} + +template +static __global__ void k_topk_sum(const float * x, float * dst, const int ncols, int ncols_pad, int n_top_k) { + // bitonic sort + int col = threadIdx.x; + int row = blockIdx.y; + + if (col >= ncols_pad) { + return; + } + + const float * x_row = x + row * ncols; + extern __shared__ int dst_row[]; + + // initialize indices + dst_row[col] = col; + + __syncthreads(); + + for (int k = 2; k <= ncols_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { + ggml_cuda_swap(dst_row[col], dst_row[ixj]); + } + } else { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { + ggml_cuda_swap(dst_row[col], dst_row[ixj]); + } + } + } + __syncthreads(); + } + } + + float val = col < n_top_k ? x[dst_row[col]] : 0; + val = warp_reduce_sum(val); + if (blockDim.x > WARP_SIZE) { + auto s_sum = dst_row; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = val; + } + __syncthreads(); + val = 0.0f; + if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { + val = s_sum[lane_id]; + } + val = warp_reduce_sum(val); + } + + if (col == 0) { + dst[row] = val; + } +} + +static __global__ void k_apply_mask(float * dst, const int * groups, + const int n_top_groups, const int n_per_group, const int ncols) { + int row = blockIdx.y; + for (int col = threadIdx.x; col < n_top_groups*n_per_group; col += blockDim.x) { + int ig = groups[row*n_top_groups + col / n_per_group]; + int ic = col % n_per_group; + dst[row*ncols + ig*n_per_group + ic] = -INFINITY; + } } static int next_power_of_2(int x) { @@ -103,7 +167,7 @@ static int next_power_of_2(int x) { return n; } -static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, +static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, int ntop, ggml_sort_order order, int min_experts, float thresh_experts, cudaStream_t stream) { // bitonic sort requires ncols to be power of 2 const int ncols_pad = next_power_of_2(ncols); @@ -118,19 +182,17 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co if (order == GGML_SORT_ORDER_ASC) { if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) { k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, - {min_experts, thresh_experts}); + ntop, {min_experts, thresh_experts}); } else { - k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, {}); + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, ntop, {}); } - //k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts); } else if (order == GGML_SORT_ORDER_DESC) { if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) { k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, - {min_experts, thresh_experts}); + ntop, {min_experts, thresh_experts}); } else { - k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, {}); + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, ntop, {}); } - //k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts); } else { GGML_ABORT("fatal error"); } @@ -151,7 +213,7 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, -1, 0.f, stream); + argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, ncols, order, -1, 0.f, stream); } void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -171,5 +233,63 @@ void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * float thresh; memcpy(&thresh, dst->op_params + 1, sizeof(float)); - argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, GGML_SORT_ORDER_DESC, min_experts, thresh, stream); + argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, ncols, GGML_SORT_ORDER_DESC, min_experts, thresh, stream); +} + +static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float * src, float * dst, int ncols, int nrows, int n_top_k) { + + GGML_ASSERT(n_top_k <= ncols); + + const int ncols_pad = next_power_of_2(ncols); + + const dim3 block_dims(ncols_pad, 1, 1); + const dim3 block_nums(1, nrows, 1); + const size_t shared_mem = std::max(ncols_pad, WARP_SIZE) * sizeof(int); + GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); + + k_topk_sum<<>>(src, dst, ncols, ncols_pad, n_top_k); +} + +void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + auto src = dst->src[0]; + GGML_ASSERT(dst->type == GGML_TYPE_I32); + GGML_ASSERT(src->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_nrows(src) == ggml_nrows(dst)); + + auto nrows = ggml_nrows(src); + + int n_groups = dst->op_params[0]; + int n_top_groups = dst->op_params[1]; + int nk = dst->op_params[2]; + + int ne00 = src->ne[0]; + int ne0 = dst->ne[0]; + GGML_ASSERT(ne0 <= ne00); + GGML_ASSERT(ne00%n_groups == 0); + int n_per_group = ne00/n_groups; + GGML_ASSERT(nk <= n_per_group); + GGML_ASSERT(n_top_groups < n_groups); + int n_discarded_groups = n_groups - n_top_groups; + + + + ggml_cuda_pool_alloc group_scores(ctx.pool(), nrows*n_groups); + //ggml_cuda_op_topk_sum(ctx, (const float *)src->data, group_scores.get(), n_per_group, nrows*n_groups, nk); + sum_rows_f32_cuda((const float *)src->data, group_scores.get(), n_per_group, nrows*n_groups, ctx.stream()); + CUDA_CHECK(cudaGetLastError()); + + ggml_cuda_pool_alloc discarded_groups(ctx.pool(), nrows*n_discarded_groups); + argsort_f32_i32_cuda(group_scores.get(), discarded_groups.get(), n_groups, nrows, n_discarded_groups, GGML_SORT_ORDER_ASC, -1, 0.0f, ctx.stream()); + CUDA_CHECK(cudaGetLastError()); + + { + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(1, nrows, 1); + cudaStream_t stream = ctx.stream(); + k_apply_mask<<>>((float *)src->data, discarded_groups.get(), n_discarded_groups, n_per_group, ne00); + CUDA_CHECK(cudaGetLastError()); + } + + argsort_f32_i32_cuda((const float *)src->data, (int *)dst->data, ne00, nrows, ne0, GGML_SORT_ORDER_DESC, -1, 0.0f, ctx.stream()); + } diff --git a/ggml/src/ggml-cuda/argsort.cuh b/ggml/src/ggml-cuda/argsort.cuh index 7bbfdf4d..7bd28a1f 100644 --- a/ggml/src/ggml-cuda/argsort.cuh +++ b/ggml/src/ggml-cuda/argsort.cuh @@ -9,3 +9,5 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/sumrows.cu b/ggml/src/ggml-cuda/sumrows.cu index 82e8e875..40be14cf 100644 --- a/ggml/src/ggml-cuda/sumrows.cu +++ b/ggml/src/ggml-cuda/sumrows.cu @@ -16,7 +16,7 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc } } -static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_nums(nrows, 1, 1); k_sum_rows_f32<<>>(x, dst, ncols); diff --git a/ggml/src/ggml-cuda/sumrows.cuh b/ggml/src/ggml-cuda/sumrows.cuh index e7545f83..0c0f4783 100644 --- a/ggml/src/ggml-cuda/sumrows.cuh +++ b/ggml/src/ggml-cuda/sumrows.cuh @@ -1,3 +1,5 @@ #include "common.cuh" void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);