diff --git a/example/ck_tile/16_fused_moe_general/main.cpp b/example/ck_tile/16_fused_moe_general/main.cpp index 0b36090686..b08cff5758 100644 --- a/example/ck_tile/16_fused_moe_general/main.cpp +++ b/example/ck_tile/16_fused_moe_general/main.cpp @@ -59,6 +59,21 @@ auto shuffle_moe_weight(const ck_tile::HostTensor& t, std::string mfma_dtype, return t; } +template +void output_matrix_2d(ck_tile::HostTensor& data, int m,int n) +{ + std::cout << std::endl; + for(int i = 0; i < m; i++) + { + std::cout << "Line " << i << "\t"; + for(int j = 0; j < n; j++) + { + std::cout << ck_tile::type_convert(data(i,j)) << "\t"; + } + std::cout << std::endl; + } +} + template void topid_unique_gen( std::vector& host_tensor, int tokens, int topk, int num_expert, int seed) @@ -256,6 +271,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // } // std::cout << std::endl; // } + output_matrix_2d(a_host, tokens, hidden_size); // std::cout << sorted_token_ids_host << std::endl; // std::cout << num_sorted_tiles_host << std::endl; // std::cout << sorted_expert_ids_host << std::endl; @@ -277,6 +293,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host); ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host); ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host); + o_buf.SetZero(); fused_moegemm_traits traits{prec_i, prec_w, @@ -363,6 +380,8 @@ bool run(const ck_tile::ArgParser& arg_parser) pass &= ck_tile::check_err( o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol); std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; + + output_matrix_2d(o_dev, tokens, hidden_size); } std::cout << std::flush << std::endl; diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp index 498440c6d2..b0571f23e5 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp @@ -70,15 +70,15 @@ struct FusedMoeGemmPipeline_General CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - // // matrix a or tokens smem - // constexpr index_t smem_mat_a = - // BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType); - // // shuffle C matrix - // constexpr index_t smem_bridge = - // BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); + // matrix a or tokens smem + constexpr index_t smem_mat_a = + BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType); + // shuffle C matrix + constexpr index_t smem_bridge = + BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); - // return max(smem_mat_a, smem_bridge); - return Policy::template GetSmemSize(); + return max(smem_mat_a, smem_bridge); + //return Policy::template GetSmemSize(); } // this is the thread-offset along row/col @@ -105,35 +105,46 @@ struct FusedMoeGemmPipeline_General ignore = hidden_size; ignore = intermediate_size; - // CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast(smem); - // auto a_lds_view = make_tensor_view( - // smem_0, Policy::template MakeLdsStoreDesc_A()); - // auto a_lds_win = make_tile_window(a_lds_view, make_tuple(number{}, number{}), {0, 0}); + CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast(smem); + auto a_lds_view = make_tensor_view( + smem_0, Policy::template MakeLdsStoreDesc_A()); + auto a_lds_win = make_tile_window(a_lds_view, make_tuple(number{}, number{}), {0, 0}); auto a_global_to_dram_window = make_tile_window( a_window_.get_bottom_tensor_view(), make_tuple(number{}, number{}), a_window_.get_window_origin(), Policy::template MakeGlobalTileDistribution_A()); + + // auto o_win = make_tile_window_linear( + // o_window_, Policy::template MakeGlobalTileDistribution_O()); + + auto a_dram_block = load_tile(a_global_to_dram_window); - //store_tile(a_lds_win, a_dram_block); - ignore = a_dram_block; + store_tile(a_lds_win, a_dram_block); + store_tile(o_window_, a_dram_block); + #if 0 //check a matrix gather right or not - constexpr auto a_spans = decltype(a_dram)::get_distributed_spans(); - int counter = 0; + constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans(); + int counter = 0; sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) { - sweep_tile_span(a_spans[number<1>{}], [&](auto idxk){ - constexpr auto i_j_idx = make_tuple(idxm, idxk); - if(threadIdx.x == 65 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0){ - counter = counter + 1; - index_t idm_0 = idxm.impl_.at(0); - index_t idn_0 = idxk.impl_.at(0); - printf("in A idm is %d , idn_ is %d , counter is %d, value is: %f \n", idm_0, idn_0, counter, ck_tile::type_convert(a_dram(i_j_idx))); - } - }); + sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) { + constexpr auto i_j_idx = make_tuple(idxm, idxk); + if(threadIdx.x == 65 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) + { + counter = counter + 1; + index_t idm_0 = idxm.impl_.at(0); + index_t idn_0 = idxk.impl_.at(0); + printf("in A idm is %d , idn_ is %d , counter is %d, value is: %f \n", + idm_0, + idn_0, + counter, + ck_tile::type_convert(a_dram_block(i_j_idx))); + } }); + }); #endif } }; diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp index 09399d1975..e6ea78def7 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp @@ -232,53 +232,6 @@ struct FusedMoeGemmPipelineGeneralPolicy } } -#if 0 - // Caution: this will require global memory pre-shuffled to follow the mfma layout - template - CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_MatrixCore_Swizzled() - { - static_assert(Alignment % WarpGemm::WarpGemmAttribute::Impl::kABKPerLane == 0); - - if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) - { - constexpr index_t Kv = Alignment; - constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; - - static_assert(KPerBlock % (K1 * K2) == 0); - constexpr index_t Nr = NPerBlock / Nw; - constexpr index_t Kr = KPerBlock / (Kv * Kw); - - constexpr index_t Nr_p = WavesPerBlock_N; - constexpr index_t Kr_p = WavesPerBlock_K; - constexpr index_t Nr_y = Nr / Nr_p; - constexpr index_t Kr_y = Kr / Kr_p; - - return make_static_tile_distribution( - tile_distribution_encoding< - sequence<1>, // 0 - // major 1 2 3 - // minor 0 1 0 1 0 1 2 - tuple, sequence, sequence>, - - // Nr_p, Kr_p Kw Nw - tuple, sequence<3, 3>>, - tuple, sequence<0, 1>>, - - // Nr_y Kr_y Kv - sequence<1, 2, 3>, - sequence<0, 0, 2>>{}); - // clang-format on - } - } -#endif template {}), - make_merge_transform(make_tuple(number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + // make_pass_through_transform(), + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); return lds_block_desc_issues_warps_lanes; } @@ -446,12 +399,13 @@ struct FusedMoeGemmPipelineGeneralPolicy constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( lds_block_desc_0, - make_tuple(make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_merge_transform(make_tuple( - number{}, number{}, number{}))), - make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + make_tuple( + //make_pass_through_transform(number{}), + //make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{},number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); return lds_block_desc_issues_warps_lanes; }