diff --git a/example/65_gemm_multiply_multiply/moe_gemm2.cpp b/example/65_gemm_multiply_multiply/moe_gemm2.cpp index af67f8efd1..80a5484cdc 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2.cpp @@ -236,7 +236,7 @@ int main(int argc, char* argv[]) // const ck::index_t experts = 8; - Tensor expert_ids(HostTensorDescriptor({experts}, {1})); + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1})); max_token_id.mData[0] = valid_size; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp index cd4cfd3b18..63042ae505 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp @@ -197,8 +197,11 @@ struct GridwiseMoeGemmScatter __host__ static auto CalculateGridSize(index_t M, index_t N) { - return std::make_tuple(math::integer_divide_ceil(N, NPerBlock), - math::integer_divide_ceil(M, MPerBlock), + // return std::make_tuple(math::integer_divide_ceil(N, NPerBlock), + // math::integer_divide_ceil(M, MPerBlock), + // 1); + return std::make_tuple(math::integer_divide_ceil(N, NPerBlock) * math::integer_divide_ceil(M, MPerBlock), + 1, 1); } @@ -1149,10 +1152,22 @@ struct GridwiseMoeGemmScatter MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n, problem.MBlock, problem.NBlock); - const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x); - const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y); - const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]); + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + const index_t expert_block_id = blockIdx.x / problem.NBlock; + // const index_t b_block_id = blockIdx.x % problem.NBlock; + const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); + const index_t es = __builtin_amdgcn_readfirstlane(p_max_token_id[expert_block_id + 1]); + const index_t expert_swizzle = es > 0 ? es : 1; //p_max_token_id[expert_id + 1]; + const index_t expert_block_swizzle = expert_block_id / expert_swizzle; + const index_t b_block_id_swizzle = blockIdx.x % (problem.NBlock * expert_swizzle); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(b_block_id_swizzle % 8 + b_block_id_swizzle / (8 * expert_swizzle) * 8); + const index_t block_m_id = __builtin_amdgcn_readfirstlane(expert_block_swizzle * expert_swizzle + b_block_id_swizzle / 8 % expert_swizzle); + + // const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x); + // const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y); + // const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]); + // const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); const index_t token0 = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);