From 1246e65f50ff6056baf57165ec0f323f1d22e49f Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Fri, 27 Feb 2026 16:20:41 -0700 Subject: [PATCH] [CK TILE] Refactor MX FLATMM example (#4821) Refactor the MX FLATMM example to support more pipelines across different architectures. This work facilitates the NPI team roadmap. --- .../ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp | 106 +++--------- .../ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp | 162 +----------------- .../mxgemm/mx_flatmm_arch_traits.hpp | 137 +++++++++++++++ .../18_flatmm/mxgemm/mx_flatmm_instance.cmake | 36 ++-- .../mxgemm/mx_flatmm_instance.cpp.in | 4 +- .../18_flatmm/mxgemm/mx_flatmm_instance.hpp | 10 +- .../18_flatmm/mxgemm/run_mx_flatmm.inc | 17 +- 7 files changed, 199 insertions(+), 273 deletions(-) create mode 100644 example/ck_tile/18_flatmm/mxgemm/mx_flatmm_arch_traits.hpp diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp index 1141717545..3d51fd9907 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -20,7 +20,7 @@ static constexpr inline auto is_row_major(Layout layout_) ck_tile::tensor_layout::gemm::RowMajor>>{}; } -template args = {a_dev_buf.GetDeviceBuffer(), b_shuffle_dev_buf.GetDeviceBuffer(), {}, @@ -99,7 +101,7 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf, constexpr auto has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_num_v = tail_num_.value; auto invoke_splitk_path = [&](auto split_k_) { - return mx_flatmm_calc +template auto preShuffleWeight(ck_tile::HostTensor& src) { auto src_lengths = src.get_lengths(); @@ -181,8 +183,8 @@ auto preShuffleWeight(ck_tile::HostTensor& src) constexpr int packed_size = ck_tile::numeric_traits::PackedSize; int KPack = std::is_same_v ? 32 : 16 * packed_size; // fp4/fp6:32 or fp8:16 - int NLane = N_Warp_Tile; - int KLane = 64 / NLane; + + int KLane = ck_tile::get_warp_size() / NLane; int K0 = K / (KLane * KPack); ck_tile::HostTensor shuffled(ck_tile::HostTensorDescriptor({N * K}, {1})); @@ -211,68 +213,10 @@ auto preShuffleWeight(ck_tile::HostTensor& src) return shuffled; } -template -auto preShuffleScale(ck_tile::HostTensor& src) -{ - auto src_lengths = src.get_lengths(); - const auto MN = KLast ? src_lengths[0] : src_lengths[1]; - const auto K = KLast ? src_lengths[1] : src_lengths[0]; - - size_t MNXdlPack = 2; - size_t KXdlPack = 2; - size_t XdlMNThread = FlatmmConfig::N_Warp_Tile; // 16 - size_t XdlKThread = 64 / XdlMNThread; - - const auto MN_Paded = ck_tile::integer_least_multiple(MN, XdlMNThread * MNXdlPack); - - ck_tile::HostTensor shuffled(ck_tile::HostTensorDescriptor({MN_Paded * K}, {1})); - - size_t K0 = K / KXdlPack / XdlKThread; // KRepeat - - // The 4 16x128 building blocks will be packed into 1 32x256 for F4 - // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 - - // unfold the MN32xK(256/32) scale buffer - // 4 16 2 2 - // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack - // Then, MNRepeat->KRepeat - - for(size_t n = 0; n < MN_Paded; ++n) - { - for(size_t k = 0; k < K; ++k) - { - auto n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat - auto tempn = n % (XdlMNThread * MNXdlPack); - auto n1 = tempn % XdlMNThread; // i XdlMNThread - auto n2 = tempn / XdlMNThread; // i MNXdlPack - - auto k0 = k / (XdlKThread * KXdlPack); // i KRepeat - auto tempk = k % (XdlKThread * KXdlPack); - auto k1 = tempk % XdlKThread; // i XdlKThread - auto k2 = tempk / XdlKThread; // i KXdlPack - - auto outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + - k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + - k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + - k2 * MNXdlPack + n2; - - if constexpr(KLast) - shuffled(outputIndex) = n < MN ? src(n, k) : dtype{}; - else - shuffled(outputIndex) = n < MN ? src(k, n) : dtype{}; - } - } - return shuffled; -} - #include "run_mx_flatmm.inc" -int run_mx_flatmm_example(int argc, char* argv[]) +int run_mx_flatmm_example(const ck_tile::ArgParser& arg_parser) { - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; @@ -281,6 +225,8 @@ int run_mx_flatmm_example(int argc, char* argv[]) std::string b_layout = arg_parser.get_str("b_layout"); int persistent_opt = arg_parser.get_int("persistent"); + std::cout << "Using default warptile of 16x16x128." << std::endl; + if(a_layout == "R" && b_layout == "C") { if(mx_prec == "fp4" || mx_prec == "fp4xfp4") @@ -289,8 +235,8 @@ int run_mx_flatmm_example(int argc, char* argv[]) return run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + MXFlatmm_GFX950_FP4FP4_Traits, + false>(arg_parser, Row{}, Col{}, Row{}); else throw std::runtime_error("Only non-persistent kernels are supported currently!"); } @@ -300,8 +246,8 @@ int run_mx_flatmm_example(int argc, char* argv[]) return run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + MXFlatmm_GFX950_FP6FP6_Traits, + false>(arg_parser, Row{}, Col{}, Row{}); else throw std::runtime_error("Only support non-persistent kernel now!"); } @@ -311,8 +257,8 @@ int run_mx_flatmm_example(int argc, char* argv[]) return run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + MXFlatmm_GFX950_FP8FP8_Traits, + false>(arg_parser, Row{}, Col{}, Row{}); else throw std::runtime_error("Only support non-persistent kernel now!"); } @@ -322,8 +268,8 @@ int run_mx_flatmm_example(int argc, char* argv[]) return run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + MXFlatmm_GFX950_FP8FP4_Traits, + false>(arg_parser, Row{}, Col{}, Row{}); else throw std::runtime_error("Only support non-persistent kernel now!"); } @@ -333,8 +279,8 @@ int run_mx_flatmm_example(int argc, char* argv[]) return run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + MXFlatmm_GFX950_FP4FP8_Traits, + false>(arg_parser, Row{}, Col{}, Row{}); else throw std::runtime_error("Only support non-persistent kernel now!"); } @@ -359,7 +305,7 @@ int main(int argc, char* argv[]) int warp_tile = arg_parser.get_int("warp_tile"); if(warp_tile == 0) { - return run_mx_flatmm_example(argc, argv); + return run_mx_flatmm_example(arg_parser); } else if(warp_tile == 1) { diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp index d4922bb44c..f3a9787b8e 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp @@ -11,167 +11,9 @@ #include "ck_tile/ops/flatmm.hpp" #include "ck_tile/ops/gemm.hpp" -// GEMM config with 16x16 warp tile -struct MXfp4_FlatmmConfig16 -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 512; - static constexpr ck_tile::index_t K_Tile = 256; +#include "mx_flatmm_arch_traits.hpp" - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 128; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool TransposeC = false; - static constexpr bool UseStructuredSparsity = false; - - static constexpr int kBlockPerCu = 1; - static constexpr int TileParitionerGroupNum = 8; - static constexpr int TileParitionerM01 = 4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool DoubleSmemBuffer = false; - - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = false; -}; - -struct MXfp6_FlatmmConfig16 -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 256; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 128; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool TransposeC = false; - static constexpr bool UseStructuredSparsity = false; - - static constexpr int kBlockPerCu = 1; - static constexpr int TileParitionerGroupNum = 8; - static constexpr int TileParitionerM01 = 4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool DoubleSmemBuffer = false; - - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = false; -}; - -struct MXfp8_FlatmmConfig16 -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 256; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 128; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool TransposeC = false; - static constexpr bool UseStructuredSparsity = false; - - static constexpr int kBlockPerCu = 1; - static constexpr int TileParitionerGroupNum = 8; - static constexpr int TileParitionerM01 = 4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool DoubleSmemBuffer = false; - - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = false; -}; - -struct MXf8f4_FlatmmConfig16 -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 256; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 128; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool TransposeC = false; - static constexpr bool UseStructuredSparsity = false; - - static constexpr int kBlockPerCu = 1; - static constexpr int TileParitionerGroupNum = 8; - static constexpr int TileParitionerM01 = 4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool DoubleSmemBuffer = false; - - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = false; -}; -struct MXf4f8_FlatmmConfig16 -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 256; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 128; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool TransposeC = false; - static constexpr bool UseStructuredSparsity = false; - - static constexpr int kBlockPerCu = 1; - static constexpr int TileParitionerGroupNum = 8; - static constexpr int TileParitionerM01 = 4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool DoubleSmemBuffer = false; - - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = false; -}; - -template +struct MXFlatmmArchTraits +{ + static constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern + + using Config = FlatmmConfig; + + template + using MXFlatmmPipeline = ck_tile::MXFlatmmPipelineAGmemBGmemCRegV1; + + static constexpr int GetNLane() { return Config::N_Warp_Tile; } + + template + static auto preShuffleScale(ck_tile::HostTensor& src) + { + auto src_lengths = src.get_lengths(); + const auto MN = KLast ? src_lengths[0] : src_lengths[1]; + const auto K = KLast ? src_lengths[1] : src_lengths[0]; + + size_t MNXdlPack = 2; + size_t KXdlPack = 2; + size_t XdlMNThread = Config::N_Warp_Tile; // 16 + size_t XdlKThread = ck_tile::get_warp_size() / XdlMNThread; + + const auto MN_Paded = ck_tile::integer_least_multiple(MN, XdlMNThread * MNXdlPack); + + ck_tile::HostTensor shuffled(ck_tile::HostTensorDescriptor({MN_Paded * K}, {1})); + + size_t K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(size_t n = 0; n < MN_Paded; ++n) + { + for(size_t k = 0; k < K; ++k) + { + auto n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + auto tempn = n % (XdlMNThread * MNXdlPack); + auto n1 = tempn % XdlMNThread; // i XdlMNThread + auto n2 = tempn / XdlMNThread; // i MNXdlPack + + auto k0 = k / (XdlKThread * KXdlPack); // i KRepeat + auto tempk = k % (XdlKThread * KXdlPack); + auto k1 = tempk % XdlKThread; // i XdlKThread + auto k2 = tempk / XdlKThread; // i KXdlPack + + auto outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + + n1 * MNXdlPack * KXdlPack + k2 * MNXdlPack + n2; + + if constexpr(KLast) + shuffled(outputIndex) = n < MN ? src(n, k) : dtype{}; + else + shuffled(outputIndex) = n < MN ? src(k, n) : dtype{}; + } + } + return shuffled; + } +}; + +using MXFlatmm_GFX950_FP4FP4_Traits = + MXFlatmmArchTraits; +using MXFlatmm_GFX950_FP8FP8_Traits = + MXFlatmmArchTraits; +using MXFlatmm_GFX950_FP6FP6_Traits = + MXFlatmmArchTraits; +using MXFlatmm_GFX950_FP8FP4_Traits = + MXFlatmmArchTraits; +using MXFlatmm_GFX950_FP4FP8_Traits = + MXFlatmmArchTraits; diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake index 9250dbe7ae..101719361c 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake @@ -6,31 +6,33 @@ function(mx_flatmm_instance_generate FILE_LIST) set(A_LAYOUT ROW) set(B_LAYOUT COL) set(C_LAYOUT ROW) - set(FLATMM_CONFIG_FP4xFP4 "MXfp4_FlatmmConfig16") - set(FLATMM_CONFIG_FP8xFP8 "MXfp8_FlatmmConfig16") - set(FLATMM_CONFIG_FP6xFP6 "MXfp6_FlatmmConfig16") - set(FLATMM_CONFIG_FP8xFP4 "MXf8f4_FlatmmConfig16") - set(FLATMM_CONFIG_FP4xFP8 "MXf4f8_FlatmmConfig16") + + set(MXFLATMM_ARCH) + + if (GPU_TARGETS MATCHES "gfx95") + list(APPEND MXFLATMM_ARCH MXFlatmm_GFX950_) + endif() # foreach(PERSISTENT false true) # TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions. foreach(PERSISTENT false) foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP6xFP6 FP8xFP4 FP4xFP8) - set(FLATMM_CONFIG ${FLATMM_CONFIG_${DATA_TYPE}}) string(REPLACE "x" ";" DATA_TYPE_AB ${DATA_TYPE}) list(GET DATA_TYPE_AB 0 A_DATA_TYPE) list(GET DATA_TYPE_AB 1 B_DATA_TYPE) - - foreach(SPLIT_K false true) - foreach(HAS_HOT_LOOP false true) - foreach(TAIL_NUMBER ODD EVEN) - set(KERNEL_FILE mxgemm/mx_flatmm_instance_${PERSISTENT}_${DATA_TYPE}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp) - string(TOLOWER ${KERNEL_FILE} KERNEL_FILE) - configure_file( - ${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in - ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE} - @ONLY) - list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}) + foreach(ARCH ${MXFLATMM_ARCH}) + set(MXFLATMM_ARCH_TRAITS "${ARCH}${A_DATA_TYPE}${B_DATA_TYPE}_Traits") + foreach(SPLIT_K false true) + foreach(HAS_HOT_LOOP false true) + foreach(TAIL_NUMBER ODD EVEN) + set(KERNEL_FILE mxgemm/instance_${ARCH}${DATA_TYPE}_${PERSISTENT}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp) + string(TOLOWER ${KERNEL_FILE} KERNEL_FILE) + configure_file( + ${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in + ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE} + @ONLY) + list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}) + endforeach() endforeach() endforeach() endforeach() diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in index e6d612f0d6..a5bbd5ed3c 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in @@ -4,7 +4,7 @@ #include "mx_flatmm_instance.hpp" // clang-format off -#define FLATMM_CONFIG @FLATMM_CONFIG@ +#define MXFLATMM_ARCH_TRAITS @MXFLATMM_ARCH_TRAITS@ #define A_DATA_TYPE @A_DATA_TYPE@ #define B_DATA_TYPE @B_DATA_TYPE@ #define C_DATA_TYPE @C_DATA_TYPE@ @@ -37,7 +37,7 @@ inline constexpr int ScaleGranularityK = 32; using ScaleM = ck_tile::FlatmmScalePointer; using ScaleN = ck_tile::FlatmmScalePointer; -template float mx_flatmm_calc, diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp index 01128f8fe8..90bd24d5dc 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp @@ -16,7 +16,7 @@ template using is_row_major_t = ck_tile::bool_constant< std::is_same_v, ck_tile::tensor_layout::gemm::RowMajor>>; -template & args, const ck_tile::stream_config& s) { + using FlatmmConfig = typename MXFlatmmArchTraits::Config; + using FlatmmShape = ck_tile::TileGemmShape< ck_tile::sequence, ck_tile::sequence, @@ -63,7 +65,8 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, constexpr auto scheduler = FlatmmConfig::Scheduler; ck_tile::ignore = Splitk; - constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern + // determined by scale shuffle pattern + constexpr int BlockedXDLN_PerWarp = MXFlatmmArchTraits::BlockedXDLN_PerWarp; using MXPipelineProblem = ck_tile::MXFlatmmPipelineProblem& args, HasHotLoop, TailNum>; - using MXFlatmmPipeline = ck_tile::MXFlatmmPipelineAGmemBGmemCRegV1; + using MXFlatmmPipeline = + typename MXFlatmmArchTraits::template MXFlatmmPipeline; using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner -int run_mx_flatmm_with_layouts(int argc, - char* argv[], +int run_mx_flatmm_with_layouts(const ck_tile::ArgParser& arg_parser, const ALayout a_layout = ALayout{}, const BLayout b_layout = BLayout{}, const CLayout c_layout = CLayout{}) { - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - using ADataType = PrecActType; using BDataType = PrecWeightType; using AccDataType = float; @@ -111,9 +106,9 @@ int run_mx_flatmm_with_layouts(int argc, } } - const auto b_shuffled_host = preShuffleWeight(b_origin_host); - const auto scale_a_shuffled = preShuffleScale(scale_a); - const auto scale_b_shuffled = preShuffleScale(scale_b); + const auto b_shuffled_host = preShuffleWeight(b_origin_host); + const auto scale_a_shuffled = MXFlatmmArchTraits::template preShuffleScale(scale_a); + const auto scale_b_shuffled = MXFlatmmArchTraits::template preShuffleScale(scale_b); ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_shuffled_dev_buf(b_shuffled_host.get_element_space_size_in_bytes()); @@ -135,7 +130,7 @@ int run_mx_flatmm_with_layouts(int argc, ck_tile::FlatmmScalePointer{ static_cast(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN}; - invoke_mx_flatmm,