Also fuse sum_rows and div

This commit is contained in:
Iwan Kawrakow
2025-10-19 18:04:13 +03:00
parent 0fb9d4963f
commit 1d70b89d35
6 changed files with 86 additions and 2 deletions

View File

@@ -3333,7 +3333,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_pool2d(ctx, dst);
break;
case GGML_OP_SUM_ROWS:
ggml_cuda_op_sum_rows(ctx, dst);
if (i + 1 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_DIV &&
cgraph->nodes[i+1]->src[1] == dst &&
cgraph->nodes[i+1]->src[0] == dst->src[0]) {
ggml_cuda_op_sum_rows_div(ctx, cgraph->nodes[i+1]);
++i;
} else {
ggml_cuda_op_sum_rows(ctx, dst);
}
break;
case GGML_OP_ARGSORT:
if (i + 5 < cgraph->n_nodes &&

View File

@@ -16,12 +16,38 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc
}
}
static __global__ void k_sum_rows_div_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}
sum = warp_reduce_sum(sum);
float norm = sum > 0 ? 1/sum : 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
dst[row * ncols + i] = x[row * ncols + i] * norm;
}
//for (int i = col; i < ncols; i += blockDim.x) {
// dst[row * ncols + i] = x[row * ncols + i] / sum;
//}
}
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
}
static void sum_rows_div_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
k_sum_rows_div_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
}
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
@@ -38,3 +64,19 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
}
void ggml_cuda_op_sum_rows_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
sum_rows_div_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
}

View File

@@ -3,3 +3,5 @@
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
void ggml_cuda_op_sum_rows_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -22357,7 +22357,15 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
} break;
case GGML_OP_SUM_ROWS:
{
ggml_compute_forward_sum_rows(params, tensor);
if (i + 1 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_DIV &&
cgraph->nodes[i+1]->src[1] == tensor &&
cgraph->nodes[i+1]->src[0] == tensor->src[0]) {
iqk_sumrows_div(cgraph->nodes[i+1], params->ith, params->nth);
++i;
} else {
ggml_compute_forward_sum_rows(params, tensor);
}
} break;
case GGML_OP_MEAN:
{

View File

@@ -124,6 +124,28 @@ inline void biased_sigmoid(int n, const float * x, const float * bias, float * y
}
}
void iqk_sumrows_div(struct ggml_tensor * div, int ith, int nth) {
auto src = div->src[0];
GGML_ASSERT(src->type == GGML_TYPE_F32);
GGML_ASSERT(div->type == GGML_TYPE_F32);
int ne00 = src->ne[0];
int nrows = ggml_nrows(src);
int npt = (nrows + nth - 1)/nth;
int first = ith*npt;
int last = std::min(first + npt, nrows);
if (last < first) return;
for (int ir = first; ir < last; ++ir) {
auto values = (const float *)((const char *)src->data + ir*src->nb[1]);
float sum = 0;
for (int j = 0; j < ne00; ++j) sum += values[j];
float norm = sum > 0 ? 1/sum : 0.0f;
auto result = (float *)((char *)div->data + ir*div->nb[1]);
for (int j = 0; j < ne00; ++j) result[j] = values[j]*norm;
}
}
void iqk_grouped_top_k(ggml_tensor * dst, int ith, int nth) {
auto src = dst->src[0];
GGML_ASSERT(dst->type == GGML_TYPE_I32);

View File

@@ -14,6 +14,8 @@ extern "C" {
struct ggml_tensor;
void iqk_sumrows_div(struct ggml_tensor * div, int ith, int nth);
void iqk_grouped_top_k(struct ggml_tensor * dst, int ith, int nth);
void iqk_argsort(struct ggml_tensor * dst, int ith, int nth);