From 3dfa794fab62dca7c0499791d37298a49630d5ee Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 16 Dec 2025 08:22:52 -0800 Subject: [PATCH 1/3] Add build trace diagnostics to CI. (#3432) * generate and visualize build traces for all archs * generate build traces in all cases * fix jenkins logic * fix typo * use more threads for parsing dependency map * add script to parse ninja traces and issue warnings * fix python script syntax and header * fix python syntax one more time * fix python syntax --- Jenkinsfile | 34 ++++++++------- .../src/enhanced_ninja_parser.py | 2 +- script/parse_ninja_trace.py | 43 +++++++++++++++++++ 3 files changed, 62 insertions(+), 17 deletions(-) create mode 100755 script/parse_ninja_trace.py diff --git a/Jenkinsfile b/Jenkinsfile index aea14c78b6..2a1d1fd904 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -72,10 +72,8 @@ def sendFailureNotifications() { } } -def generateAndArchiveBuildTraceVisualization() { +def generateAndArchiveBuildTraceVisualization(String buildTraceFileName) { try { - def buildTraceFileName = "ck_build_trace.json"; - // Attempt to download the build trace file to check if it exists def traceFileExists = false try { @@ -628,15 +626,17 @@ def cmake_build(Map conf=[:]){ sh cmd //run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){ - if ((setup_args.contains("gfx9") && params.NINJA_BUILD_TRACE) || params.BUILD_INSTANCES_ONLY){ + sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${check_arch_name()}.json" + archiveArtifacts "ck_build_trace_${check_arch_name()}.json" + sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${check_arch_name()}.json" + if (params.NINJA_BUILD_TRACE || params.BUILD_INSTANCES_ONLY){ if (params.NINJA_FTIME_TRACE) { - echo "running ninja ftime trace" + echo "running ClangBuildAnalyzer" sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log" - sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis.log" - archiveArtifacts "clang_build_analysis.log" + sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${check_arch_name()}.log" + archiveArtifacts "clang_build_analysis_${check_arch_name()}.log" } - sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace.json" - archiveArtifacts "ck_build_trace.json" + // do not run unit tests when building instances only if(!params.BUILD_INSTANCES_ONLY){ @@ -652,9 +652,8 @@ def cmake_build(Map conf=[:]){ if(params.BUILD_PACKAGES){ echo "Build ckProfiler packages" sh 'ninja -j64 package' - def arch_name = check_arch_name() - sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb" - stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}" + sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${check_arch_name()}.deb" + stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${check_arch_name()}" } } if(params.BUILD_INSTANCES_ONLY){ @@ -680,9 +679,8 @@ def cmake_build(Map conf=[:]){ if(params.BUILD_PACKAGES){ echo "Build ckProfiler packages" sh 'ninja -j64 package' - def arch_name = check_arch_name() - sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb" - stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}" + sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${check_arch_name()}.deb" + stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${check_arch_name()}" } } } @@ -1887,7 +1885,11 @@ pipeline { node(rocmnode("nogpu")) { script { // Simulate capture - generateAndArchiveBuildTraceVisualization() + generateAndArchiveBuildTraceVisualization("ck_build_trace_gfx11.json") + generateAndArchiveBuildTraceVisualization("ck_build_trace_gfx12.json") + generateAndArchiveBuildTraceVisualization("ck_build_trace_gfx90a.json") + generateAndArchiveBuildTraceVisualization("ck_build_trace_gfx942.json") + generateAndArchiveBuildTraceVisualization("ck_build_trace_gfx950.json") } cleanWs() } diff --git a/script/dependency-parser/src/enhanced_ninja_parser.py b/script/dependency-parser/src/enhanced_ninja_parser.py index 2ac8e8537a..ebcd878915 100644 --- a/script/dependency-parser/src/enhanced_ninja_parser.py +++ b/script/dependency-parser/src/enhanced_ninja_parser.py @@ -99,7 +99,7 @@ class EnhancedNinjaDependencyParser: print("No object files found - skipping dependency extraction") return - max_workers = min(16, len(object_files)) # Limit concurrent processes + max_workers = min(128, len(object_files)) # Limit concurrent processes with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all object files for processing diff --git a/script/parse_ninja_trace.py b/script/parse_ninja_trace.py new file mode 100755 index 0000000000..1706214f49 --- /dev/null +++ b/script/parse_ninja_trace.py @@ -0,0 +1,43 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +import json +import os +import sys + + +def read_json_file(file_path): + if not os.path.isfile(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "r", encoding="utf-8") as file: + try: + data = json.load(file) + except json.JSONDecodeError as e: + raise json.JSONDecodeError(f"Invalid JSON format: {e}", e.doc, e.pos) + return data + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python parse_json.py ") + sys.exit(1) + + json_file_path = sys.argv[1] + + try: + parsed_data = read_json_file(json_file_path) + print("JSON parsed successfully!") + threshold = 15 # max number of minutes for compilation + for i in range(len(parsed_data)): + if parsed_data[i]["dur"] > threshold * 60000000: + print( + f"build duration of {parsed_data[i]['name']} exceeds {threshold} minutes! actual build time: {parsed_data[i]['dur'] / 60000000:.2f} minutes!" + ) + + except FileNotFoundError as fnf_err: + print(f"Error: {fnf_err}") + except json.JSONDecodeError as json_err: + print(f"Error: {json_err}") + except Exception as e: + print(f"Unexpected error: {e}") From 57e1e4a8485835004c36144ba1b39fc3051538a7 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Wed, 17 Dec 2025 10:01:48 +0800 Subject: [PATCH 2/3] [CK_TILE] Add FP8xF4 Flatmm (#3401) * Refactor policy * fix a bank conflict * Enable mixed mx flatmm * Update --- CHANGELOG.md | 2 + .../ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp | 24 +- .../ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp | 63 ++++ .../18_flatmm/mxgemm/mx_flatmm_instance.cmake | 14 +- .../ops/flatmm/kernel/mx_flatmm_kernel.hpp | 7 +- ...mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 13 +- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 317 +++++++----------- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 7 +- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 7 +- 9 files changed, 231 insertions(+), 223 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 997fb8bb8c..a69ce2260e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. * Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32". * Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines. +* Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline ### Changed @@ -36,6 +37,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added pooling kernel in CK_TILE * Added top-k sigmoid kernel in CK_TILE * Added the blockscale 2D support for CK_TILE GEMM. +* Added Flatmm pipeline for microscaling (MX) FP8/FP4 data types ### Changed diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp index 0134465347..d6c84f3064 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -148,7 +148,7 @@ auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("m", "32", "m dimension") - .insert("n", "128", "n dimension") + .insert("n", "512", "n dimension") .insert("k", "256", "k dimension") .insert("a_layout", "R", "A tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Row by default") @@ -308,6 +308,28 @@ int run_mx_flatmm_example(int argc, char* argv[]) else throw std::runtime_error("Only support non-persistent kernel now!"); } + else if(mx_prec == "fp8xfp4") + { + if(persistent_opt == 0) + return run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + else + throw std::runtime_error("Only support non-persistent kernel now!"); + } + else if(mx_prec == "fp4xfp8") + { + if(persistent_opt == 0) + return run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + else + throw std::runtime_error("Only support non-persistent kernel now!"); + } else { throw std::runtime_error("Unsupported data_type!"); diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp index e374a4ddd3..0b6185590f 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp @@ -76,6 +76,69 @@ struct MXfp8_FlatmmConfig16 static constexpr bool TiledMMAPermuteN = false; }; +struct MXf8f4_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr int TileParitionerGroupNum = 8; + static constexpr int TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = false; +}; +struct MXf4f8_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr int TileParitionerGroupNum = 8; + static constexpr int TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = false; +}; + template struct MXFlatmmPipelineProblem : FlatmmPipelineProblem= DsReadPreload) ? DsReadPreload @@ -470,11 +470,6 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1(); - } - template CK_TILE_DEVICE auto operator()(Args&&... args) const { @@ -684,7 +679,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 a_warp_tensor; // preload A00,A10... from lds - s_waitcnt_barrier(); + s_waitcnt_barrier(); static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MXdlPack; constexpr auto kIter = loadIter / MXdlPack; diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 4d76ab7da2..e188ddec61 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -7,6 +7,8 @@ namespace ck_tile { +namespace detail { +template struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy { static constexpr auto I0 = number<0>{}; @@ -14,27 +16,47 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy static constexpr auto I2 = number<2>{}; static constexpr index_t kDramLoadPackBytes = 128; + static constexpr index_t DWORDx4 = 16; static constexpr int MXdlPack = 2; static constexpr int NXdlPack = 2; static constexpr int KXdlPack = 2; - template - static inline constexpr auto wg_attr_num_access = - std::is_same_v, pk_fp4_t> - ? WGAttrNumAccessEnum::Single - : WGAttrNumAccessEnum::Double; + private: + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + static constexpr index_t APackedSize = numeric_traits::PackedSize; + static constexpr index_t BPackedSize = numeric_traits::PackedSize; + + using ALayout = remove_cvref_t; + static_assert(std::is_same_v); + + using TileShape = typename Problem::BlockGemmShape; + using BlockWarps = typename TileShape::BlockWarps; + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t WaveSize = get_warp_size(); + static constexpr index_t WaveNum = BlockSize / WaveSize; + + static constexpr index_t MPerBlock = TileShape::kM; + static constexpr index_t NPerBlock = TileShape::kN; + static constexpr index_t KPerBlock = TileShape::kK; + static constexpr index_t MWarps = BlockWarps::at(I0); + static constexpr index_t NWarps = BlockWarps::at(I1); + static_assert(WaveNum == MWarps * NWarps, "Block warps do not match block size"); + + static constexpr index_t MPerXdl = TileShape::WarpTile::at(I0); + static constexpr index_t NPerXdl = TileShape::WarpTile::at(I1); + static constexpr index_t KPerXdl = TileShape::WarpTile::at(I2); + static_assert(MPerXdl == 16 && NPerXdl == 16); + static constexpr index_t K_Lane = get_warp_size() / 16; // 4 + static constexpr index_t K_Thread = KPerXdl / K_Lane; // 32 + + public: + static constexpr index_t AK1 = DWORDx4 * APackedSize; + static constexpr index_t BK1 = DWORDx4 * BPackedSize; - template CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm() { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - static_assert( - sizeof(ADataType) * numeric_traits::PackedSize == - sizeof(BDataType) * numeric_traits::PackedSize, - "sizeof(ADataType) / APackedSize must be equal to sizeof(BDataType) / BPackedSize!"); - using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpGemm = WarpGemmDispatcher< // ADataType, @@ -43,10 +65,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy WarpTile::at(I0), WarpTile::at(I1), WarpTile::at(I2), - Problem::TransposeC, - false, - false, - wg_attr_num_access>; + Problem::TransposeC>; using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< // ADataType, BDataType, @@ -56,28 +75,20 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy return BlockFlatmmASmemBSmemCRegV1{}; } - template + template CK_TILE_DEVICE static constexpr auto MakeMX_AAsyncLoadDramDescriptor(const TensorView& naive_view) { - using ADataType = remove_cvref_t; - using ALayout = remove_cvref_t; - constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); - constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); - static_assert(MPerXdl == 16 && NPerXdl == 16); - static_assert(std::is_same_v); - const auto& naive_desc = naive_view.get_tensor_descriptor(); constexpr auto ndims = remove_cvref_t::get_num_of_dimension(); static_assert(ndims == 2, "only support 2D tensor"); const auto rows = naive_desc.get_length(number<0>{}); const auto cols = naive_desc.get_length(number<1>{}); - constexpr index_t APackedSize = numeric_traits::PackedSize; - constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 - constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 - const index_t K0 = cols / (K1 * K2); - const auto col_lens = make_tuple(K0, number{}, number{}); + constexpr index_t K2 = AK1; // f4=32; f8=16 + constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8 + const index_t K0 = cols / (K1 * K2); + const auto col_lens = make_tuple(K0, number{}, number{}); constexpr index_t M1 = 4; // so that we can use imm offset to load lds const index_t M0 = rows / M1; @@ -106,25 +117,14 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy TensorView::DstInMemOp>{naive_view.buf_, desc}; } - template CK_TILE_DEVICE static constexpr auto MakeMX_ADramTileDistribution() { - - using ADataType = remove_cvref_t; - using ALayout = remove_cvref_t; - static_assert(std::is_same_v); - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t APackedSize = numeric_traits::PackedSize; - - constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 + constexpr index_t K2 = AK1; // f4=32; f8=16 constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 - constexpr index_t M2 = get_warp_size() / K1; // 8 - constexpr index_t M1 = BlockSize / get_warp_size(); // 4 + constexpr index_t M2 = WaveSize / K1; // 8 + constexpr index_t M1 = BlockSize / WaveSize; // 4 constexpr index_t M0 = MPerBlock / (M2 * M1); static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!"); static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); @@ -139,28 +139,16 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<0, 0, 2>>{}); } - template CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBlockDescriptor() { - using ADataType = remove_cvref_t; - using ALayout = remove_cvref_t; - constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); - constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); - static_assert(MPerXdl == 16 && NPerXdl == 16); - static_assert(std::is_same_v); - - /*reduce transform layers,compare with old ck*/ - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t APackedSize = numeric_traits::PackedSize; - constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 - constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 - constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 + constexpr index_t K2 = AK1; // f4=32; f8=16 + constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8 + constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); - constexpr index_t M3 = 4; // so that we can use imm offset to load lds - constexpr index_t M2 = get_warp_size() / K1 / M3; // 2 - constexpr index_t M1 = MPerXdl / (M2 * M3); // 2 + constexpr index_t M3 = 4; // so that we can use imm offset to load lds + constexpr index_t M2 = WaveSize / K1 / M3; // 2 + constexpr index_t M1 = MPerXdl / (M2 * M3); // 2 constexpr index_t M0 = MPerBlock / (M1 * M2 * M3); // MPerBlock/16 static_assert(M0 * M1 * M2 * M3 == MPerBlock, "M0, M1, M2, M3 must cover whole MPerBlock!"); @@ -168,14 +156,14 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( // make_tuple(number{}, - number{}, number{}, + number{}, number{}, number{}, number{}, number{}), - make_tuple(number{}, - number{}, + make_tuple(number{}, + number{}, number{}, number{}, number{}, @@ -187,8 +175,8 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy constexpr auto a_lds_block_desc_1 = transform_tensor_descriptor( a_lds_block_desc_0, make_tuple(make_pass_through_transform(M0), - make_pass_through_transform(M1), make_pass_through_transform(K0), + make_pass_through_transform(M1), make_pass_through_transform(M2), make_xor_transform(make_tuple(number{}, number{})), make_pass_through_transform(number{})), @@ -210,103 +198,71 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy make_tuple(number{}, number{}, number{}, number{})), make_merge_transform_v3_division_mod( make_tuple(number{}, number{}, number{}))), - make_tuple(sequence<0, 1, 3, 4>{}, sequence<2, 5, 6>{}), + make_tuple(sequence<0, 2, 3, 4>{}, sequence<1, 5, 6>{}), make_tuple(sequence<0>{}, sequence<1>{})); // return a_lds_block_desc_permuted; return a_lds_block_desc; } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDS_TileDistribution() { - using TileShape = typename Problem::BlockGemmShape; + static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1"); - static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16"); - static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1"); - - constexpr int M_warps = TileShape::BlockWarps::at(number<0>{}); - constexpr int N_warps = TileShape::BlockWarps::at(number<1>{}); - constexpr int M_Lane = TileShape::WarpTile::at(I0); // 16 - - constexpr int K_Lane = 64 / M_Lane; // 4 - - constexpr int K_Thread = TileShape::WarpTile::at(I2) / K_Lane; // 32 - constexpr index_t num_access_v = static_cast(wg_attr_num_access); - constexpr int K1 = K_Thread / num_access_v; // 16 - - return make_static_tile_distribution( - std::conditional_t< - num_access_v == 1, - tile_distribution_encoding< - sequence, - tuple, sequence>, + if constexpr(K_Thread == AK1) + return make_static_tile_distribution( + tile_distribution_encoding< // + sequence, + tuple, sequence>, tuple, sequence<2, 1>>, tuple, sequence<0, 2>>, sequence<2>, - sequence<1>>, - tile_distribution_encoding< // - sequence, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, - sequence<2, 2>, - sequence<0, 2>>>{}); + sequence<1>>{}); + else + return make_static_tile_distribution(tile_distribution_encoding< // + sequence, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 2>>, + sequence<2, 2>, + sequence<0, 2>>{}); } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; - using BDataType = remove_cvref_t; - constexpr index_t BPack = numeric_traits::PackedSize; - - static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16"); - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t WaveSize = get_warp_size(); - constexpr index_t WaveNum = BlockSize / WaveSize; - constexpr index_t K1 = WaveSize; // threads cnt in K dim constexpr index_t KWavePerBlk = 1; constexpr index_t K0 = KWavePerBlk; - constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp + constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; - constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; - constexpr index_t kKPerThread = 32; - constexpr index_t num_access_v = static_cast(wg_attr_num_access); - constexpr index_t K2 = kKPerThread / num_access_v; - - return make_static_tile_distribution( - std::conditional_t< // - num_access_v == 1, + if constexpr(BK1 == K_Thread) + return make_static_tile_distribution( tile_distribution_encoding< // sequence, - tuple, // 4 2 - sequence>, // 1 64 32 + tuple, // 4 2 + sequence>, // 1 64 32 tuple, sequence<2>>, tuple, sequence<1>>, sequence<2>, - sequence<2>>, + sequence<2>>{}); + else + return make_static_tile_distribution( tile_distribution_encoding< // sequence, - tuple, // 4 2 - sequence>, // 2 1 64 16 + tuple, // 4 2 + sequence>, // 2 1 64 16 tuple, sequence<2>>, tuple, sequence<2>>, sequence<2, 2>, - sequence<0, 3>>>{}); + sequence<0, 3>>{}); } - template + template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramWindow(const WindowTmp& window_tmp) { - - using BDataType = remove_cvref_t; - constexpr auto BPackedSize = numeric_traits::PackedSize; - constexpr auto kKPerBlock = Problem::BlockGemmShape::kK; constexpr auto M_Warp_Tile = Problem::BlockGemmShape::WarpTile::at(I1); constexpr auto flatNPerWarp = Problem::BlockGemmShape::flatNPerWarp; constexpr auto flatKPerWarp = Problem::BlockGemmShape::flatKPerWarp; @@ -314,7 +270,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy static_assert(std::decay_t::get_num_of_dimension() == 2); auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view(); const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); - constexpr auto flat_k_per_block = kKPerBlock * M_Warp_Tile; + constexpr auto flat_k_per_block = KPerBlock * M_Warp_Tile; auto&& byte_tensor_desc = transform_tensor_descriptor( make_naive_tensor_descriptor_packed(make_tuple( flat_n, flat_k / flat_k_per_block, number{})), @@ -331,39 +287,25 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy byte_tensor_view, make_tuple(number{}, number{}), {origin_tmp[0], origin_tmp[1] / BPackedSize}, - MakeMX_BFlatBytesDramTileDistribution()); + MakeMX_BFlatBytesDramTileDistribution()); } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t WaveSize = get_warp_size(); - constexpr index_t WaveNum = BlockSize / WaveSize; - - constexpr index_t kMPerBlock = TileShape::BlockTile::at(I0); - - constexpr index_t M_Warps = TileShape::BlockWarps::at(I0); - constexpr index_t N_Warps = TileShape::BlockWarps::at(I1); - - static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size"); - constexpr index_t M_Lanes = TileShape::WarpTile::at(I0); constexpr index_t K_Lanes = 64 / M_Lanes; // Y dimension (M) decomposition constexpr index_t Y2 = M_Lanes; - constexpr index_t Y1 = M_Warps; - constexpr index_t Y0 = kMPerBlock / (MXdlPack * Y1 * Y2); + constexpr index_t Y1 = MWarps; + constexpr index_t Y0 = MPerBlock / (MXdlPack * Y1 * Y2); // X dimension (K) decomposition constexpr index_t X0 = K_Lanes; constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load return make_static_tile_distribution( - tile_distribution_encoding, // repeat N_warps + tile_distribution_encoding, // repeat NWarps tuple, sequence>, tuple, sequence<2, 1>>, tuple, sequence<0, 2>>, @@ -371,36 +313,22 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<0, 1>>{}); } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t WaveSize = get_warp_size(); - constexpr index_t WaveNum = BlockSize / WaveSize; - - constexpr index_t kNPerBlock = TileShape::BlockTile::at(I1); - - constexpr index_t M_Warps = TileShape::BlockWarps::at(I0); - constexpr index_t N_Warps = TileShape::BlockWarps::at(I1); - - static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size"); - constexpr index_t N_Lanes = TileShape::WarpTile::at(I1); constexpr index_t K_Lanes = 64 / N_Lanes; // Y dimension (M) decomposition constexpr index_t Y2 = N_Lanes; - constexpr index_t Y1 = N_Warps; - constexpr index_t Y0 = kNPerBlock / (NXdlPack * Y1 * Y2); + constexpr index_t Y1 = NWarps; + constexpr index_t Y0 = NPerBlock / (NXdlPack * Y1 * Y2); // X dimension (K) decomposition constexpr index_t X0 = K_Lanes; constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load return make_static_tile_distribution( - tile_distribution_encoding, // ? + tile_distribution_encoding, // ? tuple, sequence>, tuple, sequence<2, 1>>, tuple, sequence<0, 2>>, @@ -408,20 +336,11 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<0, 1>>{}); } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_FlatDramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; - - constexpr index_t M_Warp = TileShape::BlockWarps::at(number<0>{}); - constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I0); - constexpr index_t M_Lane = TileShape::WarpTile::at(I0); - constexpr index_t N_Wrap = TileShape::BlockWarps::at(number<1>{}); - constexpr index_t MWavePerBlk = M_Warp; - return make_static_tile_distribution( - tile_distribution_encoding, // ? - tuple, // second direction + tile_distribution_encoding, // ? + tuple, // second direction sequence>, // first direction tuple, sequence<2, 1>>, // which direction tuple, sequence<0, 1>>, // which index @@ -430,20 +349,11 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<1>>{}); } - template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_FlatDramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; - - constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{}); - constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1); - constexpr index_t N_Lane = TileShape::WarpTile::at(I1); - constexpr index_t M_Wrap = TileShape::BlockWarps::at(number<0>{}); - constexpr index_t NWavePerBlk = N_Warp; - return make_static_tile_distribution( - tile_distribution_encoding, // ? - tuple, // second direction + tile_distribution_encoding, // ? + tuple, // second direction sequence>, // first direction tuple, sequence<2, 1>>, // which direction tuple, sequence<0, 1>>, // which index @@ -452,20 +362,41 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<1>>{}); } - template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() { - using ADataType = remove_cvref_t; - constexpr index_t APackedSize = numeric_traits::PackedSize; - return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor().get_element_space_size() / + return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor().get_element_space_size() / APackedSize; } - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - return GetSmemSizeA(); + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return GetSmemSizeA(); } +}; +} // namespace detail + +struct MXFlatmmPipelineAgBgCrPolicy +{ + +#define FORWARD_METHOD_(method) \ + template \ + CK_TILE_HOST_DEVICE static constexpr auto method(Args&&... args) \ + { \ + return detail::MXFlatmmPipelineAgBgCrPolicy::method(std::forward(args)...); \ } + + FORWARD_METHOD_(GetBlockFlatmm); + FORWARD_METHOD_(MakeMX_AAsyncLoadDramDescriptor); + FORWARD_METHOD_(MakeMX_ADramTileDistribution); + FORWARD_METHOD_(MakeMX_ALdsBlockDescriptor); + FORWARD_METHOD_(MakeMX_ALDS_TileDistribution); + FORWARD_METHOD_(MakeMX_BFlatBytesDramTileDistribution); + FORWARD_METHOD_(MakeMX_BFlatBytesDramWindow); + FORWARD_METHOD_(MakeMX_ScaleA_DramTileDistribution); + FORWARD_METHOD_(MakeMX_ScaleB_DramTileDistribution); + FORWARD_METHOD_(MakeMX_ScaleA_FlatDramTileDistribution); + FORWARD_METHOD_(MakeMX_ScaleB_FlatDramTileDistribution); + FORWARD_METHOD_(GetSmemSizeA); + FORWARD_METHOD_(GetSmemSize); + +#undef FORWARD_METHOD_ }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index a2c320f3e6..44a09423ee 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -306,10 +306,9 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl, 2>>; -template -using WarpGemmMfma_f32_16x16x128_fp4 = WarpGemmImpl< - WarpGemmAttributeMfma, - AttrNumAccess>>; +template +using WarpGemmMfma_f32_16x16x128_f8f6f4 = WarpGemmImpl< + WarpGemmAttributeMfma, AttrNumAccess>>; template using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl< // WarpGemmAttributeMfma, diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 9d928a7cfa..82c6e43834 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -116,15 +116,12 @@ template<> struct Dispatcher { using Ty template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; // scale mfma based f8f6f4 -template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8; }; -template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8; }; -template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; }; -template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; }; +template +struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_f8f6f4; }; template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed; }; template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed; }; template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed; }; template struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed; }; -template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp4<>; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; }; From 292df2719f28cd01464d5d059820684790c101da Mon Sep 17 00:00:00 2001 From: KateJu <153474223+kateju12@users.noreply.github.com> Date: Wed, 17 Dec 2025 11:50:49 +0800 Subject: [PATCH 3/3] fix some minor error (#3409) ReduceWithNoIndexTesBtHalfFloat_AMAX: fix typo error to ReduceWithNoIndexTesBHalfFloat_AMAX reduce_blockwise_test( + pass = reduce_blockwise_test( arg.do_verification, arg.init_method, arg.time_kernel, diff --git a/test/reduce/reduce_no_index.cpp b/test/reduce/reduce_no_index.cpp index 655593228a..e869ca1a76 100644 --- a/test/reduce/reduce_no_index.cpp +++ b/test/reduce/reduce_no_index.cpp @@ -193,7 +193,7 @@ TYPED_TEST(ReduceWithNoIndexHalf, ReduceWithNoIndexTestHalf_MAX) this->template Run(); } -TYPED_TEST(ReduceWithNoIndexBHalfFloat, ReduceWithNoIndexTesBtHalfFloat_AMAX) +TYPED_TEST(ReduceWithNoIndexBHalfFloat, ReduceWithNoIndexTesBHalfFloat_AMAX) { // trigger Run() -> Generic this->template Run(); diff --git a/test/reduce/reduce_with_index.cpp b/test/reduce/reduce_with_index.cpp index f46bfc4f30..c6cbea8610 100644 --- a/test/reduce/reduce_with_index.cpp +++ b/test/reduce/reduce_with_index.cpp @@ -193,7 +193,7 @@ TYPED_TEST(ReduceWithIndexHalf, ReduceWithIndexTestHalf_MAX) this->template Run(); } -TYPED_TEST(ReduceWithIndexBHalfFloat, ReduceWithIndexTesBtHalfFloat_AMAX) +TYPED_TEST(ReduceWithIndexBHalfFloat, ReduceWithIndexTesBHalfFloat_AMAX) { // trigger Run() -> Generic this->template Run();