diff --git a/example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp b/example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp index 9b157d4f72..e573fb131b 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp @@ -133,13 +133,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, //threadnum, mblock, nblock, kblock - 256, 32, 128, 128, + 256, 64, 128, 128, // ak1, bk1 8, 8, // mn_perxdl 32, 32, // mn_xdlperwave - 1, 1, + 2, 1, // a,b: loadtranfer cluster, cluster order, srcorder, srcpervec, dstpervec, lds_extra // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, @@ -169,8 +169,8 @@ int main(int argc, char* argv[]) ck::index_t N = 6144; ck::index_t K = 8192; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 8; - ck::index_t sorted_tile_size = 32; + ck::index_t sorted_tile_num = 1; + ck::index_t sorted_tile_size = 64; ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; ck::index_t tokens = 64; @@ -368,7 +368,7 @@ int main(int argc, char* argv[]) auto ref_invoker = ref_moe_gemm.MakeInvoker(); auto ref_argument = ref_moe_gemm.MakeArgument( - sorted_token_ids, expert_ids, a0_t_k, b0_e_n_k, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + sorted_token_ids, expert_ids, sorted_tile_size, a0_t_k, b0_e_n_k, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); ref_invoker.Run(ref_argument); for(int m = 0; m < SORTED_SIZE; ++m) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index aca0df55b7..fa4ba36097 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -176,8 +176,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; static_assert(NWave * warpSize == BlockSize); // static constexpr index_t NumTokens = 1; - static constexpr index_t Experts = 8; - static constexpr index_t SortedTileSize = 32; + static constexpr index_t SortedTileSize = MPerBlock; static constexpr auto MakeDsGridPointer() diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp index 103abc24a7..b328ad893f 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp @@ -29,15 +29,17 @@ struct ReferenceMoeGemm : public device::BaseOperator struct Argument : public device::BaseArgument { Argument(const Tensor& sorted_token_ids, - const Tensor& expert_ids, - const Tensor& a_t_k, + const Tensor& expert_ids, + const index_t sorted_tile_size, + const Tensor& a_t_k, const Tensor& b_e_n_k, Tensor& c_m_n, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) - : expert_ids_{expert_ids}, - sorted_token_ids_{sorted_token_ids}, + : sorted_token_ids_{sorted_token_ids}, + expert_ids_{expert_ids}, + sorted_tile_size_{sorted_tile_size}, a_t_k_{a_t_k}, b_e_n_k_{b_e_n_k}, c_m_n_{c_m_n}, @@ -56,7 +58,7 @@ struct ReferenceMoeGemm : public device::BaseOperator AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; CElementwiseOperation c_element_op_; - index_t sorted_tile_size = 32; + index_t sorted_tile_size_; }; // Invoker @@ -73,7 +75,7 @@ struct ReferenceMoeGemm : public device::BaseOperator ComputeTypeA v_a{0}; ComputeTypeB v_b{0}; const int t = arg.sorted_token_ids_(m); - const int e = arg.expert_ids_(m / arg.sorted_tile_size); + const int e = arg.expert_ids_(m / arg.sorted_tile_size_); const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0]; if(t < token_cnt) { for(int k = 0; k < K; ++k) @@ -135,6 +137,7 @@ struct ReferenceMoeGemm : public device::BaseOperator static auto MakeArgument(const Tensor& sorted_token_ids, const Tensor& expert_ids, + const index_t sorted_tile_size, const Tensor& a_t_k, const Tensor& b_e_n_k, Tensor& c_m_n, @@ -142,7 +145,7 @@ struct ReferenceMoeGemm : public device::BaseOperator BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { - return Argument{sorted_token_ids, expert_ids, a_t_k, b_e_n_k, c_m_n, a_element_op, b_element_op, c_element_op}; + return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_t_k, b_e_n_k, c_m_n, a_element_op, b_element_op, c_element_op}; } static auto MakeInvoker() { return Invoker{}; }