Fuse add+add+fused_rms (#853)

* Fuse add+add+fused_rms

* Try this

* Macro to easily enable/disable fusion

* Various:

* Check that all tensors involved are on the same device before applying fusion
* Fuse sigmoid+scale+sum_rows+div
* Fix the fused bailingmoe2 experts selection

The issue there was that the bias was not per row, but per
expert group, so only the first n_per_group biases were used
for al experts.

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-10-22 16:18:11 +03:00
committed by GitHub
parent af5bf60cc8
commit ed4e1a6588
8 changed files with 281 additions and 54 deletions

View File

@@ -3094,12 +3094,28 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
}
static inline bool ops_are_same_device(const ggml_cgraph * cgraph, int first, int last) {
if (last <= first) return true;
int device = ((const ggml_backend_cuda_buffer_context *)cgraph->nodes[first]->buffer->context)->device;
for (int i = first; i <= last; ++i) {
auto node = cgraph->nodes[i];
if (((const ggml_backend_cuda_buffer_context *)node->buffer->context)->device != device) return false;
for (int j = 0; j < GGML_MAX_SRC; ++j) {
if (!node->src[j] || !node->src[j]->buffer) continue;
if (((const ggml_backend_cuda_buffer_context *)node->src[j]->buffer->context)->device != device) return false;
}
}
return true;
}
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, const ggml_cgraph * cgraph, int & i) {
// why is this here instead of mul_mat?
if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) {
ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
}
#define ENABLE_FUSION true
#if IK_PRINT_TIMING
int64_t tim1 = ggml_time_us();
#endif
@@ -3129,17 +3145,32 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_dup(ctx, dst);
break;
case GGML_OP_ADD:
if (i + 1 < cgraph->n_nodes &&
if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_ADD &&
cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM &&
ggml_is_contiguous(dst->src[0]) &&
ggml_is_contiguous(dst->src[1]) &&
ggml_are_same_shape(dst->src[0], dst->src[1]) &&
dst == cgraph->nodes[i+1]->src[0] &&
ggml_is_contiguous(cgraph->nodes[i+1]->src[1]) &&
ggml_are_same_shape(dst, cgraph->nodes[i+1]->src[1]) &&
cgraph->nodes[i+1] == cgraph->nodes[i+2]->src[0] &&
ops_are_same_device(cgraph, i, i+2)) {
//printf("Fusing add->add->fused_rms of %s, %s, %s\n", dst->name, cgraph->nodes[i+1]->name, cgraph->nodes[i+2]->name);
ggml_cuda_op_fused_add_add_rms_norm(ctx, dst, cgraph->nodes[i+1], cgraph->nodes[i+2]);
i += 2;
}
else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_FUSED_RMS_NORM &&
ggml_is_contiguous(dst->src[0]) &&
ggml_is_contiguous(dst->src[1]) &&
ggml_are_same_shape(dst->src[0], dst->src[1])) {
ggml_are_same_shape(dst->src[0], dst->src[1]) &&
dst == cgraph->nodes[i+1]->src[0] && ops_are_same_device(cgraph, i, i+1)) {
ggml_cuda_op_fused_add_rms_norm(ctx, dst, cgraph->nodes[i+1]);
++i;
} else {
ggml_cuda_op_add(ctx, dst);
}
//ggml_cuda_op_add(ctx, dst);
break;
case GGML_OP_ADD_ID:
ggml_cuda_op_add_id(ctx, dst);
@@ -3183,22 +3214,27 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_relu(ctx, dst);
break;
case GGML_UNARY_OP_SIGMOID:
if (i + 5 < cgraph->n_nodes &&
if (ENABLE_FUSION && i + 5 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
cgraph->nodes[i+2]->op == GGML_OP_ADD &&
cgraph->nodes[i+3]->op == GGML_OP_ARGSORT &&
cgraph->nodes[i+4]->op == GGML_OP_VIEW &&
cgraph->nodes[i+5]->op == GGML_OP_GET_ROWS) {
cgraph->nodes[i+5]->op == GGML_OP_GET_ROWS && ops_are_same_device(cgraph, i, i+5)) {
cuda_glm45moe_experts(ctx, cgraph->nodes[i+5], cgraph->nodes[i+4]);
i += 5;
}
else if (i + 4 < cgraph->n_nodes &&
else if (ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
cgraph->nodes[i+2]->op == GGML_OP_ADD &&
cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK &&
cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS) {
cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS && ops_are_same_device(cgraph, i, i+4)) {
cuda_bailingmoev2_experts(ctx, cgraph->nodes[i+4], cgraph->nodes[i+3]);
i += 4;
} else if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
cgraph->nodes[i+2]->op == GGML_OP_ADD && ops_are_same_device(cgraph, i, i+2)) {
ggml_cuda_op_biased_sigmoid(ctx, cgraph->nodes[i+2]);
i += 2;
} else {
ggml_cuda_op_sigmoid(ctx, dst);
}
@@ -3309,12 +3345,13 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_diag_mask_inf(ctx, dst);
break;
case GGML_OP_SOFT_MAX:
if (i + 4 < cgraph->n_nodes &&
if (ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_RESHAPE &&
cgraph->nodes[i+2]->op == GGML_OP_ARGSORT &&
cgraph->nodes[i+3]->op == GGML_OP_VIEW &&
cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS &&
ggml_cuda_should_use_topk_moe(cgraph->nodes[i], cgraph->nodes[i+4])) {
ggml_cuda_should_use_topk_moe(cgraph->nodes[i], cgraph->nodes[i+4]) &&
ops_are_same_device(cgraph, i, i+4)) {
ggml_cuda_op_topk_moe(ctx, cgraph->nodes[i], cgraph->nodes[i+4], cgraph->nodes[i+3]);
i += 4;
} else {
@@ -3343,10 +3380,19 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_pool2d(ctx, dst);
break;
case GGML_OP_SUM_ROWS:
if (i + 1 < cgraph->n_nodes &&
if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_SCALE &&
cgraph->nodes[i+2]->op == GGML_OP_DIV &&
cgraph->nodes[i+1]->src[0] == dst &&
cgraph->nodes[i+2]->src[1] == cgraph->nodes[i+1] &&
cgraph->nodes[i+2]->src[0] == dst->src[0] && ops_are_same_device(cgraph, i, i+2)) {
ggml_cuda_op_sum_rows_div(ctx, cgraph->nodes[i+2]);
i += 2;
}
else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_DIV &&
cgraph->nodes[i+1]->src[1] == dst &&
cgraph->nodes[i+1]->src[0] == dst->src[0]) {
cgraph->nodes[i+1]->src[0] == dst->src[0] && ops_are_same_device(cgraph, i, i+1)) {
ggml_cuda_op_sum_rows_div(ctx, cgraph->nodes[i+1]);
++i;
} else {
@@ -3354,12 +3400,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
}
break;
case GGML_OP_ARGSORT:
if (i + 5 < cgraph->n_nodes &&
if (ENABLE_FUSION && i + 5 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
cgraph->nodes[i+2]->op == GGML_OP_GET_ROWS &&
cgraph->nodes[i+3]->op == GGML_OP_RESHAPE &&
cgraph->nodes[i+4]->op == GGML_OP_SOFT_MAX &&
cgraph->nodes[i+5]->op == GGML_OP_RESHAPE) {
cgraph->nodes[i+5]->op == GGML_OP_RESHAPE && ops_are_same_device(cgraph, i, i+4)) {
cuda_openai_experts(ctx, dst, cgraph->nodes[i+4]);
i += 5;
} else {
@@ -3390,6 +3436,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
printf("%s(%s): %d us\n", ggml_op_name(dst->op), dst->name, (int)(tim2 - tim1));
#endif
#undef ENABLE_FUSION
return true;
}

View File

@@ -59,7 +59,7 @@ static __global__ void k_argsort_f32_T(const float * x, dst_t * dst, const int n
// int min_experts, float thresh_experts) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
int row = blockIdx.x;
if (col >= ncols_pad) {
return;
@@ -97,17 +97,17 @@ static __global__ void k_argsort_f32_T(const float * x, dst_t * dst, const int n
}
template<ggml_sort_order order>
static __global__ void k_argsort_f32_f32_i32(const float * x_biased, const float * x, float * weights, int * ids, const int ncols, int ncols_pad, int ntop,
size_t nb_ids) {
static __global__ void k_argsort_f32_u8(const float * x, uint8_t * dst, const int ncols, int ncols_pad, int ntop) {
// int min_experts, float thresh_experts) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
int row = blockIdx.x;
if (col >= ncols_pad) {
return;
}
const float * x_row = x_biased + row * ncols;
const float * x_row = x + row * ncols;
extern __shared__ int dst_row[];
// initialize indices
@@ -117,6 +117,32 @@ static __global__ void k_argsort_f32_f32_i32(const float * x_biased, const float
sort<order>(ncols_pad, ncols, col, x_row, dst_row);
if (col < ncols) dst[row*ncols + dst_row[col]] = col < ntop ? 1 : 0;
}
template<ggml_sort_order order>
static __global__ void k_argsort_f32_f32_i32(const float * x_biased, const float * x, const uint8_t * group_mask,
float * weights, int * ids, const int ncols, int ncols_pad, int ntop, size_t nb_ids, int n_per_group, int n_groups) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.x;
if (col >= ncols_pad) {
return;
}
extern __shared__ int dst_row[];
auto x_row = (float *)(dst_row + ncols_pad);
// initialize indices
dst_row[col] = col;
int ig = col / n_per_group;
x_row[col] = ig < n_groups && group_mask[row*n_groups + ig] ? x_biased[row * ncols + col] : -INFINITY;
__syncthreads();
sort<order>(ncols_pad, ncols, col, x_row, dst_row);
if (col < ntop) {
weights[row * ntop + col] = 1/(1 + expf(-x[row * ncols + dst_row[col]]));
auto row_ids = (int *)((char *)ids + row*nb_ids);
@@ -129,7 +155,7 @@ static __global__ void k_argsort_biased_f32_f32_i32(const float * x, const float
size_t nb_ids) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
int row = blockIdx.x;
if (col >= ncols_pad) {
return;
@@ -158,7 +184,7 @@ static __global__ void k_openai_f32_f32_i32(const float * x, float * weights, in
size_t nb_ids) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
int row = blockIdx.x;
if (col >= ncols_pad) {
return;
@@ -201,10 +227,11 @@ static __global__ void k_openai_f32_f32_i32(const float * x, float * weights, in
}
template<ggml_sort_order order>
static __global__ void k_topk_sum(const float * x, const float * bias, float * x_p, float * dst, const int ncols, int ncols_pad, int n_top_k) {
static __global__ void k_topk_sum(const float * x, const float * bias, float * x_p, float * dst,
const int ne00, const int ncols, int ncols_pad, int n_top_k) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
int row = blockIdx.x;
if (col >= ncols_pad) {
return;
@@ -218,7 +245,7 @@ static __global__ void k_topk_sum(const float * x, const float * bias, float * x
if (bias && x_p) {
float * x_p_row = x_p + row * ncols;
if (col < ncols) {
x_p_row[col] = 1/(1 + expf(-x_row[col])) + bias[col];
x_p_row[col] = 1/(1 + expf(-x_row[col])) + bias[(row * ncols + col)%ne00];
}
x_row = x_p_row;
}
@@ -227,6 +254,10 @@ static __global__ void k_topk_sum(const float * x, const float * bias, float * x
sort<order>(ncols_pad, ncols, col, x_row, dst_row);
if (n_top_k == 2) {
float val = x_row[dst_row[0]] + x_row[dst_row[1]];
if (col == 0) dst[row] = val;
} else {
float val = col < n_top_k ? x_row[dst_row[col]] : 0;
val = warp_reduce_sum(val);
if (blockDim.x > WARP_SIZE) {
@@ -248,6 +279,7 @@ static __global__ void k_topk_sum(const float * x, const float * bias, float * x
if (col == 0) {
dst[row] = val;
}
}
}
static __global__ void k_apply_mask(float * dst, const int * groups,
@@ -275,7 +307,7 @@ static void argsort_f32_T_cuda(const float * x, dst_t * dst, const int ncols, co
const int ncols_pad = next_power_of_2(ncols);
const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(1, nrows, 1);
const dim3 block_nums(nrows, 1, 1);
const size_t shared_mem = ncols_pad * sizeof(int);
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
@@ -300,24 +332,46 @@ static void argsort_f32_T_cuda(const float * x, dst_t * dst, const int ncols, co
}
}
static void argsort_f32_f32_i32_cuda(const float * x_biased, const float * x, float * weights, int * ids, const int ncols, const int nrows, int ntop,
size_t nb_ids, ggml_sort_order order, cudaStream_t stream) {
static void argsort_f32_u8_cuda(const float * x, uint8_t * dst, const int ncols, const int nrows, int ntop,
ggml_sort_order order, cudaStream_t stream) {
// bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols);
const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(1, nrows, 1);
const dim3 block_nums(nrows, 1, 1);
const size_t shared_mem = ncols_pad * sizeof(int);
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
if (order == GGML_SORT_ORDER_ASC) {
k_argsort_f32_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x_biased, x, weights, ids,
ncols, ncols_pad, ntop, nb_ids);
k_argsort_f32_u8<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, ntop);
} else if (order == GGML_SORT_ORDER_DESC) {
k_argsort_f32_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x_biased, x, weights, ids,
ncols, ncols_pad, ntop, nb_ids);
k_argsort_f32_u8<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, ntop);
} else {
GGML_ABORT("fatal error");
}
}
static void argsort_f32_f32_i32_cuda(const float * x_biased, const float * x, const uint8_t * group_mask,
float * weights, int * ids, const int ncols, const int nrows, int ntop,
size_t nb_ids, int n_per_group, ggml_sort_order order, cudaStream_t stream) {
// bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols);
const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(nrows, 1, 1);
const size_t shared_mem = ncols_pad * (sizeof(int) + sizeof(float));
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
if (order == GGML_SORT_ORDER_ASC) {
k_argsort_f32_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x_biased, x, group_mask, weights, ids,
ncols, ncols_pad, ntop, nb_ids, n_per_group, ncols/n_per_group);
} else if (order == GGML_SORT_ORDER_DESC) {
k_argsort_f32_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x_biased, x, group_mask, weights, ids,
ncols, ncols_pad, ntop, nb_ids, n_per_group, ncols/n_per_group);
} else {
GGML_ABORT("fatal error");
}
@@ -329,7 +383,7 @@ static void argsort_biased_f32_f32_i32_cuda(const float * x, const float * bias,
const int ncols_pad = next_power_of_2(ncols);
const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(1, nrows, 1);
const dim3 block_nums(nrows, 1, 1);
const size_t shared_mem = ncols_pad * (sizeof(int) + sizeof(float));
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
@@ -352,7 +406,7 @@ static void argsort_openai_f32_f32_i32_cuda(const float * x, float * weights, in
const int ncols_pad = next_power_of_2(ncols);
const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(1, nrows, 1);
const dim3 block_nums(nrows, 1, 1);
const size_t shared_mem = (ncols_pad + ncols_pad > WARP_SIZE ? WARP_SIZE : 0) * sizeof(int);
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
@@ -408,18 +462,18 @@ void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor *
}
static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float * src, const float * bias, float * src_p, float * dst,
int ncols, int nrows, int n_top_k) {
int ne00, int ncols, int nrows, int n_top_k) {
GGML_ASSERT(n_top_k <= ncols);
const int ncols_pad = next_power_of_2(ncols);
const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(1, nrows, 1);
const dim3 block_nums(nrows, 1, 1);
const size_t shared_mem = (ncols_pad + WARP_SIZE) * sizeof(int);
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
k_topk_sum<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, ctx.stream()>>>(src, bias, src_p, dst, ncols, ncols_pad, n_top_k);
k_topk_sum<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, ctx.stream()>>>(src, bias, src_p, dst, ne00, ncols, ncols_pad, n_top_k);
}
void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -453,7 +507,7 @@ void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * ds
CUDA_CHECK(cudaGetLastError());
#else
ggml_cuda_pool_alloc<float> group_scores(ctx.pool(), nrows*n_groups);
ggml_cuda_op_topk_sum(ctx, (float *)src->data, nullptr, nullptr, group_scores.get(), n_per_group, nrows*n_groups, nk);
ggml_cuda_op_topk_sum(ctx, (float *)src->data, nullptr, nullptr, group_scores.get(), ne00, n_per_group, nrows*n_groups, nk);
CUDA_CHECK(cudaGetLastError());
#endif
@@ -495,26 +549,19 @@ void cuda_bailingmoev2_experts(ggml_backend_cuda_context & ctx, ggml_tensor * ds
int n_per_group = ne00/n_groups;
GGML_ASSERT(nk <= n_per_group);
GGML_ASSERT(n_top_groups <= n_groups);
int n_discarded_groups = n_groups - n_top_groups;
ggml_cuda_pool_alloc<float> group_scores(ctx.pool(), nrows*n_groups);
ggml_cuda_op_topk_sum(ctx, (const float *)probs->data, (const float *)bias->data, (float *)topk_src->data, group_scores.get(),
n_per_group, nrows*n_groups, nk);
ne00, n_per_group, nrows*n_groups, nk);
CUDA_CHECK(cudaGetLastError());
ggml_cuda_pool_alloc<int> discarded_groups(ctx.pool(), nrows*n_discarded_groups);
argsort_f32_T_cuda(group_scores.get(), discarded_groups.get(), n_groups, nrows, n_discarded_groups, GGML_SORT_ORDER_ASC, -1, 0.0f, ctx.stream());
ggml_cuda_pool_alloc<uint8_t> group_mask(ctx.pool(), nrows*n_groups);
argsort_f32_u8_cuda(group_scores.get(), group_mask.get(), n_groups, nrows, n_top_groups, GGML_SORT_ORDER_DESC, ctx.stream());
CUDA_CHECK(cudaGetLastError());
{
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
k_apply_mask<<<block_nums, block_dims, 0, ctx.stream()>>>((float *)topk_src->data, discarded_groups.get(), n_discarded_groups, n_per_group, ne00);
CUDA_CHECK(cudaGetLastError());
}
argsort_f32_f32_i32_cuda((const float *)topk_src->data, (const float *)probs->data, (float *)dst->data, (int *)topk->data,
ne00, nrows, ne0, topk->nb[1], GGML_SORT_ORDER_DESC, ctx.stream());
argsort_f32_f32_i32_cuda((const float *)topk_src->data, (const float *)probs->data, group_mask.get(),
(float *)dst->data, (int *)topk->data,
ne00, nrows, ne0, topk->nb[1], n_per_group, GGML_SORT_ORDER_DESC, ctx.stream());
}

View File

@@ -492,6 +492,41 @@ static __global__ void fused_add_rms_norm_f32(const float * a, const float * b,
}
}
template <int block_size>
static __global__ void fused_add_add_rms_norm_f32(const float * a1, const float * a2, const float * b, const float * c,
float * dst_add, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = a1[row*ncols + col] + a2[row*ncols + col] + b[row*ncols + col];
tmp += xi * xi;
dst_add[row*ncols + col] = xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f;
tmp = warp_reduce_sum(tmp);
}
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = scale * c[col] * dst_add[row*ncols + col];
}
}
static void fused_add_rms_norm_f32_cuda(const float * a, const float * b, const float * c, float * dst_add, float * dst,
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
@@ -538,3 +573,49 @@ void ggml_cuda_op_fused_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tenso
src1_d, (float *)add->data, dst_d, ne00, nrows, eps, stream);
}
static void fused_add_add_rms_norm_f32_cuda(const float * a1, const float * a2, const float * b, const float * c, float * dst_add, float * dst,
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
fused_add_add_rms_norm_f32<256><<<nrows, block_dims, 0, stream>>>(a1, a2, b, c, dst_add, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
fused_add_add_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(a1, a2, b, c, dst_add, dst, ncols, eps);
}
}
void ggml_cuda_op_fused_add_add_rms_norm(ggml_backend_cuda_context & ctx,
ggml_tensor * add1, ggml_tensor * add2, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
//const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(add1->data == add2->src[0]->data);
GGML_ASSERT(add2->data == src0->data);
GGML_ASSERT(ggml_is_contiguous(src0));
//GGML_ASSERT(ggml_is_contiguous(add->src[0]));
//GGML_ASSERT(ggml_is_contiguous(add->src[1]));
//GGML_ASSERT(ggml_are_same_shape(add->src[0], add->src[1]));
//GGML_ASSERT(ggml_are_same_shape(add->src[0], src0));
//GGML_ASSERT(add->src[0]->type == GGML_TYPE_F32);
//GGML_ASSERT(add->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
GGML_ASSERT(ggml_nrows(src1) == 1);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
fused_add_add_rms_norm_f32_cuda((const float *)add1->src[0]->data, (const float *)add1->src[1]->data, (const float *)add2->src[1]->data,
src1_d, (float *)add2->data, dst_d, ne00, nrows, eps, stream);
}

View File

@@ -9,3 +9,5 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_fused_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add, ggml_tensor * dst);
void ggml_cuda_op_fused_add_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add1, ggml_tensor * add2, ggml_tensor * dst);

View File

@@ -16,7 +16,7 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc
}
}
static __global__ void k_sum_rows_div_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) {
static __global__ void k_sum_rows_div_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, float s, float b) {
const int row = blockIdx.x;
const int col = threadIdx.x;
@@ -27,6 +27,8 @@ static __global__ void k_sum_rows_div_f32(const float * __restrict__ x, float *
sum = warp_reduce_sum(sum);
//sum = s*sum + b;
float norm = sum > 0 ? 1/sum : 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
dst[row * ncols + i] = x[row * ncols + i] * norm;
@@ -42,10 +44,10 @@ void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
}
static void sum_rows_div_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void sum_rows_div_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, float s, float b, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
k_sum_rows_div_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
k_sum_rows_div_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, s, b);
}
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -66,7 +68,16 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
}
void ggml_cuda_op_sum_rows_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float s = 1, b = 0;
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(dst->src[1]->op == GGML_OP_SUM_ROWS || dst->src[1]->op == GGML_OP_SCALE);
if (dst->src[1]->op == GGML_OP_SCALE) {
GGML_ASSERT(dst->src[1]->src[0]->op == GGML_OP_SUM_ROWS);
auto params = (const float *)dst->src[1]->op_params;
s = params[0];
b = params[1];
}
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
@@ -78,5 +89,5 @@ void ggml_cuda_op_sum_rows_div(ggml_backend_cuda_context & ctx, ggml_tensor * ds
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
sum_rows_div_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
sum_rows_div_f32_cuda(src0_d, dst_d, ncols, nrows, s, b, stream);
}

