Fuse topk+view+get_rows+reshape+softmax (CUDA)

This commit is contained in:
Iwan Kawrakow
2025-10-19 13:00:37 +03:00
parent c8ed454564
commit b79aad9d07
3 changed files with 105 additions and 5 deletions

View File

@@ -3336,7 +3336,17 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_sum_rows(ctx, dst);
break;
case GGML_OP_ARGSORT:
ggml_cuda_op_argsort(ctx, dst);
if (i + 5 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
cgraph->nodes[i+2]->op == GGML_OP_GET_ROWS &&
cgraph->nodes[i+3]->op == GGML_OP_RESHAPE &&
cgraph->nodes[i+4]->op == GGML_OP_SOFT_MAX &&
cgraph->nodes[i+5]->op == GGML_OP_RESHAPE) {
cuda_openai_experts(ctx, dst, cgraph->nodes[i+4]);
i += 5;
} else {
ggml_cuda_op_argsort(ctx, dst);
}
break;
case GGML_OP_ARGSORT_THRESH:
ggml_cuda_op_argsort_thresh(ctx, dst);

View File

@@ -153,6 +153,53 @@ static __global__ void k_argsort_biased_f32_f32_i32(const float * x, const float
}
}
template<ggml_sort_order order>
static __global__ void k_openai_f32_f32_i32(const float * x, float * weights, int * ids, const int ncols, int ncols_pad, int ntop,
size_t nb_ids) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
if (col >= ncols_pad) {
return;
}
extern __shared__ int dst_row[];
auto x_row = x + row*ncols;
// initialize indices
dst_row[col] = col;
__syncthreads();
sort<order>(ncols_pad, ncols, col, x_row, dst_row);
float max = x_row[dst_row[0]];
float val = col < ntop ? expf(x_row[dst_row[col]] - max) : 0.0f;
float sum = warp_reduce_sum(val);
if (blockDim.x > WARP_SIZE) {
__syncthreads();
float * s_sum = (float *)(dst_row + ncols_pad);
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = sum;
}
__syncthreads();
sum = 0.0f;
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
sum = s_sum[lane_id];
}
sum = warp_reduce_sum(sum);
}
float norm = 1/sum;
if (col < ntop) {
weights[row * ntop + col] = norm*val;
auto row_ids = (int *)((char *)ids + row*nb_ids);
row_ids[col] = dst_row[col];
}
}
template<ggml_sort_order order>
static __global__ void k_topk_sum(const float * x, const float * bias, float * x_p, float * dst, const int ncols, int ncols_pad, int n_top_k) {
// bitonic sort
@@ -299,6 +346,29 @@ static void argsort_biased_f32_f32_i32_cuda(const float * x, const float * bias,
}
}
static void argsort_openai_f32_f32_i32_cuda(const float * x, float * weights, int * ids, const int ncols, const int nrows, int ntop,
size_t nb_ids, ggml_sort_order order, cudaStream_t stream) {
// bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols);
const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(1, nrows, 1);
const size_t shared_mem = (ncols_pad + ncols_pad > WARP_SIZE ? WARP_SIZE : 0) * sizeof(int);
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
if (order == GGML_SORT_ORDER_ASC) {
k_openai_f32_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, weights, ids,
ncols, ncols_pad, ntop, nb_ids);
} else if (order == GGML_SORT_ORDER_DESC) {
k_openai_f32_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, weights, ids,
ncols, ncols_pad, ntop, nb_ids);
} else {
GGML_ABORT("fatal error");
}
}
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
@@ -465,11 +535,29 @@ void cuda_glm45moe_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, g
GGML_ASSERT(ne0 == dst->ne[1]);
GGML_ASSERT(ne0 <= ne00);
//printf("probs: %ld x %ld x %ld x %ld. topk: %ld x %ld x %ld x %ld. dst: %ld x %ld x %ld x %ld; %zu x %zu x %zu x %zu\n",
// probs->ne[0], probs->ne[1], probs->ne[2], probs->ne[3], topk->ne[0], topk->ne[1], topk->ne[2], topk->ne[3],
// dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
argsort_biased_f32_f32_i32_cuda((const float *)probs->data, (const float *)bias->data, (float *)dst->data, (int *)topk->data,
ne00, nrows, ne0, topk->nb[1], GGML_SORT_ORDER_DESC, ctx.stream());
}
void cuda_openai_experts(ggml_backend_cuda_context & ctx, ggml_tensor * topk, ggml_tensor * softmax) {
auto probs = topk->src[0];
int ntop = topk->op_params[1];
auto nrows = ggml_nrows(probs);
int ne00 = probs->ne[0];
int ne0 = softmax->ne[0];
GGML_ASSERT(ggml_is_contiguous(probs));
GGML_ASSERT(ggml_is_contiguous(softmax));
GGML_ASSERT(ne0 <= ne00);
if (ntop != ne0) {
printf("Oops: ntop = %d, ne0 = %d\n", ntop, ne0);
GGML_ASSERT(false);
}
//GGML_ASSERT(ne0 == ntop);
argsort_openai_f32_f32_i32_cuda((const float *)probs->data, (float *)softmax->data, (int *)topk->data,
ne00, nrows, ne0, topk->nb[1], GGML_SORT_ORDER_DESC, ctx.stream());
}

View File

@@ -15,3 +15,5 @@ void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * ds
void cuda_bailingmoev2_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk);
void cuda_glm45moe_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk);
void cuda_openai_experts(ggml_backend_cuda_context & ctx, ggml_tensor * topk, ggml_tensor * softmax);