mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
This is very slightly better
This commit is contained in:
@@ -91,8 +91,6 @@ static __global__ void k_argsort_f32_T(const float * x, dst_t * dst, const int n
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#if 0
|
|
||||||
// Somehow this is not working. Someone sees the bug?
|
|
||||||
template<ggml_sort_order order>
|
template<ggml_sort_order order>
|
||||||
static __global__ void k_topk_sum(const float * x, float * dst, const int ncols, int ncols_pad, int n_top_k) {
|
static __global__ void k_topk_sum(const float * x, float * dst, const int ncols, int ncols_pad, int n_top_k) {
|
||||||
// bitonic sort
|
// bitonic sort
|
||||||
@@ -137,11 +135,11 @@ static __global__ void k_topk_sum(const float * x, float * dst, const int ncols,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float val = col < n_top_k ? x[dst_row[col]] : 0;
|
float val = col < n_top_k ? x_row[dst_row[col]] : 0;
|
||||||
val = warp_reduce_sum(val);
|
val = warp_reduce_sum(val);
|
||||||
if (blockDim.x > WARP_SIZE) {
|
if (blockDim.x > WARP_SIZE) {
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
auto s_sum = dst_row;
|
float * s_sum = (float *)dst_row;
|
||||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
if (lane_id == 0) {
|
if (lane_id == 0) {
|
||||||
@@ -159,7 +157,6 @@ static __global__ void k_topk_sum(const float * x, float * dst, const int ncols,
|
|||||||
dst[row] = val;
|
dst[row] = val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
static __global__ void k_apply_mask(float * dst, const int * groups,
|
static __global__ void k_apply_mask(float * dst, const int * groups,
|
||||||
const int n_top_groups, const int n_per_group, const int ncols) {
|
const int n_top_groups, const int n_per_group, const int ncols) {
|
||||||
@@ -249,7 +246,6 @@ void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||||||
argsort_f32_T_cuda(src0_d, (int *)dst_d, ncols, nrows, ncols, GGML_SORT_ORDER_DESC, min_experts, thresh, stream);
|
argsort_f32_T_cuda(src0_d, (int *)dst_d, ncols, nrows, ncols, GGML_SORT_ORDER_DESC, min_experts, thresh, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if 0
|
|
||||||
static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float * src, float * dst, int ncols, int nrows, int n_top_k) {
|
static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float * src, float * dst, int ncols, int nrows, int n_top_k) {
|
||||||
|
|
||||||
GGML_ASSERT(n_top_k <= ncols);
|
GGML_ASSERT(n_top_k <= ncols);
|
||||||
@@ -263,7 +259,6 @@ static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float *
|
|||||||
|
|
||||||
k_topk_sum<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, ctx.stream()>>>(src, dst, ncols, ncols_pad, n_top_k);
|
k_topk_sum<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, ctx.stream()>>>(src, dst, ncols, ncols_pad, n_top_k);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
auto src = dst->src[0];
|
auto src = dst->src[0];
|
||||||
@@ -286,6 +281,7 @@ void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * ds
|
|||||||
GGML_ASSERT(n_top_groups < n_groups);
|
GGML_ASSERT(n_top_groups < n_groups);
|
||||||
int n_discarded_groups = n_groups - n_top_groups;
|
int n_discarded_groups = n_groups - n_top_groups;
|
||||||
|
|
||||||
|
#if 0
|
||||||
ggml_cuda_pool_alloc<float> sorted_group_scores(ctx.pool(), nk*nrows*n_groups);
|
ggml_cuda_pool_alloc<float> sorted_group_scores(ctx.pool(), nk*nrows*n_groups);
|
||||||
argsort_f32_T_cuda((const float *)src->data, sorted_group_scores.get(), n_per_group, nrows*n_groups, nk,
|
argsort_f32_T_cuda((const float *)src->data, sorted_group_scores.get(), n_per_group, nrows*n_groups, nk,
|
||||||
GGML_SORT_ORDER_DESC, -1, 0.0f, ctx.stream());
|
GGML_SORT_ORDER_DESC, -1, 0.0f, ctx.stream());
|
||||||
@@ -293,12 +289,11 @@ void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * ds
|
|||||||
ggml_cuda_pool_alloc<float> group_scores(ctx.pool(), nrows*n_groups);
|
ggml_cuda_pool_alloc<float> group_scores(ctx.pool(), nrows*n_groups);
|
||||||
sum_rows_f32_cuda((const float *)sorted_group_scores.get(), group_scores.get(), nk, nrows*n_groups, ctx.stream());
|
sum_rows_f32_cuda((const float *)sorted_group_scores.get(), group_scores.get(), nk, nrows*n_groups, ctx.stream());
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
#else
|
||||||
// This is not working for some reason, so we resort to the slightly less efficient implementation above
|
ggml_cuda_pool_alloc<float> group_scores(ctx.pool(), nrows*n_groups);
|
||||||
//ggml_cuda_pool_alloc<float> group_scores(ctx.pool(), nrows*n_groups);
|
ggml_cuda_op_topk_sum(ctx, (const float *)src->data, group_scores.get(), n_per_group, nrows*n_groups, nk);
|
||||||
//ggml_cuda_op_topk_sum(ctx, (const float *)src->data, group_scores.get(), n_per_group, nrows*n_groups, nk);
|
CUDA_CHECK(cudaGetLastError());
|
||||||
////sum_rows_f32_cuda((const float *)src->data, group_scores.get(), n_per_group, nrows*n_groups, ctx.stream());
|
#endif
|
||||||
//CUDA_CHECK(cudaGetLastError());
|
|
||||||
|
|
||||||
ggml_cuda_pool_alloc<int> discarded_groups(ctx.pool(), nrows*n_discarded_groups);
|
ggml_cuda_pool_alloc<int> discarded_groups(ctx.pool(), nrows*n_discarded_groups);
|
||||||
argsort_f32_T_cuda(group_scores.get(), discarded_groups.get(), n_groups, nrows, n_discarded_groups, GGML_SORT_ORDER_ASC, -1, 0.0f, ctx.stream());
|
argsort_f32_T_cuda(group_scores.get(), discarded_groups.get(), n_groups, nrows, n_discarded_groups, GGML_SORT_ORDER_ASC, -1, 0.0f, ctx.stream());
|
||||||
|
|||||||
Reference in New Issue
Block a user