diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp index efdee3d18c..57ed605a22 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp @@ -196,13 +196,22 @@ struct FusedMoeGemmPipelineGeneralPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O() { + using S_ = typename Problem::BlockShape; + constexpr int M_Thread_Num = 16; + constexpr int M_Rep = S_::Warp_M1 / M_Thread_Num; + static_assert(M_Rep <= 2); + + constexpr int N_Thread_Num = 4; + constexpr int NPerThread = S_::Warp_N1 / N_Thread_Num; + return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence<4, 8>>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<1, 1>>{}); + tile_distribution_encoding< + sequence<4>, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); } template