From dc963eb3594ade13d88906eb7db00877429a3214 Mon Sep 17 00:00:00 2001 From: "Ding, Yi" Date: Wed, 10 Dec 2025 06:14:21 +0000 Subject: [PATCH] Refactor policy --- .../ops/flatmm/kernel/mx_flatmm_kernel.hpp | 7 +- ...mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 5 - ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 261 +++++++----------- 3 files changed, 106 insertions(+), 167 deletions(-) diff --git a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp index 1133da33ad..799f8f26a9 100644 --- a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp @@ -414,12 +414,7 @@ struct MXFlatmmKernel : FlatmmKernel(); - } - template CK_TILE_DEVICE auto operator()(Args&&... args) const { diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 4d76ab7da2..9a829a4bdf 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -7,6 +7,8 @@ namespace ck_tile { +namespace detail { +template struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy { static constexpr auto I0 = number<0>{}; @@ -19,22 +21,44 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy static constexpr int NXdlPack = 2; static constexpr int KXdlPack = 2; - template + private: + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + static constexpr index_t APackedSize = numeric_traits::PackedSize; + static constexpr index_t BPackedSize = numeric_traits::PackedSize; + static_assert( + sizeof(ADataType) * numeric_traits::PackedSize == + sizeof(BDataType) * numeric_traits::PackedSize, + "sizeof(ADataType) / APackedSize must be equal to sizeof(BDataType) / BPackedSize!"); + + using ALayout = remove_cvref_t; + static_assert(std::is_same_v); + + using TileShape = typename Problem::BlockGemmShape; + using BlockWarps = typename TileShape::BlockWarps; + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t WaveSize = get_warp_size(); + static constexpr index_t WaveNum = BlockSize / WaveSize; + + static constexpr index_t MPerBlock = TileShape::kM; + static constexpr index_t NPerBlock = TileShape::kN; + static constexpr index_t KPerBlock = TileShape::kK; + static constexpr index_t MWarps = BlockWarps::at(I0); + static constexpr index_t NWarps = BlockWarps::at(I1); + static_assert(WaveNum == MWarps * NWarps, "Block warps do not match block size"); + + static constexpr index_t MPerXdl = TileShape::WarpTile::at(I0); + static constexpr index_t NPerXdl = TileShape::WarpTile::at(I1); + static_assert(MPerXdl == 16 && NPerXdl == 16); + static inline constexpr auto wg_attr_num_access = std::is_same_v, pk_fp4_t> ? WGAttrNumAccessEnum::Single : WGAttrNumAccessEnum::Double; - template + public: CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm() { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - static_assert( - sizeof(ADataType) * numeric_traits::PackedSize == - sizeof(BDataType) * numeric_traits::PackedSize, - "sizeof(ADataType) / APackedSize must be equal to sizeof(BDataType) / BPackedSize!"); - using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpGemm = WarpGemmDispatcher< // ADataType, @@ -46,7 +70,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy Problem::TransposeC, false, false, - wg_attr_num_access>; + wg_attr_num_access>; using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< // ADataType, BDataType, @@ -56,28 +80,20 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy return BlockFlatmmASmemBSmemCRegV1{}; } - template + template CK_TILE_DEVICE static constexpr auto MakeMX_AAsyncLoadDramDescriptor(const TensorView& naive_view) { - using ADataType = remove_cvref_t; - using ALayout = remove_cvref_t; - constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); - constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); - static_assert(MPerXdl == 16 && NPerXdl == 16); - static_assert(std::is_same_v); - const auto& naive_desc = naive_view.get_tensor_descriptor(); constexpr auto ndims = remove_cvref_t::get_num_of_dimension(); static_assert(ndims == 2, "only support 2D tensor"); const auto rows = naive_desc.get_length(number<0>{}); const auto cols = naive_desc.get_length(number<1>{}); - constexpr index_t APackedSize = numeric_traits::PackedSize; - constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 - constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 - const index_t K0 = cols / (K1 * K2); - const auto col_lens = make_tuple(K0, number{}, number{}); + constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 + constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 + const index_t K0 = cols / (K1 * K2); + const auto col_lens = make_tuple(K0, number{}, number{}); constexpr index_t M1 = 4; // so that we can use imm offset to load lds const index_t M0 = rows / M1; @@ -106,25 +122,14 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy TensorView::DstInMemOp>{naive_view.buf_, desc}; } - template CK_TILE_DEVICE static constexpr auto MakeMX_ADramTileDistribution() { - - using ADataType = remove_cvref_t; - using ALayout = remove_cvref_t; - static_assert(std::is_same_v); - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t APackedSize = numeric_traits::PackedSize; - constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 - constexpr index_t M2 = get_warp_size() / K1; // 8 - constexpr index_t M1 = BlockSize / get_warp_size(); // 4 + constexpr index_t M2 = WaveSize / K1; // 8 + constexpr index_t M1 = BlockSize / WaveSize; // 4 constexpr index_t M0 = MPerBlock / (M2 * M1); static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!"); static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); @@ -139,28 +144,17 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<0, 0, 2>>{}); } - template CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBlockDescriptor() { - using ADataType = remove_cvref_t; - using ALayout = remove_cvref_t; - constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); - constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); - static_assert(MPerXdl == 16 && NPerXdl == 16); - static_assert(std::is_same_v); - /*reduce transform layers,compare with old ck*/ - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t APackedSize = numeric_traits::PackedSize; - constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 - constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 - constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 + constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 + constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 + constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); - constexpr index_t M3 = 4; // so that we can use imm offset to load lds - constexpr index_t M2 = get_warp_size() / K1 / M3; // 2 - constexpr index_t M1 = MPerXdl / (M2 * M3); // 2 + constexpr index_t M3 = 4; // so that we can use imm offset to load lds + constexpr index_t M2 = WaveSize / K1 / M3; // 2 + constexpr index_t M1 = MPerXdl / (M2 * M3); // 2 constexpr index_t M0 = MPerBlock / (M1 * M2 * M3); // MPerBlock/16 static_assert(M0 * M1 * M2 * M3 == MPerBlock, "M0, M1, M2, M3 must cover whole MPerBlock!"); @@ -217,65 +211,45 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy return a_lds_block_desc; } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDS_TileDistribution() { - using TileShape = typename Problem::BlockGemmShape; - - static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16"); - static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1"); - - constexpr int M_warps = TileShape::BlockWarps::at(number<0>{}); - constexpr int N_warps = TileShape::BlockWarps::at(number<1>{}); - constexpr int M_Lane = TileShape::WarpTile::at(I0); // 16 + static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1"); + constexpr int M_Lane = TileShape::WarpTile::at(I0); // 16 constexpr int K_Lane = 64 / M_Lane; // 4 constexpr int K_Thread = TileShape::WarpTile::at(I2) / K_Lane; // 32 - constexpr index_t num_access_v = static_cast(wg_attr_num_access); + constexpr index_t num_access_v = static_cast(wg_attr_num_access); constexpr int K1 = K_Thread / num_access_v; // 16 return make_static_tile_distribution( std::conditional_t< num_access_v == 1, tile_distribution_encoding< - sequence, - tuple, sequence>, + sequence, + tuple, sequence>, tuple, sequence<2, 1>>, tuple, sequence<0, 2>>, sequence<2>, sequence<1>>, tile_distribution_encoding< // - sequence, - tuple, sequence>, + sequence, + tuple, sequence>, tuple, sequence<2, 1>>, tuple, sequence<1, 2>>, sequence<2, 2>, sequence<0, 2>>>{}); } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; - using BDataType = remove_cvref_t; - constexpr index_t BPack = numeric_traits::PackedSize; - - static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16"); - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t WaveSize = get_warp_size(); - constexpr index_t WaveNum = BlockSize / WaveSize; - constexpr index_t K1 = WaveSize; // threads cnt in K dim constexpr index_t KWavePerBlk = 1; constexpr index_t K0 = KWavePerBlk; - constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp - constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; constexpr index_t kKPerThread = 32; - constexpr index_t num_access_v = static_cast(wg_attr_num_access); + constexpr index_t num_access_v = static_cast(wg_attr_num_access); constexpr index_t K2 = kKPerThread / num_access_v; return make_static_tile_distribution( @@ -283,30 +257,26 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy num_access_v == 1, tile_distribution_encoding< // sequence, - tuple, // 4 2 - sequence>, // 1 64 32 + tuple, // 4 2 + sequence>, // 1 64 32 tuple, sequence<2>>, tuple, sequence<1>>, sequence<2>, sequence<2>>, tile_distribution_encoding< // sequence, - tuple, // 4 2 - sequence>, // 2 1 64 16 + tuple, // 4 2 + sequence>, // 2 1 64 16 tuple, sequence<2>>, tuple, sequence<2>>, sequence<2, 2>, sequence<0, 3>>>{}); } - template + template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramWindow(const WindowTmp& window_tmp) { - - using BDataType = remove_cvref_t; - constexpr auto BPackedSize = numeric_traits::PackedSize; - constexpr auto kKPerBlock = Problem::BlockGemmShape::kK; constexpr auto M_Warp_Tile = Problem::BlockGemmShape::WarpTile::at(I1); constexpr auto flatNPerWarp = Problem::BlockGemmShape::flatNPerWarp; constexpr auto flatKPerWarp = Problem::BlockGemmShape::flatKPerWarp; @@ -314,7 +284,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy static_assert(std::decay_t::get_num_of_dimension() == 2); auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view(); const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); - constexpr auto flat_k_per_block = kKPerBlock * M_Warp_Tile; + constexpr auto flat_k_per_block = KPerBlock * M_Warp_Tile; auto&& byte_tensor_desc = transform_tensor_descriptor( make_naive_tensor_descriptor_packed(make_tuple( flat_n, flat_k / flat_k_per_block, number{})), @@ -331,39 +301,25 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy byte_tensor_view, make_tuple(number{}, number{}), {origin_tmp[0], origin_tmp[1] / BPackedSize}, - MakeMX_BFlatBytesDramTileDistribution()); + MakeMX_BFlatBytesDramTileDistribution()); } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t WaveSize = get_warp_size(); - constexpr index_t WaveNum = BlockSize / WaveSize; - - constexpr index_t kMPerBlock = TileShape::BlockTile::at(I0); - - constexpr index_t M_Warps = TileShape::BlockWarps::at(I0); - constexpr index_t N_Warps = TileShape::BlockWarps::at(I1); - - static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size"); - constexpr index_t M_Lanes = TileShape::WarpTile::at(I0); constexpr index_t K_Lanes = 64 / M_Lanes; // Y dimension (M) decomposition constexpr index_t Y2 = M_Lanes; - constexpr index_t Y1 = M_Warps; - constexpr index_t Y0 = kMPerBlock / (MXdlPack * Y1 * Y2); + constexpr index_t Y1 = MWarps; + constexpr index_t Y0 = MPerBlock / (MXdlPack * Y1 * Y2); // X dimension (K) decomposition constexpr index_t X0 = K_Lanes; constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load return make_static_tile_distribution( - tile_distribution_encoding, // repeat N_warps + tile_distribution_encoding, // repeat NWarps tuple, sequence>, tuple, sequence<2, 1>>, tuple, sequence<0, 2>>, @@ -371,36 +327,22 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<0, 1>>{}); } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t WaveSize = get_warp_size(); - constexpr index_t WaveNum = BlockSize / WaveSize; - - constexpr index_t kNPerBlock = TileShape::BlockTile::at(I1); - - constexpr index_t M_Warps = TileShape::BlockWarps::at(I0); - constexpr index_t N_Warps = TileShape::BlockWarps::at(I1); - - static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size"); - constexpr index_t N_Lanes = TileShape::WarpTile::at(I1); constexpr index_t K_Lanes = 64 / N_Lanes; // Y dimension (M) decomposition constexpr index_t Y2 = N_Lanes; - constexpr index_t Y1 = N_Warps; - constexpr index_t Y0 = kNPerBlock / (NXdlPack * Y1 * Y2); + constexpr index_t Y1 = NWarps; + constexpr index_t Y0 = NPerBlock / (NXdlPack * Y1 * Y2); // X dimension (K) decomposition constexpr index_t X0 = K_Lanes; constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load return make_static_tile_distribution( - tile_distribution_encoding, // ? + tile_distribution_encoding, // ? tuple, sequence>, tuple, sequence<2, 1>>, tuple, sequence<0, 2>>, @@ -408,20 +350,13 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<0, 1>>{}); } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_FlatDramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; - - constexpr index_t M_Warp = TileShape::BlockWarps::at(number<0>{}); - constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I0); - constexpr index_t M_Lane = TileShape::WarpTile::at(I0); - constexpr index_t N_Wrap = TileShape::BlockWarps::at(number<1>{}); - constexpr index_t MWavePerBlk = M_Warp; - + constexpr index_t K_Lane = 64 / MPerXdl; + constexpr index_t M_Lane = MPerXdl; return make_static_tile_distribution( - tile_distribution_encoding, // ? - tuple, // second direction + tile_distribution_encoding, // ? + tuple, // second direction sequence>, // first direction tuple, sequence<2, 1>>, // which direction tuple, sequence<0, 1>>, // which index @@ -430,20 +365,13 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<1>>{}); } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_FlatDramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; - - constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{}); - constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1); - constexpr index_t N_Lane = TileShape::WarpTile::at(I1); - constexpr index_t M_Wrap = TileShape::BlockWarps::at(number<0>{}); - constexpr index_t NWavePerBlk = N_Warp; - + constexpr index_t K_Lane = 64 / NPerXdl; + constexpr index_t N_Lane = NPerXdl; return make_static_tile_distribution( - tile_distribution_encoding, // ? - tuple, // second direction + tile_distribution_encoding, // ? + tuple, // second direction sequence>, // first direction tuple, sequence<2, 1>>, // which direction tuple, sequence<0, 1>>, // which index @@ -452,20 +380,41 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<1>>{}); } - template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() { - using ADataType = remove_cvref_t; - constexpr index_t APackedSize = numeric_traits::PackedSize; - return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor().get_element_space_size() / + return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor().get_element_space_size() / APackedSize; } - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - return GetSmemSizeA(); + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return GetSmemSizeA(); } +}; +} // namespace detail + +struct MXFlatmmPipelineAgBgCrPolicy +{ + +#define FORWARD_METHOD_(method) \ + template \ + CK_TILE_HOST_DEVICE static constexpr auto method(Args&&... args) \ + { \ + return detail::MXFlatmmPipelineAgBgCrPolicy::method(std::forward(args)...); \ } + + FORWARD_METHOD_(GetBlockFlatmm); + FORWARD_METHOD_(MakeMX_AAsyncLoadDramDescriptor); + FORWARD_METHOD_(MakeMX_ADramTileDistribution); + FORWARD_METHOD_(MakeMX_ALdsBlockDescriptor); + FORWARD_METHOD_(MakeMX_ALDS_TileDistribution); + FORWARD_METHOD_(MakeMX_BFlatBytesDramTileDistribution); + FORWARD_METHOD_(MakeMX_BFlatBytesDramWindow); + FORWARD_METHOD_(MakeMX_ScaleA_DramTileDistribution); + FORWARD_METHOD_(MakeMX_ScaleB_DramTileDistribution); + FORWARD_METHOD_(MakeMX_ScaleA_FlatDramTileDistribution); + FORWARD_METHOD_(MakeMX_ScaleB_FlatDramTileDistribution); + FORWARD_METHOD_(GetSmemSizeA); + FORWARD_METHOD_(GetSmemSize); + +#undef FORWARD_METHOD_ }; } // namespace ck_tile