diff --git a/Jenkinsfile b/Jenkinsfile index ca7c4f1d93..80721ea6d3 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -581,7 +581,7 @@ def cmake_build(Map conf=[:]){ if (params.NINJA_BUILD_TRACE) { echo "running ninja build trace" } - if ((params.RUN_BUILDER_TESTS || params.RUN_FULL_CONV_TILE_TESTS) && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) { + if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) { setup_args = " -D CK_EXPERIMENTAL_BUILDER=ON " + setup_args } setup_cmd = conf.get( @@ -1428,8 +1428,8 @@ pipeline { agent{ label rocmnode("gfx90a")} environment{ setup_args = "NO_CK_BUILD" - execute_args = """ python3 ../experimental/builder/src/generate_instances.py --mode=profiler && \ - ../script/cmake-ck-dev.sh ../ gfx90a && \ + execute_args = """ python3 ../experimental/grouped_convolution_tile_instances/generate_instances.py --mode=profiler && \ + cmake .. --preset dev-gfx90a -D CK_EXPERIMENTAL_BUILDER=ON && \ make -j64 test_grouped_convnd_fwd_tile && \ ./bin/test_grouped_convnd_fwd_tile""" } diff --git a/cmake/gtest.cmake b/cmake/gtest.cmake index 993330f989..51e0359ab6 100644 --- a/cmake/gtest.cmake +++ b/cmake/gtest.cmake @@ -68,6 +68,8 @@ set(GTEST_CXX_FLAGS -Wno-deprecated -Wno-unsafe-buffer-usage -Wno-float-equal + -Wno-lifetime-safety-intra-tu-suggestions + -Wno-lifetime-safety-cross-tu-suggestions ) if(WIN32) diff --git a/example/ck_tile/01_fmha/bias.hpp b/example/ck_tile/01_fmha/bias.hpp index 33f398cc2a..b526204384 100644 --- a/example/ck_tile/01_fmha/bias.hpp +++ b/example/ck_tile/01_fmha/bias.hpp @@ -106,7 +106,7 @@ struct bias_info return info; } - friend std::ostream& operator<<(std::ostream& os, const bias_info& bi) + friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const bias_info& bi) { bi.serialize(os); return os; diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index f85b811116..c780bf7b6b 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -191,7 +191,7 @@ struct mask_info return area; } - friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) + friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const mask_info& mi) { mi.serialize(os); return os; diff --git a/example/ck_tile/01_fmha/quant.hpp b/example/ck_tile/01_fmha/quant.hpp index feb28cba24..da588910b2 100644 --- a/example/ck_tile/01_fmha/quant.hpp +++ b/example/ck_tile/01_fmha/quant.hpp @@ -8,6 +8,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + // keep sync with BlockAttentionQuantScaleEnum enum class quant_scale_enum { @@ -58,3 +61,4 @@ struct quant_scale_info return os; } }; +#pragma clang diagnostic pop diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 7451ee25b0..d77e3c9322 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -21,7 +21,6 @@ if(has_supported_gpu) list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1") - add_executable(tile_example_flatmm_basic flatmm_basic.cpp) target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp index d6c84f3064..1141717545 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -179,10 +179,11 @@ auto preShuffleWeight(ck_tile::HostTensor& src) const int K = src_lengths[0]; const int N = src_lengths[1]; constexpr int packed_size = ck_tile::numeric_traits::PackedSize; - int KPack = 16 * packed_size; // fp4:32 or fp8:16 - int NLane = N_Warp_Tile; - int KLane = 64 / NLane; - int K0 = K / (KLane * KPack); + int KPack = + std::is_same_v ? 32 : 16 * packed_size; // fp4/fp6:32 or fp8:16 + int NLane = N_Warp_Tile; + int KLane = 64 / NLane; + int K0 = K / (KLane * KPack); ck_tile::HostTensor shuffled(ck_tile::HostTensorDescriptor({N * K}, {1})); @@ -295,7 +296,14 @@ int run_mx_flatmm_example(int argc, char* argv[]) } else if(mx_prec == "fp6" || mx_prec == "fp6xfp6") { - throw std::runtime_error("fp6xfp6 is not supported."); + if(persistent_opt == 0) + return run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + else + throw std::runtime_error("Only support non-persistent kernel now!"); } else if(mx_prec == "fp8" || mx_prec == "fp8xfp8") { diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp index 0b6185590f..d4922bb44c 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp @@ -44,6 +44,38 @@ struct MXfp4_FlatmmConfig16 static constexpr bool TiledMMAPermuteN = false; }; +struct MXfp6_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr int TileParitionerGroupNum = 8; + static constexpr int TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = false; +}; + struct MXfp8_FlatmmConfig16 { static constexpr ck_tile::index_t M_Tile = 128; diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake index 5e86cd7133..9250dbe7ae 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake @@ -8,13 +8,14 @@ function(mx_flatmm_instance_generate FILE_LIST) set(C_LAYOUT ROW) set(FLATMM_CONFIG_FP4xFP4 "MXfp4_FlatmmConfig16") set(FLATMM_CONFIG_FP8xFP8 "MXfp8_FlatmmConfig16") + set(FLATMM_CONFIG_FP6xFP6 "MXfp6_FlatmmConfig16") set(FLATMM_CONFIG_FP8xFP4 "MXf8f4_FlatmmConfig16") set(FLATMM_CONFIG_FP4xFP8 "MXf4f8_FlatmmConfig16") # foreach(PERSISTENT false true) # TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions. foreach(PERSISTENT false) - foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP8xFP4 FP4xFP8) + foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP6xFP6 FP8xFP4 FP4xFP8) set(FLATMM_CONFIG ${FLATMM_CONFIG_${DATA_TYPE}}) string(REPLACE "x" ";" DATA_TYPE_AB ${DATA_TYPE}) list(GET DATA_TYPE_AB 0 A_DATA_TYPE) diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in index 9675d3345b..e6d612f0d6 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in @@ -19,6 +19,7 @@ using FP4 = ck_tile::pk_fp4_t; using FP8 = ck_tile::fp8_t; +using FP6 = ck_tile::pk_fp6x16_t; using FP16 = ck_tile::fp16_t; using BF16 = ck_tile::bf16_t; diff --git a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc index b4d1fe237b..54c23e2266 100644 --- a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc +++ b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc @@ -68,24 +68,47 @@ int run_mx_flatmm_with_layouts(int argc, M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout))); ck_tile::HostTensor scale_b(ck_tile::host_tensor_descriptor( K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout))); + if constexpr(std::is_same_v) + { + auto a_buffer_bytes = a_host.get_element_space_size_in_bytes(); + auto b_buffer_bytes = b_origin_host.get_element_space_size_in_bytes(); + ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_b); + std::vector random_bufA(a_buffer_bytes); + std::vector random_bufB(b_buffer_bytes); + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(1, 4); - if(init_method == 0) - { - ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host); - ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host); - ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a); - ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b); - } - else if(init_method == 1) - { - ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host); - ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host); - ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a); - ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b); + for(size_t i = 0; i < a_buffer_bytes; ++i) + random_bufA[i] = static_cast(dis(gen)); + + for(size_t i = 0; i < b_buffer_bytes; ++i) + random_bufB[i] = static_cast(dis(gen)); + + memcpy(a_host.data(), random_bufA.data(), a_buffer_bytes); + memcpy(b_origin_host.data(), random_bufB.data(), b_buffer_bytes); } else { - throw std::runtime_error("wrong! Unexpected init_method"); + if(init_method == 0) + { + ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host); + ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host); + ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a); + ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b); + } + else if(init_method == 1) + { + ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b); + } + else + { + throw std::runtime_error("wrong! Unexpected init_method"); + } } const auto b_shuffled_host = preShuffleWeight(b_origin_host); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp index a95c0346cf..1520f2c591 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigQuantDecode; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp index d2b95d3263..a93fe15a1b 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigQuantDecode; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp index a8c13c1b3d..39747ff0bc 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigQuantDecode; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp index 6576b22c03..ed18cd8890 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigQuantDecode; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 540d5725dd..508f3ac8ec 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -215,11 +215,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); const dim3 blocks = Kernel::BlockSize(); - if(args.k_batch != 1) - { - throw std::runtime_error("split-k is not supported yet!"); - } - + // Split-K validation is handled by Kernel::IsSupportedArgument + // Split-K is only supported for BQuantGrouped without preshuffle if(!Kernel::IsSupportedArgument(kargs)) { throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); @@ -661,182 +658,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } } } - else if(init_method == 3) - { - if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - ck_tile::FillConstant{static_cast(0x38)}(a_m_k); - ck_tile::FillConstant{static_cast(0x22)}(b_k_n); - ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); - } - else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) - { - ck_tile::FillConstant{static_cast(0x38)}(a_m_k); - ck_tile::FillConstant{static_cast(0x22)}(b_k_n); - ck_tile::FillConstant{static_cast(0.5f)}(*aq_tensor_ptr); - ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); - } - else - { - ck_tile::FillConstant{static_cast(0x22)}(a_m_k); - ck_tile::FillConstant{static_cast(2.0f)}(*aq_tensor_ptr); - ck_tile::FillConstant{static_cast(0x38)}(b_k_n); - - if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) - { - ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); - } - } - } - else if(init_method == 4) - { - if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); - } - else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); - } - ck_tile::FillUniformDistribution{2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - } - else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) - { - if constexpr(std::is_same_v || - std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - } - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - } - else if(init_method == 5) - { - if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); - } - else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - } - else - { - ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}(a_m_k); - } - // Fill aquant such that column j has value 2^j (1, 2, 4, 8, ...) - for(ck_tile::index_t row = 0; - row < static_cast(aq_tensor_ptr->get_length(0)); - ++row) - { - for(ck_tile::index_t col = 0; - col < static_cast(aq_tensor_ptr->get_length(1)); - ++col) - { - (*aq_tensor_ptr)(row, col) = static_cast(col + 1); - } - } - // std::cout << "aq_tensor_ptr: " << *aq_tensor_ptr << std::endl; - ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}(b_k_n); - } - else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) - { - if constexpr(std::is_same_v || - std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - } - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - } else { a_m_k.SetZero(); diff --git a/include/ck/BUILD_TIME_OPTIMIZATION.md b/include/ck/BUILD_TIME_OPTIMIZATION.md index 94b292b878..045d4bb929 100644 --- a/include/ck/BUILD_TIME_OPTIMIZATION.md +++ b/include/ck/BUILD_TIME_OPTIMIZATION.md @@ -105,7 +105,7 @@ struct generate_identity_sequence generate_tuple(generate_identity_sequence{}, Number{}); ``` -This reduced `transform_tensor_descriptor` instantiations from 388 to 32 (92% reduction). +This significantly reduces template instantiations for `transform_tensor_descriptor`. **Example: container_concat** @@ -135,7 +135,7 @@ __host__ __device__ constexpr auto container_concat(const Tuple& tx, const } ``` -This reduced `container_concat` instantiations from 186 to 93 (50% reduction). +This reduces `container_concat` template instantiations. **Example: make_uniform_tuple** @@ -192,7 +192,7 @@ __host__ __device__ constexpr index_t find_source_index(Sequence) } ``` -This reduced `sequence_map_inverse` instantiations from 45 to 10 (78% reduction) and wall-clock time by 95%. +This significantly reduces `sequence_map_inverse` instantiations and compile time. ### 4. Use Fold Expressions for Accumulation @@ -222,4 +222,4 @@ __host__ __device__ constexpr auto compute_element_space_size( } ``` -This reduced `calculate_element_space_size` instantiations from 24 to 10 (58% reduction) and wall-clock time by 73%. +This reduces `calculate_element_space_size` instantiations and compile time. diff --git a/include/ck/host_utility/io.hpp b/include/ck/host_utility/io.hpp index db45199b17..22d744ff15 100644 --- a/include/ck/host_utility/io.hpp +++ b/include/ck/host_utility/io.hpp @@ -13,7 +13,7 @@ namespace ck { template -std::ostream& operator<<(std::ostream& os, const std::vector& v) +std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const std::vector& v) { std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); return os; @@ -27,7 +27,8 @@ std::ostream& operator<<(std::ostream& os, const std::array& v) } template -std::ostream& operator<<(std::ostream& os, const TensorDescriptor& desc) +std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const TensorDescriptor& desc) { constexpr index_t nDim = remove_cvref_t::GetNumOfDimension(); diff --git a/include/ck/library/utility/convolution_parameter.hpp b/include/ck/library/utility/convolution_parameter.hpp index 354b112040..a25002409b 100644 --- a/include/ck/library/utility/convolution_parameter.hpp +++ b/include/ck/library/utility/convolution_parameter.hpp @@ -110,4 +110,5 @@ ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]) } // namespace utils } // namespace ck -std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p); +std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const ck::utils::conv::ConvParam& p); diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 1dda0a4863..2e95ee8cf3 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -23,10 +23,14 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" +#pragma clang diagnostic ignored "-Wlifetime-safety-cross-tu-suggestions" + namespace ck { template -std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) +std::ostream& LogRange([[clang::lifetimebound]] std::ostream& os, Range&& range, std::string delim) { bool first = true; for(auto&& v : range) @@ -580,8 +584,9 @@ struct HostTensorDescriptor return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); } - friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); - friend std::ostream& operator<<(std::ostream& os, ChosenLayout tag); + friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const HostTensorDescriptor& desc); + friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, ChosenLayout tag); private: std::vector mLens; @@ -1171,3 +1176,4 @@ struct Tensor }; } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/tensor/static_tensor.hpp b/include/ck/tensor/static_tensor.hpp index 529745e3b9..c3f3bd0c91 100644 --- a/include/ck/tensor/static_tensor.hpp +++ b/include/ck/tensor/static_tensor.hpp @@ -4,6 +4,8 @@ #ifndef CK_STATIC_TENSOR_HPP #define CK_STATIC_TENSOR_HPP +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck { // StaticTensor for Scalar @@ -270,4 +272,5 @@ __host__ __device__ constexpr auto make_static_tensor(TensorDesc, X invalid_elem } } // namespace ck +#pragma clang diagnostic pop #endif diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index 19a4748732..5a6c335b2c 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -6,6 +6,9 @@ #include "ck/utility/common_header.hpp" #include "ck/utility/multi_index.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck { template @@ -29,7 +32,10 @@ struct PassThrough __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low, @@ -305,7 +311,10 @@ struct RightPad __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low, @@ -403,7 +412,10 @@ struct Embed __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, @@ -1074,7 +1086,10 @@ struct Merge_v2_magic_division __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, @@ -1366,7 +1381,10 @@ struct Merge_v3_division_mod __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, @@ -1480,7 +1498,10 @@ struct UnMerge __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, @@ -1640,7 +1661,10 @@ struct ConvBwdDataImplicitGemmOutTransform __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 3; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ constexpr auto CalculateLowerIndexN(const UpIdx& idx_up) const @@ -2236,3 +2260,4 @@ struct Xor } }; } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/tensor_description/tensor_adaptor.hpp b/include/ck/tensor_description/tensor_adaptor.hpp index 79c5881d48..ee8c7ed71b 100644 --- a/include/ck/tensor_description/tensor_adaptor.hpp +++ b/include/ck/tensor_description/tensor_adaptor.hpp @@ -23,7 +23,10 @@ struct TensorAdaptor { __host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); } - __host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; } + __host__ __device__ constexpr const auto& GetTransforms() const [[clang::lifetimebound]] + { + return transforms_; + } __host__ __device__ static constexpr auto GetLowerDimensionHiddenIdss() { diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index 2437132d11..a237c4219d 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -7,6 +7,8 @@ #include "ck/utility/sequence_helper.hpp" #include "ck/tensor_description/multi_index_transform.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck { template @@ -179,7 +181,10 @@ struct TensorDescriptor } // TODO make these private - __host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; } + __host__ __device__ constexpr const auto& GetTransforms() const [[clang::lifetimebound]] + { + return transforms_; + } __host__ __device__ static constexpr auto GetLowerDimensionIdss() { @@ -253,9 +258,12 @@ struct TensorCoordinate __host__ __device__ constexpr index_t GetOffset() const { return idx_hidden_[Number<0>{}]; } // TODO make these private - __host__ __device__ constexpr const auto& GetHiddenIndex() const { return idx_hidden_; } + __host__ __device__ constexpr const auto& GetHiddenIndex() const [[clang::lifetimebound]] + { + return idx_hidden_; + } - __host__ __device__ auto& GetHiddenIndex() { return idx_hidden_; } + __host__ __device__ auto& GetHiddenIndex() [[clang::lifetimebound]] { return idx_hidden_; } __host__ __device__ constexpr auto GetVisibleIndex() const { @@ -284,7 +292,7 @@ struct TensorCoordinateStep __host__ __device__ constexpr const auto& GetIndexDiff() const { return GetVisibleIndexDiff(); } // TODO make these private - __host__ __device__ constexpr const auto& GetVisibleIndexDiff() const + __host__ __device__ constexpr const auto& GetVisibleIndexDiff() const [[clang::lifetimebound]] { return idx_diff_visible_; } @@ -613,3 +621,4 @@ using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step( TensorDesc{}, MultiIndex::GetNumOfDimension()>{})); } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp index 260ebcf4cc..35d987a79a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp @@ -63,7 +63,10 @@ struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2 true> c_thread_buf_; - __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + __host__ __device__ constexpr auto& GetCThreadBuffer() [[clang::lifetimebound]] + { + return c_thread_buf_; + } __device__ static auto GetWaveIdx() { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp index f831c0f6cf..e41cf8c82d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -10,6 +10,8 @@ #include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck { template @@ -1031,3 +1033,4 @@ struct BlockwiseGemmXdlops_v2 }; } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp index 1dba7f67a1..65a326e3e7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp @@ -8,6 +8,9 @@ #include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck { template ::value, bool>::type = false> -std::ostream& operator<<(std::ostream& os, const Layout&) +std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const Layout&) { os << Layout::name; return os; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 8c316bc71d..6060889c10 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -17,6 +17,9 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck { // Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to @@ -1132,3 +1135,4 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight }; // namespace ck } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/utility/amd_wave_read_first_lane.hpp b/include/ck/utility/amd_wave_read_first_lane.hpp index 44259f0601..4b64b76cc7 100644 --- a/include/ck/utility/amd_wave_read_first_lane.hpp +++ b/include/ck/utility/amd_wave_read_first_lane.hpp @@ -44,7 +44,8 @@ struct get_carrier<3> // replacement of host std::copy_n() template - __device__ static OutputIterator copy_n(InputIterator from, Size size, OutputIterator to) + __device__ static OutputIterator + copy_n(InputIterator from, Size size, [[clang::lifetimebound]] OutputIterator to) { if(0 < size) { diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index ebdbbb107d..204b199629 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -4,6 +4,8 @@ #pragma once #include "ck/utility/data_type.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck { // vector_type @@ -116,7 +118,7 @@ struct vector_type()>> __host__ __device__ constexpr vector_type(type v) : data_{v} {} template - __host__ __device__ constexpr const auto& AsType() const + __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value, "Something went wrong, please check src and dst types."); @@ -136,7 +138,7 @@ struct vector_type()>> } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value, "Something went wrong, please check src and dst types."); @@ -248,7 +250,7 @@ struct vector_type()>> __host__ __device__ constexpr vector_type(type v) : data_{v} {} template - __host__ __device__ constexpr const auto& AsType() const + __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value || is_same::value, "Something went wrong, please check src and dst types."); @@ -272,7 +274,7 @@ struct vector_type()>> } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value || is_same::value, "Something went wrong, please check src and dst types."); @@ -583,7 +585,7 @@ struct vector_type()>> } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value || is_same::value || is_same::value, @@ -754,7 +756,7 @@ struct vector_type()>> } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value || is_same::value || is_same::value || @@ -1427,7 +1429,7 @@ struct non_native_vector_base< } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same_v || is_same_v || is_same_v, "Something went wrong, please check src and dst types."); @@ -1627,7 +1629,7 @@ struct vector_type()>> __host__ __device__ constexpr vector_type(type v) : data_{v} {} template - __host__ __device__ constexpr const auto& AsType() const + __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value || is_same::value, @@ -1797,7 +1799,7 @@ struct vector_type()>> } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value || is_same::value || is_same::value || @@ -2284,3 +2286,4 @@ using pk_i4x4_t = typename vector_type::type; using pk_i4x8_t = typename vector_type::type; } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/utility/env.hpp b/include/ck/utility/env.hpp index 0cb0b4caf8..4cabd89e33 100644 --- a/include/ck/utility/env.hpp +++ b/include/ck/utility/env.hpp @@ -9,6 +9,9 @@ #include #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck { namespace internal { template @@ -188,5 +191,5 @@ void UpdateEnvVar(EnvVar, const std::string_view& val) // environment variable to enable logging: // export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) - +#pragma clang diagnostic pop #endif diff --git a/include/ck/utility/pipeline_enum.hpp b/include/ck/utility/pipeline_enum.hpp index 4421386f59..a224011a04 100644 --- a/include/ck/utility/pipeline_enum.hpp +++ b/include/ck/utility/pipeline_enum.hpp @@ -25,7 +25,8 @@ enum struct PipelineVersion } // namespace ck #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) -inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) +inline std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const ck::PipelineVersion& p) { switch(p) { diff --git a/include/ck/utility/scheduler_enum.hpp b/include/ck/utility/scheduler_enum.hpp index 0c4bfabaf3..67c5c3b50a 100644 --- a/include/ck/utility/scheduler_enum.hpp +++ b/include/ck/utility/scheduler_enum.hpp @@ -70,7 +70,8 @@ enum struct TailNumber } // namespace ck #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) -inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s) +inline std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const ck::LoopScheduler& s) { switch(s) { diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp index d49817eb8f..7e47da5bf8 100644 --- a/include/ck/utility/static_buffer.hpp +++ b/include/ck/utility/static_buffer.hpp @@ -5,6 +5,8 @@ #include "statically_indexed_array.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck { // static buffer for scalar @@ -104,7 +106,7 @@ struct StaticBufferTupleOfVector // Set S // i is offset of S template - __host__ __device__ constexpr S& operator()(Number i) + __host__ __device__ constexpr S& operator()(Number i) [[clang::lifetimebound]] { constexpr auto i_v = i / s_per_v; constexpr auto i_s = i % s_per_v; @@ -195,3 +197,4 @@ __host__ __device__ constexpr auto make_static_buffer(LongNumber) } } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/utility/tuple.hpp b/include/ck/utility/tuple.hpp index 1657595030..16cd35e1d6 100644 --- a/include/ck/utility/tuple.hpp +++ b/include/ck/utility/tuple.hpp @@ -51,7 +51,7 @@ get_tuple_element_data_reference(const TupleElementKeyData& x) // for write access of tuple element template __host__ __device__ constexpr Data& -get_tuple_element_data_reference(TupleElementKeyData& x) +get_tuple_element_data_reference([[clang::lifetimebound]] TupleElementKeyData& x) { return x.mData; } @@ -106,6 +106,7 @@ struct TupleImpl, Xs...> : TupleElementKeyData __host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey) + [[clang::lifetimebound]] { return get_tuple_element_data_reference>(*this); } @@ -147,7 +148,7 @@ struct Tuple : detail::TupleImpl - __host__ __device__ constexpr auto& At(Number) + __host__ __device__ constexpr auto& At(Number) [[clang::lifetimebound]] { static_assert(I < base::Size(), "wrong! out of range"); return base::GetElementDataByKey(detail::TupleElementKey{}); @@ -162,7 +163,7 @@ struct Tuple : detail::TupleImpl - __host__ __device__ constexpr auto& operator()(Number i) + __host__ __device__ constexpr auto& operator()(Number i) [[clang::lifetimebound]] { return At(i); } diff --git a/include/ck/wrapper/layout.hpp b/include/ck/wrapper/layout.hpp index 334d5851db..6d99f4e5e3 100644 --- a/include/ck/wrapper/layout.hpp +++ b/include/ck/wrapper/layout.hpp @@ -5,6 +5,9 @@ #include "ck/wrapper/utils/layout_utils.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + // Disable from doxygen docs generation /// @cond INTERNAL namespace ck { @@ -482,3 +485,4 @@ struct Layout } // namespace wrapper } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/wrapper/tensor.hpp b/include/ck/wrapper/tensor.hpp index 9f8278a357..ed7f2fa23d 100644 --- a/include/ck/wrapper/tensor.hpp +++ b/include/ck/wrapper/tensor.hpp @@ -7,6 +7,9 @@ #include "utils/tensor_partition.hpp" #include "utils/layout_utils.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + // Disable from doxygen docs generation /// @cond INTERNAL namespace ck { @@ -441,3 +444,4 @@ struct Tensor } // namespace wrapper } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 438e44f5f1..91212292d2 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -54,6 +54,7 @@ #include "ck_tile/core/numeric/null_type.hpp" #include "ck_tile/core/numeric/numeric.hpp" #include "ck_tile/core/numeric/pk_fp4.hpp" +#include "ck_tile/core/numeric/pk_fp6.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" #include "ck_tile/core/numeric/type_convert.hpp" #include "ck_tile/core/numeric/vector_type.hpp" diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index 732799cef8..30c93b8f00 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -11,6 +11,9 @@ #include "ck_tile/core/utility/magic_div.hpp" #include "ck_tile/core/utility/print.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { enum struct coord_transform_enum @@ -1776,3 +1779,4 @@ make_indexing_transform_with_adaptor(const UpLength& up_lengths, const IndexingA } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 8f9dd30bda..3a7231f71d 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1417,7 +1417,7 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset) { - static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16 || N == 32 || N == 64, "wrong! not implemented"); using rtn_type = thread_buffer; @@ -1457,6 +1457,15 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, return bit_cast(tmp); } + else if constexpr(N == 12) + { + auto tmp = llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } else if constexpr(N == 16) { int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index bdc0daaed2..e26ac2e600 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1134,6 +1134,25 @@ llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32"); +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_i32x3_(int32x3_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v3i32"); + +CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x3(dwordx3_union vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset) +{ + int32x3_t v_reg; + v_reg[0] = vdata.as_i32[0]; + v_reg[1] = vdata.as_i32[1]; + v_reg[2] = vdata.as_i32[2]; + llvm_amdgcn_raw_buffer_store_i32x3_(v_reg, rsrc, voffset, soffset, 0); +}; + CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, int32x4_t rsrc, @@ -1290,7 +1309,7 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset) { - static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16 || N == 32 || N == 64, "wrong! not implemented"); using rtn_type = thread_buffer; @@ -1330,6 +1349,18 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, return bit_cast(tmp); } + else if constexpr(N == 12) + { + auto tmp = llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + dwordx3_union ret; + ret.as_i32[0] = tmp[0]; + ret.as_i32[1] = tmp[1]; + ret.as_i32[2] = tmp[2]; + return bit_cast(ret); + } else if constexpr(N == 16) { int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, @@ -1411,15 +1442,19 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) || - (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32))), + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) || + (std::is_same::value && (N == 1)), "wrong! not implemented"); using rtn_type = thread_buffer; @@ -1750,7 +1785,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer(coherence)); } + else if constexpr(N == 12) + { + llvm_amdgcn_raw_buffer_store_i32x3(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset); + } else if constexpr(N == 16) { llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), @@ -1859,10 +1901,13 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer src_thread_d (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + std::is_same::value && (N == 1), "wrong! not implemented"); if constexpr(std::is_same::value) // fp32 diff --git a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp index 4c9ef7d6ba..1eef5819bc 100644 --- a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp +++ b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp @@ -7,6 +7,9 @@ #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/utility/ignore.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile::core::arch::mma { /** @@ -112,6 +115,7 @@ struct amdgcn_mma }; } // namespace ck_tile::core::arch::mma +#pragma clang diagnostic pop // Include the implementations #include "wmma/wmma.hpp" diff --git a/include/ck_tile/core/container/map.hpp b/include/ck_tile/core/container/map.hpp index d342235b38..8c861ceeb6 100644 --- a/include/ck_tile/core/container/map.hpp +++ b/include/ck_tile/core/container/map.hpp @@ -8,6 +8,9 @@ #include "ck_tile/core/container/sequence.hpp" #include "ck_tile/core/container/tuple.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { // naive map @@ -157,3 +160,4 @@ CK_TILE_HOST_DEVICE static void print(const map& m) } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index 7f8176d5ec..11e7b1e52f 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -13,6 +13,9 @@ #include #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + #ifndef CK_TILE_TUPLE_IMPL #define CK_TILE_TUPLE_IMPL 1 #endif @@ -98,13 +101,14 @@ CK_TILE_HOST_DEVICE constexpr T getv(const tuple_object&) } template -CK_TILE_HOST_DEVICE constexpr const T& getv(const tuple_object& x) +CK_TILE_HOST_DEVICE constexpr const T& +getv([[clang::lifetimebound]] const tuple_object& x) { return x.element; } template -CK_TILE_HOST_DEVICE constexpr T& getv(tuple_object& x) +CK_TILE_HOST_DEVICE constexpr T& getv([[clang::lifetimebound]] tuple_object& x) { return x.element; } @@ -292,7 +296,7 @@ struct tuple : impl::tuple_base, T...> //template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TP_COM_(); return reinterpret_cast&>(*this).at(i); } template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number) { TP_COM_(); return reinterpret_cast&>(*this).at(number{}); } template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number) const { TP_COM_(); return reinterpret_cast&>(*this).at(number{}); } - + // template CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TP_COM_(); reinterpret_cast&>(*this).at(i) = x; } template CK_TILE_HOST_DEVICE constexpr void set_as(number, const Tx & x) { TP_COM_(); reinterpret_cast&>(*this).at(number{}) = x; } @@ -864,3 +868,4 @@ struct tuple_element> } \ }() #endif +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/numeric/e8m0.hpp b/include/ck_tile/core/numeric/e8m0.hpp index 41aeb8ffab..ee12524283 100644 --- a/include/ck_tile/core/numeric/e8m0.hpp +++ b/include/ck_tile/core/numeric/e8m0.hpp @@ -6,6 +6,9 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/mxfp_convert.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { /** @@ -100,3 +103,4 @@ CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t::operator float() const } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index d74db6b336..5822e3b9bc 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -9,6 +9,9 @@ #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/mxfp_convert.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + #if defined(__gfx950__) #define CK_TILE_FP4_CVT_DEVICE 1 #else @@ -517,3 +520,4 @@ CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const #endif } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/numeric/pk_fp6.hpp b/include/ck_tile/core/numeric/pk_fp6.hpp new file mode 100644 index 0000000000..0de61f6b1f --- /dev/null +++ b/include/ck_tile/core/numeric/pk_fp6.hpp @@ -0,0 +1,109 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/mxfp_convert.hpp" + +namespace ck_tile { +template +struct pk_fp6_t +{ + static constexpr index_t num_bits_elem = 6; + using element_type = int32_t; // element storage fundamental type + static constexpr index_t packed_size = pk_size; + static constexpr index_t num_bits_vec_elem = + sizeof(element_type) * 8; // 32-bit uint for storage + static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0, + "Packed elements must fit exactly into the element storage."); + static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem; + element_type data_[vector_size]; // packed data + using type = pk_fp6_t; + CK_TILE_HOST_DEVICE constexpr explicit pk_fp6_t(int value = 0) + { + for(size_t i = 0; i < vector_size; ++i) + { + data_[i] = value; + } + } + CK_TILE_HOST_DEVICE void pack(const int32_t x, const index_t i) + { + int32_t bits = static_cast(x) & 0x3F; + const int bit_pos = i * num_bits_elem; + const int arr_index = bit_pos / num_bits_vec_elem; + const int bit_offset = bit_pos % num_bits_vec_elem; + const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + int32_t old_value = data_[arr_index]; + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + data_[arr_index] = old_value; + + // if it crosses into the next block, shift the remainder + if(overhang > 0 && (arr_index + 1) < vector_size) + { + int32_t next_value = data_[arr_index + 1]; + next_value |= (bits >> (num_bits_elem - overhang)); + data_[arr_index + 1] = next_value; + } + } + + template + CK_TILE_HOST_DEVICE static int32_t unpack(const T& pk, const index_t i) + { + const int bit_pos = i * num_bits_elem; + const int arr_idx = bit_pos / num_bits_vec_elem; + const int bit_offset = bit_pos % num_bits_vec_elem; + const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + int32_t bits = pk.data_[arr_idx] >> bit_offset; + if(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (pk.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang); + } + + return bits & 0x3F; + } + + CK_TILE_HOST_DEVICE int32_t unpack(const index_t i) const { return unpack(*this, i); } + + CK_TILE_HOST_DEVICE int32_t operator[](index_t i) const { return data_[i]; } + + CK_TILE_HOST_DEVICE static float fp6_e2m3_to_float(int32_t fp6_bits) + { + fp6_bits = fp6_bits & 0x3F; + + uint32_t sign = (fp6_bits >> 5) & 0x1; // bit 5 + uint32_t exponent = (fp6_bits >> 3) & 0x3; // bits 4-3 + uint32_t mantissa = fp6_bits & 0x7; // bits 2-0 + + float result; + if(exponent == 0 && mantissa == 0) + { + result = 0.f; + } + else if(exponent != 0) + { + result = std::exp2f(static_cast(exponent) - 1); + float mantissa_value = 1.0f + mantissa / 8.0f; + result *= mantissa_value; + } + else + { + result = mantissa / 8.0f; + } + return sign == 1 ? -1 * result : result; + } +}; + +using pk_fp6x16_t = pk_fp6_t<16>; +using pk_fp6x32_t = pk_fp6_t<32>; +template <> +struct numeric_traits +{ + static constexpr int PackedSize = 16; +}; +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp index deaa9e0bd9..634b845725 100644 --- a/include/ck_tile/core/numeric/type_convert.hpp +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -72,6 +72,7 @@ CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, fp32x2_t, fp32x2) } // namespace ck_tile #include "ck_tile/core/numeric/pk_fp4.hpp" +#include "ck_tile/core/numeric/pk_fp6.hpp" namespace ck_tile { diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index def054f415..756bc7f6fc 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -160,6 +160,40 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16))); using int32x32_t = int32_t __attribute__((ext_vector_type(32))); using int32x64_t = int32_t __attribute__((ext_vector_type(64))); +struct int32x3_tt +{ + int32_t data[3]; +}; + +struct int32x6_tt +{ + int32_t data[6]; +}; + +template <> +struct impl::ext_vector +{ + static constexpr index_t N = 12; + using value_type = int32x3_tt; + using type = int32x3_tt; +}; + +template <> +struct impl::ext_vector +{ + static constexpr index_t N = 1; + using value_type = int32x3_tt; + using type = int32x3_tt; +}; + +template <> +struct impl::ext_vector +{ + static constexpr index_t N = 2; + using value_type = int32x6_tt; + using type = int32x6_tt; +}; + // u32 // using uint32_t = ... using uint32x2_t = uint32_t __attribute__((ext_vector_type(2))); diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index f3aeed6e61..59f82939b9 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -303,7 +303,6 @@ struct buffer_view>::scalar_type, - scalar_per_t_vector * scalar_per_x_vector>; - // using buf_t = ushort __attribute__((ext_vector_type(8))); - auto rtn = *c_style_pointer_cast(&p_data_[i + linear_offset]); - return bit_cast(rtn); + constexpr index_t load_elts = scalar_per_t_vector * scalar_per_x_vector; + if constexpr(load_elts == 12 && sizeof(typename X::value_type) == 1) + { + auto rtn = reinterpret_cast(p_data_) + (i + linear_offset) / 4; + struct + { + int32_t x, y, z; + } tmp = {rtn[0], rtn[1], rtn[2]}; + return bit_cast(tmp); + } + else + { + using buf_t = ext_vector_t>::scalar_type, + scalar_per_t_vector * scalar_per_x_vector>; + auto rtn = *c_style_pointer_cast(&p_data_[i + linear_offset]); + return bit_cast(rtn); + } #endif } else @@ -968,6 +979,7 @@ struct buffer_view, int8x16_t> && std::is_same_v, int8x16_t>) || // int8 on thread buffer (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || + (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || (std::is_same_v, int8_t> && std::is_same_v, thread_buffer>) || @@ -1033,6 +1045,11 @@ struct buffer_view(&p_data_[i]) = *c_style_pointer_cast(&x); } + else if constexpr(std::is_same_v, thread_buffer>) + { + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } else if constexpr((std::is_same_v, int8_t> && std::is_same_v, int8x16_t>) || (std::is_same_v, int8_t> && @@ -1075,6 +1092,12 @@ struct buffer_view(&p_data_[i]) = *c_style_pointer_cast(&x); } + else + { + static_assert(false, + "wrong! not implemented for this combination, please add " + "implementation"); + } } } else diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index 10c7587bcb..bdd81dae07 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -14,6 +14,9 @@ #include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/core/container/thread_buffer.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { template @@ -266,3 +269,4 @@ inline constexpr bool is_similiar_distributed_tensor_v = } // namespace detail } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/tensor/tensor_adaptor.hpp b/include/ck_tile/core/tensor/tensor_adaptor.hpp index 78160b800d..e6cdb66ef9 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor.hpp @@ -12,6 +12,9 @@ #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/numeric/numeric.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { // Transforms: Tuple @@ -950,3 +953,4 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&.. remove_cvref_t, \ remove_cvref_t>{trans}; \ }() +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp b/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp index 2ea76a3814..6d33bde83e 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp @@ -14,6 +14,9 @@ #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/print.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { template @@ -367,3 +370,4 @@ CK_TILE_HOST_DEVICE void print(const tensor_adaptor_coordinate& coord) detail::CK_PRINT_X_<>{}(coord); } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 837f2b87a6..833a7f4413 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -14,6 +14,9 @@ #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { /* @@ -582,3 +585,4 @@ pad_tensor_view(const TensorView& tensor_view, const TileLengths& tile_lengths, } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp index f9c2aba502..aa5714e5c2 100644 --- a/include/ck_tile/core/tensor/tile_distribution.hpp +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -15,6 +15,9 @@ #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { template @@ -731,3 +734,4 @@ CK_TILE_HOST_DEVICE void print(const tile_distribution #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { template @@ -206,3 +209,4 @@ void UpdateEnvVar(EnvVar, const std::string_view& val) // environment variable to enable logging: // export CK_TILE_LOGGING=ON or CK_TILE_LOGGING=1 or CK_TILE_LOGGING=ENABLED CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING) +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp index aa4bfa3f15..ae79d575a8 100644 --- a/include/ck_tile/core/utility/functional.hpp +++ b/include/ck_tile/core/utility/functional.hpp @@ -10,6 +10,8 @@ #include #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck_tile { namespace detail { @@ -270,3 +272,4 @@ constexpr auto conditional_expr(X&& x, Y&& y) } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/host/arg_parser.hpp b/include/ck_tile/host/arg_parser.hpp index 8c45d2b175..fee7f7779b 100644 --- a/include/ck_tile/host/arg_parser.hpp +++ b/include/ck_tile/host/arg_parser.hpp @@ -13,6 +13,9 @@ #include #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { /* * a host side utility, arg parser for, either @@ -234,3 +237,4 @@ class ArgParser std::vector keys; }; } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index 2ba3b1e7c3..a2f6728316 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -720,4 +720,57 @@ std::enable_if_t<(std::is_same_v, ranges::range_val return err_count == 0; } +/** + * @brief Check errors between pk_fp6x16_t ranges + * + * Compares two ranges of pk_fp6x16_t without tolerance. + * This specialization handles ck_tile::pk_fp6x16_t type. + * + * @tparam Range Type of output range + * @tparam RefRange Type of reference range + * @param out Output range to check + * @param ref Reference range to check against + * @param msg Error message to display if check fails + * @return True if check passes, false otherwise + */ +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_same_v, pk_fp6x16_t>), + bool> + CK_TILE_HOST check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double = 0, + double = 0) +{ + if(check_size_mismatch(out, ref, msg)) + return false; + + int err_count = 0; + float max_err = 0.0f; + auto update_err = [&](float o, float r, std::size_t index) { + if(std::fabs(o - r) > 1e-8) + { + std::cerr << msg << " out[" << index << "] != ref[" << index << "]: " << o + << " != " << r << std::endl; + ++err_count; + max_err = max_err < std::fabs(o - r) ? o : max_err; + } + }; + for(std::size_t i = 0; i < ref.size(); ++i) + { + const pk_fp6x16_t o = *std::next(std::begin(out), i); + const pk_fp6x16_t r = *std::next(std::begin(ref), i); + for(std::size_t j = 0; j < numeric_traits::PackedSize; j++) + { + update_err(o.unpack(j), r.unpack(j), i * numeric_traits::PackedSize + j); + } + } + if(err_count > 0) + { + report_error_stats(err_count, max_err, ref.size()); + } + return err_count == 0; +} + } // namespace ck_tile diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index d26686ec37..ddeb3ad781 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -17,6 +17,9 @@ #include "ck_tile/host/joinable_thread.hpp" #include "ck_tile/host/ranges.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { template @@ -859,3 +862,4 @@ auto get_default_stride(std::size_t row, return stride; } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 7830150b63..da6b074b98 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -625,6 +625,17 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor& a_m_k, a_m_k_scaled(m, k) = a_f4_lo * a_scale; a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale; } + else if constexpr(std::is_same_v) + { + if(k % pk_fp6x16_t::packed_size != 0) + continue; + auto a_scale = ck_tile::type_convert(scale_a(m, k / ScaleBlockSize)); + for(std::size_t k_ = 0; k_ < pk_fp6x16_t::packed_size; k_++) + { + a_m_k_scaled(m, k + k_) = + pk_fp6x16_t::fp6_e2m3_to_float(a_m_k(m, k).unpack(k_)) * a_scale; + } + } else { a_m_k_scaled(m, k) = @@ -653,6 +664,17 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor& a_m_k, b_k_n_scaled(k, n) = b_f4_lo * b_scale; b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale; } + else if constexpr(std::is_same_v) + { + if(k % pk_fp6x16_t::packed_size != 0) + continue; + auto b_scale = ck_tile::type_convert(scale_b(k / ScaleBlockSize, n)); + for(std::size_t k_ = 0; k_ < pk_fp6x16_t::packed_size; k_++) + { + b_k_n_scaled(k + k_, n) = + pk_fp6x16_t::fp6_e2m3_to_float(b_k_n(k, n).unpack(k_)) * b_scale; + } + } else { b_k_n_scaled(k, n) = diff --git a/include/ck_tile/ops/common/utils.hpp b/include/ck_tile/ops/common/utils.hpp index 425083a9de..4a30e3af16 100644 --- a/include/ck_tile/ops/common/utils.hpp +++ b/include/ck_tile/ops/common/utils.hpp @@ -22,6 +22,7 @@ template <> struct DataTypeTraits { static constexpr const char * name = template <> struct DataTypeTraits { static constexpr const char * name = "int8"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_int4"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp6x16"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4_raw"; }; template struct memOpToStr; diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index bc7d2323d0..23d7a9fca9 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -118,8 +118,9 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 + ? 16 + : 16 /*dwordx4*/ * APackedSize / sizeof(ADataType); + static constexpr index_t BK1 = std::is_same_v + ? 16 + : 16 /*dwordx4*/ * BPackedSize / sizeof(BDataType); static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) ? DsReadPreload @@ -537,24 +542,26 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, number{}), + make_tuple(number{}, + number{}), {0, 0}); auto a_store_lds_window_pong = make_tile_window( // a_lds_block_pong, - make_tuple(number{}, number{}), + make_tuple(number{}, + number{}), {0, 0}); // ping-pong window for A LDS - auto a_warp_window_ping = - make_tile_window(a_lds_block_ping, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution()); - auto a_warp_window_pong = - make_tile_window(a_lds_block_pong, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution()); + auto a_warp_window_ping = make_tile_window( + a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution()); + auto a_warp_window_pong = make_tile_window( + a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution()); // B flat DRAM window for load @@ -621,7 +628,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { @@ -663,7 +670,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, MIterPerWarp> @@ -683,7 +690,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + tuple, + number>{}); }); __builtin_amdgcn_sched_barrier(0); @@ -750,7 +758,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = load_tile_with_offset( // a_warp_window_ping, tuple, - number>{}); + number>{}); } }); // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished @@ -760,7 +768,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + tuple, + number>{}); }); HotLoopScheduler(); @@ -839,7 +848,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = load_tile_with_offset( // a_warp_window_pong, tuple, - number>{}); + number>{}); } }); // barrier as ds_load A(2i + 1) and buffer_load_lds A(2i + 2) finished @@ -849,7 +858,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + tuple, + number>{}); }); HotLoopScheduler(); }; @@ -874,7 +884,6 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 0); } - // TAIL if constexpr(TailNum == TailNumber::Even) { @@ -933,7 +942,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = load_tile_with_offset( // a_warp_window_ping, tuple, - number>{}); + number>{}); } }); // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished @@ -947,7 +956,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + tuple, + number>{}); }); Last2ndHotLoopScheduler(); @@ -977,12 +987,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = - load_tile_with_offset(a_warp_window_pong, - tuple, - number>{}); + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_pong, + tuple, + number>{}); } }); LastHotLoopScheduler(); @@ -1014,12 +1024,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}) = - load_tile_with_offset(a_warp_window_ping, - tuple, - number>{}); + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_ping, + tuple, + number>{}); } }); LastHotLoopScheduler(); diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 34d18cb8e1..7cf6326dfd 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -17,6 +17,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy static constexpr index_t kDramLoadPackBytes = 128; static constexpr index_t DWORDx4 = 16; + static constexpr index_t DWORDx3 = 12; static constexpr int MXdlPack = 2; static constexpr int NXdlPack = 2; @@ -77,15 +78,16 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy CK_TILE_DEVICE static constexpr auto MakeMX_ABytesDramTileDistribution() { - constexpr index_t K2 = DWORDx4; // 16 bytes - constexpr index_t K1 = kDramLoadPackBytes / K2; // 8 - constexpr index_t K0 = KPerBlock / (K1 * K2 * APackedSize); // KPerBlock/256/packsize + constexpr index_t K2 = std::is_same_v ? DWORDx3 : DWORDx4; + constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // fp8/fp6/fp4 K1 equal to 8 + constexpr index_t K0 = + KPerBlock / APackedSize * sizeof(ADataType) / (K1 * K2); // KPerBlock/256/packsize constexpr index_t M2 = WaveSize / K1; // 8 constexpr index_t M1 = BlockSize / WaveSize; // 4 constexpr index_t M0 = MPerBlock / (M2 * M1); static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!"); - static_assert(K0 * K1 * K2 * APackedSize == KPerBlock, + static_assert(K0 * K1 * K2 == KPerBlock / APackedSize * sizeof(ADataType), "K0, K1, K2 must cover whole KPerBlock!"); return make_static_tile_distribution( @@ -107,9 +109,9 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view(); const auto [rows, cols] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); - constexpr index_t K2 = DWORDx4; // 16 bytes - constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8 - const index_t K0 = cols / (K1 * K2 * APackedSize); + constexpr index_t K2 = std::is_same_v ? DWORDx3 : DWORDx4; + constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // fp8/fp6/fp4 K1 equal to 8 + const index_t K0 = cols / (K1 * K2 / sizeof(ADataType) * APackedSize); const auto col_lens = make_tuple(K0, number{}, number{}); constexpr index_t M1 = 4; // so that we can use imm offset to load lds @@ -138,19 +140,23 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy auto&& byte_ptr = reinterpret_cast(&(tensor_view_tmp.get_buffer_view()(0))); auto&& byte_tensor_view = make_tensor_view(byte_ptr, desc); - auto&& origin_tmp = window_tmp.get_window_origin(); + auto&& origin_tmp = window_tmp.get_window_origin(); + constexpr index_t test1 = APackedSize / sizeof(ADataType); return make_tile_window(byte_tensor_view, - make_tuple(number{}, number{}), - {origin_tmp[0], origin_tmp[1] / APackedSize}, + make_tuple(number{}, number{}), + {origin_tmp[0], origin_tmp[1] / test1}, MakeMX_ABytesDramTileDistribution()); } CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBytesBlockDescriptor() { - constexpr index_t K2 = AK1 / APackedSize; // 16 - constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8 - constexpr index_t K0 = KPerBlock / (K1 * AK1); // KPerBlock/256 - static_assert(K0 * K1 * K2 * APackedSize == KPerBlock, + constexpr index_t K2 = std::is_same_v ? DWORDx3 : AK1 / APackedSize; + constexpr index_t K2_Pad = 16; + constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8 + constexpr index_t K0 = std::is_same_v + ? KPerBlock / (K1 * K2 / sizeof(ADataType) * APackedSize) + : KPerBlock / (K1 * AK1); // KPerBlock/256 + static_assert(K0 * K1 * K2 / sizeof(ADataType) * APackedSize == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); constexpr index_t M3 = 4; // so that we can use imm offset to load lds @@ -169,12 +175,12 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy number{}, number{}, number{}), - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}, - number{}, + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}, number<1>{}), number{}, number<1>{}); @@ -216,7 +222,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy { static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1"); - if constexpr(K_Thread == AK1) + if constexpr(std::is_same_v) return make_static_tile_distribution( tile_distribution_encoding< // sequence, @@ -225,7 +231,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy tuple, sequence<0, 2>>, sequence<2>, sequence<1>>{}); - else + else if constexpr(std::is_same_v) return make_static_tile_distribution( tile_distribution_encoding< sequence, @@ -235,6 +241,19 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy tuple, sequence<1, 2>>, sequence<2, 2>, sequence<0, 2>>{}); + else if constexpr(std::is_same_v) + // K_Lane=4, K_Thread=32 + return make_static_tile_distribution( + tile_distribution_encoding< // + sequence, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<0, 2>>, + sequence<2, 2>, + sequence<1, 2>>{}); + else + static_assert(false, "unsupported datatype"); } CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution() @@ -245,17 +264,17 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; - if constexpr(BK1 == K_Thread) + if constexpr(std::is_same_v) return make_static_tile_distribution( tile_distribution_encoding< // sequence, tuple, // 4 2 - sequence>, // 1 64 32 + sequence>, // 1 64 16 tuple, sequence<2>>, tuple, sequence<1>>, sequence<2>, sequence<2>>{}); - else + else if constexpr(std::is_same_v) return make_static_tile_distribution( tile_distribution_encoding< // sequence, @@ -265,6 +284,21 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy tuple, sequence<2>>, sequence<2, 2>, sequence<0, 3>>{}); + else if constexpr(std::is_same_v) + return make_static_tile_distribution( + tile_distribution_encoding< // + sequence, + tuple, // 4 2 + sequence>, // 64 1 2 12 + tuple, sequence<2>>, + tuple, sequence<1>>, + sequence<2, 2>, + sequence<2, 3>>{}); + else + static_assert(false, "unsupported datatype"); } template @@ -280,21 +314,27 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); constexpr auto flat_k_per_block = KPerBlock * M_Warp_Tile; auto&& byte_tensor_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple( - flat_n, flat_k / flat_k_per_block, number{})), + make_naive_tensor_descriptor_packed( + make_tuple(flat_n, + flat_k / flat_k_per_block, + number{})), make_tuple(make_pass_through_transform(flat_n), make_merge_transform_v3_division_mod(make_tuple( - flat_k / flat_k_per_block, number{}))), + flat_k / flat_k_per_block, + number{}))), make_tuple(sequence<0>{}, sequence<1, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); auto&& byte_ptr = reinterpret_cast(&(tensor_view_tmp.get_buffer_view()(0))); auto&& byte_tensor_view = make_tensor_view(byte_ptr, byte_tensor_desc); auto&& origin_tmp = window_tmp.get_window_origin(); + auto origin_n = origin_tmp[0]; + auto origin_k = static_cast(origin_tmp[1] * sizeof(BDataType) / BPackedSize); return make_tile_window( byte_tensor_view, - make_tuple(number{}, number{}), - {origin_tmp[0], origin_tmp[1] / BPackedSize}, + make_tuple(number{}, + number{}), + {origin_n, origin_k}, MakeMX_BFlatBytesDramTileDistribution()); } @@ -372,7 +412,14 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() { - return sizeof(ADataType) * MakeMX_ALdsBytesBlockDescriptor().get_element_space_size(); + if constexpr(!std::is_same_v) + { + return sizeof(ADataType) * MakeMX_ALdsBytesBlockDescriptor().get_element_space_size(); + } + else + { + return MakeMX_ALdsBytesBlockDescriptor().get_element_space_size(); + } } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return GetSmemSizeA(); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp index 957cf7ab8f..987704e433 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp @@ -41,7 +41,8 @@ enum struct TailNumber } // namespace ck_tile -inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineScheduler& s) +inline std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const ck_tile::GemmPipelineScheduler& s) { switch(s) { @@ -53,7 +54,8 @@ inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineSch return os; } -inline std::ostream& operator<<(std::ostream& os, const ck_tile::TailNumber& s) +inline std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const ck_tile::TailNumber& s) { switch(s) { diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 9e23a06b23..24076ca494 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1614,7 +1614,8 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4 return make_tuple(number<0>{}, int32x8_t{}); else if constexpr(std::is_same_v) return make_tuple(number<1>{}, int32x8_t{}); - // else if e2m3 => make_tuple(number<2>{}, int32x6_t{}) + else if constexpr(std::is_same_v) + return make_tuple(number<2>{}, pk_fp6x32_t{}); // else if e3m2 => make_tuple(number<3>{}, int32x6_t{}) else if constexpr(std::is_same_v) return make_tuple(number<4>{}, int32x4_t{}); diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 21bd691b49..db86fdbeac 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -380,9 +380,18 @@ struct QuantGemmKernel __device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs, const std::size_t k_id = blockIdx.z) { - constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2); - const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1); - const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1); + constexpr auto K1 = + GemmPipeline::BlockGemmShape::WarpTile::at(I2); // smallest unit of K work per block + const index_t K_t = amd_wave_read_first_lane( + kargs.k_batch * K1); // amount of K elements consumed if every split-K batch + // performs exactly one "unit" (K1) + const index_t KRead = amd_wave_read_first_lane( + (kargs.K + K_t - 1) / K_t * K1); // total k elements to be read in this batch + // offset not necessarily = KRead, because B can have packed elements (e.g. fp8i4) + constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + const index_t b_k_offset_elements = + amd_wave_read_first_lane(k_id * KRead / BPackedSize); if constexpr(std::is_same_v) { @@ -395,11 +404,11 @@ struct QuantGemmKernel if constexpr(std::is_same_v) { - b_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_B); + b_k_split_offset = amd_wave_read_first_lane(b_k_offset_elements * kargs.stride_B); } else if constexpr(std::is_same_v) { - b_k_split_offset = amd_wave_read_first_lane(k_id * KRead); + b_k_split_offset = amd_wave_read_first_lane(b_k_offset_elements); } if(k_id < static_cast(kargs.k_batch - 1)) @@ -410,10 +419,47 @@ struct QuantGemmKernel { splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1)); } + + // Compute BQ offset for BQuantGrouped mode (non-preshuffle only) + // Note: With the alignment validation in IsSupportedArgument, KRead is always + // a multiple of BQuantGroupSize::kK, so bq_k_split_offset will be correctly aligned. + if constexpr(kQuantType == QuantType::BQuantGrouped && !BPreshuffleQuant) + { + using BQuantGroupSize = remove_cvref_t; + // Compute the K offset for this batch (in terms of K elements) + const index_t k_offset = amd_wave_read_first_lane(k_id * KRead); + // Convert K offset to BQ group offset (logical offset in K/kK dimension) + bq_group_offset = amd_wave_read_first_lane(k_offset / BQuantGroupSize::kK); + + // BQ tensor layout: + // RowMajor: [K/kK, N/kN] with stride [N/kN, 1] + // ColumnMajor: [N/kN, K/kK] with stride [K/kK, 1] + if constexpr(std::is_same_v) + { + // For RowMajor BQ, K is the row dimension + // offset = bq_group_offset * stride_BQ + const index_t stride_bq = + amd_wave_read_first_lane(integer_divide_ceil(kargs.N, BQuantGroupSize::kN)); + bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset * stride_bq); + } + else if constexpr(std::is_same_v) + { + // For ColumnMajor BQ, K is the column dimension + // offset = bq_group_offset + bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset); + } + } + else + { + bq_group_offset = 0; + bq_k_split_offset = 0; + } } index_t a_k_split_offset; index_t b_k_split_offset; + index_t bq_group_offset; // Logical offset in K-groups (K/kK dimension) + index_t bq_k_split_offset; // Memory pointer offset (accounting for layout/stride) index_t splitted_k; }; @@ -805,10 +851,13 @@ struct QuantGemmKernel CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr, const QuantGemmKernelArgs& kargs, + const index_t bq_group_offset, const index_t i_m, const index_t i_n) { // Step 1: Create tensor view for BQ + // Note: For split-K, the bq_ptr is already offset by bq_k_split_offset (pointer offset). + // The dimension should use the remaining K-groups from this offset position. const auto& bq_tensor_view = [&]() { if constexpr(kQuantType == QuantType::RowColQuant) { @@ -850,11 +899,12 @@ struct QuantGemmKernel "ABQuantGrouped requires ColumnMajor BQ layout"); } + using BQuantGroupSize = remove_cvref_t; if constexpr(std::is_same_v) { return make_naive_tensor_view( bq_ptr, - make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK), + make_tuple(kargs.QK_B - bq_group_offset, integer_divide_ceil(kargs.N, BQuantGroupSize::kN)), make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), 1), number{}, @@ -865,8 +915,8 @@ struct QuantGemmKernel return make_naive_tensor_view( bq_ptr, make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), - integer_divide_ceil(kargs.K, BQuantGroupSize::kK)), - make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK), 1), + kargs.QK_B - bq_group_offset), + make_tuple(kargs.QK_B, 1), number{}, number<1>{}); } @@ -1047,13 +1097,61 @@ struct QuantGemmKernel CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs) { + // Split-K is supported for BQuantGrouped mode without preshuffle if(kargs.k_batch != 1) { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + constexpr bool is_bquant_non_preshuffle = + (kQuantType == QuantType::BQuantGrouped) && !BPreshuffleQuant; + if constexpr(!is_bquant_non_preshuffle) { - CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 ! " + "Split-K only supported for BQuantGrouped without preshuffle."); + } + return false; + } + else + { + using BQuantGroupSize = remove_cvref_t; + constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2); + const index_t K_t = kargs.k_batch * K1; + const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; + constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + // Constraint 1: KRead must align with B packing requirements. + // For packed data types, multiple K elements are stored in each storage unit. + // Split-K advances the B pointer by (KRead / BPackedSize) storage units per batch. + // If KRead is not divisible by BPackedSize, this division produces a fractional + // offset, making it impossible to start reading from a valid storage unit boundary. + if(KRead % BPackedSize != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("KRead must be a multiple of B packed size for split-K!"); + } + return false; + } + + // Constraint 2: KRead must align with quantization group boundaries. + // Each split-K batch reads KRead consecutive K elements. If KRead is not + // a multiple of BQuantGroupSize::kK, the batch will span partial quantization + // groups, requiring split access to a quantization scale. This violates the + // atomic processing requirement where each batch must work with complete groups. + if(KRead % BQuantGroupSize::kK != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Split-K batch size must be aligned with quantization group " + "size! KRead=" + + std::to_string(KRead) + + " is not divisible by BQuantGroupSize::kK=" + + std::to_string(BQuantGroupSize::kK)); + } + return false; + } } - return false; } if constexpr(std::is_same_v) @@ -1215,7 +1313,10 @@ struct QuantGemmKernel const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n); - const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); + // Note: Pass bq_group_offset so the tensor view dimension reflects + // the remaining K-groups from the split-K offset position. + const auto& bq_block_window = MakeBQBlockWindow( + bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); @@ -1343,8 +1444,9 @@ struct QuantGemmKernel const BDataType* b_ptr = static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; const AQDataType* aq_ptr = static_cast(kargs.aq_ptr); - const BQDataType* bq_ptr = static_cast(kargs.bq_ptr); - CDataType* c_ptr = static_cast(kargs.c_ptr); + const BQDataType* bq_ptr = + static_cast(kargs.bq_ptr) + splitk_batch_offset.bq_k_split_offset; + CDataType* c_ptr = static_cast(kargs.c_ptr); // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index c9e725f5fd..8b77b01e2f 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -387,8 +387,8 @@ struct QuantGroupedGemmKernel Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); const auto& b_block_window = Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); - const auto& bq_block_window = - Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); + const auto& bq_block_window = Base::MakeBQBlockWindow( + bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n); const index_t num_loop = __builtin_amdgcn_readfirstlane( TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); @@ -453,8 +453,8 @@ struct QuantGroupedGemmKernel Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); const auto& aq_block_window = Base::MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n); - const auto& bq_block_window = - Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); + const auto& bq_block_window = Base::MakeBQBlockWindow( + bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n); // Get hot-loop and tail configuration const index_t num_loop = __builtin_amdgcn_readfirstlane( diff --git a/profiler/src/profiler_operation_registry.hpp b/profiler/src/profiler_operation_registry.hpp index 28674554a1..fd698ee340 100644 --- a/profiler/src/profiler_operation_registry.hpp +++ b/profiler/src/profiler_operation_registry.hpp @@ -9,6 +9,9 @@ #include #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + class ProfilerOperationRegistry final { ProfilerOperationRegistry() = default; @@ -83,3 +86,4 @@ class ProfilerOperationRegistry final ::ProfilerOperationRegistry::GetInstance().Add(name, description, operation) \ _Pragma("clang diagnostic pop") // clang-format on +#pragma clang diagnostic pop diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 8e005d588e..2b19053f41 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -128,6 +128,17 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_bquant_transpose PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # BQuant split-K tests (no preshuffle) + add_gtest_executable(test_tile_gemm_quant_bquant_splitk_decode + test_gemm_quant_bquant_splitk_decode.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_splitk_decode PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_bquant_splitk_prefill + test_gemm_quant_bquant_splitk_prefill.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_splitk_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # BQuant tests (with PreshuffleB) - split into 5 files add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_decode_1d test_gemm_quant_bquant_preshuffle_decode_1d.cpp diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_decode.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_decode.cpp new file mode 100644 index 0000000000..ea1a8a1fbb --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_decode.cpp @@ -0,0 +1,61 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize128 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant split-K tests - Decode shape, GroupSize 128 +// Tuple format: +// clang-format off +using BQuantSplitKDecodeTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant split-K Decode +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantSplitKDecodeTypes); + +// BQuant split-K tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK2Test) +{ + // K=1024 for split_k=2: 1024/2=512=4×128 ✓ + this->run_test_with_validation(32, 128, 1024, 2); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK3Test) +{ + // K=3072 for split_k=3: 3072/3=1024=8×128 ✓ + this->run_test_with_validation(32, 128, 3072, 3); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK4Test) +{ + // K=2048 for split_k=4: 2048/4=512=4×128 ✓ + this->run_test_with_validation(32, 128, 2048, 4); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK5Test) +{ + // K=2560 for split_k=5: 2560/5=512=4×128 ✓ + // Also K must be divisible by K_Tile(256)*split_k(5)=1280 + this->run_test_with_validation(32, 128, 2560, 5); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_prefill.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_prefill.cpp new file mode 100644 index 0000000000..f4f93dbbb6 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_prefill.cpp @@ -0,0 +1,64 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize128 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant split-K tests - Prefill shape, GroupSize 128 +// Tuple format: +// clang-format off +using BQuantSplitKPrefillTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant split-K Prefill +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantSplitKPrefillTypes); + +// BQuant split-K tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK2Test) +{ + // K=1024 for split_k=2: 1024/2=512=4×128 ✓ + // K must be divisible by K_Tile(128)*split_k(2)=256 + this->run_test_with_validation(128, 128, 1024, 2); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK3Test) +{ + // K=3072 for split_k=3: 3072/3=1024=8×128 ✓ + // K must be divisible by K_Tile(128)*split_k(3)=384 + this->run_test_with_validation(128, 128, 3072, 3); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK4Test) +{ + // K=2048 for split_k=4: 2048/4=512=4×128 ✓ + // K must be divisible by K_Tile(128)*split_k(4)=512 + this->run_test_with_validation(128, 128, 2048, 4); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK5Test) +{ + // K=1920 for split_k=5: 1920/5=384=3×128 ✓ + // K must be divisible by K_Tile(128)*split_k(5)=640 + this->run_test_with_validation(128, 128, 1920, 5); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 0033bb42a8..ca21bc69b7 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -655,7 +655,10 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase b_k_n_dev = b_k_n; @@ -746,12 +752,12 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBasetemplate calculate_rtol_atol( - K, 1, max_accumulated_value); + K, k_batch, max_accumulated_value); // Validate results bool pass = ck_tile::check_err(c_m_n_dev_result, @@ -806,7 +812,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase{})); EXPECT_TRUE(pass) << "BQuantGrouped validation failed with M=" << M << ", N=" << N - << ", K=" << K; + << ", K=" << K << ", k_batch=" << k_batch; if(!pass) { diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt index 1390e5ee07..2fac12ebe4 100644 --- a/test/ck_tile/gemm_streamk/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -23,19 +23,6 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950") #TODO: support all arches #TODO: current c-shuffle only supports C layout as R add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp) - add_gtest_executable(test_ck_tile_streamk_reduction - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp - test_gemm_streamk_util.cpp) - add_gtest_executable(test_ck_tile_streamk_smoke - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_persistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_persistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp8_persistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf8_persistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_nonpersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_nonpersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp8_nonpersistent.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf8_nonpersistent.cpp - test_gemm_streamk_util.cpp) add_gtest_executable(test_ck_tile_streamk_extended ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent.cpp @@ -46,7 +33,6 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950") ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent.cpp test_gemm_streamk_util.cpp) - target_compile_options(test_ck_tile_streamk_smoke PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(test_ck_tile_streamk_extended PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping test_ck_tile_streamk unit tests for current target") diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf16_nonpersistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf16_nonpersistent.cpp deleted file mode 100644 index 95117b6f0d..0000000000 --- a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf16_nonpersistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKBf16NonPersistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistent - -TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistent, KernelTypesStreamKBf16NonPersistent); - -#include "test_gemm_streamk_smoke_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf16_persistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf16_persistent.cpp deleted file mode 100644 index 5e0705ab29..0000000000 --- a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf16_persistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKBf16Persistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKBf16Persistent - -TYPED_TEST_SUITE(TestCkTileStreamKBf16Persistent, KernelTypesStreamKBf16Persistent); - -#include "test_gemm_streamk_smoke_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf8_nonpersistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf8_nonpersistent.cpp deleted file mode 100644 index 21e447af29..0000000000 --- a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf8_nonpersistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKBf8NonPersistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistent - -TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistent, KernelTypesStreamKBf8NonPersistent); - -#include "test_gemm_streamk_smoke_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf8_persistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf8_persistent.cpp deleted file mode 100644 index 62b7767a69..0000000000 --- a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_bf8_persistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKBf8Persistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKBf8Persistent - -TYPED_TEST_SUITE(TestCkTileStreamKBf8Persistent, KernelTypesStreamKBf8Persistent); - -#include "test_gemm_streamk_smoke_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_nonpersistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_nonpersistent.cpp deleted file mode 100644 index fc18b9ebf7..0000000000 --- a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_nonpersistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKFp16NonPersistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistent - -TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistent, KernelTypesStreamKFp16NonPersistent); - -#include "test_gemm_streamk_smoke_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_persistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_persistent.cpp deleted file mode 100644 index 8756da4ad8..0000000000 --- a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_persistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKFp16Persistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKFp16Persistent - -TYPED_TEST_SUITE(TestCkTileStreamKFp16Persistent, KernelTypesStreamKFp16Persistent); - -#include "test_gemm_streamk_smoke_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_reduction.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_reduction.cpp deleted file mode 100644 index bcd4583da2..0000000000 --- a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_reduction.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKFp16Reduction : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKFp16Reduction - -TYPED_TEST_SUITE(TestCkTileStreamKFp16Reduction, KernelTypesStreamKFp16Reduction); - -#include "test_gemm_streamk_reduction_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp8_nonpersistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp8_nonpersistent.cpp deleted file mode 100644 index 58dca5ca1d..0000000000 --- a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp8_nonpersistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKFp8NonPersistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistent - -TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistent, KernelTypesStreamKFp8NonPersistent); - -#include "test_gemm_streamk_smoke_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp8_persistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp8_persistent.cpp deleted file mode 100644 index 1d1e1e31ec..0000000000 --- a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp8_persistent.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_gemm_streamk_common_includes.hpp" - -template -class TestCkTileStreamKFp8Persistent : public TestCkTileStreamK -{ -}; - -#define TEST_SUITE_NAME TestCkTileStreamKFp8Persistent - -TYPED_TEST_SUITE(TestCkTileStreamKFp8Persistent, KernelTypesStreamKFp8Persistent); - -#include "test_gemm_streamk_smoke_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_reduction_cases.inc b/test/ck_tile/gemm_streamk/test_gemm_streamk_reduction_cases.inc deleted file mode 100644 index 66c3e3b5e9..0000000000 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_reduction_cases.inc +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile_Tree) -{ - const ck_tile::index_t num_cu = get_cu_count(); - constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; - constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; - constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; - - ck_tile::index_t M = M_Tile; - ck_tile::index_t N = N_Tile; - ck_tile::index_t K = K_Tile * num_cu; - - this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction); -} - -TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile) -{ - const ck_tile::index_t num_cu = get_cu_count(); - constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; - constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; - constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; - - ck_tile::index_t M = M_Tile; - ck_tile::index_t N = N_Tile; - ck_tile::index_t K = K_Tile * num_cu; - - this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction); -} - -TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Tree) -{ - const ck_tile::index_t num_cu = get_cu_count(); - constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; - constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; - constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; - - ck_tile::index_t M = M_Tile * 4; - ck_tile::index_t N = N_Tile; - ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile); - - this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction); -} - -TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Reduction) -{ - const ck_tile::index_t num_cu = get_cu_count(); - constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; - constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; - constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; - - ck_tile::index_t M = M_Tile * 4; - ck_tile::index_t N = N_Tile; - ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile); - - this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction); -} - -TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles_Tree) -{ - const ck_tile::index_t num_cu = get_cu_count(); - constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; - constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; - constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; - - ck_tile::index_t M = M_Tile * 3; - ck_tile::index_t N = N_Tile * 7; - ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile); - - this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction); -} - -TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles) -{ - const ck_tile::index_t num_cu = get_cu_count(); - constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; - constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; - constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; - - ck_tile::index_t M = M_Tile * 3; - ck_tile::index_t N = N_Tile * 7; - ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile); - - this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction); -} diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_smoke_cases.inc b/test/ck_tile/gemm_streamk/test_gemm_streamk_smoke_cases.inc deleted file mode 100644 index 4bd6e9d973..0000000000 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_smoke_cases.inc +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -TYPED_TEST(TEST_SUITE_NAME, StreamK_EdgeCase) -{ - ck_tile::index_t M = 256; - ck_tile::index_t N = 256; - ck_tile::index_t K = 256; - - this->Run(M, N, K); -} - -TYPED_TEST(TEST_SUITE_NAME, StreamK_DPOnly) -{ - const ck_tile::index_t num_cu = get_cu_count(); - constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; - constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; - constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; - - // For DP only, we ensure that the number of tiles is a multiple of the number of CUs. This - // assumes tile sizes are large enough such that occupancy is 1. - ck_tile::index_t M = M_Tile * num_cu; - ck_tile::index_t N = N_Tile; - ck_tile::index_t K = K_Tile; - - this->Run(M, N, K); -} - -TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly) -{ - const ck_tile::index_t num_cu = get_cu_count(); - constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; - constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; - constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; - - // For SK only, we have 4 macro tiles in C. But, we need to make sure there is enough work along - // the K dimension to avoid falling into the edge case. Thus, we always have at least num_cu - // macro tiles in the K dimension. This assumes tile sizes are large enough such that occupancy - // is 1. - ck_tile::index_t M = M_Tile * 2; - ck_tile::index_t N = N_Tile * 2; - ck_tile::index_t K = K_Tile * num_cu; - - this->Run(M, N, K); -} diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp index ece313b8aa..efb7416580 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp @@ -33,14 +33,6 @@ using KernelTypesStreamKFp16Persistent = ::testing::Types< std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent> >; -using KernelTypesStreamKFp16Reduction = ::testing::Types< -// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent - std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>>; - using KernelTypesStreamKBf16Persistent = ::testing::Types< std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>, std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>, diff --git a/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt b/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt index 664866d458..8f9bd39886 100644 --- a/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt @@ -1,6 +1,8 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT +include(generate_configs.cmake) + # ============================================================================ # GEMM Tile Engine Unit Tests # @@ -87,7 +89,7 @@ function(create_individual_gemm_test_target datatype layout config_name trait ti target_compile_options(${target_name} PRIVATE -DCK_TILE_USE_OCP_FP8) endif() - message(STATUS " Created test target: ${target_name}") + message(DEBUG " Created test target: ${target_name}") endfunction() # ============================================================================ @@ -101,12 +103,12 @@ endfunction() # layout - Matrix layout (rcr, rrr, ccr, crr) # config_name - Configuration file name without .json extension # ============================================================================ -function(build_gemm_test_targets datatype layout config_name) +function(build_gemm_test_targets datatype layout config_name configs_dir_path) set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}") # Locate and validate configuration file set(config_filename "${config_name}.json") - set(json_blob "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config_filename}") + set(json_blob "${configs_dir_path}/${config_filename}") if(NOT EXISTS ${json_blob}) message(WARNING "Test config file not found: ${json_blob}") @@ -137,11 +139,11 @@ function(build_gemm_test_targets datatype layout config_name) # Verify kernel list file was generated if(NOT EXISTS ${working_path}/gemm_kernel_list.txt) - message(STATUS "No kernels found for ${datatype}_${layout}_${config_name} (validation filtered out all combinations)") + message(DEBUG "No kernels found for ${datatype}_${layout}_${config_name} (validation filtered out all combinations)") return() endif() - message(STATUS "Building tests for ${datatype}_${layout}_${config_name}") + message(DEBUG "Building tests for ${datatype}_${layout}_${config_name}") # STEP 2a: Extract test parameters from config set(test_params_file "${working_path}/test_params.hpp") @@ -230,7 +232,7 @@ message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") # GPU architecture filtering - only build tests for supported architectures set(GEMM_TEST_GPU_TARGETS "") -set(DESIRED_TARGETS "gfx90a;gfx942") +set(DESIRED_TARGETS "gfx90a;gfx942;gfx950") foreach(target IN LISTS SUPPORTED_GPU_TARGETS) if(target IN_LIST DESIRED_TARGETS) @@ -241,7 +243,7 @@ endforeach() # Early exit if no compatible GPU architectures are available if(NOT GEMM_TEST_GPU_TARGETS) - message(WARNING "Skipping StreamK GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + message(WARNING "Skipping StreamK GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") return() endif() @@ -282,25 +284,35 @@ set(TEST_LAYOUTS "rcr;rrr;ccr;crr") # Test Target Generation - Datatype-Specific Categories # ============================================================================ -# 1. SIMPLE TEST: Test for basic functionality with data types (fp16, bf16) -# These data types can use larger warp tiles due to smaller memory footprint -set(SIMPLE_TEST_CONFIG "simple_test_config") -set(SIMPLE_TEST_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${SIMPLE_TEST_CONFIG}.json") -set(SIMPLE_DATATYPES "fp16;bf16") +# 1. SMOKE TESTS: Test for basic functionality with data types (fp8, bf8, fp16, bf16) +set(SMALL_DATATYPES "fp16;bf16;fp8;bf8") +set(SIXTEEN_BIT_DATATYPES "fp16;bf16") +set(EIGHT_BIT_DATATYPES "fp8;bf8") +set(LARGE_TILES "256,256,32") +set(SMALL_TILES "128,128,32") +set(CONFIG_LIST "") +set(GENERATED_CONFIG_PATH ${CMAKE_CURRENT_BINARY_DIR}/configs) +get_cu_count(CU_COUNT) -if(EXISTS ${SIMPLE_TEST_CONFIG_FILE}) - message(STATUS "Processing simple test config: ${SIMPLE_TEST_CONFIG} (fp16, bf16)") - foreach(datatype IN LISTS SIMPLE_DATATYPES) - # fp16, bf16: testing all layouts (rcr, rrr, ccr, crr) +message(STATUS "Generating and processing configs for Stream-K tests") +foreach(datatype IN LISTS SMALL_DATATYPES) + + if(datatype IN_LIST SIXTEEN_BIT_DATATYPES) + generate_test_configs(${CU_COUNT} ${LARGE_TILES} ${datatype} CONFIG_LIST ${GENERATED_CONFIG_PATH}) + else() + generate_test_configs(${CU_COUNT} ${SMALL_TILES} ${datatype} CONFIG_LIST ${GENERATED_CONFIG_PATH}) + endif() + + foreach(config IN LISTS CONFIG_LIST) + # testing all layouts (rcr, rrr, ccr, crr) foreach(layout IN LISTS TEST_LAYOUTS) - build_gemm_test_targets("${datatype}" "${layout}" "${SIMPLE_TEST_CONFIG}") + build_gemm_test_targets("${datatype}" "${layout}" "${config}" "${GENERATED_CONFIG_PATH}") endforeach() endforeach() -else() - message(WARNING "Simple test config file not found: ${SIMPLE_TEST_CONFIG_FILE}") -endif() +endforeach() + # ============================================================================ message(STATUS "StreamK GEMM tile engine tests configured with datatype-specific design:") -message(STATUS " - Simple test: fp16/bf16 (all layouts)") +message(STATUS " - Smoke tests: fp16/bf16/fp8/bf8 (all layouts)") diff --git a/test/ck_tile/gemm_streamk_tile_engine/README.md b/test/ck_tile/gemm_streamk_tile_engine/README.md index 4655673852..965342536b 100644 --- a/test/ck_tile/gemm_streamk_tile_engine/README.md +++ b/test/ck_tile/gemm_streamk_tile_engine/README.md @@ -34,17 +34,25 @@ Each test configuration can specify optimized problem sizes in its JSON file: The key idea: **Unit tests that use tile_engine's exact kernel generation and verification methodology** instead of creating separate test infrastructure. ## Test Configurations +Test configs are generated during the Generation Phase. They are stored under the build directory at test/ck_tile/gemm_streamk_tile_engine/configs. The Compute Unit (CU) count of the device is required to generate the configs. If the Generation Phase occurs on a machine without a GPU or does not contain same GPU architecture on which you will run the tests, you can manually set the CU count using the `CU_COUNT` option: +```bash +# Assuming you are at the root of the repo +cd build +../script/cmake-ck-dev.sh .. gfx90a -G Ninja -DCU_COUNT=100 +``` +You can reference the public whitepaper for your specific GPU to get the appropriate CU count. +If no `CU_COUNT` option is given and no HIP device is found, then the default value of 100 CUs will be used to determine the problem sizes tested. -### 1. **Simple Test** (`simple_test_config.json`) -- **Purpose**: Basic functionality validation for fp16/bf16 data types -- **Config**: 128x128x32, warp 2x2x1, warp_tile 32x32x16 +### 1. **Smoke Tests** +- **Purpose**: Basic functionality validation for fp16/bf16/fp8/bf8 data types +- **Config**: 256x256x32 (for bf16/fp16) or 128x128x32 (for bf8/fp8), warp 2x2x1, warp_tile 32x32x16 - **Traits**: compv3 pipeline only -- **Coverage**: All 4 layouts (rcr, rrr, ccr, crr) for fp16, bf16 +- **Coverage**: All 4 layouts (rcr, rrr, ccr, crr) ## Data Type Support -- ✅ **fp16, bf16**: Fully supported - all layouts (rcr, rrr, ccr, crr) +- ✅ **fp16, bf16, fp8, bf8**: Fully supported - all layouts (rcr, rrr, ccr, crr) - ❌ **fp64**: Not supported (hardware MFMA limitation) -- ⏳ **fp32, bf8, pk-int4-t**: Not yet supported by gemm_instance_builder (will be added later) +- ⏳ **fp32, pk-int4-t**: Not yet supported by gemm_instance_builder (will be added later) ## Test Result Behavior diff --git a/test/ck_tile/gemm_streamk_tile_engine/configs/simple_test_config.json b/test/ck_tile/gemm_streamk_tile_engine/configs/simple_test_config.json deleted file mode 100644 index 1cfeef7570..0000000000 --- a/test/ck_tile/gemm_streamk_tile_engine/configs/simple_test_config.json +++ /dev/null @@ -1,35 +0,0 @@ -{ - "problem": { - "description": "Basic functionality validation with moderate problem sizes" - }, - "test_params": { - "problem_sizes": [ - {"m": 256, "n": 256, "k": 128, "split_k": 1}, - {"m": 512, "n": 256, "k": 256, "split_k": 1}, - {"m": 256, "n": 512, "k": 256, "split_k": 1} - ] - }, - "tile_config": { - "tile_m": {"values": [128]}, - "tile_n": {"values": [128]}, - "tile_k": {"values": [64]}, - "warp_m": {"values": [2]}, - "warp_n": {"values": [2]}, - "warp_k": {"values": [1]}, - "warp_tile_m": {"values": [16]}, - "warp_tile_n": {"values": [16]}, - "warp_tile_k": {"values": [16]} - }, - "trait_config": { - "pipeline": {"values": ["compv3"]}, - "epilogue": {"values": ["default"]}, - "scheduler": {"values": ["intrawave"]}, - "pad_m": {"values": [false]}, - "pad_n": {"values": [false]}, - "pad_k": {"values": [false]}, - "persistent": {"values": [false, true]}, - "reduction_strategy": {"values": ["atomic"]} - }, - "k_block_per_cu": 1, - "permute_n": false -} diff --git a/test/ck_tile/gemm_streamk_tile_engine/cu_count.cpp b/test/ck_tile/gemm_streamk_tile_engine/cu_count.cpp new file mode 100644 index 0000000000..88a2d06901 --- /dev/null +++ b/test/ck_tile/gemm_streamk_tile_engine/cu_count.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +/** + * @brief Determines whether a `hipError` is present in the given `error_status` + * @return true if the `error_status` has an error, otherwise false. + */ +bool has_error(const hipError_t& error_status) +{ + if(error_status != hipSuccess) + { + std::cerr << hipGetErrorString(error_status); + return true; + } + + return false; +} + +/** + * @brief Returns the number of Compute Units (CUs) on the given device. + * @return The number of CUs on the device. If an error occurs while querying the device, zero is + * returned. + */ +int get_cu_count() +{ + hipDevice_t dev; + hipDeviceProp_t dev_prop; + + const hipError_t device_status = hipGetDevice(&dev); + + if(has_error(device_status)) + return 0; + + const hipError_t prop_status = hipGetDeviceProperties(&dev_prop, dev); + if(has_error(prop_status)) + return 0; + + return dev_prop.multiProcessorCount; +} + +int main() { return get_cu_count(); } diff --git a/test/ck_tile/gemm_streamk_tile_engine/generate_configs.cmake b/test/ck_tile/gemm_streamk_tile_engine/generate_configs.cmake new file mode 100644 index 0000000000..4f18b5dcbe --- /dev/null +++ b/test/ck_tile/gemm_streamk_tile_engine/generate_configs.cmake @@ -0,0 +1,103 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +set(CU_COUNT 0 CACHE STRING "Number of Compute Units on the device") + +# ============================================================================ +# get_cu_count +# +# Returns the CU count for the device. If the given cu_count_arg is a positive +# integer, then the nothing happens. Otherwise, we attempt to query the CU +# count from the device. If the query is unsucessful, the default value of 100 +# is returned. +# +# Parameters: +# cu_count_arg - The starting CU count +# ============================================================================ +function(get_cu_count cu_count_arg) + message(STATUS "Starting query for CU count needed for Stream-K test config generation") + + if(NOT "${${cu_count_arg}}" MATCHES "^[0-9]+$") + message(FATAL_ERROR "The CU count must be a non-negative integer. \ + The given value of ${${cu_count_arg}} is invalid.") + endif() + + if("${${cu_count_arg}}" STREQUAL "0") + + set(CPP_FILE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cu_count.cpp) + set(CPP_EXE_PATH ${CMAKE_CURRENT_BINARY_DIR}/cu_count) + + execute_process( + COMMAND ${CMAKE_HIP_COMPILER} -x hip ${CPP_FILE_PATH} -o ${CPP_EXE_PATH} + RESULT_VARIABLE compile_result + ) + + if (NOT compile_result EQUAL 0) + message(FATAL_ERROR "Compilation of ${CPP_FILE_PATH} failed.\n") + endif() + + execute_process( + COMMAND ${CPP_EXE_PATH} + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_VARIABLE standard_error + RESULT_VARIABLE queried_cu_count + ) + + if (standard_error) + message(STATUS "Error information from attempting to query HIP device and properties:\n" + "${standard_error}") + endif() + + + # Delete the generated cu_count executable + file(REMOVE "${CPP_EXE_PATH}") + + if(queried_cu_count EQUAL 0) + message(WARNING "Unable to query the number of Compute Units. \ + Please use the CU_COUNT CLI option to pass in the \ + number of Compute Units for your target device; otherwise, \ + the default value of 100 will be used.") + set(${cu_count_arg} 100 PARENT_SCOPE) + else() + set(${cu_count_arg} ${queried_cu_count} PARENT_SCOPE) + endif() + + endif() + +endfunction() + +# ============================================================================ +# generate_test_configs +# +# Generate config json files for Stream-K tests +# +# Parameters: +# cu_count_arg - The number of CUs on the device +# tile_sizes - A list of block tile sizes: tile_m,tile_n,tile_k +# datatype - The datatype for which the config is being generated +# config_list - The variable to which the list of config file names are written +# configs_path - Path to the configs directory to which config files are written +# ============================================================================ +function(generate_test_configs cu_count_arg tile_sizes datatype config_list configs_path) + message(STATUS "Generating Stream-K test config files for ${datatype}") + + file(MAKE_DIRECTORY ${configs_path}) + + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_SOURCE_DIR}/generate_configs.py + --cu_count ${cu_count_arg} + --configs_dir_path ${configs_path} + --tiles ${tile_sizes} + --datatype ${datatype} + OUTPUT_VARIABLE CONFIG_LIST + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE script_ret_val + ) + + if (NOT script_ret_val EQUAL 0) + message(FATAL_ERROR "Eror occured during execution of ${CMAKE_CURRENT_SOURCE_DIR}/generate_configs.py") + endif() + + set(${config_list} ${CONFIG_LIST} PARENT_SCOPE) + +endfunction() diff --git a/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py b/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py new file mode 100644 index 0000000000..ba075a2729 --- /dev/null +++ b/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +from enum import Enum +from typing import Dict, Tuple, List +import argparse +import json +import os +import sys +from dataclasses import dataclass, field, asdict + + +@dataclass +class TileConfig: + """Represents the Tile Config section of a Tile Engine config""" + + tile_m: List[int] = field(default_factory=list) + tile_n: List[int] = field(default_factory=list) + tile_k: List[int] = field(default_factory=list) + warp_m: List[int] = field(default_factory=lambda: [2]) + warp_n: List[int] = field(default_factory=lambda: [2]) + warp_k: List[int] = field(default_factory=lambda: [1]) + warp_tile_m: List[int] = field(default_factory=lambda: [32]) + warp_tile_n: List[int] = field(default_factory=lambda: [32]) + warp_tile_k: List[int] = field(default_factory=lambda: [16]) + + def to_dict(self) -> Dict: + return {k: {"values": v} for k, v in asdict(self).items()} + + +@dataclass +class TraitConfig: + """Represents the Trait Config section of a Tile Engine config""" + + pipeline: List[str] = field(default_factory=lambda: ["compv3"]) + epilogue: List[str] = field(default_factory=lambda: ["cshuffle"]) + scheduler: List[str] = field(default_factory=lambda: ["intrawave"]) + pad_m: List[bool] = field(default_factory=lambda: [False]) + pad_n: List[bool] = field(default_factory=lambda: [False]) + pad_k: List[bool] = field(default_factory=lambda: [False]) + persistent: List[bool] = field(default_factory=lambda: [True, False]) + reduction_strategy: List[str] = field(default_factory=list) + + def to_dict(self) -> Dict: + return {k: {"values": v} for k, v in asdict(self).items()} + + +class TestVariant(Enum): + """Represents a Stream-K test variant""" + + def __init__( + self, + val: int, + reduction_strategy: List[str], + persistent: List[bool], + datatypes: List[str], + description: str, + ): + self._value_ = val + self.reduction_strategy = reduction_strategy + self.persistent = persistent + self.datatypes = datatypes + self.description = description + + ATOMIC_SMOKE = ( + 0, + ["atomic"], + [True, False], + ["fp16", "bf16", "fp8", "bf8"], + "Stream-K atomic smoke tests", + ) + REDUCTION_SMOKE = ( + 2, + ["reduction", "tree"], + [True, False], + ["fp16", "bf16", "fp8", "bf8"], + "Stream-K reduction smoke tests", + ) + EXTENDED = ( + 3, + ["atomic"], + [True, False], + ["fp16", "bf16", "fp8", "bf8"], + "Stream-K extended smoke tests", + ) + + def apply(self, trait_config: TraitConfig) -> None: + """Applies the current test variant's persistent and reduction strategy setting to the given trait_config""" + trait_config.persistent = self.persistent + trait_config.reduction_strategy = self.reduction_strategy + + +@dataclass +class ProblemSize: + """Represents a problem size in a Tile Engine config""" + + m: int + n: int + k: int + variant: TestVariant + split_k: int = 1 + + def to_dict(self) -> Dict: + return {"m": self.m, "n": self.n, "k": self.k, "split_k": self.split_k} + + +@dataclass +class Config: + """Represents a Tile Engine config""" + + description: str + problem_sizes: list[ProblemSize] = field(default_factory=list) + tile_config: TileConfig = field(default_factory=TileConfig) + trait_config: TraitConfig = field(default_factory=TraitConfig) + k_block_per_cu: int = 1 + permute_n: bool = False + + def add_problem_size(self, problem: ProblemSize) -> None: + """Adds the given problem to this config's problem_sizes""" + self.problem_sizes.append(problem) + + def to_dict(self) -> Dict: + config_dict = { + "problem": {"description": f"{self.description}"}, + "test_params": { + "problem_sizes": [ps.to_dict() for ps in self.problem_sizes] + }, + "tile_config": self.tile_config.to_dict(), + "trait_config": self.trait_config.to_dict(), + "k_block_per_cu": self.k_block_per_cu, + "permute_n": self.permute_n, + } + return config_dict + + def write_to_file(self, output_file: str) -> None: + """Writes this configs to the given output_file in a json format""" + with open(output_file, "w") as config_file: + json.dump(self.to_dict(), config_file, indent=4) + config_file.write("\n") + + +def create_problem_sizes( + tile_m: int, tile_n: int, tile_k: int, cu_count: int +) -> List[ProblemSize]: + """Creates and returns a list of problem sizes using the given arguments""" + problem_sizes = [ + ProblemSize(256, 256, 256, TestVariant.ATOMIC_SMOKE), + ProblemSize(tile_m * cu_count, tile_n, tile_k, TestVariant.ATOMIC_SMOKE), + ProblemSize( + tile_m * 2, tile_n * 2, cu_count * tile_k, TestVariant.ATOMIC_SMOKE + ), + ProblemSize(tile_m, tile_n, cu_count * tile_k, TestVariant.REDUCTION_SMOKE), + ProblemSize( + tile_m * 4, + tile_n, + tile_k * cu_count + (25 * tile_k), + TestVariant.REDUCTION_SMOKE, + ), + ProblemSize( + tile_m * 3, + tile_n * 7, + tile_k * cu_count + (30 * tile_k), + TestVariant.REDUCTION_SMOKE, + ), + # TODO: Add this test once we determine how to label tests as regresion with tile engine + # ProblemSize((tile_m * cu_count * 2) + (tile_m * 2), tile_n, 2048, TestVariant.EXTENDED) + ] + + return problem_sizes + + +def write_config_files( + problem_sizes: List[ProblemSize], + configs_dir_path: str, + datatype: str, + tile_sizes: Tuple[int, int, int], +) -> str: + """Writes the given problem_sizes to a config file and returns the names of the config files written to""" + config_names = [] + tile_m, tile_n, tile_k = tile_sizes + tile_config = TileConfig([tile_m], [tile_n], [tile_k]) + + # Create a config for each test variant + for variant in TestVariant: + problem_sizes_filtered = [ps for ps in problem_sizes if ps.variant == variant] + + if (datatype not in variant.datatypes) or len(problem_sizes_filtered) == 0: + continue + + trait_config = TraitConfig() + variant.apply(trait_config) + config_name = f"streamk_{variant.name.lower()}_tests_config_{datatype}" + config_names.append(config_name) + file_path = os.path.join(configs_dir_path, config_name + ".json") + config = Config( + variant.description, problem_sizes_filtered, tile_config, trait_config + ) + config.write_to_file(file_path) + + return config_names + + +def print_config_names(config_file_names: List[str]) -> None: + """Prints given config file names as a single semi-colon separated string""" + print(";".join(config_file_names)) + + +def create_config_files( + cu_count: int, configs_dir_path: str, tile_sizes: int, datatype: str +) -> None: + """Creates Stream-K test config files and prints the file names in a semi-colon-separated list""" + tile_m, tile_n, tile_k = tile_sizes + + problem_sizes = create_problem_sizes(tile_m, tile_n, tile_k, cu_count) + config_names = write_config_files( + problem_sizes, configs_dir_path, datatype, tile_sizes + ) + print_config_names(config_names) + + +def get_args() -> Tuple[int, str, Tuple[int, int, int], str]: + """Returns user provided arguments""" + + def tile_sizes_type(val: str): + sizes = None + parts = val.split(",") + if len(parts) != 3: + raise argparse.ArgumentTypeError( + "--tiles must contain exactly three comma-separated values (m,n,k), e.g. --tiles 256,256,32" + ) + try: + sizes = tuple(int(size) for size in parts) + except ValueError: + raise argparse.ArgumentTypeError( + "--tiles must contain exactly three comma-separated integers (m,n,k), e.g. --tiles 256,256,32" + ) + + return sizes + + parser = argparse.ArgumentParser(description="Create Stream-K test configs") + parser.add_argument( + "--cu_count", required=True, help="Number of Compute Units on the device" + ) + parser.add_argument( + "--configs_dir_path", + required=True, + help="Full path configs directory where config files will be written to", + ) + + parser.add_argument( + "--tiles", + required=True, + type=tile_sizes_type, + help="Block tile sizes for m, n, and k, respectively. Ex: --tiles 256,256,32", + ) + + parser.add_argument( + "--datatype", + choices=["fp16", "bf16", "fp8", "bf8"], + required=True, + help="The datatype for which the config is generated.", + ) + + args = parser.parse_args() + + return (int(args.cu_count), args.configs_dir_path, args.tiles, args.datatype) + + +def main(): + cu_count, configs_dir_path, tile_sizes, datatype = get_args() + create_config_files(cu_count, configs_dir_path, tile_sizes, datatype) + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/test/ck_tile/gemm_streamk_tile_engine/test_gemm_streamk_simple.cpp b/test/ck_tile/gemm_streamk_tile_engine/test_gemm_streamk_simple.cpp index 913e7d8531..1c06d33e77 100644 --- a/test/ck_tile/gemm_streamk_tile_engine/test_gemm_streamk_simple.cpp +++ b/test/ck_tile/gemm_streamk_tile_engine/test_gemm_streamk_simple.cpp @@ -12,6 +12,7 @@ #include #include +#include #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" @@ -126,13 +127,18 @@ class StreamKGemmTileEngineTest : public ::testing::TestWithParam 0) << "Kernel name should not be empty"; + + std::cout << "Testing kernel: " << KERNEL_NAME << std::endl; + std::cout << "Problem size: " << m_ << "x" << n_ << "x" << k_ << std::endl; + // Get tensor layouts from generated kernel const ALayout layout_a = ALayout{}; const BLayout layout_b = BLayout{}; const CLayout layout_c = CLayout{}; - // Use split_k from test parameters - int split_k = split_k_; + // Calculate tensor strides int stride_a_calc = ck_tile::get_default_stride(m_, k_, 0, is_row_major(layout_a)); int stride_b_calc = ck_tile::get_default_stride(k_, n_, 0, is_row_major(layout_b)); int stride_c_calc = ck_tile::get_default_stride(m_, n_, 0, is_row_major(layout_c)); @@ -144,27 +150,42 @@ TEST_P(StreamKGemmTileEngineTest, BasicFunctionality) ck_tile::host_tensor_descriptor(k_, n_, stride_b_calc, is_row_major(layout_b))); ck_tile::HostTensor c_m_n_dev_result( ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c))); - ck_tile::HostTensor c_m_n_host_result( + ck_tile::HostTensor c_m_n_dev_ref( ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c))); // Initialize input tensors with uniform random distribution [-1.0, 1.0] (matches tile_engine) ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + c_m_n_dev_ref.SetZero(); // Allocate GPU device memory ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + ck_tile::DeviceMem ref_c_m_n_dev_buf(c_m_n_dev_ref.get_element_space_size_in_bytes()); // Copy data to device and zero output buffer a_m_k_dev_buf.ToDevice(a_m_k.data()); b_k_n_dev_buf.ToDevice(b_k_n.data()); c_m_n_dev_buf.SetZero(); - c_m_n_dev_result.SetZero(); + ref_c_m_n_dev_buf.SetZero(); - // Calculate reference result on host for verification - ck_tile::reference_gemm( - a_m_k, b_k_n, c_m_n_host_result); + // Calculate reference result on device for verification + ADataType* a_m_k_dev_ref_ptr = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* b_k_n_dev_ref_ptr = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* c_m_n_dev_ref_ptr = static_cast(ref_c_m_n_dev_buf.GetDeviceBuffer()); + ck_tile:: + reference_gemm_gpu( + a_m_k_dev_ref_ptr, + b_k_n_dev_ref_ptr, + c_m_n_dev_ref_ptr, + m_, + n_, + k_, + stride_a_calc, + stride_b_calc, + stride_c_calc); + ref_c_m_n_dev_buf.FromDevice(c_m_n_dev_ref.data()); // Create GEMM kernel arguments ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), @@ -188,9 +209,10 @@ TEST_P(StreamKGemmTileEngineTest, BasicFunctionality) 1}; // rotating_count // Launch the generated kernel (no timing overhead for fastest execution) + std::tuple launch_result; try { - SelectedKernel::launch(args, stream_config); + launch_result = SelectedKernel::launch(args, stream_config); // Kernel launched successfully if no exception thrown } catch(const std::exception& e) @@ -211,22 +233,13 @@ TEST_P(StreamKGemmTileEngineTest, BasicFunctionality) c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); // Verify results using tile_engine's adaptive error thresholds + const ck_tile::index_t num_wgs_per_tile = get<1>(launch_result); bool verification_passed = compare_results( - KERNEL_NAME, k_, split_k, c_m_n_dev_result, c_m_n_host_result); + KERNEL_NAME, k_, num_wgs_per_tile, c_m_n_dev_result, c_m_n_dev_ref); EXPECT_TRUE(verification_passed) << "GEMM result verification failed"; } -TEST_P(StreamKGemmTileEngineTest, KernelInfo) -{ - // Simple test to verify kernel information is available - EXPECT_TRUE(strlen(KERNEL_NAME) > 0) << "Kernel name should not be empty"; - - std::cout << "Testing kernel: " << KERNEL_NAME << std::endl; - std::cout << "Problem size: " << m_ << "x" << n_ << "x" << k_ << " with split_k=" << split_k_ - << std::endl; -} - // Use config-specific test parameters (included via compile flags) // CONFIG_TEST_PARAMS is defined in the auto-generated test_params.hpp file INSTANTIATE_TEST_SUITE_P(GemmVerification, diff --git a/test/ck_tile/memory_copy/test_copy.cpp b/test/ck_tile/memory_copy/test_copy.cpp index 2a43b596e4..208b92e702 100644 --- a/test/ck_tile/memory_copy/test_copy.cpp +++ b/test/ck_tile/memory_copy/test_copy.cpp @@ -20,6 +20,25 @@ struct MemoryCopyParam ck_tile::index_t warp_id; }; +template +struct type_list +{ +}; + +template +struct type_at; + +template +struct type_at> : type_at> +{ +}; + +template +struct type_at<0, type_list> +{ + using type = Head; +}; + template class TestCkTileMemoryCopy : public ::testing::TestWithParam> { @@ -33,48 +52,47 @@ class TestCkTileMemoryCopy : public ::testing::TestWithParam ? 1 : 0; ck_tile::HostTensor x_host({m, n}); ck_tile::HostTensor y_host_dev({m, n}); + ck_tile::HostTensor host_init_buf({x_host.get_element_space_size_in_bytes()}); std::cout << "input: " << x_host.mDesc << std::endl; std::cout << "output: " << y_host_dev.mDesc << std::endl; - ck_tile::index_t value = 1; - for(int i = 0; i < m; i++) - { - value = 1; - for(int j = 0; j < n; j++) - { - value = (value + 1) % 127; - x_host(i, j) = static_cast(value); - } - } - + for(size_t i = 0; i < x_host.get_element_space_size_in_bytes(); i++) + host_init_buf.mData[i] = i % 64; + memcpy(x_host.mData.data(), + host_init_buf.mData.data(), + x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); - using BlockWaves = ck_tile::sequence<2, 1>; - using BlockTile = ck_tile::sequence<64, 8>; - using WaveTile = ck_tile::sequence<64, 8>; - using Vector = ck_tile::sequence<1, dword_bytes / sizeof(DataType)>; + using BlockTileList = type_list, ck_tile::sequence<16, 96>>; + using VectorList = type_list, + ck_tile::sequence<1, 24>>; + using BlockWaves = ck_tile::sequence<2, 1>; + using BlockTile = type_at::type; + using WaveTile = type_at::type; + using Vector = type_at::type; ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(m, BlockTile::at(ck_tile::number<0>{})); using Shape = ck_tile::TileCopyShape; - using Problem = ck_tile::TileCopyProblem; + using Problem = ck_tile::TileCopyProblem; using Kernel = ck_tile::TileCopy; constexpr ck_tile::index_t kBlockSize = 128; constexpr ck_tile::index_t kBlockPerCu = 1; + // when copy fp6x16 buffer, tread it as int8 buffer and recompute n-dim size. + ck_tile::index_t cpy_n = + CpyCfg == 1 ? n * sizeof(DataType) / + (sizeof(int8_t) * ck_tile::numeric_traits::PackedSize) + : n; auto ms = launch_kernel( ck_tile::stream_config{nullptr, true}, @@ -85,21 +103,28 @@ class TestCkTileMemoryCopy : public ::testing::TestWithParam(x_buf.GetDeviceBuffer()), static_cast(y_buf.GetDeviceBuffer()), m, - n, + cpy_n, warp_id)); - auto bytes = 2 * m * n * sizeof(DataType); + auto bytes = 2 * m * n * sizeof(DataType) / ck_tile::numeric_traits::PackedSize; std::cout << "elapsed: " << ms << " (ms)" << std::endl; std::cout << (bytes * 1e-6 / ms) << " (GB/s)" << std::endl; // reference y_buf.FromDevice(y_host_dev.mData.data()); bool pass = ck_tile::check_err(y_host_dev, x_host); - EXPECT_TRUE(pass); } }; +class TestCkTileMemoryCopyF6x16Async : public TestCkTileMemoryCopy +{ +}; + +class TestCkTileMemoryCopyF6x16 : public TestCkTileMemoryCopy +{ +}; + class TestCkTileMemoryCopyHalfAsync : public TestCkTileMemoryCopy { }; @@ -116,6 +141,18 @@ class TestCkTileMemoryCopyFP8Async : public TestCkTileMemoryCopy { }; +TEST_P(TestCkTileMemoryCopyF6x16, TestCorrectness) +{ + auto [M, N, warp_id] = GetParam(); + this->Run({M, N, warp_id}); +} + +TEST_P(TestCkTileMemoryCopyF6x16Async, TestCorrectness) +{ + auto [M, N, warp_id] = GetParam(); + this->Run({M, N, warp_id}); +} + TEST_P(TestCkTileMemoryCopyHalfAsync, TestCorrectness) { auto [M, N, warp_id] = GetParam(); @@ -140,6 +177,20 @@ TEST_P(TestCkTileMemoryCopyFP8Async, TestCorrectness) this->Run({M, N, warp_id}); } +INSTANTIATE_TEST_SUITE_P(TestCkTileMemCopySuite, + TestCkTileMemoryCopyF6x16, + ::testing::Values(std::tuple{32, 128, 0}, + std::tuple{64, 256, 0}, + std::tuple{32, 128, 1}, + std::tuple{64, 256, 1})); + +INSTANTIATE_TEST_SUITE_P(TestCkTileMemCopySuite, + TestCkTileMemoryCopyF6x16Async, + ::testing::Values(std::tuple{32, 128, 0}, + std::tuple{64, 256, 0}, + std::tuple{32, 128, 1}, + std::tuple{64, 256, 1})); + INSTANTIATE_TEST_SUITE_P(TestCkTileMemCopySuite, TestCkTileMemoryCopyHalfAsync, ::testing::Values(std::tuple{64, 8, 0}, diff --git a/test/ck_tile/memory_copy/test_copy.hpp b/test/ck_tile/memory_copy/test_copy.hpp index 847763881b..2ce4982a04 100644 --- a/test/ck_tile/memory_copy/test_copy.hpp +++ b/test/ck_tile/memory_copy/test_copy.hpp @@ -51,12 +51,15 @@ struct TileCopyShape "Inconsistent wave group size!"); }; -template +template struct TileCopyProblem { using XDataType = remove_cvref_t; using BlockShape = remove_cvref_t; static constexpr bool AsyncCopy = AsyncCopy_; + // 0: copy 1, 2, 4 bytes data type + // 1: copy dwordx3 bytes data type + static constexpr int CpyCfg = CpyCfg_; }; template @@ -67,6 +70,7 @@ struct TileCopy static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; static constexpr bool AsyncCopy = Problem::AsyncCopy; + static constexpr int CpyCfg = Problem::CpyCfg; template CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution() @@ -98,8 +102,40 @@ struct TileCopy return make_static_tile_distribution(outer_encoding); } + template + // CK_TILE_DEVICE static constexpr auto MakeDwordx3DRAMDistribution() + CK_TILE_DEVICE static constexpr auto MakeDwordx3DRAMDistribution() + { + using S = typename Problem::BlockShape; + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t X0 = S::ThreadPerWarp_N; // threads needed along N dimension, fastest + // changing with given vector size. + constexpr index_t X1 = + S::Block_N; // no. of elements along N dimensions to be read by each thread. + + constexpr index_t X2 = 12; // l/w dwordx3 bytes + + constexpr index_t Y0 = + S::WaveNum / S::WaveGroups; // number of active warps working in this thread block. + constexpr index_t Y2 = + warp_size / X0; // number of threads in a warp needed along M dimension. + constexpr index_t Y1 = + S::Warp_M / + Y2; // number of iterations each warp needs to perform to cover the entire tile window. + constexpr auto outer_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, // Y2==16,X0==4 + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<1, 0, 2>>{}; + + return make_static_tile_distribution(outer_encoding); + } + CK_TILE_DEVICE void - operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const + run_normal_cpy(XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const { using S = typename Problem::BlockShape; @@ -170,6 +206,124 @@ struct TileCopy move_tile_window(y_block_window, {0, S::Block_N}); } } -}; + CK_TILE_DEVICE void + run_dwordx3_cpy(XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const + { + using S = typename Problem::BlockShape; + constexpr index_t X0 = S::ThreadPerWarp_N; + constexpr index_t X1 = S::Block_N; + constexpr index_t X2 = 12; // l/w dwordx3 bytes + + // LDS buffer + constexpr int dim1_stride = + AsyncCopy ? 16 : 12; // async_load dwordx3 will write 3 bytes & skip 1 bytes in lds. + constexpr int repeat_num = X1 / (X0 * X2); + __shared__ int8_t x_lds[repeat_num * S::Block_M * X0 * dim1_stride]; + + constexpr auto block_dims = make_tuple(number{}, number{}); + constexpr auto block_dims_ = make_tuple(number{}, + number{}, + number{}, + number{}); + constexpr auto block_strides = make_tuple(number{}, + number{}, + number{}, + number<1>{}); + + const auto x_lds_desc_ = + make_naive_tensor_descriptor(block_dims_, block_strides, number<12>{}, number<1>{}); + const auto x_lds_desc = transform_tensor_descriptor( + x_lds_desc_, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod(make_tuple( + number<2>{}, number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + auto x_lds_view = + make_tensor_view(reinterpret_cast(x_lds), x_lds_desc); + + auto x_block_lds_write_window = make_tile_window(x_lds_view, block_dims, {0, 0}); + + auto x_block_lds_read_window = make_tile_window( + x_lds_view, block_dims, {0, 0}, MakeDwordx3DRAMDistribution()); + + const index_t iM = __builtin_amdgcn_readfirstlane(get_block_id() * S::Block_M); + // Input tensor + const auto x_m_n = + make_naive_tensor_view(reinterpret_cast(p_x), + make_tuple(M, N), + make_tuple(N, 1), + number{}, + number<1>{}); + auto x_block_window = + make_tile_window(x_m_n, block_dims, {iM, 0}, MakeDwordx3DRAMDistribution()); + + // Output tensor + const auto y_m = + make_naive_tensor_view(reinterpret_cast(p_y), + make_tuple(M, N), + make_tuple(N, 1), + number{}, + number<1>{}); + auto y_block_window = make_tile_window(y_m, block_dims, {iM, 0}); + + const index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N)); + const index_t my_id = __builtin_amdgcn_readfirstlane(get_warp_id()); + constexpr index_t async_copy_fence_cnt = 0; + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + if(my_id == warp_id) + { + if constexpr(AsyncCopy) + { + async_load_tile(x_block_lds_write_window, x_block_window); + // We don't have prefetch here, wait the data back immediately. + // Wait all asyncload insts complete. + // Wait all waves synced + s_waitcnt_barrier(); + auto lds_tile = load_tile(x_block_lds_read_window); + // store from registers to DRAM + store_tile(y_block_window, lds_tile); + } + else + { + // load from DRAM to registers + auto dram_tile = load_tile(x_block_window); + // store in lds + store_tile(x_block_lds_write_window, dram_tile); + // Wait all lds write insts complete + // Wait all waves synced + block_sync_lds(); + // read from lds to registers + auto lds_tile = load_tile(x_block_lds_read_window); + // store from registers to DRAM + store_tile(y_block_window, lds_tile); + } + } + + move_tile_window(x_block_window, {0, S::Block_N}); + move_tile_window(y_block_window, {0, S::Block_N}); + } + } + + CK_TILE_DEVICE void + operator()(XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const + { + if constexpr(CpyCfg == 1) + { + run_dwordx3_cpy(p_x, p_y, M, N, warp_id); + } + else if constexpr(CpyCfg == 0) + { + run_normal_cpy(p_x, p_y, M, N, warp_id); + } + else + { + static_assert(false, "unsupported copy config type."); + } + } +}; } // namespace ck_tile diff --git a/test/position_embedding/position_embedding.cpp b/test/position_embedding/position_embedding.cpp index 134d2e5f37..689a7a799a 100644 --- a/test/position_embedding/position_embedding.cpp +++ b/test/position_embedding/position_embedding.cpp @@ -9,6 +9,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + #ifndef TEST_ALIBI_VERBOSE #define TEST_ALIBI_VERBOSE 0 #endif @@ -213,3 +216,4 @@ int main() // clang-format on return rtn ? 0 : -1; } +#pragma clang diagnostic pop diff --git a/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.hpp b/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.hpp index f8c196e32a..b0d8445c16 100644 --- a/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.hpp +++ b/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.hpp @@ -13,6 +13,9 @@ #include "ck_tile/host.hpp" #include "gemm_multi_d_common.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-seggestions" + // Data types and Layouts are defined by the generated kernel headers // No hardcoded type definitions here to avoid conflicts @@ -230,3 +233,4 @@ void gemm_multi_d_host_reference(int verify, a_m_k, b_k_n, {d0_m_n, d1_m_n}, c_m_n_host_result); } } +#pragma clang diagnostic pop diff --git a/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.hpp b/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.hpp index 748fe581d3..41ccc4a01b 100644 --- a/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.hpp +++ b/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.hpp @@ -7,6 +7,9 @@ #include "ck_tile/host.hpp" #include "gemm_preshuffle_common.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + //[TODO] Move parts of this File to commons enum class Metric { @@ -234,3 +237,4 @@ void gemm_host_reference(int verify, c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data()); } } +#pragma clang diagnostic pop diff --git a/tile_engine/ops/gemm/gemm_universal/gemm_benchmark.hpp b/tile_engine/ops/gemm/gemm_universal/gemm_benchmark.hpp index 7c8df32ad8..11aef4c251 100644 --- a/tile_engine/ops/gemm/gemm_universal/gemm_benchmark.hpp +++ b/tile_engine/ops/gemm/gemm_universal/gemm_benchmark.hpp @@ -13,6 +13,8 @@ #include "ck_tile/host.hpp" #include "gemm_common.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" // Data types and Layouts are defined by the generated kernel headers // No hardcoded type definitions here to avoid conflicts @@ -240,3 +242,4 @@ void gemm_host_reference(int verify, c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data()); } } +#pragma clang diagnostic pop diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp index 45beb0acce..d877f174b2 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp @@ -17,6 +17,9 @@ // Data types and Layouts are defined by the generated kernel headers // No hardcoded type definitions here to avoid conflicts +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + enum class Metric { LATENCY = 0, @@ -199,3 +202,4 @@ void gemm_host_reference(int verify, c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data()); } } +#pragma clang diagnostic pop diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py index 877c803d69..c8d6f86ccc 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py @@ -481,8 +481,9 @@ struct SelectedKernel {{ AccDataType, TileShape, GemmUniversalTraits>; - - static float launch(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {{ + + static std::tuple launch(const ck_tile::StreamKHostArgs& args, + const ck_tile::stream_config& stream) {{ constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem(GemmKernel{{}}, grids, blocks, 0, kargs)); + + return std::tuple{{time, num_wgs_per_tile}}; }} }}; """ diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp index d168030f97..2a7b07c698 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp @@ -22,25 +22,25 @@ class GemmProfiler // Overload for single kernel benchmarking void benchmark(GemmProblem& gemm_problem, - std::function kernel_func) + std::function( + const ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)> kernel_func) { - // Create a vector with a single callable that returns both name and time - std::vector(ck_tile::StreamKHostArgs&, - const ck_tile::stream_config&)>> + // Create a vector with a single callable that returns name, time, and num_wgs_per_tile + std::vector( + ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>> callables; callables.push_back( [kernel_func](ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) { - float time = kernel_func(args, stream); - return std::make_tuple(std::string(KERNEL_NAME), time); + auto [time, num_wgs_per_tile] = kernel_func(args, stream); + return std::make_tuple(std::string(KERNEL_NAME), time, num_wgs_per_tile); }); benchmark(gemm_problem, callables); } void benchmark(GemmProblem& gemm_problem, - std::vector( + std::vector( ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>>& callables) { const ALayout layout_a = ALayout{}; @@ -160,9 +160,9 @@ class GemmProfiler ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::HostTensor& c_m_n_host_result, ck_tile::HostTensor& c_m_n_dev_result, - const std::tuple& kernel_run_result) + const std::tuple& kernel_run_result) { - auto [name, avg_time] = kernel_run_result; + auto [name, avg_time, num_wgs_per_tile] = kernel_run_result; auto dp_persistent = SelectedKernel::UsePersistentKernel ? "PersistentKernel" : "NonPersistentKernel"; @@ -196,8 +196,7 @@ class GemmProfiler c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool verified_correct = !setting_.verify_ || - compare( - name, gemm_problem.k_, gemm_problem.split_k_, c_m_n_dev_result, c_m_n_host_result); + compare(name, gemm_problem.k_, num_wgs_per_tile, c_m_n_dev_result, c_m_n_host_result); if(verified_correct) {