Fuse sigmoid+add+topk+get_rows (CUDA)

This commit is contained in:
Iwan Kawrakow
2025-10-19 09:13:34 +03:00
parent f3ff1a5c48
commit 8fe2bb927a
3 changed files with 99 additions and 2 deletions

View File

@@ -3173,12 +3173,29 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_relu(ctx, dst);
break;
case GGML_UNARY_OP_SIGMOID:
if (i + 4 < cgraph->n_nodes &&
if (i + 5 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
cgraph->nodes[i+2]->op == GGML_OP_ADD &&
cgraph->nodes[i+3]->op == GGML_OP_ARGSORT &&
cgraph->nodes[i+4]->op == GGML_OP_VIEW &&
cgraph->nodes[i+5]->op == GGML_OP_GET_ROWS) {
cuda_glm45moe_experts(ctx, cgraph->nodes[i+5], cgraph->nodes[i+4]);
i += 5;
}
//else if (i + 5 < cgraph->n_nodes) {
// printf("sigmoid(%s) -> %s(%s) -> %s(%s) -> %s(%s) -> %s(%s) -> %s(%s)\n", dst->name,
// ggml_op_name(cgraph->nodes[i+1]->op), cgraph->nodes[i+1]->name,
// ggml_op_name(cgraph->nodes[i+2]->op), cgraph->nodes[i+2]->name,
// ggml_op_name(cgraph->nodes[i+3]->op), cgraph->nodes[i+3]->name,
// ggml_op_name(cgraph->nodes[i+4]->op), cgraph->nodes[i+4]->name,
// ggml_op_name(cgraph->nodes[i+5]->op), cgraph->nodes[i+5]->name);
//}
else if (i + 4 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
cgraph->nodes[i+2]->op == GGML_OP_ADD &&
cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK &&
cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS) {
cuda_bailingmoev2_experts(ctx, cgraph->nodes[i+4], cgraph->nodes[i+3]);
cuda_bailingmoev2_experts(ctx, cgraph->nodes[i+4], cgraph->nodes[i+4]);
i += 4;
} else {
ggml_cuda_op_sigmoid(ctx, dst);

View File

@@ -124,6 +124,35 @@ static __global__ void k_argsort_f32_f32_i32(const float * x_biased, const float
}
}
template<ggml_sort_order order>
static __global__ void k_argsort_biased_f32_f32_i32(const float * x, const float * bias, float * weights, int * ids, const int ncols, int ncols_pad, int ntop,
size_t nb_ids) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
if (col >= ncols_pad) {
return;
}
extern __shared__ int dst_row[];
auto x_row = (float *)(dst_row + ncols_pad);
// initialize indices
dst_row[col] = col;
x_row[col] = col < ncols ? 1/(1 + expf(-x[row*ncols + col])) + bias[col] : -INFINITY;
__syncthreads();
sort<order>(ncols_pad, ncols, col, x_row, dst_row);
if (col < ntop) {
weights[row * ntop + col] = 1/(1 + expf(-x[row * ncols + dst_row[col]]));
auto row_ids = (int *)((char *)ids + row*nb_ids);
row_ids[col] = dst_row[col];
}
}
template<ggml_sort_order order>
static __global__ void k_topk_sum(const float * x, const float * bias, float * x_p, float * dst, const int ncols, int ncols_pad, int n_top_k) {
// bitonic sort
@@ -247,6 +276,29 @@ static void argsort_f32_f32_i32_cuda(const float * x_biased, const float * x, fl
}
}
static void argsort_biased_f32_f32_i32_cuda(const float * x, const float * bias, float * weights, int * ids, const int ncols, const int nrows, int ntop,
size_t nb_ids, ggml_sort_order order, cudaStream_t stream) {
// bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols);
const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(1, nrows, 1);
const size_t shared_mem = ncols_pad * (sizeof(int) + sizeof(float));
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
if (order == GGML_SORT_ORDER_ASC) {
k_argsort_biased_f32_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, bias, weights, ids,
ncols, ncols_pad, ntop, nb_ids);
} else if (order == GGML_SORT_ORDER_DESC) {
k_argsort_biased_f32_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, bias, weights, ids,
ncols, ncols_pad, ntop, nb_ids);
} else {
GGML_ABORT("fatal error");
}
}
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
@@ -395,3 +447,29 @@ void cuda_bailingmoev2_experts(ggml_backend_cuda_context & ctx, ggml_tensor * ds
ne00, nrows, ne0, topk->nb[1], GGML_SORT_ORDER_DESC, ctx.stream());
}
void cuda_glm45moe_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk_view) {
GGML_ASSERT(topk_view->op == GGML_OP_VIEW);
auto topk = topk_view->src[0];
auto topk_src = topk->src[0];
auto probs = topk_src->src[0]->src[0];
auto bias = topk_src->src[1];
auto nrows = ggml_nrows(probs);
int ne00 = probs->ne[0];
int ne0 = topk_view->ne[0];
GGML_ASSERT(ggml_is_contiguous(probs));
GGML_ASSERT(bias->ne[1] == 1);
GGML_ASSERT(bias->ne[0] == probs->ne[0]);
GGML_ASSERT(ne0 == dst->ne[1]);
GGML_ASSERT(ne0 <= ne00);
//printf("probs: %ld x %ld x %ld x %ld. topk: %ld x %ld x %ld x %ld. dst: %ld x %ld x %ld x %ld; %zu x %zu x %zu x %zu\n",
// probs->ne[0], probs->ne[1], probs->ne[2], probs->ne[3], topk->ne[0], topk->ne[1], topk->ne[2], topk->ne[3],
// dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
argsort_biased_f32_f32_i32_cuda((const float *)probs->data, (const float *)bias->data, (float *)dst->data, (int *)topk->data,
ne00, nrows, ne0, topk->nb[1], GGML_SORT_ORDER_DESC, ctx.stream());
}

View File

@@ -13,3 +13,5 @@ void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor *
void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void cuda_bailingmoev2_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk);
void cuda_glm45moe_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk);