From ef8e3620cc00e809cb6428d8f0d684dfc3b46fbd Mon Sep 17 00:00:00 2001 From: letaoqin Date: Mon, 25 Nov 2024 07:40:03 +0000 Subject: [PATCH] gather and scatter right --- example/ck_tile/16_fused_moe_general/main.cpp | 31 ++++++++++--------- .../core/algorithm/indexing_adaptor.hpp | 4 +-- .../fused_moegemm_pipeline_general.hpp | 15 +++++---- .../fused_moegemm_pipeline_general_policy.hpp | 14 +++++---- 4 files changed, 33 insertions(+), 31 deletions(-) diff --git a/example/ck_tile/16_fused_moe_general/main.cpp b/example/ck_tile/16_fused_moe_general/main.cpp index b08cff5758..20f637fb4a 100644 --- a/example/ck_tile/16_fused_moe_general/main.cpp +++ b/example/ck_tile/16_fused_moe_general/main.cpp @@ -60,15 +60,15 @@ auto shuffle_moe_weight(const ck_tile::HostTensor& t, std::string mfma_dtype, } template -void output_matrix_2d(ck_tile::HostTensor& data, int m,int n) +void output_matrix_2d(ck_tile::HostTensor& data, int m, int n) { - std::cout << std::endl; + std::cout << std::endl; for(int i = 0; i < m; i++) { std::cout << "Line " << i << "\t"; for(int j = 0; j < n; j++) { - std::cout << ck_tile::type_convert(data(i,j)) << "\t"; + std::cout << ck_tile::type_convert(data(i, j)) << "\t"; } std::cout << std::endl; } @@ -261,17 +261,8 @@ bool run(const ck_tile::ArgParser& arg_parser) num_sorted_tiles_host.mData[0], experts, block_m); - // std::cout << std::endl; - // for(int i = 0; i < tokens; i++) - // { - // std::cout << "Line " << i << "\t"; - // for(int j = 0; j < hidden_size; j++) - // { - // std::cout << ck_tile::type_convert(a_host(i,j)) << "\t"; - // } - // std::cout << std::endl; - // } - output_matrix_2d(a_host, tokens, hidden_size); + + // output_matrix_2d(a_host, tokens, hidden_size); // std::cout << sorted_token_ids_host << std::endl; // std::cout << num_sorted_tiles_host << std::endl; // std::cout << sorted_expert_ids_host << std::endl; @@ -381,7 +372,17 @@ bool run(const ck_tile::ArgParser& arg_parser) o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol); std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; - output_matrix_2d(o_dev, tokens, hidden_size); + // std::cout << std::endl; + // int count = 0; + // for(int i = 0; i < tokens; i++) + // { + // std::cout << "Line " << i << "\t"; + // for(int j = 0; j < hidden_size; j++) + // { + // std::cout << ck_tile::type_convert(o_dev(count++)) << "\t"; + // } + // std::cout << std::endl; + // } } std::cout << std::flush << std::endl; diff --git a/include/ck_tile/core/algorithm/indexing_adaptor.hpp b/include/ck_tile/core/algorithm/indexing_adaptor.hpp index c6084b6a76..740ce728ef 100644 --- a/include/ck_tile/core/algorithm/indexing_adaptor.hpp +++ b/include/ck_tile/core/algorithm/indexing_adaptor.hpp @@ -80,7 +80,7 @@ struct indexing_adaptor pre_up_index_ = idx_up[number<0>{}]; pre_low_index_ = idx_low(number<0>{}); #if 0 - if(threadIdx.x == 65 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) + if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{})); } @@ -105,7 +105,7 @@ struct indexing_adaptor pre_up_index_ = up_index; pre_low_index_ = low_index; #if 0 - if(threadIdx.x == 65 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) + if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("\n index form %d to %d, diff from %d to %d \n", up_index, diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp index b0571f23e5..bf97ff015c 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp @@ -78,7 +78,7 @@ struct FusedMoeGemmPipeline_General BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); return max(smem_mat_a, smem_bridge); - //return Policy::template GetSmemSize(); + // return Policy::template GetSmemSize(); } // this is the thread-offset along row/col @@ -108,7 +108,10 @@ struct FusedMoeGemmPipeline_General CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast(smem); auto a_lds_view = make_tensor_view( smem_0, Policy::template MakeLdsStoreDesc_A()); - auto a_lds_win = make_tile_window(a_lds_view, make_tuple(number{}, number{}), {0, 0}); + auto a_lds_win = make_tile_window( + a_lds_view, + make_tuple(number{}, number{}), + {0, 0}); auto a_global_to_dram_window = make_tile_window( a_window_.get_bottom_tensor_view(), @@ -116,15 +119,11 @@ struct FusedMoeGemmPipeline_General a_window_.get_window_origin(), Policy::template MakeGlobalTileDistribution_A()); - // auto o_win = make_tile_window_linear( - // o_window_, Policy::template MakeGlobalTileDistribution_O()); - - auto a_dram_block = load_tile(a_global_to_dram_window); store_tile(a_lds_win, a_dram_block); store_tile(o_window_, a_dram_block); - + #if 0 //check a matrix gather right or not constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans(); @@ -132,7 +131,7 @@ struct FusedMoeGemmPipeline_General sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) { sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) { constexpr auto i_j_idx = make_tuple(idxm, idxk); - if(threadIdx.x == 65 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) + if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { counter = counter + 1; index_t idm_0 = idxm.impl_.at(0); diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp index e6ea78def7..98a932a956 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp @@ -367,9 +367,10 @@ struct FusedMoeGemmPipelineGeneralPolicy constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( lds_block_desc_0, make_tuple( - // make_pass_through_transform(), + // make_pass_through_transform(), make_merge_transform(make_tuple(number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}, number{}))), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -400,10 +401,11 @@ struct FusedMoeGemmPipelineGeneralPolicy constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( lds_block_desc_0, make_tuple( - //make_pass_through_transform(number{}), - //make_pass_through_transform(number{}), - make_merge_transform(make_tuple(number{},number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}))), + // make_pass_through_transform(number{}), + // make_pass_through_transform(number{}), + make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}), make_tuple(sequence<0>{}, sequence<1>{}));