diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 563433ec..29f9d26c 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3094,12 +3094,28 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor } +static inline bool ops_are_same_device(const ggml_cgraph * cgraph, int first, int last) { + if (last <= first) return true; + int device = ((const ggml_backend_cuda_buffer_context *)cgraph->nodes[first]->buffer->context)->device; + for (int i = first; i <= last; ++i) { + auto node = cgraph->nodes[i]; + if (((const ggml_backend_cuda_buffer_context *)node->buffer->context)->device != device) return false; + for (int j = 0; j < GGML_MAX_SRC; ++j) { + if (!node->src[j] || !node->src[j]->buffer) continue; + if (((const ggml_backend_cuda_buffer_context *)node->src[j]->buffer->context)->device != device) return false; + } + } + return true; +} + static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, const ggml_cgraph * cgraph, int & i) { // why is this here instead of mul_mat? if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) { ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device); } +#define ENABLE_FUSION true + #if IK_PRINT_TIMING int64_t tim1 = ggml_time_us(); #endif @@ -3129,17 +3145,32 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_dup(ctx, dst); break; case GGML_OP_ADD: - if (i + 1 < cgraph->n_nodes && + if (ENABLE_FUSION && i + 2 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_ADD && + cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM && + ggml_is_contiguous(dst->src[0]) && + ggml_is_contiguous(dst->src[1]) && + ggml_are_same_shape(dst->src[0], dst->src[1]) && + dst == cgraph->nodes[i+1]->src[0] && + ggml_is_contiguous(cgraph->nodes[i+1]->src[1]) && + ggml_are_same_shape(dst, cgraph->nodes[i+1]->src[1]) && + cgraph->nodes[i+1] == cgraph->nodes[i+2]->src[0] && + ops_are_same_device(cgraph, i, i+2)) { + //printf("Fusing add->add->fused_rms of %s, %s, %s\n", dst->name, cgraph->nodes[i+1]->name, cgraph->nodes[i+2]->name); + ggml_cuda_op_fused_add_add_rms_norm(ctx, dst, cgraph->nodes[i+1], cgraph->nodes[i+2]); + i += 2; + } + else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_FUSED_RMS_NORM && ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && - ggml_are_same_shape(dst->src[0], dst->src[1])) { + ggml_are_same_shape(dst->src[0], dst->src[1]) && + dst == cgraph->nodes[i+1]->src[0] && ops_are_same_device(cgraph, i, i+1)) { ggml_cuda_op_fused_add_rms_norm(ctx, dst, cgraph->nodes[i+1]); ++i; } else { ggml_cuda_op_add(ctx, dst); } - //ggml_cuda_op_add(ctx, dst); break; case GGML_OP_ADD_ID: ggml_cuda_op_add_id(ctx, dst); @@ -3183,22 +3214,27 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_relu(ctx, dst); break; case GGML_UNARY_OP_SIGMOID: - if (i + 5 < cgraph->n_nodes && + if (ENABLE_FUSION && i + 5 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ADD && cgraph->nodes[i+3]->op == GGML_OP_ARGSORT && cgraph->nodes[i+4]->op == GGML_OP_VIEW && - cgraph->nodes[i+5]->op == GGML_OP_GET_ROWS) { + cgraph->nodes[i+5]->op == GGML_OP_GET_ROWS && ops_are_same_device(cgraph, i, i+5)) { cuda_glm45moe_experts(ctx, cgraph->nodes[i+5], cgraph->nodes[i+4]); i += 5; } - else if (i + 4 < cgraph->n_nodes && + else if (ENABLE_FUSION && i + 4 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ADD && cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK && - cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS) { + cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS && ops_are_same_device(cgraph, i, i+4)) { cuda_bailingmoev2_experts(ctx, cgraph->nodes[i+4], cgraph->nodes[i+3]); i += 4; + } else if (ENABLE_FUSION && i + 2 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && + cgraph->nodes[i+2]->op == GGML_OP_ADD && ops_are_same_device(cgraph, i, i+2)) { + ggml_cuda_op_biased_sigmoid(ctx, cgraph->nodes[i+2]); + i += 2; } else { ggml_cuda_op_sigmoid(ctx, dst); } @@ -3309,12 +3345,13 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_diag_mask_inf(ctx, dst); break; case GGML_OP_SOFT_MAX: - if (i + 4 < cgraph->n_nodes && + if (ENABLE_FUSION && i + 4 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ARGSORT && cgraph->nodes[i+3]->op == GGML_OP_VIEW && cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS && - ggml_cuda_should_use_topk_moe(cgraph->nodes[i], cgraph->nodes[i+4])) { + ggml_cuda_should_use_topk_moe(cgraph->nodes[i], cgraph->nodes[i+4]) && + ops_are_same_device(cgraph, i, i+4)) { ggml_cuda_op_topk_moe(ctx, cgraph->nodes[i], cgraph->nodes[i+4], cgraph->nodes[i+3]); i += 4; } else { @@ -3343,10 +3380,19 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_pool2d(ctx, dst); break; case GGML_OP_SUM_ROWS: - if (i + 1 < cgraph->n_nodes && + if (ENABLE_FUSION && i + 2 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_SCALE && + cgraph->nodes[i+2]->op == GGML_OP_DIV && + cgraph->nodes[i+1]->src[0] == dst && + cgraph->nodes[i+2]->src[1] == cgraph->nodes[i+1] && + cgraph->nodes[i+2]->src[0] == dst->src[0] && ops_are_same_device(cgraph, i, i+2)) { + ggml_cuda_op_sum_rows_div(ctx, cgraph->nodes[i+2]); + i += 2; + } + else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_DIV && cgraph->nodes[i+1]->src[1] == dst && - cgraph->nodes[i+1]->src[0] == dst->src[0]) { + cgraph->nodes[i+1]->src[0] == dst->src[0] && ops_are_same_device(cgraph, i, i+1)) { ggml_cuda_op_sum_rows_div(ctx, cgraph->nodes[i+1]); ++i; } else { @@ -3354,12 +3400,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg } break; case GGML_OP_ARGSORT: - if (i + 5 < cgraph->n_nodes && + if (ENABLE_FUSION && i + 5 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_VIEW && cgraph->nodes[i+2]->op == GGML_OP_GET_ROWS && cgraph->nodes[i+3]->op == GGML_OP_RESHAPE && cgraph->nodes[i+4]->op == GGML_OP_SOFT_MAX && - cgraph->nodes[i+5]->op == GGML_OP_RESHAPE) { + cgraph->nodes[i+5]->op == GGML_OP_RESHAPE && ops_are_same_device(cgraph, i, i+4)) { cuda_openai_experts(ctx, dst, cgraph->nodes[i+4]); i += 5; } else { @@ -3390,6 +3436,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg printf("%s(%s): %d us\n", ggml_op_name(dst->op), dst->name, (int)(tim2 - tim1)); #endif +#undef ENABLE_FUSION + return true; } diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 14d5d93b..ca11ad13 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -59,7 +59,7 @@ static __global__ void k_argsort_f32_T(const float * x, dst_t * dst, const int n // int min_experts, float thresh_experts) { // bitonic sort int col = threadIdx.x; - int row = blockIdx.y; + int row = blockIdx.x; if (col >= ncols_pad) { return; @@ -97,17 +97,17 @@ static __global__ void k_argsort_f32_T(const float * x, dst_t * dst, const int n } template -static __global__ void k_argsort_f32_f32_i32(const float * x_biased, const float * x, float * weights, int * ids, const int ncols, int ncols_pad, int ntop, - size_t nb_ids) { +static __global__ void k_argsort_f32_u8(const float * x, uint8_t * dst, const int ncols, int ncols_pad, int ntop) { +// int min_experts, float thresh_experts) { // bitonic sort int col = threadIdx.x; - int row = blockIdx.y; + int row = blockIdx.x; if (col >= ncols_pad) { return; } - const float * x_row = x_biased + row * ncols; + const float * x_row = x + row * ncols; extern __shared__ int dst_row[]; // initialize indices @@ -117,6 +117,32 @@ static __global__ void k_argsort_f32_f32_i32(const float * x_biased, const float sort(ncols_pad, ncols, col, x_row, dst_row); + if (col < ncols) dst[row*ncols + dst_row[col]] = col < ntop ? 1 : 0; +} + +template +static __global__ void k_argsort_f32_f32_i32(const float * x_biased, const float * x, const uint8_t * group_mask, + float * weights, int * ids, const int ncols, int ncols_pad, int ntop, size_t nb_ids, int n_per_group, int n_groups) { + // bitonic sort + int col = threadIdx.x; + int row = blockIdx.x; + + if (col >= ncols_pad) { + return; + } + + extern __shared__ int dst_row[]; + auto x_row = (float *)(dst_row + ncols_pad); + + // initialize indices + dst_row[col] = col; + int ig = col / n_per_group; + x_row[col] = ig < n_groups && group_mask[row*n_groups + ig] ? x_biased[row * ncols + col] : -INFINITY; + + __syncthreads(); + + sort(ncols_pad, ncols, col, x_row, dst_row); + if (col < ntop) { weights[row * ntop + col] = 1/(1 + expf(-x[row * ncols + dst_row[col]])); auto row_ids = (int *)((char *)ids + row*nb_ids); @@ -129,7 +155,7 @@ static __global__ void k_argsort_biased_f32_f32_i32(const float * x, const float size_t nb_ids) { // bitonic sort int col = threadIdx.x; - int row = blockIdx.y; + int row = blockIdx.x; if (col >= ncols_pad) { return; @@ -158,7 +184,7 @@ static __global__ void k_openai_f32_f32_i32(const float * x, float * weights, in size_t nb_ids) { // bitonic sort int col = threadIdx.x; - int row = blockIdx.y; + int row = blockIdx.x; if (col >= ncols_pad) { return; @@ -201,10 +227,11 @@ static __global__ void k_openai_f32_f32_i32(const float * x, float * weights, in } template -static __global__ void k_topk_sum(const float * x, const float * bias, float * x_p, float * dst, const int ncols, int ncols_pad, int n_top_k) { +static __global__ void k_topk_sum(const float * x, const float * bias, float * x_p, float * dst, + const int ne00, const int ncols, int ncols_pad, int n_top_k) { // bitonic sort int col = threadIdx.x; - int row = blockIdx.y; + int row = blockIdx.x; if (col >= ncols_pad) { return; @@ -218,7 +245,7 @@ static __global__ void k_topk_sum(const float * x, const float * bias, float * x if (bias && x_p) { float * x_p_row = x_p + row * ncols; if (col < ncols) { - x_p_row[col] = 1/(1 + expf(-x_row[col])) + bias[col]; + x_p_row[col] = 1/(1 + expf(-x_row[col])) + bias[(row * ncols + col)%ne00]; } x_row = x_p_row; } @@ -227,6 +254,10 @@ static __global__ void k_topk_sum(const float * x, const float * bias, float * x sort(ncols_pad, ncols, col, x_row, dst_row); + if (n_top_k == 2) { + float val = x_row[dst_row[0]] + x_row[dst_row[1]]; + if (col == 0) dst[row] = val; + } else { float val = col < n_top_k ? x_row[dst_row[col]] : 0; val = warp_reduce_sum(val); if (blockDim.x > WARP_SIZE) { @@ -248,6 +279,7 @@ static __global__ void k_topk_sum(const float * x, const float * bias, float * x if (col == 0) { dst[row] = val; } + } } static __global__ void k_apply_mask(float * dst, const int * groups, @@ -275,7 +307,7 @@ static void argsort_f32_T_cuda(const float * x, dst_t * dst, const int ncols, co const int ncols_pad = next_power_of_2(ncols); const dim3 block_dims(ncols_pad, 1, 1); - const dim3 block_nums(1, nrows, 1); + const dim3 block_nums(nrows, 1, 1); const size_t shared_mem = ncols_pad * sizeof(int); // FIXME: this limit could be raised by ~2-4x on Ampere or newer @@ -300,24 +332,46 @@ static void argsort_f32_T_cuda(const float * x, dst_t * dst, const int ncols, co } } -static void argsort_f32_f32_i32_cuda(const float * x_biased, const float * x, float * weights, int * ids, const int ncols, const int nrows, int ntop, - size_t nb_ids, ggml_sort_order order, cudaStream_t stream) { +static void argsort_f32_u8_cuda(const float * x, uint8_t * dst, const int ncols, const int nrows, int ntop, + ggml_sort_order order, cudaStream_t stream) { // bitonic sort requires ncols to be power of 2 const int ncols_pad = next_power_of_2(ncols); const dim3 block_dims(ncols_pad, 1, 1); - const dim3 block_nums(1, nrows, 1); + const dim3 block_nums(nrows, 1, 1); const size_t shared_mem = ncols_pad * sizeof(int); // FIXME: this limit could be raised by ~2-4x on Ampere or newer GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); if (order == GGML_SORT_ORDER_ASC) { - k_argsort_f32_f32_i32<<>>(x_biased, x, weights, ids, - ncols, ncols_pad, ntop, nb_ids); + k_argsort_f32_u8<<>>(x, dst, ncols, ncols_pad, ntop); } else if (order == GGML_SORT_ORDER_DESC) { - k_argsort_f32_f32_i32<<>>(x_biased, x, weights, ids, - ncols, ncols_pad, ntop, nb_ids); + k_argsort_f32_u8<<>>(x, dst, ncols, ncols_pad, ntop); + } else { + GGML_ABORT("fatal error"); + } +} + +static void argsort_f32_f32_i32_cuda(const float * x_biased, const float * x, const uint8_t * group_mask, + float * weights, int * ids, const int ncols, const int nrows, int ntop, + size_t nb_ids, int n_per_group, ggml_sort_order order, cudaStream_t stream) { + // bitonic sort requires ncols to be power of 2 + const int ncols_pad = next_power_of_2(ncols); + + const dim3 block_dims(ncols_pad, 1, 1); + const dim3 block_nums(nrows, 1, 1); + const size_t shared_mem = ncols_pad * (sizeof(int) + sizeof(float)); + + // FIXME: this limit could be raised by ~2-4x on Ampere or newer + GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); + + if (order == GGML_SORT_ORDER_ASC) { + k_argsort_f32_f32_i32<<>>(x_biased, x, group_mask, weights, ids, + ncols, ncols_pad, ntop, nb_ids, n_per_group, ncols/n_per_group); + } else if (order == GGML_SORT_ORDER_DESC) { + k_argsort_f32_f32_i32<<>>(x_biased, x, group_mask, weights, ids, + ncols, ncols_pad, ntop, nb_ids, n_per_group, ncols/n_per_group); } else { GGML_ABORT("fatal error"); } @@ -329,7 +383,7 @@ static void argsort_biased_f32_f32_i32_cuda(const float * x, const float * bias, const int ncols_pad = next_power_of_2(ncols); const dim3 block_dims(ncols_pad, 1, 1); - const dim3 block_nums(1, nrows, 1); + const dim3 block_nums(nrows, 1, 1); const size_t shared_mem = ncols_pad * (sizeof(int) + sizeof(float)); // FIXME: this limit could be raised by ~2-4x on Ampere or newer @@ -352,7 +406,7 @@ static void argsort_openai_f32_f32_i32_cuda(const float * x, float * weights, in const int ncols_pad = next_power_of_2(ncols); const dim3 block_dims(ncols_pad, 1, 1); - const dim3 block_nums(1, nrows, 1); + const dim3 block_nums(nrows, 1, 1); const size_t shared_mem = (ncols_pad + ncols_pad > WARP_SIZE ? WARP_SIZE : 0) * sizeof(int); // FIXME: this limit could be raised by ~2-4x on Ampere or newer @@ -408,18 +462,18 @@ void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * } static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float * src, const float * bias, float * src_p, float * dst, - int ncols, int nrows, int n_top_k) { + int ne00, int ncols, int nrows, int n_top_k) { GGML_ASSERT(n_top_k <= ncols); const int ncols_pad = next_power_of_2(ncols); const dim3 block_dims(ncols_pad, 1, 1); - const dim3 block_nums(1, nrows, 1); + const dim3 block_nums(nrows, 1, 1); const size_t shared_mem = (ncols_pad + WARP_SIZE) * sizeof(int); GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); - k_topk_sum<<>>(src, bias, src_p, dst, ncols, ncols_pad, n_top_k); + k_topk_sum<<>>(src, bias, src_p, dst, ne00, ncols, ncols_pad, n_top_k); } void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -453,7 +507,7 @@ void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * ds CUDA_CHECK(cudaGetLastError()); #else ggml_cuda_pool_alloc group_scores(ctx.pool(), nrows*n_groups); - ggml_cuda_op_topk_sum(ctx, (float *)src->data, nullptr, nullptr, group_scores.get(), n_per_group, nrows*n_groups, nk); + ggml_cuda_op_topk_sum(ctx, (float *)src->data, nullptr, nullptr, group_scores.get(), ne00, n_per_group, nrows*n_groups, nk); CUDA_CHECK(cudaGetLastError()); #endif @@ -495,26 +549,19 @@ void cuda_bailingmoev2_experts(ggml_backend_cuda_context & ctx, ggml_tensor * ds int n_per_group = ne00/n_groups; GGML_ASSERT(nk <= n_per_group); GGML_ASSERT(n_top_groups <= n_groups); - int n_discarded_groups = n_groups - n_top_groups; ggml_cuda_pool_alloc group_scores(ctx.pool(), nrows*n_groups); ggml_cuda_op_topk_sum(ctx, (const float *)probs->data, (const float *)bias->data, (float *)topk_src->data, group_scores.get(), - n_per_group, nrows*n_groups, nk); + ne00, n_per_group, nrows*n_groups, nk); CUDA_CHECK(cudaGetLastError()); - ggml_cuda_pool_alloc discarded_groups(ctx.pool(), nrows*n_discarded_groups); - argsort_f32_T_cuda(group_scores.get(), discarded_groups.get(), n_groups, nrows, n_discarded_groups, GGML_SORT_ORDER_ASC, -1, 0.0f, ctx.stream()); + ggml_cuda_pool_alloc group_mask(ctx.pool(), nrows*n_groups); + argsort_f32_u8_cuda(group_scores.get(), group_mask.get(), n_groups, nrows, n_top_groups, GGML_SORT_ORDER_DESC, ctx.stream()); CUDA_CHECK(cudaGetLastError()); - { - const dim3 block_dims(WARP_SIZE, 1, 1); - const dim3 block_nums(nrows, 1, 1); - k_apply_mask<<>>((float *)topk_src->data, discarded_groups.get(), n_discarded_groups, n_per_group, ne00); - CUDA_CHECK(cudaGetLastError()); - } - - argsort_f32_f32_i32_cuda((const float *)topk_src->data, (const float *)probs->data, (float *)dst->data, (int *)topk->data, - ne00, nrows, ne0, topk->nb[1], GGML_SORT_ORDER_DESC, ctx.stream()); + argsort_f32_f32_i32_cuda((const float *)topk_src->data, (const float *)probs->data, group_mask.get(), + (float *)dst->data, (int *)topk->data, + ne00, nrows, ne0, topk->nb[1], n_per_group, GGML_SORT_ORDER_DESC, ctx.stream()); } diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 5a49132a..f296b79f 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -492,6 +492,41 @@ static __global__ void fused_add_rms_norm_f32(const float * a, const float * b, } } +template +static __global__ void fused_add_add_rms_norm_f32(const float * a1, const float * a2, const float * b, const float * c, + float * dst_add, float * dst, const int ncols, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = a1[row*ncols + col] + a2[row*ncols + col] + b[row*ncols + col]; + tmp += xi * xi; + dst_add[row*ncols + col] = xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __shared__ float s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f; + tmp = warp_reduce_sum(tmp); + } + + const float mean = tmp / ncols; + const float scale = rsqrtf(mean + eps); + + for (int col = tid; col < ncols; col += block_size) { + dst[row*ncols + col] = scale * c[col] * dst_add[row*ncols + col]; + } +} static void fused_add_rms_norm_f32_cuda(const float * a, const float * b, const float * c, float * dst_add, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { @@ -538,3 +573,49 @@ void ggml_cuda_op_fused_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tenso src1_d, (float *)add->data, dst_d, ne00, nrows, eps, stream); } +static void fused_add_add_rms_norm_f32_cuda(const float * a1, const float * a2, const float * b, const float * c, float * dst_add, float * dst, + const int ncols, const int nrows, const float eps, cudaStream_t stream) { + GGML_ASSERT(ncols % WARP_SIZE == 0); + if (ncols < 1024) { + const dim3 block_dims(256, 1, 1); + fused_add_add_rms_norm_f32<256><<>>(a1, a2, b, c, dst_add, dst, ncols, eps); + } else { + const dim3 block_dims(1024, 1, 1); + fused_add_add_rms_norm_f32<1024><<>>(a1, a2, b, c, dst_add, dst, ncols, eps); + } +} + +void ggml_cuda_op_fused_add_add_rms_norm(ggml_backend_cuda_context & ctx, + ggml_tensor * add1, ggml_tensor * add2, ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + //const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(add1->data == add2->src[0]->data); + GGML_ASSERT(add2->data == src0->data); + GGML_ASSERT(ggml_is_contiguous(src0)); + //GGML_ASSERT(ggml_is_contiguous(add->src[0])); + //GGML_ASSERT(ggml_is_contiguous(add->src[1])); + //GGML_ASSERT(ggml_are_same_shape(add->src[0], add->src[1])); + //GGML_ASSERT(ggml_are_same_shape(add->src[0], src0)); + //GGML_ASSERT(add->src[0]->type == GGML_TYPE_F32); + //GGML_ASSERT(add->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->ne[0] == src1->ne[0]); + GGML_ASSERT(ggml_nrows(src1) == 1); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + const int64_t ne00 = src0->ne[0]; + + const int64_t nrows = ggml_nrows(src0); + fused_add_add_rms_norm_f32_cuda((const float *)add1->src[0]->data, (const float *)add1->src[1]->data, (const float *)add2->src[1]->data, + src1_d, (float *)add2->data, dst_d, ne00, nrows, eps, stream); +} diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh index 29d67d2e..40f758de 100644 --- a/ggml/src/ggml-cuda/norm.cuh +++ b/ggml/src/ggml-cuda/norm.cuh @@ -9,3 +9,5 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_fused_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add, ggml_tensor * dst); + +void ggml_cuda_op_fused_add_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add1, ggml_tensor * add2, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/sumrows.cu b/ggml/src/ggml-cuda/sumrows.cu index 4888bbfd..c86ad4ed 100644 --- a/ggml/src/ggml-cuda/sumrows.cu +++ b/ggml/src/ggml-cuda/sumrows.cu @@ -16,7 +16,7 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc } } -static __global__ void k_sum_rows_div_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) { +static __global__ void k_sum_rows_div_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, float s, float b) { const int row = blockIdx.x; const int col = threadIdx.x; @@ -27,6 +27,8 @@ static __global__ void k_sum_rows_div_f32(const float * __restrict__ x, float * sum = warp_reduce_sum(sum); + //sum = s*sum + b; + float norm = sum > 0 ? 1/sum : 0.0f; for (int i = col; i < ncols; i += blockDim.x) { dst[row * ncols + i] = x[row * ncols + i] * norm; @@ -42,10 +44,10 @@ void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int k_sum_rows_f32<<>>(x, dst, ncols); } -static void sum_rows_div_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void sum_rows_div_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, float s, float b, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_nums(nrows, 1, 1); - k_sum_rows_div_f32<<>>(x, dst, ncols); + k_sum_rows_div_f32<<>>(x, dst, ncols, s, b); } void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -66,7 +68,16 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } void ggml_cuda_op_sum_rows_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + float s = 1, b = 0; const ggml_tensor * src0 = dst->src[0]; + GGML_ASSERT(dst->src[1]->op == GGML_OP_SUM_ROWS || dst->src[1]->op == GGML_OP_SCALE); + if (dst->src[1]->op == GGML_OP_SCALE) { + GGML_ASSERT(dst->src[1]->src[0]->op == GGML_OP_SUM_ROWS); + auto params = (const float *)dst->src[1]->op_params; + s = params[0]; + b = params[1]; + } + const float * src0_d = (const float *)src0->data; float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); @@ -78,5 +89,5 @@ void ggml_cuda_op_sum_rows_div(ggml_backend_cuda_context & ctx, ggml_tensor * ds const int64_t ncols = src0->ne[0]; const int64_t nrows = ggml_nrows(src0); - sum_rows_div_f32_cuda(src0_d, dst_d, ncols, nrows, stream); + sum_rows_div_f32_cuda(src0_d, dst_d, ncols, nrows, s, b, stream); } diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index abcd8be4..090f5e86 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -125,6 +125,16 @@ static __global__ void sigmoid_f32(const float * x, float * dst, const int k) { dst[i] = 1.0f / (1.0f + expf(-x[i])); } +static __global__ void biased_sigmoid_f32(const float * x, const float * bias, float * dst, float * dst_biased, const int k, const int ncols) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = 1.0f / (1.0f + expf(-x[i])); + dst_biased[i] = dst[i] + bias[i % ncols]; +} + static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -221,6 +231,11 @@ static void sigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStre sigmoid_f32<<>>(x, dst, k); } +static void biased_sigmoid_f32_cuda(const float * x, const float * bias, float * dst, float * dst_biased, const int k, const int ncols, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE; + biased_sigmoid_f32<<>>(x, bias, dst, dst_biased, k, ncols); +} + static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE; hardsigmoid_f32<<>>(x, dst, k); @@ -402,6 +417,26 @@ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { sigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } +void ggml_cuda_op_biased_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(dst->op == GGML_OP_ADD); + GGML_ASSERT(dst->src[0]->op == GGML_OP_UNARY); + const ggml_tensor * src0 = dst->src[0]->src[0]; + const ggml_tensor * bias = dst->src[1]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(bias->type == GGML_TYPE_F32); + GGML_ASSERT(bias->ne[0] == src0->ne[0]); + GGML_ASSERT(ggml_nrows(bias) == 1); + + biased_sigmoid_f32_cuda(src0_d, (const float *)bias->data, (float *)dst->src[0]->data, dst_d, ggml_nelements(src0), src0->ne[0], stream); +} + void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index d5481a60..f47a5cc7 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -47,6 +47,8 @@ void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_biased_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 0ccdcc6e..3e55cc59 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -7793,6 +7793,7 @@ ggml_cgraph * llm_build_context::build_openai_moe() { cur = ffn_inp; cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il); + ggml_build_forward_expand(gf, cur); cb(cur, "attn_post_norm", il); bool use_dup_bias = cur->ne[1] < 32 && model.layers[il].ffn_up_exps_b_dup &&