diff --git a/CHANGELOG.md b/CHANGELOG.md index 997fb8bb8c..a69ce2260e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. * Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32". * Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines. +* Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline ### Changed @@ -36,6 +37,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added pooling kernel in CK_TILE * Added top-k sigmoid kernel in CK_TILE * Added the blockscale 2D support for CK_TILE GEMM. +* Added Flatmm pipeline for microscaling (MX) FP8/FP4 data types ### Changed diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp index 0134465347..d6c84f3064 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -148,7 +148,7 @@ auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("m", "32", "m dimension") - .insert("n", "128", "n dimension") + .insert("n", "512", "n dimension") .insert("k", "256", "k dimension") .insert("a_layout", "R", "A tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Row by default") @@ -308,6 +308,28 @@ int run_mx_flatmm_example(int argc, char* argv[]) else throw std::runtime_error("Only support non-persistent kernel now!"); } + else if(mx_prec == "fp8xfp4") + { + if(persistent_opt == 0) + return run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + else + throw std::runtime_error("Only support non-persistent kernel now!"); + } + else if(mx_prec == "fp4xfp8") + { + if(persistent_opt == 0) + return run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + else + throw std::runtime_error("Only support non-persistent kernel now!"); + } else { throw std::runtime_error("Unsupported data_type!"); diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp index e374a4ddd3..0b6185590f 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp @@ -76,6 +76,69 @@ struct MXfp8_FlatmmConfig16 static constexpr bool TiledMMAPermuteN = false; }; +struct MXf8f4_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr int TileParitionerGroupNum = 8; + static constexpr int TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = false; +}; +struct MXf4f8_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr int TileParitionerGroupNum = 8; + static constexpr int TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = false; +}; + template struct MXFlatmmPipelineProblem : FlatmmPipelineProblem= DsReadPreload) ? DsReadPreload @@ -470,11 +470,6 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1(); - } - template CK_TILE_DEVICE auto operator()(Args&&... args) const { @@ -684,7 +679,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 a_warp_tensor; // preload A00,A10... from lds - s_waitcnt_barrier(); + s_waitcnt_barrier(); static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MXdlPack; constexpr auto kIter = loadIter / MXdlPack; diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 4d76ab7da2..e188ddec61 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -7,6 +7,8 @@ namespace ck_tile { +namespace detail { +template struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy { static constexpr auto I0 = number<0>{}; @@ -14,27 +16,47 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy static constexpr auto I2 = number<2>{}; static constexpr index_t kDramLoadPackBytes = 128; + static constexpr index_t DWORDx4 = 16; static constexpr int MXdlPack = 2; static constexpr int NXdlPack = 2; static constexpr int KXdlPack = 2; - template - static inline constexpr auto wg_attr_num_access = - std::is_same_v, pk_fp4_t> - ? WGAttrNumAccessEnum::Single - : WGAttrNumAccessEnum::Double; + private: + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + static constexpr index_t APackedSize = numeric_traits::PackedSize; + static constexpr index_t BPackedSize = numeric_traits::PackedSize; + + using ALayout = remove_cvref_t; + static_assert(std::is_same_v); + + using TileShape = typename Problem::BlockGemmShape; + using BlockWarps = typename TileShape::BlockWarps; + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t WaveSize = get_warp_size(); + static constexpr index_t WaveNum = BlockSize / WaveSize; + + static constexpr index_t MPerBlock = TileShape::kM; + static constexpr index_t NPerBlock = TileShape::kN; + static constexpr index_t KPerBlock = TileShape::kK; + static constexpr index_t MWarps = BlockWarps::at(I0); + static constexpr index_t NWarps = BlockWarps::at(I1); + static_assert(WaveNum == MWarps * NWarps, "Block warps do not match block size"); + + static constexpr index_t MPerXdl = TileShape::WarpTile::at(I0); + static constexpr index_t NPerXdl = TileShape::WarpTile::at(I1); + static constexpr index_t KPerXdl = TileShape::WarpTile::at(I2); + static_assert(MPerXdl == 16 && NPerXdl == 16); + static constexpr index_t K_Lane = get_warp_size() / 16; // 4 + static constexpr index_t K_Thread = KPerXdl / K_Lane; // 32 + + public: + static constexpr index_t AK1 = DWORDx4 * APackedSize; + static constexpr index_t BK1 = DWORDx4 * BPackedSize; - template CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm() { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - static_assert( - sizeof(ADataType) * numeric_traits::PackedSize == - sizeof(BDataType) * numeric_traits::PackedSize, - "sizeof(ADataType) / APackedSize must be equal to sizeof(BDataType) / BPackedSize!"); - using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpGemm = WarpGemmDispatcher< // ADataType, @@ -43,10 +65,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy WarpTile::at(I0), WarpTile::at(I1), WarpTile::at(I2), - Problem::TransposeC, - false, - false, - wg_attr_num_access>; + Problem::TransposeC>; using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< // ADataType, BDataType, @@ -56,28 +75,20 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy return BlockFlatmmASmemBSmemCRegV1{}; } - template + template CK_TILE_DEVICE static constexpr auto MakeMX_AAsyncLoadDramDescriptor(const TensorView& naive_view) { - using ADataType = remove_cvref_t; - using ALayout = remove_cvref_t; - constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); - constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); - static_assert(MPerXdl == 16 && NPerXdl == 16); - static_assert(std::is_same_v); - const auto& naive_desc = naive_view.get_tensor_descriptor(); constexpr auto ndims = remove_cvref_t::get_num_of_dimension(); static_assert(ndims == 2, "only support 2D tensor"); const auto rows = naive_desc.get_length(number<0>{}); const auto cols = naive_desc.get_length(number<1>{}); - constexpr index_t APackedSize = numeric_traits::PackedSize; - constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 - constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 - const index_t K0 = cols / (K1 * K2); - const auto col_lens = make_tuple(K0, number{}, number{}); + constexpr index_t K2 = AK1; // f4=32; f8=16 + constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8 + const index_t K0 = cols / (K1 * K2); + const auto col_lens = make_tuple(K0, number{}, number{}); constexpr index_t M1 = 4; // so that we can use imm offset to load lds const index_t M0 = rows / M1; @@ -106,25 +117,14 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy TensorView::DstInMemOp>{naive_view.buf_, desc}; } - template CK_TILE_DEVICE static constexpr auto MakeMX_ADramTileDistribution() { - - using ADataType = remove_cvref_t; - using ALayout = remove_cvref_t; - static_assert(std::is_same_v); - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t APackedSize = numeric_traits::PackedSize; - - constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 + constexpr index_t K2 = AK1; // f4=32; f8=16 constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 - constexpr index_t M2 = get_warp_size() / K1; // 8 - constexpr index_t M1 = BlockSize / get_warp_size(); // 4 + constexpr index_t M2 = WaveSize / K1; // 8 + constexpr index_t M1 = BlockSize / WaveSize; // 4 constexpr index_t M0 = MPerBlock / (M2 * M1); static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!"); static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); @@ -139,28 +139,16 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<0, 0, 2>>{}); } - template CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBlockDescriptor() { - using ADataType = remove_cvref_t; - using ALayout = remove_cvref_t; - constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); - constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); - static_assert(MPerXdl == 16 && NPerXdl == 16); - static_assert(std::is_same_v); - - /*reduce transform layers,compare with old ck*/ - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t APackedSize = numeric_traits::PackedSize; - constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 - constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 - constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 + constexpr index_t K2 = AK1; // f4=32; f8=16 + constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8 + constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); - constexpr index_t M3 = 4; // so that we can use imm offset to load lds - constexpr index_t M2 = get_warp_size() / K1 / M3; // 2 - constexpr index_t M1 = MPerXdl / (M2 * M3); // 2 + constexpr index_t M3 = 4; // so that we can use imm offset to load lds + constexpr index_t M2 = WaveSize / K1 / M3; // 2 + constexpr index_t M1 = MPerXdl / (M2 * M3); // 2 constexpr index_t M0 = MPerBlock / (M1 * M2 * M3); // MPerBlock/16 static_assert(M0 * M1 * M2 * M3 == MPerBlock, "M0, M1, M2, M3 must cover whole MPerBlock!"); @@ -168,14 +156,14 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( // make_tuple(number{}, - number{}, number{}, + number{}, number{}, number{}, number{}, number{}), - make_tuple(number{}, - number{}, + make_tuple(number{}, + number{}, number{}, number{}, number{}, @@ -187,8 +175,8 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy constexpr auto a_lds_block_desc_1 = transform_tensor_descriptor( a_lds_block_desc_0, make_tuple(make_pass_through_transform(M0), - make_pass_through_transform(M1), make_pass_through_transform(K0), + make_pass_through_transform(M1), make_pass_through_transform(M2), make_xor_transform(make_tuple(number{}, number{})), make_pass_through_transform(number{})), @@ -210,103 +198,71 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy make_tuple(number{}, number{}, number{}, number{})), make_merge_transform_v3_division_mod( make_tuple(number{}, number{}, number{}))), - make_tuple(sequence<0, 1, 3, 4>{}, sequence<2, 5, 6>{}), + make_tuple(sequence<0, 2, 3, 4>{}, sequence<1, 5, 6>{}), make_tuple(sequence<0>{}, sequence<1>{})); // return a_lds_block_desc_permuted; return a_lds_block_desc; } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDS_TileDistribution() { - using TileShape = typename Problem::BlockGemmShape; + static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1"); - static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16"); - static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1"); - - constexpr int M_warps = TileShape::BlockWarps::at(number<0>{}); - constexpr int N_warps = TileShape::BlockWarps::at(number<1>{}); - constexpr int M_Lane = TileShape::WarpTile::at(I0); // 16 - - constexpr int K_Lane = 64 / M_Lane; // 4 - - constexpr int K_Thread = TileShape::WarpTile::at(I2) / K_Lane; // 32 - constexpr index_t num_access_v = static_cast(wg_attr_num_access); - constexpr int K1 = K_Thread / num_access_v; // 16 - - return make_static_tile_distribution( - std::conditional_t< - num_access_v == 1, - tile_distribution_encoding< - sequence, - tuple, sequence>, + if constexpr(K_Thread == AK1) + return make_static_tile_distribution( + tile_distribution_encoding< // + sequence, + tuple, sequence>, tuple, sequence<2, 1>>, tuple, sequence<0, 2>>, sequence<2>, - sequence<1>>, - tile_distribution_encoding< // - sequence, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, - sequence<2, 2>, - sequence<0, 2>>>{}); + sequence<1>>{}); + else + return make_static_tile_distribution(tile_distribution_encoding< // + sequence, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 2>>, + sequence<2, 2>, + sequence<0, 2>>{}); } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; - using BDataType = remove_cvref_t; - constexpr index_t BPack = numeric_traits::PackedSize; - - static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16"); - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t WaveSize = get_warp_size(); - constexpr index_t WaveNum = BlockSize / WaveSize; - constexpr index_t K1 = WaveSize; // threads cnt in K dim constexpr index_t KWavePerBlk = 1; constexpr index_t K0 = KWavePerBlk; - constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp + constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; - constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; - constexpr index_t kKPerThread = 32; - constexpr index_t num_access_v = static_cast(wg_attr_num_access); - constexpr index_t K2 = kKPerThread / num_access_v; - - return make_static_tile_distribution( - std::conditional_t< // - num_access_v == 1, + if constexpr(BK1 == K_Thread) + return make_static_tile_distribution( tile_distribution_encoding< // sequence, - tuple, // 4 2 - sequence>, // 1 64 32 + tuple, // 4 2 + sequence>, // 1 64 32 tuple, sequence<2>>, tuple, sequence<1>>, sequence<2>, - sequence<2>>, + sequence<2>>{}); + else + return make_static_tile_distribution( tile_distribution_encoding< // sequence, - tuple, // 4 2 - sequence>, // 2 1 64 16 + tuple, // 4 2 + sequence>, // 2 1 64 16 tuple, sequence<2>>, tuple, sequence<2>>, sequence<2, 2>, - sequence<0, 3>>>{}); + sequence<0, 3>>{}); } - template + template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramWindow(const WindowTmp& window_tmp) { - - using BDataType = remove_cvref_t; - constexpr auto BPackedSize = numeric_traits::PackedSize; - constexpr auto kKPerBlock = Problem::BlockGemmShape::kK; constexpr auto M_Warp_Tile = Problem::BlockGemmShape::WarpTile::at(I1); constexpr auto flatNPerWarp = Problem::BlockGemmShape::flatNPerWarp; constexpr auto flatKPerWarp = Problem::BlockGemmShape::flatKPerWarp; @@ -314,7 +270,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy static_assert(std::decay_t::get_num_of_dimension() == 2); auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view(); const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); - constexpr auto flat_k_per_block = kKPerBlock * M_Warp_Tile; + constexpr auto flat_k_per_block = KPerBlock * M_Warp_Tile; auto&& byte_tensor_desc = transform_tensor_descriptor( make_naive_tensor_descriptor_packed(make_tuple( flat_n, flat_k / flat_k_per_block, number{})), @@ -331,39 +287,25 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy byte_tensor_view, make_tuple(number{}, number{}), {origin_tmp[0], origin_tmp[1] / BPackedSize}, - MakeMX_BFlatBytesDramTileDistribution()); + MakeMX_BFlatBytesDramTileDistribution()); } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t WaveSize = get_warp_size(); - constexpr index_t WaveNum = BlockSize / WaveSize; - - constexpr index_t kMPerBlock = TileShape::BlockTile::at(I0); - - constexpr index_t M_Warps = TileShape::BlockWarps::at(I0); - constexpr index_t N_Warps = TileShape::BlockWarps::at(I1); - - static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size"); - constexpr index_t M_Lanes = TileShape::WarpTile::at(I0); constexpr index_t K_Lanes = 64 / M_Lanes; // Y dimension (M) decomposition constexpr index_t Y2 = M_Lanes; - constexpr index_t Y1 = M_Warps; - constexpr index_t Y0 = kMPerBlock / (MXdlPack * Y1 * Y2); + constexpr index_t Y1 = MWarps; + constexpr index_t Y0 = MPerBlock / (MXdlPack * Y1 * Y2); // X dimension (K) decomposition constexpr index_t X0 = K_Lanes; constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load return make_static_tile_distribution( - tile_distribution_encoding, // repeat N_warps + tile_distribution_encoding, // repeat NWarps tuple, sequence>, tuple, sequence<2, 1>>, tuple, sequence<0, 2>>, @@ -371,36 +313,22 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<0, 1>>{}); } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t WaveSize = get_warp_size(); - constexpr index_t WaveNum = BlockSize / WaveSize; - - constexpr index_t kNPerBlock = TileShape::BlockTile::at(I1); - - constexpr index_t M_Warps = TileShape::BlockWarps::at(I0); - constexpr index_t N_Warps = TileShape::BlockWarps::at(I1); - - static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size"); - constexpr index_t N_Lanes = TileShape::WarpTile::at(I1); constexpr index_t K_Lanes = 64 / N_Lanes; // Y dimension (M) decomposition constexpr index_t Y2 = N_Lanes; - constexpr index_t Y1 = N_Warps; - constexpr index_t Y0 = kNPerBlock / (NXdlPack * Y1 * Y2); + constexpr index_t Y1 = NWarps; + constexpr index_t Y0 = NPerBlock / (NXdlPack * Y1 * Y2); // X dimension (K) decomposition constexpr index_t X0 = K_Lanes; constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load return make_static_tile_distribution( - tile_distribution_encoding, // ? + tile_distribution_encoding, // ? tuple, sequence>, tuple, sequence<2, 1>>, tuple, sequence<0, 2>>, @@ -408,20 +336,11 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<0, 1>>{}); } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_FlatDramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; - - constexpr index_t M_Warp = TileShape::BlockWarps::at(number<0>{}); - constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I0); - constexpr index_t M_Lane = TileShape::WarpTile::at(I0); - constexpr index_t N_Wrap = TileShape::BlockWarps::at(number<1>{}); - constexpr index_t MWavePerBlk = M_Warp; - return make_static_tile_distribution( - tile_distribution_encoding, // ? - tuple, // second direction + tile_distribution_encoding, // ? + tuple, // second direction sequence>, // first direction tuple, sequence<2, 1>>, // which direction tuple, sequence<0, 1>>, // which index @@ -430,20 +349,11 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<1>>{}); } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_FlatDramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; - - constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{}); - constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1); - constexpr index_t N_Lane = TileShape::WarpTile::at(I1); - constexpr index_t M_Wrap = TileShape::BlockWarps::at(number<0>{}); - constexpr index_t NWavePerBlk = N_Warp; - return make_static_tile_distribution( - tile_distribution_encoding, // ? - tuple, // second direction + tile_distribution_encoding, // ? + tuple, // second direction sequence>, // first direction tuple, sequence<2, 1>>, // which direction tuple, sequence<0, 1>>, // which index @@ -452,20 +362,41 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<1>>{}); } - template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() { - using ADataType = remove_cvref_t; - constexpr index_t APackedSize = numeric_traits::PackedSize; - return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor().get_element_space_size() / + return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor().get_element_space_size() / APackedSize; } - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - return GetSmemSizeA(); + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return GetSmemSizeA(); } +}; +} // namespace detail + +struct MXFlatmmPipelineAgBgCrPolicy +{ + +#define FORWARD_METHOD_(method) \ + template \ + CK_TILE_HOST_DEVICE static constexpr auto method(Args&&... args) \ + { \ + return detail::MXFlatmmPipelineAgBgCrPolicy::method(std::forward(args)...); \ } + + FORWARD_METHOD_(GetBlockFlatmm); + FORWARD_METHOD_(MakeMX_AAsyncLoadDramDescriptor); + FORWARD_METHOD_(MakeMX_ADramTileDistribution); + FORWARD_METHOD_(MakeMX_ALdsBlockDescriptor); + FORWARD_METHOD_(MakeMX_ALDS_TileDistribution); + FORWARD_METHOD_(MakeMX_BFlatBytesDramTileDistribution); + FORWARD_METHOD_(MakeMX_BFlatBytesDramWindow); + FORWARD_METHOD_(MakeMX_ScaleA_DramTileDistribution); + FORWARD_METHOD_(MakeMX_ScaleB_DramTileDistribution); + FORWARD_METHOD_(MakeMX_ScaleA_FlatDramTileDistribution); + FORWARD_METHOD_(MakeMX_ScaleB_FlatDramTileDistribution); + FORWARD_METHOD_(GetSmemSizeA); + FORWARD_METHOD_(GetSmemSize); + +#undef FORWARD_METHOD_ }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index a2c320f3e6..44a09423ee 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -306,10 +306,9 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl, 2>>; -template -using WarpGemmMfma_f32_16x16x128_fp4 = WarpGemmImpl< - WarpGemmAttributeMfma, - AttrNumAccess>>; +template +using WarpGemmMfma_f32_16x16x128_f8f6f4 = WarpGemmImpl< + WarpGemmAttributeMfma, AttrNumAccess>>; template using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl< // WarpGemmAttributeMfma, diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 9d928a7cfa..82c6e43834 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -116,15 +116,12 @@ template<> struct Dispatcher { using Ty template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; // scale mfma based f8f6f4 -template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8; }; -template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8; }; -template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; }; -template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; }; +template +struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_f8f6f4; }; template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed; }; template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed; }; template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed; }; template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed; }; -template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp4<>; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; };