From f9c29b5ec796efe04840d583a2d807bfc185676f Mon Sep 17 00:00:00 2001 From: coderfeli Date: Fri, 25 Apr 2025 03:09:53 +0000 Subject: [PATCH] set 16x16 --- .../moe_gemm1_xdl_fp8.cpp | 23 ++++++++-------- .../gpu/grid/gridwise_moe_gemm.hpp | 26 +++++++++---------- 2 files changed, 23 insertions(+), 26 deletions(-) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index f594080755..ec33dbc24f 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -155,13 +155,13 @@ using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t MXDLPerWave = 2; -static constexpr ck::index_t NXDLPerWave = 2; +static constexpr ck::index_t MXDLPerWave = 4; +static constexpr ck::index_t NXDLPerWave = 4; static constexpr ck::index_t BLOCKSIZE = 256; static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; +static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); -static constexpr ck::index_t Nswizzle = true; +static constexpr ck::index_t Nswizzle = false; static constexpr bool MulRoutedWeight = false; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); @@ -188,7 +188,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 2, 1, S<1, 32, 1, 8>, S, + 2, 2, S<1, 32, 1, 8>, S, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, MulRoutedWeight, A0DataType>; // clang-format on @@ -201,11 +201,11 @@ int main(int argc, char* argv[]) // GEMM shape ck::index_t N = 4096; - ck::index_t K = 4096; + ck::index_t K = 6144; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 8; - ck::index_t valid_tile_num = 8; - ck::index_t tokens = 128; + ck::index_t sorted_tile_num = 133; + ck::index_t valid_tile_num = 128; + ck::index_t tokens = 8192; ck::index_t topk = 2; // ck::index_t tokens = batch * topk; @@ -268,11 +268,10 @@ int main(int argc, char* argv[]) // int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} - max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; - int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + max_token_id.mData = {valid_size}; for(int i = 0; i < sorted_tile_num; i++) { - expert_ids.mData[i] = eids[i]; + expert_ids.mData[i] = i / (valid_tile_num / experts); } int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; int tokenid = 0; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index fba46d4ac6..7b399c6daa 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1681,7 +1681,8 @@ struct GridwiseMoeGemm const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n, problem.MBlock, problem.NBlock); - const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged"); const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; if(expert_block_id * MPerBlock >= max_token_id) return; @@ -1690,12 +1691,13 @@ struct GridwiseMoeGemm const auto block_mn = [&]() -> std::pair { if constexpr(NSwizzle) { - const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; - const index_t prefix_block = ecnt_prefix * problem.NBlock; - const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; - const index_t expert_swizzle = ecnt > 0 ? ecnt : 1; - const index_t bid_new = blockIdx.x - prefix_block; - const index_t nid = __builtin_amdgcn_readfirstlane( + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = + ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); const index_t mid = __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); @@ -1708,7 +1710,6 @@ struct GridwiseMoeGemm }(); const index_t block_n_id = block_mn.first; const index_t block_m_id = block_mn.second; - const index_t token0 = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); @@ -1720,11 +1721,9 @@ struct GridwiseMoeGemm constexpr auto AMRepeats = MPerBlock / AMThreads; const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; - if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || - token0 >= problem.NumTokens) + if(token_pos >= max_token_id || token0 >= problem.NumTokens) return; - StaticallyIndexedArray - gather_offsets; //= p_sorted_token_ids[token_pos]; + StaticallyIndexedArray gather_offsets; static_for<0, AMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[token_pos + m0]; index_t token_offset = fused_token & 0xffffff; @@ -2083,8 +2082,7 @@ struct GridwiseMoeGemm const float* p_sorted_weights_0 = p_ds_grid[I0]; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS - StaticallyIndexedArray - scatter_offsets; //= p_sorted_token_ids[c_token_pos]; + StaticallyIndexedArray scatter_offsets; StaticallyIndexedArray scatter_weights; //= for topk auto dstidx = sfc_cde_block.GetIndex(access_id);