mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
cuda: grouped top_k
This commit is contained in:
@@ -25,8 +25,8 @@ struct store {
|
|||||||
constexpr static bool has_thresh = false;
|
constexpr static bool has_thresh = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<ggml_sort_order order, typename Store>
|
template<ggml_sort_order order, typename Store, typename dst_t>
|
||||||
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad, int ntop, Store s) {
|
static __global__ void k_argsort_f32_T(const float * x, dst_t * dst, const int ncols, int ncols_pad, int ntop, Store s) {
|
||||||
// int min_experts, float thresh_experts) {
|
// int min_experts, float thresh_experts) {
|
||||||
// bitonic sort
|
// bitonic sort
|
||||||
int col = threadIdx.x;
|
int col = threadIdx.x;
|
||||||
@@ -74,15 +74,25 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
float max_val = x_row[dst_row[0]];
|
float max_val = x_row[dst_row[0]];
|
||||||
if (col < ntop) {
|
if (col < ntop) {
|
||||||
dst[row * ntop + col] = col < s.min_experts || x_row[dst_row[col]] >= s.thresh_experts*max_val ? dst_row[col] : -1;
|
if constexpr (std::is_same_v<dst_t, int>) {
|
||||||
|
dst[row * ntop + col] = col < s.min_experts || x_row[dst_row[col]] >= s.thresh_experts*max_val ? dst_row[col] : -1;
|
||||||
|
} else {
|
||||||
|
dst[row * ntop + col] = col < s.min_experts || x_row[dst_row[col]] >= s.thresh_experts*max_val ? x_row[dst_row[col]] : 0.f;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (col < ntop) {
|
if (col < ntop) {
|
||||||
dst[row * ntop + col] = dst_row[col];
|
if constexpr (std::is_same_v<dst_t, int>) {
|
||||||
|
dst[row * ntop + col] = dst_row[col];
|
||||||
|
} else {
|
||||||
|
dst[row * ntop + col] = x_row[dst_row[col]];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
// Somehow this is not working. Someone sees the bug?
|
||||||
template<ggml_sort_order order>
|
template<ggml_sort_order order>
|
||||||
static __global__ void k_topk_sum(const float * x, float * dst, const int ncols, int ncols_pad, int n_top_k) {
|
static __global__ void k_topk_sum(const float * x, float * dst, const int ncols, int ncols_pad, int n_top_k) {
|
||||||
// bitonic sort
|
// bitonic sort
|
||||||
@@ -130,6 +140,7 @@ static __global__ void k_topk_sum(const float * x, float * dst, const int ncols,
|
|||||||
float val = col < n_top_k ? x[dst_row[col]] : 0;
|
float val = col < n_top_k ? x[dst_row[col]] : 0;
|
||||||
val = warp_reduce_sum(val);
|
val = warp_reduce_sum(val);
|
||||||
if (blockDim.x > WARP_SIZE) {
|
if (blockDim.x > WARP_SIZE) {
|
||||||
|
__syncthreads();
|
||||||
auto s_sum = dst_row;
|
auto s_sum = dst_row;
|
||||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
@@ -148,6 +159,7 @@ static __global__ void k_topk_sum(const float * x, float * dst, const int ncols,
|
|||||||
dst[row] = val;
|
dst[row] = val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
static __global__ void k_apply_mask(float * dst, const int * groups,
|
static __global__ void k_apply_mask(float * dst, const int * groups,
|
||||||
const int n_top_groups, const int n_per_group, const int ncols) {
|
const int n_top_groups, const int n_per_group, const int ncols) {
|
||||||
@@ -167,7 +179,8 @@ static int next_power_of_2(int x) {
|
|||||||
return n;
|
return n;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, int ntop,
|
template <typename dst_t>
|
||||||
|
static void argsort_f32_T_cuda(const float * x, dst_t * dst, const int ncols, const int nrows, int ntop,
|
||||||
ggml_sort_order order, int min_experts, float thresh_experts, cudaStream_t stream) {
|
ggml_sort_order order, int min_experts, float thresh_experts, cudaStream_t stream) {
|
||||||
// bitonic sort requires ncols to be power of 2
|
// bitonic sort requires ncols to be power of 2
|
||||||
const int ncols_pad = next_power_of_2(ncols);
|
const int ncols_pad = next_power_of_2(ncols);
|
||||||
@@ -181,17 +194,17 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
|
|||||||
|
|
||||||
if (order == GGML_SORT_ORDER_ASC) {
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) {
|
if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) {
|
||||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC, store_ser><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad,
|
k_argsort_f32_T<GGML_SORT_ORDER_ASC, store_ser><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad,
|
||||||
ntop, {min_experts, thresh_experts});
|
ntop, {min_experts, thresh_experts});
|
||||||
} else {
|
} else {
|
||||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC, store><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, ntop, {});
|
k_argsort_f32_T<GGML_SORT_ORDER_ASC, store><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, ntop, {});
|
||||||
}
|
}
|
||||||
} else if (order == GGML_SORT_ORDER_DESC) {
|
} else if (order == GGML_SORT_ORDER_DESC) {
|
||||||
if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) {
|
if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) {
|
||||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC, store_ser><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad,
|
k_argsort_f32_T<GGML_SORT_ORDER_DESC, store_ser><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad,
|
||||||
ntop, {min_experts, thresh_experts});
|
ntop, {min_experts, thresh_experts});
|
||||||
} else {
|
} else {
|
||||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC, store><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, ntop, {});
|
k_argsort_f32_T<GGML_SORT_ORDER_DESC, store><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, ntop, {});
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
@@ -213,7 +226,7 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||||||
|
|
||||||
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
||||||
|
|
||||||
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, ncols, order, -1, 0.f, stream);
|
argsort_f32_T_cuda(src0_d, (int *)dst_d, ncols, nrows, ncols, order, -1, 0.f, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
@@ -233,9 +246,10 @@ void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||||||
float thresh;
|
float thresh;
|
||||||
memcpy(&thresh, dst->op_params + 1, sizeof(float));
|
memcpy(&thresh, dst->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, ncols, GGML_SORT_ORDER_DESC, min_experts, thresh, stream);
|
argsort_f32_T_cuda(src0_d, (int *)dst_d, ncols, nrows, ncols, GGML_SORT_ORDER_DESC, min_experts, thresh, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if 0
|
||||||
static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float * src, float * dst, int ncols, int nrows, int n_top_k) {
|
static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float * src, float * dst, int ncols, int nrows, int n_top_k) {
|
||||||
|
|
||||||
GGML_ASSERT(n_top_k <= ncols);
|
GGML_ASSERT(n_top_k <= ncols);
|
||||||
@@ -249,6 +263,7 @@ static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float *
|
|||||||
|
|
||||||
k_topk_sum<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, ctx.stream()>>>(src, dst, ncols, ncols_pad, n_top_k);
|
k_topk_sum<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, ctx.stream()>>>(src, dst, ncols, ncols_pad, n_top_k);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
auto src = dst->src[0];
|
auto src = dst->src[0];
|
||||||
@@ -271,15 +286,22 @@ void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * ds
|
|||||||
GGML_ASSERT(n_top_groups < n_groups);
|
GGML_ASSERT(n_top_groups < n_groups);
|
||||||
int n_discarded_groups = n_groups - n_top_groups;
|
int n_discarded_groups = n_groups - n_top_groups;
|
||||||
|
|
||||||
|
ggml_cuda_pool_alloc<float> sorted_group_scores(ctx.pool(), nk*nrows*n_groups);
|
||||||
|
argsort_f32_T_cuda((const float *)src->data, sorted_group_scores.get(), n_per_group, nrows*n_groups, nk,
|
||||||
|
GGML_SORT_ORDER_DESC, -1, 0.0f, ctx.stream());
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
ggml_cuda_pool_alloc<float> group_scores(ctx.pool(), nrows*n_groups);
|
ggml_cuda_pool_alloc<float> group_scores(ctx.pool(), nrows*n_groups);
|
||||||
//ggml_cuda_op_topk_sum(ctx, (const float *)src->data, group_scores.get(), n_per_group, nrows*n_groups, nk);
|
sum_rows_f32_cuda((const float *)sorted_group_scores.get(), group_scores.get(), nk, nrows*n_groups, ctx.stream());
|
||||||
sum_rows_f32_cuda((const float *)src->data, group_scores.get(), n_per_group, nrows*n_groups, ctx.stream());
|
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
|
// This is not working for some reason, so we resort to the slightly less efficient implementation above
|
||||||
|
//ggml_cuda_pool_alloc<float> group_scores(ctx.pool(), nrows*n_groups);
|
||||||
|
//ggml_cuda_op_topk_sum(ctx, (const float *)src->data, group_scores.get(), n_per_group, nrows*n_groups, nk);
|
||||||
|
////sum_rows_f32_cuda((const float *)src->data, group_scores.get(), n_per_group, nrows*n_groups, ctx.stream());
|
||||||
|
//CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
ggml_cuda_pool_alloc<int> discarded_groups(ctx.pool(), nrows*n_discarded_groups);
|
ggml_cuda_pool_alloc<int> discarded_groups(ctx.pool(), nrows*n_discarded_groups);
|
||||||
argsort_f32_i32_cuda(group_scores.get(), discarded_groups.get(), n_groups, nrows, n_discarded_groups, GGML_SORT_ORDER_ASC, -1, 0.0f, ctx.stream());
|
argsort_f32_T_cuda(group_scores.get(), discarded_groups.get(), n_groups, nrows, n_discarded_groups, GGML_SORT_ORDER_ASC, -1, 0.0f, ctx.stream());
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -290,6 +312,6 @@ void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * ds
|
|||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
argsort_f32_i32_cuda((const float *)src->data, (int *)dst->data, ne00, nrows, ne0, GGML_SORT_ORDER_DESC, -1, 0.0f, ctx.stream());
|
argsort_f32_T_cuda((const float *)src->data, (int *)dst->data, ne00, nrows, ne0, GGML_SORT_ORDER_DESC, -1, 0.0f, ctx.stream());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user