View File

@@ -125,6 +125,16 @@ static __global__ void sigmoid_f32(const float * x, float * dst, const int k) {
dst[i] = 1.0f / (1.0f + expf(-x[i]));
}
static __global__ void biased_sigmoid_f32(const float * x, const float * bias, float * dst, float * dst_biased, const int k, const int ncols) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = 1.0f / (1.0f + expf(-x[i]));
dst_biased[i] = dst[i] + bias[i % ncols];
}
static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -221,6 +231,11 @@ static void sigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStre
sigmoid_f32<<<num_blocks, CUDA_SIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void biased_sigmoid_f32_cuda(const float * x, const float * bias, float * dst, float * dst_biased, const int k, const int ncols, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE;
biased_sigmoid_f32<<<num_blocks, CUDA_SIGMOID_BLOCK_SIZE, 0, stream>>>(x, bias, dst, dst_biased, k, ncols);
}
static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -402,6 +417,26 @@ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
sigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_biased_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->op == GGML_OP_ADD);
GGML_ASSERT(dst->src[0]->op == GGML_OP_UNARY);
const ggml_tensor * src0 = dst->src[0]->src[0];
const ggml_tensor * bias = dst->src[1];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(bias->type == GGML_TYPE_F32);
GGML_ASSERT(bias->ne[0] == src0->ne[0]);
GGML_ASSERT(ggml_nrows(bias) == 1);
biased_sigmoid_f32_cuda(src0_d, (const float *)bias->data, (float *)dst->src[0]->data, dst_d, ggml_nelements(src0), src0->ne[0], stream);
}
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;

View File

@@ -47,6 +47,8 @@ void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_biased_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -7793,6 +7793,7 @@ ggml_cgraph * llm_build_context::build_openai_moe() {
cur = ffn_inp;
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il);
ggml_build_forward_expand(gf, cur);
cb(cur, "attn_post_norm", il);
bool use_dup_bias = cur->ne[1] < 32 && model.layers[il].ffn_up_exps_b_dup &&