diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp index d5db3199ca..bd9c358b57 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp @@ -140,14 +140,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu // clang-format off < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, - 32, 128, 128, + 128, 128, 128, 16, 16, 32, 32, - 1, 1, + 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v2, FP8>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // clang-format on int main(int argc, char* argv[]) diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index 04f10b53ae..12a6f4bcb2 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -185,8 +185,8 @@ int main(int argc, char* argv[]) ck::index_t N = 4096; ck::index_t K = 4096; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 6; - ck::index_t valid_tile_num = 6; + ck::index_t sorted_tile_num = 18; + ck::index_t valid_tile_num = 16; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; ck::index_t tokens = 128; @@ -237,10 +237,10 @@ int main(int argc, char* argv[]) // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; - int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + // int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} for(int i = 0; i < sorted_tile_num; i++) { - expert_ids.mData[i] = eids[i]; + expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts); } if(tokens * topk > valid_size) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index a2d1114bbe..fba46d4ac6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -186,7 +186,9 @@ struct GridwiseMoeGemm math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); static constexpr index_t KLane = mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); - static constexpr index_t KRepeat = KPerBlock / KLane / KPack; + + static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; + static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup); static constexpr index_t NLane = NPerXdl; static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; // static constexpr index_t NumTokens = 1; @@ -246,7 +248,7 @@ struct GridwiseMoeGemm } __host__ __device__ static auto CalculateBK0Shuffled(index_t K) { - return math::integer_divide_ceil(K, KLane * KPack); + return math::integer_divide_ceil(K, KLane * KPack / KGroup); } __host__ __device__ static auto CalculateKPadded(index_t K) @@ -388,7 +390,7 @@ struct GridwiseMoeGemm __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1297,7 +1299,7 @@ struct GridwiseMoeGemm make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1804,7 +1806,7 @@ struct GridwiseMoeGemm make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds