From d55e98519f2b5ce8eb404b50f943a28916ac25a8 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 30 Aug 2025 12:09:41 +0300 Subject: [PATCH] CUDA: prompt processing optimizations for MoE models (#739) * Skip the row id computation for the ffn_down op Sadly, almost negligible performance gain. * Also this doesn't do much * Also this barely moves the needle * This is slightly better --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda.cu | 86 +++++++++++++--------------- ggml/src/ggml-cuda/mmq_id.cu | 61 +++++++++++++++++--- ggml/src/ggml-cuda/mmq_id.cuh | 7 ++- ggml/src/ggml-cuda/mmq_id_common.cuh | 18 +++++- src/llama.cpp | 2 +- 5 files changed, 117 insertions(+), 57 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index e734298c..f7e0b489 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2395,7 +2395,7 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } if (ggml_is_quantized(src0->type) && ggml_cuda_can_use_mmq_id(src0->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { - ggml_cuda_mul_mat_q_id(ctx, src0, src1, ids, dst, nullptr, nullptr); + ggml_cuda_mul_mat_q_id(ctx, src0, src1, ids, dst, nullptr, nullptr, false); return false; } @@ -2702,36 +2702,38 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor compute_row_ids((const int32_t *)ids->data, ids_src1, ids_dst, expert_bounds, ne02, ne12, n_ids, ne11, nb11, nb12, ids->nb[1], stream); - const int64_t ne11_flat = ne12*n_ids; - const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING); - size_t nbytes_src1_q8_1 = ne11_flat*ne10_padded * sizeof(block_q8_1)/QK8_1 + - get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq); - ggml_cuda_pool_alloc src1_quantized(ctx.pool(), nbytes_src1_q8_1); - - size_t ts_src1 = ggml_type_size(src1->type); - quantize_mmq_q8_1_cuda_id((const float *)src1->data, ids_src1, src1_quantized.get(), - src0_1->type, ne10, src1->nb[1] / ts_src1, src1->nb[2] / ts_src1, src1->nb[2] / ts_src1, - ne10_padded, ne11_flat, 1, 1, stream); - ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); - dst_row.data = dst_up_contiguous.get(); - ggml_cuda_mul_mat_q_id(ctx, src0_1, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get()); - if (dst->src[4]) { - ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[4]->data, (const int32_t *)ids->data, - (float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1], - dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream); - CUDA_CHECK(cudaGetLastError()); - } + { + const int64_t ne11_flat = ne12*n_ids; + const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING); + size_t nbytes_src1_q8_1 = ne11_flat*ne10_padded * sizeof(block_q8_1)/QK8_1 + + get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq); + ggml_cuda_pool_alloc src1_quantized(ctx.pool(), nbytes_src1_q8_1); - dst_row.data = dst_gate_contiguous.get(); - ggml_cuda_mul_mat_q_id(ctx, src0_2, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get()); - if (dst->src[5]) { - ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[5]->data, (const int32_t *)ids->data, - (float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1], - dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream); - CUDA_CHECK(cudaGetLastError()); + size_t ts_src1 = ggml_type_size(src1->type); + quantize_mmq_q8_1_cuda_id((const float *)src1->data, ids_src1, src1_quantized.get(), + src0_1->type, ne10, src1->nb[1] / ts_src1, src1->nb[2] / ts_src1, src1->nb[2] / ts_src1, + ne10_padded, ne11_flat, 1, 1, stream); + + dst_row.data = dst_up_contiguous.get(); + ggml_cuda_mul_mat_q_id(ctx, src0_1, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get(), false); + if (dst->src[4]) { + ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[4]->data, (const int32_t *)ids->data, + (float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1], + dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream); + CUDA_CHECK(cudaGetLastError()); + } + + dst_row.data = dst_gate_contiguous.get(); + ggml_cuda_mul_mat_q_id(ctx, src0_2, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get(), false); + if (dst->src[5]) { + ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[5]->data, (const int32_t *)ids->data, + (float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1], + dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream); + CUDA_CHECK(cudaGetLastError()); + } } auto unary_op = (ggml_unary_op)dst->op_params[0]; @@ -2748,19 +2750,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && ggml_cuda_should_use_mmq(next->src[0]->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { - //ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, (char *)ids_device.get(), nullptr); - ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, nullptr, nullptr); + ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, (char *)ids_device.get(), nullptr, true); + //ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, nullptr, nullptr); return true; } return false; } - std::vector ids_host(ggml_nbytes(ids)); - const char * ids_dev = (const char *) ids->data; - CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); - CUDA_CHECK(cudaStreamSynchronize(stream)); - ggml_tensor src0_1_row = *src0_1; ggml_tensor src0_2_row = *src0_2; ggml_tensor src1_row = *src1; @@ -2834,20 +2831,19 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor bool first = false; //true; - ggml_cuda_pool_alloc dev_row_mapping(ctx.pool()); - std::vector moe_counts, cum_moe_counts; + const int64_t ne_get_rows = ne12 * n_ids; + ggml_cuda_pool_alloc dev_row_mapping(ctx.pool(), ne_get_rows + (n_as + 2)/2); - bool is_ser = prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping); - if (is_ser) { - if (fuse_down) { - CUDA_CHECK(cudaMemsetAsync(next->data, 0, ggml_nbytes(next), stream)); - } else { - CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream)); - } - } + compute_row_ids2((const int32_t *)ids->data, dev_row_mapping.get(), (int32_t *)(dev_row_mapping.get() + ne_get_rows), + ne02, ne12, n_ids, ne11, nb11, nb12, ids->nb[1], stream); + + std::vector cum_moe_counts(n_as + 1); + CUDA_CHECK(cudaMemcpyAsync(cum_moe_counts.data(), dev_row_mapping.get() + ne_get_rows, (n_as+1)*sizeof(int), + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); for (int64_t i02 = 0; i02 < n_as; i02++) { - int64_t num_src1_rows = moe_counts[i02]; + int64_t num_src1_rows = cum_moe_counts[i02+1] - cum_moe_counts[i02]; if (num_src1_rows == 0) continue; size_t mapping_offset = cum_moe_counts[i02]; diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index 230715c0..7c578f44 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -24,11 +24,13 @@ struct mmq_ids_helper_store { }; static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store"); +struct mmid_row_mapping { int32_t i1, i2; }; + // Helper function for mul_mat_id, converts ids to a more convenient format. // ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert. // ids_dst describes the same mapping but for the dst tensor. // The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1]. -template +template __launch_bounds__(ggml_cuda_get_physical_warp_size(), 1) static __global__ void mmq_ids_helper( const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, @@ -103,8 +105,12 @@ static __global__ void mmq_ids_helper( const mmq_ids_helper_store store_it = store[itc]; const int it = store_it.it(); const int iex_used = store_it.iex_used(); - ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y; - ids_dst [nex_prev + itc] = it*n_expert_used + iex_used; + if constexpr (type == 0) { + ids_src1[nex_prev + itc] = it; + ids_dst [nex_prev + itc] = it*n_expert_used + iex_used; + } else { + ((mmid_row_mapping *)ids_src1)[nex_prev + itc] = {iex_used, it}; + } } if (threadIdx.x != 0) { @@ -120,7 +126,7 @@ static __global__ void mmq_ids_helper( expert_bounds[gridDim.x] = nex_prev + it_compact; } -template +template static void launch_mmq_ids_helper( const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { @@ -135,7 +141,7 @@ static void launch_mmq_ids_helper( const dim3 num_blocks(n_experts, 1, 1); const dim3 block_size(warp_size, 1, 1); const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store); - mmq_ids_helper<<>> + mmq_ids_helper<<>> (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1); } @@ -313,8 +319,48 @@ void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, CUDA_CHECK(cudaGetLastError()); } +void compute_row_ids2(const int32_t * ids, mmid_row_mapping * rmapping, int32_t * expert_bounds, + int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, + cudaStream_t stream) { + + const int si1 = nb21 / sizeof(int); + const int sis1 = nb12 / nb11; + + switch (n_expert_used) { + case 2: + launch_mmq_ids_helper< 2, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 4: + launch_mmq_ids_helper< 4, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 6: + launch_mmq_ids_helper< 6, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 8: + launch_mmq_ids_helper< 8, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 16: + launch_mmq_ids_helper<16, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 32: + launch_mmq_ids_helper<32, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + default: + launch_mmq_ids_helper< 0, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + } + CUDA_CHECK(cudaGetLastError()); +} + void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, - const ggml_tensor * ids_tensor, ggml_tensor * dst, char * ids_data, char * src1_quantized_data) { + const ggml_tensor * ids_tensor, ggml_tensor * dst, char * ids_data, char * src1_quantized_data, bool is_next) { GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT(ids_tensor->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID. @@ -377,6 +423,7 @@ void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor * ids_src1 = (int32_t *)ids_data; ids_dst = ids_src1 + ne_get_rows; expert_bounds = ids_dst + ne_get_rows; + if (is_next) ids_src1 = ids_dst; } else { GGML_ASSERT(ids_tensor->nb[0] == ggml_element_size(ids_tensor)); @@ -460,7 +507,7 @@ void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor * ne00, ne01, ne_get_rows, s01, ne_get_rows, s1, ne02, ne02, s02, s12, s2, ne03, ne13, s03, s13, s3, - use_stream_k, ne12}; + use_stream_k, ne12, (int)n_expert_used, (int)ne02}; //printf("ne00 = %ld, ne01 = %ld, ne_get_rows = %ld, s01 = %ld, s1 = %ld\n", ne00, ne01, ne_get_rows, s01, s1); //printf("ne02 = %ld, s02 = %ld, s12 = %ld, s2 = %ld\n", ne02, s02, s12, s2); diff --git a/ggml/src/ggml-cuda/mmq_id.cuh b/ggml/src/ggml-cuda/mmq_id.cuh index 56739307..4e73a4b9 100644 --- a/ggml/src/ggml-cuda/mmq_id.cuh +++ b/ggml/src/ggml-cuda/mmq_id.cuh @@ -4,9 +4,14 @@ void ggml_cuda_mul_mat_q_id( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, - ggml_tensor * dst, char * ids_data, char * src1_quantized_data); + ggml_tensor * dst, char * ids_data, char * src1_quantized_data, bool is_next); void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, cudaStream_t stream); + +struct mmid_row_mapping; +void compute_row_ids2(const int32_t * ids, mmid_row_mapping * rmapping, int32_t * expert_bounds, + int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, cudaStream_t stream); + bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11); diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index 89baa31b..9db10801 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -3943,7 +3943,7 @@ struct mmq_args_id { int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst; int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst; int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst; - bool use_stream_k; int64_t ncols_max; + bool use_stream_k; int64_t ncols_max; int n_experts_used; int n_total_experts; }; template @@ -4057,9 +4057,21 @@ void mul_mat_q_case_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args const int warp_size = ggml_cuda_get_physical_warp_size_host(); //ggml_cuda_info().devices[id].warp_size; const int nwarps = mmq_get_nwarps_host(cc, warp_size); - const int mmq_x_max = get_mmq_x_max_host(cc); + int mmq_x_max = get_mmq_x_max_host(cc); const int mmq_y = get_mmq_y_host(cc); + int ncols_max = args.ncols_max; + if (args.ids_dst && 4*args.n_experts_used < args.n_total_experts) { + ncols_max *= 4*args.n_experts_used; + ncols_max /= args.n_total_experts; + if (ncols_max < 1) ncols_max = 1; + ncols_max = 32*((ncols_max + 31)/32); + //ncols_max = 16*((ncols_max + 15)/16); + if (ncols_max > args.ncols_max) ncols_max = args.ncols_max; + //printf("%s: ncols_max = %d, %d\n", __func__, (int)args.ncols_max, ncols_max); + //mmq_x_max /= 2; + } + int mmq_x_best = 0; int ntiles_x_best = INT_MAX; @@ -4070,7 +4082,7 @@ void mul_mat_q_case_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args continue; } - const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x; + const int ntiles_x = (ncols_max + mmq_x - 1) / mmq_x; if (ntiles_x < ntiles_x_best) { mmq_x_best = mmq_x; diff --git a/src/llama.cpp b/src/llama.cpp index 8d7d3917..3aabb1b0 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13342,7 +13342,7 @@ struct llm_build_context { // whether to use n_tokens as the matrix dimension during multiplication or n_head // n_tokens is higher during prompt processing, this allows to optimize for this case - bool pp_opt = n_tokens >= 128; // Is it a fixed constant or is it somehow relared to n_head? original: n_tokens > n_head; + bool pp_opt = n_tokens >= 32; //128; // Is it a fixed constant or is it somehow relared to n_head? original: n_tokens > n_head; for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL;