From 66efcf96038fb2a7d31bfdec4cdf35460962fc0c Mon Sep 17 00:00:00 2001 From: letaoqin Date: Wed, 27 Nov 2024 03:37:08 +0000 Subject: [PATCH] change g tile distribution --- .../fused_moegemm_pipeline_general.hpp | 3 +- .../fused_moegemm_pipeline_general_policy.hpp | 41 ++++++++++--------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp index 15f34f2283..edc2bdca09 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp @@ -126,7 +126,7 @@ struct FusedMoeGemmPipeline_General Policy::template MakeGlobalTileDistribution_G()); // Block GEMM - constexpr auto gemm_0 = Policy::template GetBlockGemm0(); + constexpr auto gemm_0 = Policy::template GetBlockGemm0(); using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); auto s_acc = SaccBlockTileType{}; @@ -138,7 +138,6 @@ struct FusedMoeGemmPipeline_General ignore = s_acc; store_tile(o_window_, a_dram_block); - #if 0 //check a matrix gather right or not constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans(); diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp index bb51fc6f28..d301f884af 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp @@ -17,7 +17,8 @@ namespace ck_tile { struct FusedMoeGemmPipelineGeneralPolicy { - static constexpr int kKIter = 2; + static constexpr int kKIter = 2; + static constexpr int kKPerBlock = 32; CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords() { @@ -197,14 +198,18 @@ struct FusedMoeGemmPipelineGeneralPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G() { - using S_ = typename Problem::BlockShape; + using S_ = typename Problem::BlockShape; + constexpr index_t K2 = S_::Warp_K0; + constexpr index_t K1 = get_warp_size() / S_::Warp_N0; + constexpr index_t K0 = kKPerBlock / (K1 * K2); + return make_static_tile_distribution( tile_distribution_encoding< sequence<1>, tuple, - sequence>, - tuple, sequence<1, 2>>, + sequence>, tuple, sequence<2, 1>>, + tuple, sequence<1, 2>>, sequence<1, 2, 2>, sequence<0, 0, 2>>{}); } @@ -212,23 +217,21 @@ struct FusedMoeGemmPipelineGeneralPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm0() { - using S_ = typename Problem::BlockShape; - using GemmProblem = - BlockGemmProblem>; + using S_ = typename Problem::BlockShape; + using GemmProblem = BlockGemmProblem>; constexpr auto warp_gemm = GetWarpGemm0(); - using BlockGemmPolicy = - BlockGemmASmemBRegCRegV1CustomPolicy; + using BlockGemmPolicy = BlockGemmASmemBRegCRegV1CustomPolicy; return BlockGemmASmemBRegCRegV1{}; }