diff --git a/.gitignore b/.gitignore index d8468cf24e..98234268c1 100644 --- a/.gitignore +++ b/.gitignore @@ -83,6 +83,11 @@ __pycache__/ .cache/ +# Generated test data +test_data/* +!test_data/*.py +!test_data/*.sh + # Exceptions to build* patterns above # The experimental/builder directory should be tracked despite matching build* !experimental/builder diff --git a/CMakeLists.txt b/CMakeLists.txt index acae1f5ece..eaed7d3509 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -766,6 +766,9 @@ if(CK_EXPERIMENTAL_BUILDER) ${PROJECT_SOURCE_DIR}/experimental/builder/include/ck_tile/builder DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile ) + + set(CK_TILE_SRC_FOLDER ${CMAKE_SOURCE_DIR}/include/ck_tile/) + rocm_install(DIRECTORY ${CK_TILE_SRC_FOLDER} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile) endif() set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") diff --git a/Jenkinsfile b/Jenkinsfile index 5f03310cab..aea14c78b6 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1476,15 +1476,19 @@ pipeline { setup_args = "NO_CK_BUILD" execute_args = """ cd ../build && \ ../script/cmake-ck-dev.sh ../ gfx90a && \ - make -j64 test_grouped_convnd_fwd_dataset_xdl && \ + make -j64 test_grouped_convnd_fwd_dataset_xdl \ + test_grouped_convnd_bwd_data_dataset_xdl \ + test_grouped_convnd_bwd_weight_dataset_xdl && \ cd ../test_data && \ # Dataset generation modes: # - small: ~60 test cases (minimal, quick testing - 3 models, 2 batch sizes, 2 image sizes) # - half: ~300 test cases (moderate coverage - 16 models, 3 batch sizes, 5 image sizes), ~ 17 hours testing time # - full: ~600 test cases (comprehensive - 16 models, 5 batch sizes, 9 image sizes), ~ 40 hours testing time - ./generate_test_dataset.sh half && \ + ./generate_test_dataset.sh small && \ cd ../build && \ - ./bin/test_grouped_convnd_fwd_dataset_xdl""" + ./bin/test_grouped_convnd_fwd_dataset_xdl && \ + ./bin/test_grouped_convnd_bwd_data_dataset_xdl && \ + ./bin/test_grouped_convnd_bwd_weight_dataset_xdl""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args) diff --git a/example/26_contraction/run_contraction_bilinear_example.inc b/example/26_contraction/run_contraction_bilinear_example.inc index 78135d6296..69eb42defd 100644 --- a/example/26_contraction/run_contraction_bilinear_example.inc +++ b/example/26_contraction/run_contraction_bilinear_example.inc @@ -233,7 +233,20 @@ int run_contraction_bilinear_example(int argc, char* argv[]) } } - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + if(ck::is_gfx11_supported()) + { + return ck::utils::check_err(e_ms_ns_device_result, + e_ms_ns_host_result, + "Error: Incorrect results!", + 1e-4, + 1e-4) + ? 0 + : 1; + } + else + { + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } } return 0; diff --git a/example/26_contraction/run_contraction_scale_example.inc b/example/26_contraction/run_contraction_scale_example.inc index 67f29dbc36..a7451fab71 100644 --- a/example/26_contraction/run_contraction_scale_example.inc +++ b/example/26_contraction/run_contraction_scale_example.inc @@ -216,7 +216,20 @@ int run_contraction_scale_example(int argc, char* argv[]) } } - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + if(ck::is_gfx11_supported()) + { + return ck::utils::check_err(e_ms_ns_device_result, + e_ms_ns_host_result, + "Error: Incorrect results!", + 1e-4, + 1e-4) + ? 0 + : 1; + } + else + { + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } } return 0; diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 47c47334e7..f79494a478 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -12,40 +12,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -} - -template -constexpr ck_tile::index_t get_k_warp_tile_flatmm() -{ -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif -} - struct GemmConfigBase { static constexpr bool kPadM = false; @@ -122,7 +88,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; @@ -141,7 +108,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; @@ -160,7 +128,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; @@ -204,7 +173,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; @@ -223,7 +193,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; @@ -242,7 +213,8 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; @@ -282,7 +254,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; @@ -306,7 +279,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp index 0fcf9680bc..4a83a2c4ab 100644 --- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp +++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp @@ -63,25 +63,30 @@ struct UniversalInvoker const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index c5a400b4dd..67b411c1f0 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -11,40 +11,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -} - -template -constexpr ck_tile::index_t get_k_warp_tile_flatmm() -{ -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif -} - template struct GemmTypeConfig; @@ -111,7 +77,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; @@ -134,7 +101,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; @@ -157,7 +125,8 @@ struct GemmConfigComputeV4_V2 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; @@ -178,7 +147,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool kPadK = true; @@ -203,7 +173,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp index 30a25d83d7..2724834bb5 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp @@ -11,24 +11,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -} - struct GemmConfigBase { static constexpr bool kPadM = false; diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp index 0317685770..1fa8a03087 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp @@ -10,40 +10,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -} - -template -constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif -} - template struct GemmTypeConfig; @@ -100,7 +66,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); }; template @@ -117,7 +84,7 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - get_k_from_preshuffled_warp_tile(); + ck_tile::get_k_warp_tile(); static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 7a4760e1da..37fc998e5b 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -24,39 +24,6 @@ inline size_t hash_multiple_strings(const std::vector& inputs) return combined_hash; } -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -} -template -constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif -} - template static constexpr inline auto is_row_major(Layout layout_) { @@ -124,7 +91,8 @@ struct GemmConfigQuantDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); }; template @@ -140,7 +108,8 @@ struct GemmConfigRowColQuant : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); }; template @@ -157,7 +126,7 @@ struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - get_k_from_preshuffled_warp_tile(); + ck_tile::get_k_warp_tile(); static constexpr bool PreshuffleQuant = true; }; @@ -176,7 +145,7 @@ struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - get_k_from_preshuffled_warp_tile(); + ck_tile::get_k_warp_tile(); static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; @@ -206,7 +175,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - get_k_from_preshuffled_warp_tile(); + ck_tile::get_k_warp_tile(); static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; @@ -236,7 +205,8 @@ struct GemmConfigQuantPrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); }; template diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 9a9c2235e0..99e7479e36 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -84,63 +84,46 @@ namespace ck_tile::builder::factory { // CK Tile kernel template -consteval bool IsTileAlgorithm() -{ - return ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && SpecifiesTileTransfer && - SpecifiesTileConvSpecialization && SpecifiesTileBlockGemm && - SpecifiesTileOptimizations; -} +concept IsTileAlgorithm = ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && + SpecifiesTileTransfer && SpecifiesTileConvSpecialization && + SpecifiesTileBlockGemm && SpecifiesTileOptimizations; // XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline) template -consteval bool IsXdlV3Algorithm() -{ - return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && - SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesBlockGemm; -} +concept IsXdlV3Algorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesBlockGemm; // Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply) template -consteval bool IsXdlAlgorithm() -{ - return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && - SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && - SpecifiesLoopScheduler; -} +concept IsXdlAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && + SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; // WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions) template -consteval bool IsWmmaAlgorithm() -{ - return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && - SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; -} +concept IsWmmaAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; // Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts template -consteval bool IsDlAlgorithm() -{ - return ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && - SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; -} +concept IsDlAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && + SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; // XDL-based kernel with large tensor support template -consteval bool IsLargeTensorAlgorithm() -{ - return IsXdlAlgorithm() && SpecifiesLargeTensorSupport; -} +concept IsLargeTensorAlgorithm = + IsXdlAlgorithm && SpecifiesLargeTensorSupport; template ; // CK Tile supports common factory for each direction - if constexpr(IsTileAlgorithm()) + if constexpr(IsTileAlgorithm) { return typename ConvTileFactory::Instance{}; } else if constexpr(ConvDirectionIsForward) { - if constexpr(IsXdlV3Algorithm()) + if constexpr(IsXdlV3Algorithm) { return typename ConvFwdXdlV3Factory::Instance{}; } - else if constexpr(IsXdlAlgorithm()) + else if constexpr(IsXdlAlgorithm) { return typename ConvFwdXdlFactory::Instance{}; } - else if constexpr(IsWmmaAlgorithm()) + else if constexpr(IsWmmaAlgorithm) { return typename ConvFwdWmmaFactory::Instance{}; } - else if constexpr(IsDlAlgorithm()) + else if constexpr(IsDlAlgorithm) { return typename ConvFwdDlFactory::Instance{}; } - else if constexpr(IsLargeTensorAlgorithm()) + else if constexpr(IsLargeTensorAlgorithm) { return typename ConvFwdLargeTensorFactory::Instance{}; } diff --git a/experimental/builder/include/ck_tile/builder/testing/type_traits.hpp b/experimental/builder/include/ck_tile/builder/testing/type_traits.hpp index 14b8e75668..8db0e5d25d 100644 --- a/experimental/builder/include/ck_tile/builder/testing/type_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/type_traits.hpp @@ -31,13 +31,20 @@ constexpr size_t data_type_sizeof(DataType data_type) { case DataType::UNDEFINED_DATA_TYPE: return 0; case DataType::FP32: return 4; + case DataType::FP32_FP32: return 8; case DataType::FP16: return 2; + case DataType::FP16_FP16: return 4; case DataType::BF16: return 2; + case DataType::BF16_BF16: return 4; case DataType::FP8: return 1; + case DataType::BF8: return 1; + case DataType::FP64: return 8; case DataType::INT32: return 4; case DataType::I8: return 1; + case DataType::I8_I8: return 2; case DataType::U8: return 1; } + return 0; // Default case to ensure all control paths return a value } } // namespace ck_tile::builder::test diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 8739f65740..43e9350f8f 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -72,7 +72,12 @@ inline bool is_xdl_supported() is_gfx12_supported() || is_gfx11_supported(); } -template +template inline bool is_xdl_wmma_supported() { if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || @@ -82,7 +87,7 @@ inline bool is_xdl_wmma_supported() } else if(is_gfx12_supported() || is_gfx11_supported()) { - if constexpr((MPerXDL != 16) || (NPerXDL != 16)) + if constexpr((MPerXDL32 != 16) || (NPerXDL32 != 16)) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 3e37aac86e..9179a279c5 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -17,6 +17,7 @@ #endif #endif #include "ck/utility/get_id.hpp" +#include "ck/utility/sequence.hpp" namespace ck { namespace tensor_operation { @@ -96,6 +97,57 @@ static constexpr auto GetNXdlPerWave2() IsWave64>(); \ } +template +static constexpr auto GetWarpTileConfig() +{ + constexpr auto MXdlPerWave64 = MXdlPerWave_; + constexpr auto MXdlPerWave32 = MXdlPerWave_ * MPerXDL_ / 16; + constexpr auto CShuffleMXdlPerWavePerShuffle32 = CShuffleMXdlPerWavePerShuffle_ * MPerXDL_ / 16; + + constexpr auto NXdlPerWave = + IsWave64 + ? GetNXdlPerWave2() + : GetNXdlPerWave2(); + + if constexpr(IsWave64 == false && NXdlPerWave != 0) + { + constexpr auto CShuffleNXdlPerWavePerShuffle32 = + NXdlPerWave >= CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16 + ? CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16 + : CShuffleNXdlPerWavePerShuffle_; + static_assert(CShuffleNXdlPerWavePerShuffle32 > 0); + return Sequence<16, + 16, + MXdlPerWave32, + NXdlPerWave, + CShuffleMXdlPerWavePerShuffle32, + CShuffleNXdlPerWavePerShuffle32>{}; + } + else + { + return Sequence{}; + } +} + #define INVOKER_RUN_IMPL \ float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \ { \ diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 83dbebb8d6..fff435f1c2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -166,11 +166,27 @@ struct DeviceContractionMultipleD_Xdl_CShuffle { using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -321,7 +337,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, @@ -340,10 +356,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -360,13 +376,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = @@ -588,7 +604,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } @@ -783,6 +804,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle << MPerBlock << ", " << NPerBlock << ", " << KPerBlock << ", " + << MPerXDL << ", " + << NPerXDL << ", " << AK1 << ", " << BK1 << ", " << ABlockTransferSrcVectorDim << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index ec48beb789..1db9fd45b8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -620,7 +620,44 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK 1) + // TODO: Enable splitK + if(a.k_batch > 1) + { + if constexpr(std::is_same_v) + { + if(a.StrideC != a.N) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " SplitK (k_batch=" << a.k_batch + << ") requires contiguous output stride." + << " For RowMajor layout: StrideC must equal N." + << " Got StrideC=" << a.StrideC << ", N=" << a.N << std::endl; + } + return false; + } + } + else if constexpr(std::is_same_v) + { + if(a.StrideC != a.M) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " SplitK (k_batch=" << a.k_batch + << ") requires contiguous output stride." + << " For ColumnMajor layout: StrideC must equal M." + << " Got StrideC=" << a.StrideC << ", M=" << a.M << std::endl; + } + return false; + } + } + } + bool group_arg_valid = false; if(isWave64) { diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 0e7d1def75..08d555d27c 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -366,6 +366,26 @@ struct amdgcn_compiler_target_state #else static constexpr bool CK_TILE_ARCH_GFX1010 = false; #endif +#if defined(__gfx1011__) + static constexpr bool CK_TILE_ARCH_GFX1011 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1011 = false; +#endif +#if defined(__gfx1012__) + static constexpr bool CK_TILE_ARCH_GFX1012 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1012 = false; +#endif +#if defined(__gfx1013__) + static constexpr bool CK_TILE_ARCH_GFX1013 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1013 = false; +#endif +#if defined(__gfx10_1_generic__) + static constexpr bool CK_TILE_ARCH_GFX10_1_GENERIC = true; +#else + static constexpr bool CK_TILE_ARCH_GFX10_1_GENERIC = false; +#endif // __gfx10_1_generic__ #if defined(__gfx1030__) static constexpr bool CK_TILE_ARCH_GFX1030 = true; @@ -504,6 +524,10 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1010, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1011, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1012, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1013, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX10_1_GENERIC, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \ diff --git a/include/ck_tile/host/tensor_shuffle_utils.hpp b/include/ck_tile/host/tensor_shuffle_utils.hpp index a1edce804f..5c99ae8a1c 100644 --- a/include/ck_tile/host/tensor_shuffle_utils.hpp +++ b/include/ck_tile/host/tensor_shuffle_utils.hpp @@ -68,7 +68,7 @@ auto shuffle_bq(const ck_tile::HostTensor* t, int block_bq_k) } template -auto shuffle_b(const ck_tile::HostTensor& t) +auto shuffle_b(const ck_tile::HostTensor& t, const GemmConfig& gemmConfig) { assert(t.get_lengths().size() == 2); int n_ = t.get_lengths()[1]; @@ -78,10 +78,10 @@ auto shuffle_b(const ck_tile::HostTensor& t) { constexpr int divisor = 2; constexpr int kABK1PerLane = 8; - constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, + int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / gemmConfig.N_Warp_Tile, + gemmConfig.N_Warp_Tile, + k_ / gemmConfig.K_Warp_Tile, kABK0PerLane, divisor, kABK1PerLane}); @@ -98,18 +98,24 @@ auto shuffle_b(const ck_tile::HostTensor& t) else { assert(is_wave32() == false); - divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4; } - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, + ck_tile::HostTensor t_view({n_ / gemmConfig.N_Warp_Tile, + gemmConfig.N_Warp_Tile, + k_ / gemmConfig.K_Warp_Tile, divisor, - GemmConfig::K_Warp_Tile / divisor}); + gemmConfig.K_Warp_Tile / divisor}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } } +template +auto shuffle_b(const ck_tile::HostTensor& t) +{ + return shuffle_b(t, GemmConfig{}); +} + template auto bq_permuteN(const ck_tile::HostTensor& t, index_t group_n) { @@ -129,22 +135,22 @@ auto bq_permuteN(const ck_tile::HostTensor& t, index_t group_n) } template -auto shuffle_b_permuteN(const ck_tile::HostTensor& t) +auto shuffle_b_permuteN(const ck_tile::HostTensor& t, const GemmConfig& gemmConfig) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp; if(ck_tile::is_gfx12_supported()) { constexpr int divisor = 2; constexpr int kABK1PerLane = 8; - constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, - GemmConfig::N_Warp, - GemmConfig::N_Warp_Tile, + int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / gemmConfig.N_Tile, + gemmConfig.N_Warp, + gemmConfig.N_Warp_Tile, NRepeat, - k_ / GemmConfig::K_Warp_Tile, + k_ / gemmConfig.K_Warp_Tile, kABK0PerLane, divisor, kABK1PerLane}); @@ -161,17 +167,23 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor& t) else { assert(is_wave32() == false); - divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4; } - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, - GemmConfig::N_Warp, - GemmConfig::N_Warp_Tile, + ck_tile::HostTensor t_view({n_ / gemmConfig.N_Tile, + gemmConfig.N_Warp, + gemmConfig.N_Warp_Tile, NRepeat, - k_ / GemmConfig::K_Warp_Tile, + k_ / gemmConfig.K_Warp_Tile, divisor, - GemmConfig::K_Warp_Tile / divisor}); + gemmConfig.K_Warp_Tile / divisor}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); } } + +template +auto shuffle_b_permuteN(const ck_tile::HostTensor& t) +{ + return shuffle_b_permuteN(t, GemmConfig{}); +} } // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index ad1862306a..53bfa6041d 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -35,7 +35,8 @@ template // The number of continuous xdl_output per warp + index_t BlockedXDLN_PerWarp_ = 1, // The number of continuous xdl_output per warp + bool DoubleSmemBuffer_ = false> struct CShuffleEpilogueProblem { using AsDataType = remove_cvref_t; @@ -59,6 +60,7 @@ struct CShuffleEpilogueProblem static constexpr bool FixedVectorSize = FixedVectorSize_; static constexpr index_t VectorSizeC = VectorSizeC_; static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_; + static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_; static constexpr index_t kNumWaveGroups = kNumWaveGroups_; static constexpr index_t NumDTensor = DsDataType::size(); @@ -118,6 +120,7 @@ struct CShuffleEpilogue static constexpr bool FixedVectorSize = Problem::FixedVectorSize; static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN; static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp; + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t MPerIteration = MPerXdl * MWave; static constexpr index_t NPerIteration = NPerXdl * NWave; @@ -204,6 +207,26 @@ struct CShuffleEpilogue } return max_vector_size / sizeof(DiDataType); } + + /** + * @brief Shuffle tile configuration parameters check and aligment + * + * @details Return tuple(1, 1) if shuffle_tile values are too large for SMEM. + */ + template + CK_TILE_HOST_DEVICE static constexpr auto AlignShuffleTileWithSmem() + { + constexpr index_t m_val = MPerXdl * MWave * m_shuffle_tile; + constexpr index_t n_val = NPerXdl * NWave * n_shuffle_tile; + + constexpr auto shuffle_tile = + m_val * n_val * sizeof(ODataType) > get_smem_capacity() || DoubleSmemBuffer + ? std::make_tuple(1, 1) + : std::make_tuple(m_shuffle_tile, n_shuffle_tile); + + return shuffle_tile; + } + /** * @brief Shuffle tile configuration parameters * @@ -214,20 +237,23 @@ struct CShuffleEpilogue */ static constexpr auto shuffle_tile_tuple = [] { constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size(); - if constexpr(elem_per_thread >= GetVectorSizeC()) + if constexpr(elem_per_thread <= GetVectorSizeC()) { return std::make_tuple(1, 1); } else { - constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread; + constexpr index_t num_xdl_shuffles = elem_per_thread / GetVectorSizeC(); + static_assert(elem_per_thread % GetVectorSizeC() == 0); if constexpr(std::is_same_v) { static_assert((kMPerBlock % (MPerXdl * MWave) == 0) && (kMPerBlock % num_xdl_shuffles == 0), "kMPerBlock must be divisible by MPerXdl*MWave and " "num_xdl_shuffles for CShuffleEpilogue"); - return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1); + return AlignShuffleTileWithSmem(); } else { @@ -235,7 +261,9 @@ struct CShuffleEpilogue (kNPerBlock % num_xdl_shuffles == 0), "kNPerBlock must be divisible by NPerXdl*NWave and " "num_xdl_shuffles for CShuffleEpilogue"); - return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave))); + return AlignShuffleTileWithSmem<1, + min(num_xdl_shuffles, + kNPerBlock / (NPerXdl * NWave))>(); } } }(); diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 8adbfb9723..2761b16571 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -232,7 +232,7 @@ struct BatchedGemmKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr1[GetSmemSize()]; + __shared__ char smem_ptr1[GemmPipeline::GetSmemSize()]; UniversalGemmKernel::RunGemm2LDS({a_ptr}, {b_ptr}, {/*ds_ptr*/}, diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 838fc236d2..95114e8496 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -310,7 +310,7 @@ struct GroupedGemmKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr_1[GetSmemSize()]; + __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; RunGemmWithPipelineSelection2LDS(a_ptr, b_ptr, c_ptr, diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 866a4cc693..5f7e78fac2 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -1084,7 +1084,7 @@ struct UniversalGemmKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr_1[GetSmemSize()]; + __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value)) @@ -1169,7 +1169,7 @@ struct UniversalGemmKernel // Run the GEMM if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr_1[GetSmemSize()]; + __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && EpiloguePipeline::GetVectorSizeC() % 2 != 0 && diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp index 8029f6a2c7..aa8469be4f 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp @@ -43,4 +43,26 @@ struct TileGemmShape } }; +template +constexpr index_t get_k_warp_tile() +{ +#if CK_TILE_USE_WMMA + return 16; +#else +#if defined(CK_GFX950_SUPPORT) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 16 : 32; + else + return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 32 : 64; +#endif +#endif +} + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 628e9194ae..cb452043d1 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -61,6 +61,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using BQDataType = remove_cvref_t; + using BLayout = remove_cvref_t; using BQLayout = remove_cvref_t; using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; @@ -156,9 +157,11 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using CDataType = remove_cvref_t; // BDataType gets converted from PkInt4 during loading - using OverrideBDataType = - std::conditional_t, ADataType, BDataType>; - + using OverrideBDataType = std::conditional_t< + std::is_same_v && + std::is_same_v, + ADataType, + BDataType>; using Base = BlockGemmBQuantBase; using WarpGemm = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index d9af5cce1f..3e97380374 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -1324,7 +1324,7 @@ struct QuantGemmKernel assert(kargs.k_batch == 1); if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr_1[GetSmemSize()]; + __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; RunGemm2LDS(a_ptr, b_ptr, diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 726f678d37..7e246961cb 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -325,7 +325,7 @@ struct QuantGroupedGemmKernel kQuantType == QuantType::BQuantGrouped) { - __shared__ char smem_ptr_1[GetSmemSize()]; + __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; RunGemmWithPipelineSelection2LDS(a_ptr, b_ptr, aq_ptr, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 2c191cc2b4..f6ebbd9228 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -33,9 +33,17 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using QuantGroupSize = remove_cvref_t; + using ALayout = remove_cvref_t; + using BQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + // BDataType gets converted from PkInt4 during loading using OverrideBDataType = - std::conditional_t, ADataType, BDataType>; + std::conditional_t && + std::is_same_v, + ADataType, + BDataType>; static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); using I0 = number<0>; @@ -50,11 +58,6 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3>::PackedSize; - using ALayout = remove_cvref_t; - using BQLayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - using BlockGemm = remove_cvref_t())>; static constexpr index_t BlockSize = Problem::kBlockSize; @@ -184,6 +187,23 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_dram_window); } + template + CK_TILE_DEVICE void + BGlobalPrefetch(BBlockTile_& b_block_tile, + BDramWindow& b_copy_dram_window, + const BDramTileWindowStep& b_dram_tile_window_step) const + { + if constexpr(!std::is_same_v) + { + LoadAndConvertBTile(b_block_tile, b_copy_dram_window); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + } + else + { + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + } + } + template (ABlockTileDistr{})); using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); + decltype(make_static_distributed_tensor(BBlockTileDistr{})); using BQBlockTile = decltype(make_static_distributed_tensor(BQBlockTileDistr{})); @@ -289,8 +309,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -322,8 +341,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3::value)) diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index c9e81d4744..2e80ff64c1 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -1005,7 +1005,7 @@ struct GroupedConvolutionBackwardWeightKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr_1[GetSmemSize()]; + __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && GroupedConvTraitsType_::VectorSizeC % 2 != 0 && diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index a9f3274805..0f143d7ff7 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -1184,7 +1184,7 @@ struct GroupedConvolutionForwardKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr_1[GetSmemSize()]; + __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && GroupedConvTraitsType_::VectorSizeC % 2 != 0 && diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index a0c078a1e9..e949ed45e6 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -202,7 +202,13 @@ class TestCkTileGemmPipeline : public ::testing::Test N_Warp_Tile, K_Warp_Tile, UniversalGemmProblem::TransposeC, - memory_operation>>; + memory_operation, + 1, /*kNumWaveGroups_*/ + false, /*FixedVectorSize_*/ + 1, /*VectorSizeC_*/ + false, /*TiledMMAPermuteN_*/ + 1, /*BlockedXDLN_PerWarp_*/ + DoubleSmemBuffer /*DoubleSmemBuffer*/>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 2b0ffaafa2..1542275916 100755 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -11,25 +11,93 @@ list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") # Typed Test Suite for GEMM Quantization - split into multiple files to reduce compile time - # AQuant tests - add_gtest_executable(test_tile_gemm_quant_aquant - test_gemm_quant_aquant.cpp + # AQuant tests - split into 6 files + add_gtest_executable(test_tile_gemm_quant_aquant_base_rcr + test_gemm_quant_aquant_base_rcr.cpp ) - target_compile_options(test_tile_gemm_quant_aquant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_tile_gemm_quant_aquant_base_rcr PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) - # BQuant tests (without PreshuffleB) - add_gtest_executable(test_tile_gemm_quant_bquant - test_gemm_quant_bquant.cpp + add_gtest_executable(test_tile_gemm_quant_aquant_base_rrr_crr + test_gemm_quant_aquant_base_rrr_crr.cpp ) - target_compile_options(test_tile_gemm_quant_bquant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_tile_gemm_quant_aquant_base_rrr_crr PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) - # BQuant tests (with PreshuffleB) - # disabling this test until it can be built within reasonable time! - # currently taking ~50 minutes on gfx12! - #add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle - # test_gemm_quant_bquant_preshuffle.cpp - #) - #target_compile_options(test_tile_gemm_quant_bquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_tile_gemm_quant_aquant_base_ccr + test_gemm_quant_aquant_base_ccr.cpp + ) + target_compile_options(test_tile_gemm_quant_aquant_base_ccr PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_aquant_prefill + test_gemm_quant_aquant_prefill.cpp + ) + target_compile_options(test_tile_gemm_quant_aquant_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_aquant_transpose_c + test_gemm_quant_aquant_transpose_c.cpp + ) + target_compile_options(test_tile_gemm_quant_aquant_transpose_c PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_aquant_preshuffle + test_gemm_quant_aquant_preshuffle.cpp + ) + target_compile_options(test_tile_gemm_quant_aquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + # BQuant tests (without PreshuffleB) - split into 6 files + add_gtest_executable(test_tile_gemm_quant_bquant_1d_128 + test_gemm_quant_bquant_1d_128.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_1d_128 PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_bquant_1d_64 + test_gemm_quant_bquant_1d_64.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_1d_64 PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_bquant_2d_small_n + test_gemm_quant_bquant_2d_small_n.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_2d_small_n PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_bquant_2d_medium_n + test_gemm_quant_bquant_2d_medium_n.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_2d_medium_n PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_bquant_2d_large_n + test_gemm_quant_bquant_2d_large_n.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_2d_large_n PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_bquant_transpose + test_gemm_quant_bquant_transpose.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_transpose PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + # BQuant tests (with PreshuffleB) - split into 5 files + add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_decode_1d + test_gemm_quant_bquant_preshuffle_decode_1d.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_preshuffle_decode_1d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_prefill_1d + test_gemm_quant_bquant_preshuffle_prefill_1d.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_preshuffle_prefill_1d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_tiled_permute + test_gemm_quant_bquant_preshuffle_tiled_permute.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_preshuffle_tiled_permute PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_decode_2d + test_gemm_quant_bquant_preshuffle_decode_2d.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_preshuffle_decode_2d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_prefill_2d + test_gemm_quant_bquant_preshuffle_prefill_2d.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_preshuffle_prefill_2d PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) # RowColQuant tests add_gtest_executable(test_tile_gemm_quant_rowcol @@ -42,6 +110,34 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") test_gemm_quant_tensor.cpp ) target_compile_options(test_tile_gemm_quant_tensor PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + # Umbrella target to build all gemm quant tests + add_custom_target(test_tile_gemm_quant_all) + add_dependencies(test_tile_gemm_quant_all + # AQuant tests + test_tile_gemm_quant_aquant_base_rcr + test_tile_gemm_quant_aquant_base_rrr_crr + test_tile_gemm_quant_aquant_base_ccr + test_tile_gemm_quant_aquant_prefill + test_tile_gemm_quant_aquant_transpose_c + test_tile_gemm_quant_aquant_preshuffle + # BQuant tests + test_tile_gemm_quant_bquant_1d_128 + test_tile_gemm_quant_bquant_1d_64 + test_tile_gemm_quant_bquant_2d_small_n + test_tile_gemm_quant_bquant_2d_medium_n + test_tile_gemm_quant_bquant_2d_large_n + test_tile_gemm_quant_bquant_transpose + # BQuant preshuffle tests + test_tile_gemm_quant_bquant_preshuffle_decode_1d + test_tile_gemm_quant_bquant_preshuffle_prefill_1d + test_tile_gemm_quant_bquant_preshuffle_tiled_permute + test_tile_gemm_quant_bquant_preshuffle_decode_2d + test_tile_gemm_quant_bquant_preshuffle_prefill_2d + # Other quant tests + test_tile_gemm_quant_rowcol + test_tile_gemm_quant_tensor + ) else() message(DEBUG "Skipping ck_tile quant gemm tests for current target") endif() diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp deleted file mode 100644 index b6e69cd649..0000000000 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm.hpp" - -#include -#include - -#include "test_gemm_quant_fixtures.hpp" - -// Type aliases for readability -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -using FP8 = ck_tile::fp8_t; -using BF8 = ck_tile::bf8_t; -using Half = ck_tile::half_t; -using PkInt4 = ck_tile::pk_int4_t; -using AQuantGrouped = std::integral_constant; -using BQuantGrouped = std::integral_constant; -using RowColQuant = std::integral_constant; -using TensorQuant = std::integral_constant; -using GroupSize = ck_tile::QuantGroupShape>; -using GroupSize64 = ck_tile::QuantGroupShape>; - -// 2d block sizes for BQuant -using GroupSize2D8N = ck_tile::QuantGroupShape>; -using GroupSize2D16N = ck_tile::QuantGroupShape>; -using GroupSize2D32N = ck_tile::QuantGroupShape>; -using GroupSize2D64N = ck_tile::QuantGroupShape>; -using GroupSize2D128N = ck_tile::QuantGroupShape>; - -// Type combinations for AQuant tests -// Tuple format: -// clang-format off -using AQuantTypes = ::testing::Types< - // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // RRR layout (RowMajor A, RowMajor B, RowMajor C with RowMajor AQ) - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // CRR layout (ColumnMajor A, RowMajor B, RowMajor C with RowMajor AQ) - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // CCR layout (ColumnMajor A, ColumnMajor B, RowMajor C with ColumnMajor AQ) - NEW layout support - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // RCR layout - with the Prefill BlockTile Config. - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // PreshuffleQuant = false && TransposeC = true (with RowMajor AQ) - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // PreshuffleQuant = true && TransposeC = false (with RowMajor AQ - PreshuffleQuant only supports RowMajor) - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // PreshuffleQuant = true && TransposeC = true (with RowMajor AQ - PreshuffleQuant only supports RowMajor) - std::tuple, - std::tuple, - std::tuple, - std::tuple ->; -// clang-format on - -// Test suite for AQuant -TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTypes); - -// AQuant tests -TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_ccr.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_ccr.cpp new file mode 100644 index 0000000000..0e04f9fc9e --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_ccr.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - CCR layout +// Tuple format: +// clang-format off +using AQuantBaseCCRTypes = ::testing::Types< + // CCR layout (ColumnMajor A, ColumnMajor B, RowMajor C with ColumnMajor AQ) - NEW layout support + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Base CCR +TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantBaseCCRTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rcr.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rcr.cpp new file mode 100644 index 0000000000..da32c06304 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rcr.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - RCR layout base configuration +// Tuple format: +// clang-format off +using AQuantBaseRCRTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Base RCR +TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantBaseRCRTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rrr_crr.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rrr_crr.cpp new file mode 100644 index 0000000000..6e90c44764 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rrr_crr.cpp @@ -0,0 +1,46 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - RRR and CRR layouts +// Tuple format: +// clang-format off +using AQuantBaseRRRCRRTypes = ::testing::Types< + // RRR layout (RowMajor A, RowMajor B, RowMajor C with RowMajor AQ) + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // CRR layout (ColumnMajor A, RowMajor B, RowMajor C with RowMajor AQ) + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Base RRR/CRR +TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantBaseRRRCRRTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp new file mode 100644 index 0000000000..133c11860a --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - Prefill Configuration +// Tuple format: +// clang-format off +using AQuantPrefillTypes = ::testing::Types< + // RCR layout - with the Prefill BlockTile Config. + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Prefill +TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantPrefillTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_preshuffle.cpp new file mode 100644 index 0000000000..35d15f9354 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_preshuffle.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - PreshuffleQuant Configurations +// Tuple format: +// clang-format off +using AQuantPreshuffleTypes = ::testing::Types< + // PreshuffleQuant = true && TransposeC = false (with RowMajor AQ - PreshuffleQuant only supports RowMajor) + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // PreshuffleQuant = true && TransposeC = true (with RowMajor AQ - PreshuffleQuant only supports RowMajor) + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Preshuffle +TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantPreshuffleTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_transpose_c.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_transpose_c.cpp new file mode 100644 index 0000000000..a2a4c2c38b --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_transpose_c.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - TransposeC Configuration +// Tuple format: +// clang-format off +using AQuantTransposeCTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = true (with RowMajor AQ) + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant TransposeC +TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTransposeCTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp deleted file mode 100644 index 4b1ad068a7..0000000000 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm.hpp" - -#include -#include - -#include "test_gemm_quant_fixtures.hpp" - -// Type aliases for readability -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -using FP8 = ck_tile::fp8_t; -using BF8 = ck_tile::bf8_t; -using Half = ck_tile::half_t; -using PkInt4 = ck_tile::pk_int4_t; -using BF16 = ck_tile::bf16_t; -using UInt8 = ck_tile::pk_fp4_raw_t; -using BQuantGrouped = std::integral_constant; -using GroupSize = ck_tile::QuantGroupShape>; -using GroupSize64 = ck_tile::QuantGroupShape>; -using GroupSize32 = ck_tile::QuantGroupShape>; - -// 2d block sizes for BQuant -using GroupSize2D8N = ck_tile::QuantGroupShape>; -using GroupSize2D16N = ck_tile::QuantGroupShape>; -using GroupSize2D32N = ck_tile::QuantGroupShape>; -using GroupSize2D64N = ck_tile::QuantGroupShape>; -using GroupSize2D128N = ck_tile::QuantGroupShape>; - -// Type combinations for BQuant tests (without PreshuffleB) -// Tuple format: -// clang-format off -using BQuantTypes = ::testing::Types< - // 1d cases with grouping only on k axis - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - std::tuple, - - // 2d cases with grouping also on the n axis - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // some cases with transpose layouts - std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, - std::tuple, - std::tuple, - std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, - std::tuple, - std::tuple, - - // pkint4 + transpose cases - std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, - std::tuple, - std::tuple, - std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, - std::tuple, - std::tuple ->; -// clang-format on - -// Test suite for BQuant (without PreshuffleB) -TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTypes); - -// BQuant tests -TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp new file mode 100644 index 0000000000..d491d89ef4 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 128 +// Tuple format: +// clang-format off +using BQuant1D128Types = ::testing::Types< + // 1d cases with grouping only on k axis + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant 1D 128 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D128Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp new file mode 100644 index 0000000000..1019caf1bc --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize64 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 64 +// Tuple format: +// clang-format off +using BQuant1D64Types = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant 1D 64 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D64Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_large_n.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_large_n.cpp new file mode 100644 index 0000000000..a8b6dcd14b --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_large_n.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 2D Large N (128N) +// Tuple format: +// clang-format off +using BQuant2DLargeNTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant 2D Large N +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant2DLargeNTypes); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_medium_n.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_medium_n.cpp new file mode 100644 index 0000000000..67d52ef874 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_medium_n.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; + +// 2d block sizes for BQuant +using GroupSize2D32N = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 2D Medium N (32N and 64N) +// Tuple format: +// clang-format off +using BQuant2DMediumNTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant 2D Medium N +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant2DMediumNTypes); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_small_n.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_small_n.cpp new file mode 100644 index 0000000000..865713992d --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_small_n.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; + +// 2d block sizes for BQuant +using GroupSize2D8N = ck_tile::QuantGroupShape>; +using GroupSize2D16N = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 2D Small N (8N and 16N) +// Tuple format: +// clang-format off +using BQuant2DSmallNTypes = ::testing::Types< + // 2d cases with grouping also on the n axis + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant 2D Small N +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant2DSmallNTypes); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp deleted file mode 100644 index ae01bddf96..0000000000 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm.hpp" - -#include -#include - -#include "test_gemm_quant_fixtures.hpp" - -// Type aliases for readability -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -using FP8 = ck_tile::fp8_t; -using BF8 = ck_tile::bf8_t; -using Half = ck_tile::half_t; -using PkInt4 = ck_tile::pk_int4_t; -using BQuantGrouped = std::integral_constant; -using GroupSize = ck_tile::QuantGroupShape>; - -// 2d block sizes for BQuant -using GroupSize2D8N = ck_tile::QuantGroupShape>; -using GroupSize2D16N = ck_tile::QuantGroupShape>; -using GroupSize2D32N = ck_tile::QuantGroupShape>; -using GroupSize2D64N = ck_tile::QuantGroupShape>; - -// Type combinations for BQuant tests with PreshuffleB -// Tuple format: -// clang-format off -using BPreshuffleBQuantTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // //2d cases with preshuffle B - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple ->; -// clang-format on - -// Test suite for BQuant with PreshuffleB -TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleBQuantTypes); - -// BQuant PreshuffleB tests -TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_1d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_1d.cpp new file mode 100644 index 0000000000..cf599ebbfd --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_1d.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant Preshuffle tests - Decode Config 1D +// Tuple format: +// clang-format off +using BPreshuffleDecode1DTypes = ::testing::Types< + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant Preshuffle Decode 1D +TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleDecode1DTypes); + +// BQuant PreshuffleB tests +TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp new file mode 100644 index 0000000000..65ea165b10 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp @@ -0,0 +1,51 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; + +// 2d block sizes for BQuant +using GroupSize2D8N = ck_tile::QuantGroupShape>; +using GroupSize2D16N = ck_tile::QuantGroupShape>; +using GroupSize2D32N = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant Preshuffle tests - Decode 2D +// Tuple format: +// clang-format off +using BPreshuffleDecode2DTypes = ::testing::Types< + // 2d cases with preshuffle B + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant Preshuffle Decode 2D +TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleDecode2DTypes); + +// BQuant PreshuffleB tests +TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_1d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_1d.cpp new file mode 100644 index 0000000000..3f6dd225d7 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_1d.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant Preshuffle tests - Prefill Config 1D +// Tuple format: +// clang-format off +using BPreshufflePrefill1DTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant Preshuffle Prefill 1D +TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshufflePrefill1DTypes); + +// BQuant PreshuffleB tests +TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp new file mode 100644 index 0000000000..368204987a --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; + +// 2d block sizes for BQuant +using GroupSize2D8N = ck_tile::QuantGroupShape>; +using GroupSize2D16N = ck_tile::QuantGroupShape>; +using GroupSize2D32N = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant Preshuffle tests - Prefill 2D +// Tuple format: +// clang-format off +using BPreshufflePrefill2DTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant Preshuffle Prefill 2D +TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshufflePrefill2DTypes); + +// BQuant PreshuffleB tests +TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_tiled_permute.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_tiled_permute.cpp new file mode 100644 index 0000000000..8a05f5812a --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_tiled_permute.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant Preshuffle tests - TiledPermuteN Config +// Tuple format: +// clang-format off +using BPreshuffleTiledPermuteTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant Preshuffle TiledPermuteN +TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleTiledPermuteTypes); + +// BQuant PreshuffleB tests +TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp new file mode 100644 index 0000000000..230dd8f0fc --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp @@ -0,0 +1,53 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize64 = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - Transpose Layouts +// Tuple format: +// clang-format off +using BQuantTransposeTypes = ::testing::Types< + // some cases with transpose layouts + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, + std::tuple, + std::tuple, + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, + std::tuple, + std::tuple, + + // pkint4 + transpose cases + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, + std::tuple, + std::tuple, + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant Transpose +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTransposeTypes); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp index 5628b6feae..a7189e7865 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -11,26 +11,6 @@ #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" -template -constexpr ck_tile::index_t get_k_warp_tile_flatmm() -{ -#if CK_TILE_USE_WMMA - return 16; -#else -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif -#endif -} - template class TestCkTileGroupedGemmPreshuffle : public ::testing::Test { @@ -67,7 +47,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test static const ck_tile::index_t M_Warp_Tile = 16; static const ck_tile::index_t N_Warp_Tile = 16; - static const ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + static const ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem static constexpr bool TransposeC = false; // transpose c is not supported @@ -101,46 +82,6 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>); } - template - auto shuffle_b(const ck_tile::HostTensor& t) - { - assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - - if(ck_tile::is_gfx12_supported()) - { - constexpr int divisor = 2; - constexpr int kABK1PerLane = 8; - constexpr int kABK0PerLane = K_Warp_Tile / divisor / kABK1PerLane; - ck_tile::HostTensor t_view({n_ / N_Warp_Tile, - N_Warp_Tile, - k_ / K_Warp_Tile, - kABK0PerLane, - divisor, - kABK1PerLane}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); - } - else - { - int divisor = 1; - if(ck_tile::is_gfx11_supported()) - { - divisor = 1; - } - else - { - assert(is_wave32() == false); - divisor = N_Warp_Tile == 32 ? 2 : 4; - } - ck_tile::HostTensor t_view( - {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); - } - } - template void invoke_grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, @@ -340,6 +281,14 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test } } + struct BShuffleGemmConfig + { + static constexpr ck_tile::index_t N_Warp_Tile = + TestCkTileGroupedGemmPreshuffle::N_Warp_Tile; + static constexpr ck_tile::index_t K_Warp_Tile = + TestCkTileGroupedGemmPreshuffle::K_Warp_Tile; + }; + public: void Run(const std::vector& Ms, const std::vector& Ns, @@ -424,7 +373,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); // Host-side preshuffle of B - auto b_shuffle_host = shuffle_b(b_k_n_tensors[i]); + auto b_shuffle_host = ck_tile::shuffle_b(b_k_n_tensors[i]); a_m_k_dev_buf.push_back(std::make_unique( a_m_k_tensors[i].get_element_space_size_in_bytes())); diff --git a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt index 892e123d3d..55f09726cc 100644 --- a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt @@ -6,18 +6,18 @@ if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") - # Split into three separate test executables for faster parallel compilation - add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp) - target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +# if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") +# # Split into three separate test executables for faster parallel compilation +# add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp) +# target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp) - target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +# add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp) +# target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp) - target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +# add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp) +# target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) # add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp) # target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -endif() +# endif() diff --git a/test/common/csv_test_loader.hpp b/test/common/csv_test_loader.hpp new file mode 100644 index 0000000000..78d3595f1a --- /dev/null +++ b/test/common/csv_test_loader.hpp @@ -0,0 +1,246 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ck/library/utility/convolution_parameter.hpp" + +namespace ck { +namespace test { + +namespace fs = std::filesystem; + +// Helper function to find test_data directory relative to the test binary +static std::string GetTestDataPath() +{ + // Get the path to the current executable + fs::path exe_path = fs::read_symlink("/proc/self/exe"); + + // Get the directory containing the executable + fs::path current_dir = exe_path.parent_path(); + + // Search for test_data directory by going up the directory tree + // This makes the code robust regardless of build directory depth + while(current_dir != current_dir.root_path()) + { + fs::path test_data_path = current_dir / "test_data"; + if(fs::exists(test_data_path) && fs::is_directory(test_data_path)) + { + return test_data_path.string(); + } + current_dir = current_dir.parent_path(); + } + + // If not found, return empty string + std::cerr << "ERROR: Could not find test_data directory relative to executable" << std::endl; + return ""; +} + +// CSV Reader Function for Loading Test Cases +// Reads convolution parameters from CSV file and returns vector of ConvParam structures +inline std::vector load_csv_test_cases(const std::string& filename) +{ + std::vector conv_params; // Return vector + std::ifstream file(filename); // Open CSV file + + if(!file.is_open()) + { + std::cerr << "ERROR: Cannot open CSV file: " << filename << std::endl; + return conv_params; // Return empty vector on error + } + + std::string line; + int line_number = 0; + + // Read file line by line + while(std::getline(file, line)) + { + line_number++; + // Skip comment lines (starting with #) and empty lines + if(line.empty() || line[0] == '#') + { + continue; + } + + // Skip header line (contains column names) + if(line.find("NDim,Groups,BatchSize") != std::string::npos) + { + continue; + } + + // Parse CSV line using stringstream + std::stringstream ss(line); + std::string cell; + std::vector row; + + // Split line by commas + while(std::getline(ss, cell, ',')) + { + row.push_back(cell); + } + + // Validate row has correct number of columns + if(row.size() < 19) + { // Need at least 19 columns for 2D (excluding TestName) + std::cerr << "WARNING: Line " << line_number << " has insufficient columns (" + << row.size() << "), skipping" << std::endl; + continue; + } + + try + { + // Parse CSV data into ConvParam structure + // CSV Format: + // NDim,Groups,BatchSize,OutChannels,InChannels,KernelH,KernelW,InputH,InputW,OutputH,OutputW,StrideH,StrideW,DilationH,DilationW,LeftPadH,LeftPadW,RightPadH,RightPadW,TestName + int NDim = std::stoi(row[0]); + int Groups = std::stoi(row[1]); + int BatchSize = std::stoi(row[2]); + int OutChannels = std::stoi(row[3]); + int InChannels = std::stoi(row[4]); + + if(NDim == 1) + { + // 1D Convolution: Need fewer columns for 1D parameters + if(row.size() < 13) + { + std::cerr << "WARNING: 1D convolution on line " << line_number + << " needs 13+ columns, has " << row.size() << ", skipping" + << std::endl; + continue; + } + // 1D Convolution: {NDim, Groups, BatchSize, OutChannels, InChannels, + // {KernelW}, {InputW}, {StrideW}, {DilationW}, {LeftPadW}, {RightPadW}} + ck::utils::conv::ConvParam param = { + NDim, // NDim = 1 + Groups, // Groups + BatchSize, // Batch size + OutChannels, // Output channels + InChannels, // Input channels + {std::stoi(row[5])}, // Kernel: {W} + {std::stoi(row[7])}, // Input: {W} + {std::stoi(row[11])}, // Stride: {W} + {std::stoi(row[13])}, // Dilation: {W} + {std::stoi(row[15])}, // Left pad: {W} + {std::stoi(row[17])} // Right pad: {W} + }; + conv_params.push_back(param); + } + else if(NDim == 2) + { + // 2D Convolution: {NDim, Groups, BatchSize, OutChannels, InChannels, + // {KernelH,KernelW}, {InputH,InputW}, {StrideH,StrideW}, {DilationH,DilationW}, + // {LeftPadH,LeftPadW}, {RightPadH,RightPadW}} + ck::utils::conv::ConvParam param = { + NDim, // NDim = 2 + Groups, // Groups + BatchSize, // Batch size + OutChannels, // Output channels + InChannels, // Input channels + {std::stoi(row[5]), std::stoi(row[6])}, // Kernel: {H, W} + {std::stoi(row[7]), std::stoi(row[8])}, // Input: {H, W} + {std::stoi(row[11]), std::stoi(row[12])}, // Stride: {H, W} + {std::stoi(row[13]), std::stoi(row[14])}, // Dilation: {H, W} + {std::stoi(row[15]), std::stoi(row[16])}, // Left pad: {H, W} + {std::stoi(row[17]), std::stoi(row[18])} // Right pad: {H, W} + }; + conv_params.push_back(param); + } + else if(NDim == 3) + { + // 3D Convolution: Need more columns for 3D parameters + if(row.size() < 26) + { + std::cerr << "WARNING: 3D convolution on line " << line_number + << " needs 26+ columns, has " << row.size() << ", skipping" + << std::endl; + continue; + } + // 3D Convolution: {NDim, Groups, BatchSize, OutChannels, InChannels, + // {KernelD,KernelH,KernelW}, {InputD,InputH,InputW}, {OutputD,OutputH,OutputW}, + // {StrideD,StrideH,StrideW}, {DilationD,DilationH,DilationW}, + // {LeftPadD,LeftPadH,LeftPadW}, {RightPadD,RightPadH,RightPadW}} + ck::utils::conv::ConvParam param = { + NDim, // NDim = 3 + Groups, // Groups + BatchSize, // Batch size + OutChannels, // Output channels + InChannels, // Input channels + {std::stoi(row[5]), std::stoi(row[6]), std::stoi(row[7])}, // Kernel: {D, H, W} + {std::stoi(row[8]), std::stoi(row[9]), std::stoi(row[10])}, // Input: {D, H, W} + {std::stoi(row[14]), + std::stoi(row[15]), + std::stoi(row[16])}, // Stride: {D, H, W} + {std::stoi(row[17]), + std::stoi(row[18]), + std::stoi(row[19])}, // Dilation: {D, H, W} + {std::stoi(row[20]), + std::stoi(row[21]), + std::stoi(row[22])}, // Left pad: {D, H, W} + {std::stoi(row[23]), + std::stoi(row[24]), + std::stoi(row[25])} // Right pad: {D, H, W} + }; + conv_params.push_back(param); + } + else + { + std::cerr << "WARNING: Unsupported NDim=" << NDim << " on line " << line_number + << ", skipping" << std::endl; + } + } + catch(const std::exception& e) + { + std::cerr << "ERROR: Failed to parse line " << line_number << ": " << e.what() + << std::endl; + continue; + } + } + + file.close(); + std::cout << "Loaded " << conv_params.size() << " test cases from " << filename << std::endl; + return conv_params; +} + +// Helper function to load CSV test cases and populate conv_params vector +// Returns true if loading succeeded, false otherwise +inline bool load_and_populate_test_cases(const std::vector& csv_paths, + std::vector& conv_params, + const std::string& dimension_label) +{ + for(const auto& csv_path : csv_paths) + { + auto csv_cases = load_csv_test_cases(csv_path); + if(!csv_cases.empty()) + { + // Successfully loaded CSV data - add all test cases to conv_params + for(const auto& test_case : csv_cases) + { + conv_params.push_back(test_case); + } + std::cout << "Loaded " << csv_cases.size() << " " << dimension_label + << " test cases from " << csv_path << std::endl; + return true; + } + } + + // Failed to load from any path + std::cerr << "ERROR: Failed to load CSV test data from any of these locations:" << std::endl; + for(const auto& path : csv_paths) + { + std::cerr << " - " << path << std::endl; + } + std::cerr << "\nPlease ensure CSV test data exists in one of these locations." << std::endl; + std::cerr << "Run generate_test_dataset.sh in test_data/ to create test datasets." << std::endl; + + return false; +} + +} // namespace test +} // namespace ck diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt index 1da477ebb3..a9413bd25b 100644 --- a/test/grouped_convnd_bwd_data/CMakeLists.txt +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -9,6 +9,10 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_executable(test_grouped_convnd_bwd_data_xdl_large_cases test_grouped_convnd_bwd_data_xdl_large_cases.cpp) target_compile_options(test_grouped_convnd_bwd_data_xdl_large_cases PRIVATE -Wno-global-constructors -Wno-undef) target_link_libraries(test_grouped_convnd_bwd_data_xdl_large_cases PRIVATE gtest_main getopt::getopt utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) + + add_executable(test_grouped_convnd_bwd_data_dataset_xdl test_grouped_convnd_bwd_data_dataset_xdl.cpp) + target_compile_options(test_grouped_convnd_bwd_data_dataset_xdl PRIVATE -Wno-global-constructors -Wno-undef) + target_link_libraries(test_grouped_convnd_bwd_data_dataset_xdl PRIVATE gtest_main getopt::getopt utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) endif() add_gtest_executable(test_grouped_convnd_bwd_data_wmma test_grouped_convnd_bwd_data_wmma.cpp) if(result EQUAL 0) diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_dataset_xdl.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_dataset_xdl.cpp new file mode 100644 index 0000000000..53b8ec32af --- /dev/null +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_dataset_xdl.cpp @@ -0,0 +1,317 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include // Standard C library (exit codes, malloc) +#include // C++ I/O streams (cout, cerr) +#include // C++ initializer list support (unused here) +#include // C++ vector container - stores test cases +#include // String operations +#include // Google Test framework - provides TEST_P, INSTANTIATE_TEST_SUITE_P + +#include "profiler/profile_grouped_conv_bwd_data_impl.hpp" // The actual GPU profiler that does convolution work +#include "../common/csv_test_loader.hpp" // Shared CSV test case loader + +using namespace ck::tensor_layout::convolution; // Import tensor layout names (GNHWK, GKYXC, etc.) + +// Load CSV data for 2D tests +static std::vector Get2DTestCases() +{ + static std::vector test_cases; + if(test_cases.empty()) + { + std::string test_data_dir = ck::test::GetTestDataPath(); + if(test_data_dir.empty()) + { + std::cerr << "FATAL: test_data directory not found" << std::endl; + return test_cases; + } + + std::vector csv_paths = {test_data_dir + "/conv_test_set_2d_dataset.csv"}; + bool loaded = ck::test::load_and_populate_test_cases(csv_paths, test_cases, "2D"); + if(!loaded) + { + std::cerr << "FATAL: Failed to load 2D test cases from " << csv_paths[0] << std::endl; + } + } + return test_cases; +} + +// Load CSV data for 3D tests +static std::vector Get3DTestCases() +{ + static std::vector test_cases; + if(test_cases.empty()) + { + std::string test_data_dir = ck::test::GetTestDataPath(); + if(test_data_dir.empty()) + { + std::cerr << "FATAL: test_data directory not found" << std::endl; + return test_cases; + } + + std::vector csv_paths = {test_data_dir + "/conv_test_set_3d_dataset.csv"}; + bool loaded = ck::test::load_and_populate_test_cases(csv_paths, test_cases, "3D"); + if(!loaded) + { + std::cerr << "FATAL: Failed to load 3D test cases from " << csv_paths[0] << std::endl; + } + } + return test_cases; +} + +// Helper template to run a single backward data convolution test with split_k +template +bool RunConvBwdDataTest(const ck::utils::conv::ConvParam& param, ck::index_t split_k) +{ + return ck::profiler::profile_grouped_conv_bwd_data_impl(true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param, // ConvParam + split_k, // Split-K value + -1); // instance_index +} + +// 2D Tests - GNHWK layout - Float - SplitK=1 +class TestGroupedConvndBwdData2dGNHWKFloatSplitK1 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdData2dGNHWKFloatSplitK1, ConvTest) +{ + EXPECT_TRUE((RunConvBwdDataTest<2, GNHWK, GKYXC, GNHWC, float>(GetParam(), 1))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdData2dGNHWKFloatSplitK1, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - GNHWK layout - Float - SplitK=2 +class TestGroupedConvndBwdData2dGNHWKFloatSplitK2 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdData2dGNHWKFloatSplitK2, ConvTest) +{ + EXPECT_TRUE((RunConvBwdDataTest<2, GNHWK, GKYXC, GNHWC, float>(GetParam(), 2))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdData2dGNHWKFloatSplitK2, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - GNHWK layout - Half - SplitK=1 +class TestGroupedConvndBwdData2dGNHWKHalfSplitK1 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdData2dGNHWKHalfSplitK1, ConvTest) +{ + EXPECT_TRUE((RunConvBwdDataTest<2, GNHWK, GKYXC, GNHWC, ck::half_t>(GetParam(), 1))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdData2dGNHWKHalfSplitK1, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - GNHWK layout - Half - SplitK=2 +class TestGroupedConvndBwdData2dGNHWKHalfSplitK2 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdData2dGNHWKHalfSplitK2, ConvTest) +{ + EXPECT_TRUE((RunConvBwdDataTest<2, GNHWK, GKYXC, GNHWC, ck::half_t>(GetParam(), 2))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdData2dGNHWKHalfSplitK2, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - GNHWK layout - BFloat16 - SplitK=1 +class TestGroupedConvndBwdData2dGNHWKBFloat16SplitK1 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdData2dGNHWKBFloat16SplitK1, ConvTest) +{ + EXPECT_TRUE((RunConvBwdDataTest<2, GNHWK, GKYXC, GNHWC, ck::bhalf_t>(GetParam(), 1))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdData2dGNHWKBFloat16SplitK1, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - GNHWK layout - BFloat16 - SplitK=2 +class TestGroupedConvndBwdData2dGNHWKBFloat16SplitK2 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdData2dGNHWKBFloat16SplitK2, ConvTest) +{ + EXPECT_TRUE((RunConvBwdDataTest<2, GNHWK, GKYXC, GNHWC, ck::bhalf_t>(GetParam(), 2))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdData2dGNHWKBFloat16SplitK2, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - NHWGK layout - Float - SplitK=1 +class TestGroupedConvndBwdData2dNHWGKFloatSplitK1 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdData2dNHWGKFloatSplitK1, ConvTest) +{ + EXPECT_TRUE((RunConvBwdDataTest<2, NHWGK, GKYXC, NHWGC, float>(GetParam(), 1))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdData2dNHWGKFloatSplitK1, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - NHWGK layout - Float - SplitK=2 +class TestGroupedConvndBwdData2dNHWGKFloatSplitK2 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdData2dNHWGKFloatSplitK2, ConvTest) +{ + EXPECT_TRUE((RunConvBwdDataTest<2, NHWGK, GKYXC, NHWGC, float>(GetParam(), 2))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdData2dNHWGKFloatSplitK2, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - NHWGK layout - Half - SplitK=1 +class TestGroupedConvndBwdData2dNHWGKHalfSplitK1 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdData2dNHWGKHalfSplitK1, ConvTest) +{ + EXPECT_TRUE((RunConvBwdDataTest<2, NHWGK, GKYXC, NHWGC, ck::half_t>(GetParam(), 1))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdData2dNHWGKHalfSplitK1, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - NHWGK layout - Half - SplitK=2 +class TestGroupedConvndBwdData2dNHWGKHalfSplitK2 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdData2dNHWGKHalfSplitK2, ConvTest) +{ + EXPECT_TRUE((RunConvBwdDataTest<2, NHWGK, GKYXC, NHWGC, ck::half_t>(GetParam(), 2))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdData2dNHWGKHalfSplitK2, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - NHWGK layout - BFloat16 - SplitK=1 +class TestGroupedConvndBwdData2dNHWGKBFloat16SplitK1 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdData2dNHWGKBFloat16SplitK1, ConvTest) +{ + EXPECT_TRUE((RunConvBwdDataTest<2, NHWGK, GKYXC, NHWGC, ck::bhalf_t>(GetParam(), 1))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdData2dNHWGKBFloat16SplitK1, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - NHWGK layout - BFloat16 - SplitK=2 +class TestGroupedConvndBwdData2dNHWGKBFloat16SplitK2 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdData2dNHWGKBFloat16SplitK2, ConvTest) +{ + EXPECT_TRUE((RunConvBwdDataTest<2, NHWGK, GKYXC, NHWGC, ck::bhalf_t>(GetParam(), 2))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdData2dNHWGKBFloat16SplitK2, + ::testing::ValuesIn(Get2DTestCases())); + +// 3D Tests - NDHWGK layout - Float - SplitK=1 +class TestGroupedConvndBwdData3dNDHWGKFloatSplitK1 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdData3dNDHWGKFloatSplitK1, ConvTest) +{ + EXPECT_TRUE((RunConvBwdDataTest<3, NDHWGK, GKZYXC, NDHWGC, float>(GetParam(), 1))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdData3dNDHWGKFloatSplitK1, + ::testing::ValuesIn(Get3DTestCases())); + +// 3D Tests - NDHWGK layout - Float - SplitK=2 +class TestGroupedConvndBwdData3dNDHWGKFloatSplitK2 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdData3dNDHWGKFloatSplitK2, ConvTest) +{ + EXPECT_TRUE((RunConvBwdDataTest<3, NDHWGK, GKZYXC, NDHWGC, float>(GetParam(), 2))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdData3dNDHWGKFloatSplitK2, + ::testing::ValuesIn(Get3DTestCases())); + +// 3D Tests - NDHWGK layout - Half - SplitK=1 +class TestGroupedConvndBwdData3dNDHWGKHalfSplitK1 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdData3dNDHWGKHalfSplitK1, ConvTest) +{ + EXPECT_TRUE((RunConvBwdDataTest<3, NDHWGK, GKZYXC, NDHWGC, ck::half_t>(GetParam(), 1))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdData3dNDHWGKHalfSplitK1, + ::testing::ValuesIn(Get3DTestCases())); + +// 3D Tests - NDHWGK layout - Half - SplitK=2 +class TestGroupedConvndBwdData3dNDHWGKHalfSplitK2 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdData3dNDHWGKHalfSplitK2, ConvTest) +{ + EXPECT_TRUE((RunConvBwdDataTest<3, NDHWGK, GKZYXC, NDHWGC, ck::half_t>(GetParam(), 2))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdData3dNDHWGKHalfSplitK2, + ::testing::ValuesIn(Get3DTestCases())); + +// 3D Tests - NDHWGK layout - BFloat16 - SplitK=1 +class TestGroupedConvndBwdData3dNDHWGKBFloat16SplitK1 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdData3dNDHWGKBFloat16SplitK1, ConvTest) +{ + EXPECT_TRUE((RunConvBwdDataTest<3, NDHWGK, GKZYXC, NDHWGC, ck::bhalf_t>(GetParam(), 1))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdData3dNDHWGKBFloat16SplitK1, + ::testing::ValuesIn(Get3DTestCases())); + +// 3D Tests - NDHWGK layout - BFloat16 - SplitK=2 +class TestGroupedConvndBwdData3dNDHWGKBFloat16SplitK2 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdData3dNDHWGKBFloat16SplitK2, ConvTest) +{ + EXPECT_TRUE((RunConvBwdDataTest<3, NDHWGK, GKZYXC, NDHWGC, ck::bhalf_t>(GetParam(), 2))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdData3dNDHWGKBFloat16SplitK2, + ::testing::ValuesIn(Get3DTestCases())); diff --git a/test/grouped_convnd_bwd_weight/CMakeLists.txt b/test/grouped_convnd_bwd_weight/CMakeLists.txt index e46113bea0..7b994f5bb8 100644 --- a/test/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/test/grouped_convnd_bwd_weight/CMakeLists.txt @@ -4,6 +4,10 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance device_grouped_convnd_bwd_weight_instance) + + add_executable(test_grouped_convnd_bwd_weight_dataset_xdl test_grouped_convnd_bwd_weight_dataset_xdl.cpp) + target_compile_options(test_grouped_convnd_bwd_weight_dataset_xdl PRIVATE -Wno-global-constructors -Wno-undef) + target_link_libraries(test_grouped_convnd_bwd_weight_dataset_xdl PRIVATE gtest_main getopt::getopt utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance device_grouped_convnd_bwd_weight_instance) elseif(DL_KERNELS) add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_dataset_xdl.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_dataset_xdl.cpp new file mode 100644 index 0000000000..aff6ba8873 --- /dev/null +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_dataset_xdl.cpp @@ -0,0 +1,258 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include // Standard C library (exit codes, malloc) +#include // C++ I/O streams (cout, cerr) +#include // C++ initializer list support (unused here) +#include // C++ vector container - stores test cases +#include // String operations +#include // Google Test framework - provides TEST_P, INSTANTIATE_TEST_SUITE_P + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/host_utility/device_prop.hpp" + +#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp" // The actual GPU profiler that does convolution work +#include "../common/csv_test_loader.hpp" // Shared CSV test case loader + +using namespace ck::tensor_layout::convolution; + +// Load CSV data for 2D tests +static std::vector Get2DTestCases() +{ + static std::vector test_cases; + if(test_cases.empty()) + { + std::string test_data_dir = ck::test::GetTestDataPath(); + if(test_data_dir.empty()) + { + std::cerr << "FATAL: test_data directory not found" << std::endl; + return test_cases; + } + + std::vector csv_paths = {test_data_dir + "/conv_test_set_2d_dataset.csv"}; + bool loaded = ck::test::load_and_populate_test_cases(csv_paths, test_cases, "2D"); + if(!loaded) + { + std::cerr << "FATAL: Failed to load 2D test cases from " << csv_paths[0] << std::endl; + } + } + return test_cases; +} + +// Load CSV data for 3D tests +static std::vector Get3DTestCases() +{ + static std::vector test_cases; + if(test_cases.empty()) + { + std::string test_data_dir = ck::test::GetTestDataPath(); + if(test_data_dir.empty()) + { + std::cerr << "FATAL: test_data directory not found" << std::endl; + return test_cases; + } + + std::vector csv_paths = {test_data_dir + "/conv_test_set_3d_dataset.csv"}; + bool loaded = ck::test::load_and_populate_test_cases(csv_paths, test_cases, "3D"); + if(!loaded) + { + std::cerr << "FATAL: Failed to load 3D test cases from " << csv_paths[0] << std::endl; + } + } + return test_cases; +} + +// Helper template to run a single backward weight convolution test +template +bool RunConvBwdWeightTest(const ck::utils::conv::ConvParam& param, ck::index_t split_k) +{ + return ck::profiler::profile_grouped_conv_bwd_weight_impl( + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param, // ConvParam + std::to_string(split_k), // Split-K value as string + -1); // instance_index +} + +// 2D Tests - NHWGK layout - Float - SplitK=1 +class TestGroupedConvndBwdWeight2dNHWGKFloatSplitK1 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdWeight2dNHWGKFloatSplitK1, ConvTest) +{ + EXPECT_TRUE((RunConvBwdWeightTest<2, NHWGC, GKYXC, NHWGK, float, float, float>(GetParam(), 1))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdWeight2dNHWGKFloatSplitK1, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - NHWGK layout - Float - SplitK=2 +class TestGroupedConvndBwdWeight2dNHWGKFloatSplitK2 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdWeight2dNHWGKFloatSplitK2, ConvTest) +{ + EXPECT_TRUE((RunConvBwdWeightTest<2, NHWGC, GKYXC, NHWGK, float, float, float>(GetParam(), 2))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdWeight2dNHWGKFloatSplitK2, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - NHWGK layout - Half - SplitK=1 +class TestGroupedConvndBwdWeight2dNHWGKHalfSplitK1 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdWeight2dNHWGKHalfSplitK1, ConvTest) +{ + EXPECT_TRUE((RunConvBwdWeightTest<2, NHWGC, GKYXC, NHWGK, ck::half_t, ck::half_t, ck::half_t>( + GetParam(), 1))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdWeight2dNHWGKHalfSplitK1, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - NHWGK layout - Half - SplitK=2 +class TestGroupedConvndBwdWeight2dNHWGKHalfSplitK2 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdWeight2dNHWGKHalfSplitK2, ConvTest) +{ + EXPECT_TRUE((RunConvBwdWeightTest<2, NHWGC, GKYXC, NHWGK, ck::half_t, ck::half_t, ck::half_t>( + GetParam(), 2))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdWeight2dNHWGKHalfSplitK2, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - NHWGK layout - BFloat16 - SplitK=1 +class TestGroupedConvndBwdWeight2dNHWGKBFloat16SplitK1 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdWeight2dNHWGKBFloat16SplitK1, ConvTest) +{ + EXPECT_TRUE((RunConvBwdWeightTest<2, NHWGC, GKYXC, NHWGK, ck::bhalf_t, float, ck::bhalf_t>( + GetParam(), 1))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdWeight2dNHWGKBFloat16SplitK1, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - NHWGK layout - BFloat16 - SplitK=2 +class TestGroupedConvndBwdWeight2dNHWGKBFloat16SplitK2 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdWeight2dNHWGKBFloat16SplitK2, ConvTest) +{ + EXPECT_TRUE((RunConvBwdWeightTest<2, NHWGC, GKYXC, NHWGK, ck::bhalf_t, float, ck::bhalf_t>( + GetParam(), 2))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdWeight2dNHWGKBFloat16SplitK2, + ::testing::ValuesIn(Get2DTestCases())); + +// 3D Tests - NDHWGK layout - Float - SplitK=1 +class TestGroupedConvndBwdWeight3dNDHWGKFloatSplitK1 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdWeight3dNDHWGKFloatSplitK1, ConvTest) +{ + EXPECT_TRUE( + (RunConvBwdWeightTest<3, NDHWGC, GKZYXC, NDHWGK, float, float, float>(GetParam(), 1))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdWeight3dNDHWGKFloatSplitK1, + ::testing::ValuesIn(Get3DTestCases())); + +// 3D Tests - NDHWGK layout - Float - SplitK=2 +class TestGroupedConvndBwdWeight3dNDHWGKFloatSplitK2 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdWeight3dNDHWGKFloatSplitK2, ConvTest) +{ + EXPECT_TRUE( + (RunConvBwdWeightTest<3, NDHWGC, GKZYXC, NDHWGK, float, float, float>(GetParam(), 2))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdWeight3dNDHWGKFloatSplitK2, + ::testing::ValuesIn(Get3DTestCases())); + +// 3D Tests - NDHWGK layout - Half - SplitK=1 +class TestGroupedConvndBwdWeight3dNDHWGKHalfSplitK1 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdWeight3dNDHWGKHalfSplitK1, ConvTest) +{ + EXPECT_TRUE( + (RunConvBwdWeightTest<3, NDHWGC, GKZYXC, NDHWGK, ck::half_t, ck::half_t, ck::half_t>( + GetParam(), 1))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdWeight3dNDHWGKHalfSplitK1, + ::testing::ValuesIn(Get3DTestCases())); + +// 3D Tests - NDHWGK layout - Half - SplitK=2 +class TestGroupedConvndBwdWeight3dNDHWGKHalfSplitK2 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdWeight3dNDHWGKHalfSplitK2, ConvTest) +{ + EXPECT_TRUE( + (RunConvBwdWeightTest<3, NDHWGC, GKZYXC, NDHWGK, ck::half_t, ck::half_t, ck::half_t>( + GetParam(), 2))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdWeight3dNDHWGKHalfSplitK2, + ::testing::ValuesIn(Get3DTestCases())); + +// 3D Tests - NDHWGK layout - BFloat16 - SplitK=1 +class TestGroupedConvndBwdWeight3dNDHWGKBFloat16SplitK1 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdWeight3dNDHWGKBFloat16SplitK1, ConvTest) +{ + EXPECT_TRUE((RunConvBwdWeightTest<3, NDHWGC, GKZYXC, NDHWGK, ck::bhalf_t, float, ck::bhalf_t>( + GetParam(), 1))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdWeight3dNDHWGKBFloat16SplitK1, + ::testing::ValuesIn(Get3DTestCases())); + +// 3D Tests - NDHWGK layout - BFloat16 - SplitK=2 +class TestGroupedConvndBwdWeight3dNDHWGKBFloat16SplitK2 + : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndBwdWeight3dNDHWGKBFloat16SplitK2, ConvTest) +{ + EXPECT_TRUE((RunConvBwdWeightTest<3, NDHWGC, GKZYXC, NDHWGK, ck::bhalf_t, float, ck::bhalf_t>( + GetParam(), 2))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndBwdWeight3dNDHWGKBFloat16SplitK2, + ::testing::ValuesIn(Get3DTestCases())); diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_dataset_xdl.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_dataset_xdl.cpp index 0928256817..c99f7ccf2f 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_dataset_xdl.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_dataset_xdl.cpp @@ -5,330 +5,165 @@ #include // C++ I/O streams (cout, cerr) #include // C++ initializer list support (unused here) #include // C++ vector container - stores test cases -#include // File I/O for CSV reading -#include // String stream for CSV parsing #include // String operations -#include // Google Test framework - provides TYPED_TEST, EXPECT_TRUE +#include // Google Test framework - provides TEST_P, INSTANTIATE_TEST_SUITE_P #include "profiler/profile_grouped_conv_fwd_impl.hpp" // The actual GPU profiler that does convolution work - -// CSV Reader Function for Loading Test Cases -// Reads convolution parameters from CSV file and returns vector of ConvParam structures -std::vector load_csv_test_cases(const std::string& filename) -{ - std::vector conv_params; // Return vector - std::ifstream file(filename); // Open CSV file - - if(!file.is_open()) - { - std::cerr << "ERROR: Cannot open CSV file: " << filename << std::endl; - return conv_params; // Return empty vector on error - } - - std::string line; - int line_number = 0; - - // Read file line by line - while(std::getline(file, line)) - { - line_number++; - // Skip comment lines (starting with #) and empty lines - if(line.empty() || line[0] == '#') - { - continue; - } - - // Skip header line (contains column names) - if(line.find("NDim,Groups,BatchSize") != std::string::npos) - { - continue; - } - - // Parse CSV line using stringstream - std::stringstream ss(line); - std::string cell; - std::vector row; - - // Split line by commas - while(std::getline(ss, cell, ',')) - { - row.push_back(cell); - } - - // Validate row has correct number of columns - if(row.size() < 19) - { // Need at least 19 columns for 2D (excluding TestName) - std::cerr << "WARNING: Line " << line_number << " has insufficient columns (" - << row.size() << "), skipping" << std::endl; - continue; - } - - try - { - // Parse CSV data into ConvParam structure - // CSV Format: - // NDim,Groups,BatchSize,OutChannels,InChannels,KernelH,KernelW,InputH,InputW,OutputH,OutputW,StrideH,StrideW,DilationH,DilationW,LeftPadH,LeftPadW,RightPadH,RightPadW,TestName - int NDim = std::stoi(row[0]); - int Groups = std::stoi(row[1]); - int BatchSize = std::stoi(row[2]); - int OutChannels = std::stoi(row[3]); - int InChannels = std::stoi(row[4]); - - if(NDim == 2) - { - // 2D Convolution: {NDim, Groups, BatchSize, OutChannels, InChannels, - // {KernelH,KernelW}, {InputH,InputW}, {StrideH,StrideW}, {DilationH,DilationW}, - // {LeftPadH,LeftPadW}, {RightPadH,RightPadW}} - ck::utils::conv::ConvParam param = { - NDim, // NDim = 2 - Groups, // Groups - BatchSize, // Batch size - OutChannels, // Output channels - InChannels, // Input channels - {std::stoi(row[5]), std::stoi(row[6])}, // Kernel: {H, W} - {std::stoi(row[7]), std::stoi(row[8])}, // Input: {H, W} - {std::stoi(row[11]), std::stoi(row[12])}, // Stride: {H, W} - {std::stoi(row[13]), std::stoi(row[14])}, // Dilation: {H, W} - {std::stoi(row[15]), std::stoi(row[16])}, // Left pad: {H, W} - {std::stoi(row[17]), std::stoi(row[18])} // Right pad: {H, W} - }; - conv_params.push_back(param); - } - else if(NDim == 3) - { - // 3D Convolution: Need more columns for 3D parameters - if(row.size() < 26) - { - std::cerr << "WARNING: 3D convolution on line " << line_number - << " needs 26+ columns, has " << row.size() << ", skipping" - << std::endl; - continue; - } - // 3D Convolution: {NDim, Groups, BatchSize, OutChannels, InChannels, - // {KernelD,KernelH,KernelW}, {InputD,InputH,InputW}, {OutputD,OutputH,OutputW}, - // {StrideD,StrideH,StrideW}, {DilationD,DilationH,DilationW}, - // {LeftPadD,LeftPadH,LeftPadW}, {RightPadD,RightPadH,RightPadW}} - ck::utils::conv::ConvParam param = { - NDim, // NDim = 3 - Groups, // Groups - BatchSize, // Batch size - OutChannels, // Output channels - InChannels, // Input channels - {std::stoi(row[5]), std::stoi(row[6]), std::stoi(row[7])}, // Kernel: {D, H, W} - {std::stoi(row[8]), std::stoi(row[9]), std::stoi(row[10])}, // Input: {D, H, W} - {std::stoi(row[14]), - std::stoi(row[15]), - std::stoi(row[16])}, // Stride: {D, H, W} - {std::stoi(row[17]), - std::stoi(row[18]), - std::stoi(row[19])}, // Dilation: {D, H, W} - {std::stoi(row[20]), - std::stoi(row[21]), - std::stoi(row[22])}, // Left pad: {D, H, W} - {std::stoi(row[23]), - std::stoi(row[24]), - std::stoi(row[25])} // Right pad: {D, H, W} - }; - conv_params.push_back(param); - } - else - { - std::cerr << "WARNING: Unsupported NDim=" << NDim << " on line " << line_number - << ", skipping" << std::endl; - } - } - catch(const std::exception& e) - { - std::cerr << "ERROR: Failed to parse line " << line_number << ": " << e.what() - << std::endl; - continue; - } - } - - file.close(); - std::cout << "Loaded " << conv_params.size() << " test cases from " << filename << std::endl; - return conv_params; -} - -// Template class that works with different data types and tensor layouts -template -class TestGroupedConvndFwd : public ::testing::Test // Inherit from Google Test base class -{ - protected: - using DataType = - std::tuple_element_t<0, Tuple>; // Extract data type from tuple (fp32, fp16, bf16, int8) - using InLayout = - std::tuple_element_t<1, Tuple>; // Extract input tensor layout (NHWGC, NDHWGC, etc.) - using WeiLayout = - std::tuple_element_t<2, Tuple>; // Extract weight tensor layout (GKYXC, GKZYXC, etc.) - using OutLayout = - std::tuple_element_t<3, Tuple>; // Extract output tensor layout (NHWGK, NDHWGK, etc.) - using IndexType = ck::long_index_t; // 64-bit integer type for tensor dimensions - - // THE KEY CONTAINER: This stores all test case parameters - // Each test will push_back() ConvParam structures here - std::vector conv_params; - - // Template function to run tests for N-dimensional spatial convolution (2D or 3D) - template - void Run() - { - EXPECT_FALSE(conv_params.empty()); // Google Test assertion: ensure we have test cases - bool pass = true; // Track overall pass/fail across all test cases - - // MAIN LOOP: Execute every test case that was added to conv_params - for(auto& param : conv_params) - { - // CALL THE ACTUAL GPU PROFILER - This is where convolution happens! - pass = pass && - ck::profiler::profile_grouped_conv_fwd_impl( // Index type (int64) - true, // do_verification: Compare GPU result with CPU reference - 1, // init_method: How to initialize random test data (1 = uniform -5 to 5) - false, // do_log: Don't print detailed tensor values - false, // time_kernel: Don't do performance timing (just correctness) - param); // ConvParam: {NDim, Groups, Batch, OutChannels, InChannels, - // KernelSize, InputSize, ...} - } - EXPECT_TRUE(pass); // Google Test assertion: ALL test cases must pass - } -}; +#include "../common/csv_test_loader.hpp" // Shared CSV test case loader using namespace ck::tensor_layout::convolution; // Import tensor layout names (NHWGC, GKYXC, etc.) -// GOOGLE TEST TYPE COMBINATIONS: Define what data types and layouts to test -// This creates 4 separate test instances for 2D convolution: -using KernelTypes2d = - ::testing::Types, // fp32 test - std::tuple, // fp16 test - std::tuple, // bfloat16 test - std::tuple>; // int8 test - -// This creates 3 separate test instances for 3D convolution (no int8 support for 3D): -using KernelTypes3d = - ::testing::Types, // fp32 3D test - std::tuple, // fp16 3D test - std::tuple>; // bfloat16 3D test - -// Create specialized test classes that inherit from the base template class -template -class TestGroupedConvndFwd2d : public TestGroupedConvndFwd // 2D convolution test class +// Load CSV data for 2D tests +static std::vector Get2DTestCases() { -}; - -template -class TestGroupedConvndFwd3d : public TestGroupedConvndFwd // 3D convolution test class -{ -}; - -// GOOGLE TEST MAGIC: Create test suites -// This tells Google Test to create 4 test instances for 2D (fp32, fp16, bf16, int8) -TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d); -// This tells Google Test to create 3 test instances for 3D (fp32, fp16, bf16) -TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d); - -// THE ACTUAL 2D TEST - This runs 4 times (once for each data type: fp32, fp16, bf16, int8) -TYPED_TEST(TestGroupedConvndFwd2d, Test2D) -{ - // LOAD TEST CASES FROM CSV FILE instead of hardcoded cases - // Try different locations for the CSV file (build directory vs source directory) - std::vector csv_paths = { - "../test_data/conv_test_set_2d_dataset.csv", // From build directory to source - }; - - bool loaded = false; - for(const auto& csv_path : csv_paths) + static std::vector test_cases; + if(test_cases.empty()) { - auto csv_cases = load_csv_test_cases(csv_path); - if(!csv_cases.empty()) + std::string test_data_dir = ck::test::GetTestDataPath(); + if(test_data_dir.empty()) { - // Successfully loaded CSV data - add all test cases to conv_params - for(const auto& test_case : csv_cases) - { - this->conv_params.push_back(test_case); - } - std::cout << "Loaded " << csv_cases.size() << " 2D test cases from " << csv_path - << std::endl; - loaded = true; - break; + std::cerr << "FATAL: test_data directory not found" << std::endl; + return test_cases; + } + + std::vector csv_paths = {test_data_dir + "/conv_test_set_2d_dataset.csv"}; + bool loaded = ck::test::load_and_populate_test_cases(csv_paths, test_cases, "2D"); + if(!loaded) + { + std::cerr << "FATAL: Failed to load 2D test cases from " << csv_paths[0] << std::endl; } } - - // FAIL if CSV loading fails - no fallback! - if(!loaded) - { - std::cerr << "ERROR: Failed to load CSV test data from any of these locations:" - << std::endl; - for(const auto& path : csv_paths) - { - std::cerr << " - " << path << std::endl; - } - std::cerr << "\nPlease ensure CSV test data exists in one of these locations." << std::endl; - std::cerr << "Run generate_test_dataset.sh in test_data/ to create test datasets." - << std::endl; - - // Force test failure - no test cases means test should fail - EXPECT_TRUE(loaded) << "CSV test data loading failed"; - } - - // Execute all test cases with 2D convolution - // This calls Run<2>() which loops through conv_params and calls GPU profiler for each - this->template Run<2>(); + return test_cases; } -// THE ACTUAL 3D TEST - This runs 3 times (once for each data type: fp32, fp16, bf16) -TYPED_TEST(TestGroupedConvndFwd3d, Test3D) +// Load CSV data for 3D tests +static std::vector Get3DTestCases() { - // LOAD TEST CASES FROM CSV FILE instead of hardcoded cases - // Try different locations for the CSV file (build directory vs source directory) - std::vector csv_paths = { - "../test_data/conv_test_set_3d_dataset.csv", // From build directory to source - }; - - bool loaded = false; - for(const auto& csv_path : csv_paths) + static std::vector test_cases; + if(test_cases.empty()) { - auto csv_cases = load_csv_test_cases(csv_path); - if(!csv_cases.empty()) + std::string test_data_dir = ck::test::GetTestDataPath(); + if(test_data_dir.empty()) { - // Successfully loaded CSV data - add all test cases to conv_params - for(const auto& test_case : csv_cases) - { - this->conv_params.push_back(test_case); - } - std::cout << "Loaded " << csv_cases.size() << " 3D test cases from " << csv_path - << std::endl; - loaded = true; - break; + std::cerr << "FATAL: test_data directory not found" << std::endl; + return test_cases; + } + + std::vector csv_paths = {test_data_dir + "/conv_test_set_3d_dataset.csv"}; + bool loaded = ck::test::load_and_populate_test_cases(csv_paths, test_cases, "3D"); + if(!loaded) + { + std::cerr << "FATAL: Failed to load 3D test cases from " << csv_paths[0] << std::endl; } } - - // FAIL if CSV loading fails - no fallback! - if(!loaded) - { - std::cerr << "ERROR: Failed to load CSV test data from any of these locations:" - << std::endl; - for(const auto& path : csv_paths) - { - std::cerr << " - " << path << std::endl; - } - std::cerr << "\nPlease ensure CSV test data exists in one of these locations." << std::endl; - std::cerr << "Run generate_test_dataset.sh in test_data/ to create test datasets." - << std::endl; - - // Force test failure - no test cases means test should fail - EXPECT_TRUE(loaded) << "CSV test data loading failed"; - } - - // Execute all test cases with 3D convolution - // This calls Run<3>() which loops through conv_params and calls GPU profiler for each - this->template Run<3>(); + return test_cases; } + +// Helper template to run a single convolution test +template +bool RunConvTest(const ck::utils::conv::ConvParam& param) +{ + using IndexType = ck::long_index_t; + return ck::profiler::profile_grouped_conv_fwd_impl(true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param); +} + +// 2D Tests - Float +class TestGroupedConvndFwd2dFloat : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndFwd2dFloat, ConvTest) +{ + EXPECT_TRUE((RunConvTest<2, NHWGC, GKYXC, NHWGK, float>(GetParam()))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndFwd2dFloat, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - Half +class TestGroupedConvndFwd2dHalf : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndFwd2dHalf, ConvTest) +{ + EXPECT_TRUE((RunConvTest<2, NHWGC, GKYXC, NHWGK, ck::half_t>(GetParam()))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndFwd2dHalf, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - BFloat16 +class TestGroupedConvndFwd2dBFloat16 : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndFwd2dBFloat16, ConvTest) +{ + EXPECT_TRUE((RunConvTest<2, NHWGC, GKYXC, NHWGK, ck::bhalf_t>(GetParam()))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndFwd2dBFloat16, + ::testing::ValuesIn(Get2DTestCases())); + +// 2D Tests - Int8 +class TestGroupedConvndFwd2dInt8 : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndFwd2dInt8, ConvTest) +{ + EXPECT_TRUE((RunConvTest<2, NHWGC, GKYXC, NHWGK, int8_t>(GetParam()))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndFwd2dInt8, + ::testing::ValuesIn(Get2DTestCases())); + +// 3D Tests - Float +class TestGroupedConvndFwd3dFloat : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndFwd3dFloat, ConvTest) +{ + EXPECT_TRUE((RunConvTest<3, NDHWGC, GKZYXC, NDHWGK, float>(GetParam()))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndFwd3dFloat, + ::testing::ValuesIn(Get3DTestCases())); + +// 3D Tests - Half +class TestGroupedConvndFwd3dHalf : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndFwd3dHalf, ConvTest) +{ + EXPECT_TRUE((RunConvTest<3, NDHWGC, GKZYXC, NDHWGK, ck::half_t>(GetParam()))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndFwd3dHalf, + ::testing::ValuesIn(Get3DTestCases())); + +// 3D Tests - BFloat16 +class TestGroupedConvndFwd3dBFloat16 : public ::testing::TestWithParam +{ +}; +TEST_P(TestGroupedConvndFwd3dBFloat16, ConvTest) +{ + EXPECT_TRUE((RunConvTest<3, NDHWGC, GKZYXC, NDHWGK, ck::bhalf_t>(GetParam()))); +} +INSTANTIATE_TEST_SUITE_P(Dataset, + TestGroupedConvndFwd3dBFloat16, + ::testing::ValuesIn(Get3DTestCases())); diff --git a/test_data/generate_test_dataset.sh b/test_data/generate_test_dataset.sh index e9c4937445..27f45a3bc7 100755 --- a/test_data/generate_test_dataset.sh +++ b/test_data/generate_test_dataset.sh @@ -8,6 +8,20 @@ set -e # Exit on error set +x # Disable command echo (even if called with bash -x) +# Trap to kill all background jobs on script exit/interruption +cleanup() { + echo "" + echo "Cleaning up background processes..." + # Kill all jobs in the current process group + jobs -p | xargs -r kill 2>/dev/null || true + wait 2>/dev/null || true + echo "Cleanup complete." + exit 1 +} + +# Set up trap for common termination signals +trap cleanup SIGINT SIGTERM EXIT + echo "==========================================" echo "CK Convolution Test Dataset Generator" echo "==========================================" @@ -18,7 +32,7 @@ if ! python3 -c "import torch" 2>/dev/null; then echo "PyTorch not found. Creating virtual environment..." # Create a virtual environment in the current directory - VENV_DIR="./pytorch_venv" + VENV_DIR="./.venv" if [ ! -d "$VENV_DIR" ]; then python3 -m venv $VENV_DIR || { echo "ERROR: Failed to create virtual environment." @@ -66,11 +80,71 @@ if ! $PYTHON_CMD -c "import torch; import sys; sys.exit(0 if torch.cuda.is_avail echo "Continuing anyway to generate placeholder data..." fi +# Parse command line arguments +CONFIG_MODE="full" # Default configuration mode: 'small', 'half' or 'full' +MAX_PARALLEL_JOBS=1 # Default number of parallel jobs +NUM_GPUS=1 # Number of GPUs to use (0 means no GPU assignment) + +# Process arguments +while [[ $# -gt 0 ]]; do + case $1 in + -j) + MAX_PARALLEL_JOBS="$2" + shift 2 + ;; + -j*) + MAX_PARALLEL_JOBS="${1#-j}" + shift + ;; + --gpus) + NUM_GPUS="$2" + shift 2 + ;; + small|half|full) + CONFIG_MODE="$1" + shift + ;; + *) + echo "Usage: $0 [small|half|full] [-j ] [--gpus ]" + echo " Configuration modes: small, half, full (default: full)" + echo " -j : Number of parallel jobs (default: 1)" + echo " --gpus : Number of GPUs to use (e.g., 8 for GPUs 0-7)" + exit 1 + ;; + esac +done + +# Setup GPU array if GPUs are requested +if [ $NUM_GPUS -gt 0 ]; then + # Auto-detect available GPUs + AVAILABLE_GPUS_COUNT=$(rocm-smi --showid 2>/dev/null | grep -oP 'GPU\[\K[0-9]+' | wc -l) + if [ "$AVAILABLE_GPUS_COUNT" -gt 0 ]; then + MAX_AVAILABLE=$AVAILABLE_GPUS_COUNT + else + MAX_AVAILABLE=0 + fi + + # Validate requested GPU count + if [ $NUM_GPUS -gt $MAX_AVAILABLE ]; then + echo "WARNING: Requested $NUM_GPUS GPUs but only $MAX_AVAILABLE available. Using $MAX_AVAILABLE GPUs." + NUM_GPUS=$MAX_AVAILABLE + fi + + # Build GPU array (0 to NUM_GPUS-1) + GPU_ARRAY=() + for ((i=0; i /dev/null 2>> $OUTPUT_DIR/${model}_miopen_log_2d.txt || true - + # Run in background + ( + # Set HIP_VISIBLE_DEVICES if GPU was assigned + if [ -n "$GPU_ID" ]; then + export HIP_VISIBLE_DEVICES=$GPU_ID + fi + + MIOPEN_ENABLE_LOGGING_CMD=1 $PYTHON_CMD run_model_with_miopen.py \ + --model $model --batch-size $batch_size --channels $channels --height $height --width $width --precision $precision \ + > /dev/null 2>> $OUTPUT_DIR/${model}_miopen_log_2d.txt || true + echo -e "${GREEN}[DONE]${NC} ${CYAN}2D${NC} ${YELLOW}$CONFIG_NAME${NC}" + ) & + + job_pids+=($!) + + # Limit number of parallel jobs + if [ ${#job_pids[@]} -ge $MAX_PARALLEL_JOBS ]; then + # Wait for any job to complete + wait -n + # Remove completed jobs from array + for i in "${!job_pids[@]}"; do + if ! kill -0 "${job_pids[$i]}" 2>/dev/null; then + unset 'job_pids[$i]' + fi + done + job_pids=("${job_pids[@]}") # Re-index array + fi done < $OUTPUT_DIR/model_configs_2d.csv +# Wait for all remaining 2D jobs to complete +echo "Waiting for remaining 2D jobs to complete..." +wait + +echo "All 2D models processed!" +echo "" + # Process 3D models from CSV configuration file echo "Processing 3D models from $OUTPUT_DIR/model_configs_3d.csv..." @@ -175,6 +291,10 @@ CURRENT_3D_CONFIG=0 echo "Total 3D configurations to process: $TOTAL_3D_CONFIGS" echo "" +# Reset job tracking array +declare -a job_pids=() +# GPU counter continues from 2D models for round-robin assignment + # Read 3D configurations from CSV (skip comments and header) while IFS=',' read -r config_name model batch_size channels temporal_size height width precision; do # Skip comments and empty lines @@ -185,21 +305,59 @@ while IFS=',' read -r config_name model batch_size channels temporal_size height # Increment counter CURRENT_3D_CONFIG=$((CURRENT_3D_CONFIG + 1)) - # Build configuration command for 3D models CONFIG="--model $model --batch-size $batch_size --channels $channels --temporal-size $temporal_size --height $height --width $width --precision $precision" CONFIG_NAME="$config_name" - echo -e "${GREEN}[${CURRENT_3D_CONFIG}/${TOTAL_3D_CONFIGS}]${NC} ${CYAN}3D${NC} ${YELLOW}$CONFIG_NAME${NC}" + # Assign GPU in round-robin fashion if GPUs are specified + if [ $NUM_GPUS -gt 0 ]; then + GPU_ID=${GPU_ARRAY[$((GPU_COUNTER % NUM_GPUS))]} + GPU_COUNTER=$((GPU_COUNTER + 1)) + echo -e "${GREEN}[${CURRENT_3D_CONFIG}/${TOTAL_3D_CONFIGS}]${NC} ${CYAN}3D${NC} ${YELLOW}$CONFIG_NAME${NC} ${PURPLE}[GPU ${GPU_ID}]${NC} - Starting in background" + else + GPU_ID="" + echo -e "${GREEN}[${CURRENT_3D_CONFIG}/${TOTAL_3D_CONFIGS}]${NC} ${CYAN}3D${NC} ${YELLOW}$CONFIG_NAME${NC} - Starting in background" + fi + # Run in background + ( + # Set HIP_VISIBLE_DEVICES if GPU was assigned + if [ -n "$GPU_ID" ]; then + export HIP_VISIBLE_DEVICES=$GPU_ID + fi + + MIOPEN_ENABLE_LOGGING_CMD=1 $PYTHON_CMD run_model_with_miopen.py \ + --model $model --batch-size $batch_size --channels $channels --temporal-size $temporal_size --height $height --width $width --precision $precision \ + > /dev/null 2>> $OUTPUT_DIR/${model}_miopen_log_3d.txt || true + echo -e "${GREEN}[DONE]${NC} ${CYAN}3D${NC} ${YELLOW}$CONFIG_NAME${NC}" + ) & - # Actual run with logging (suppress stdout, only capture stderr with MIOpen commands) - MIOPEN_ENABLE_LOGGING_CMD=1 $PYTHON_CMD run_model_with_miopen.py \ - --model $model --batch-size $batch_size --channels $channels --temporal-size $temporal_size --height $height --width $width --precision $precision \ - > /dev/null 2>> $OUTPUT_DIR/${model}_miopen_log_3d.txt || true + job_pids+=($!) + + # Limit number of parallel jobs + if [ ${#job_pids[@]} -ge $MAX_PARALLEL_JOBS ]; then + # Wait for any job to complete + wait -n + # Remove completed jobs from array + for i in "${!job_pids[@]}"; do + if ! kill -0 "${job_pids[$i]}" 2>/dev/null; then + unset 'job_pids[$i]' + fi + done + job_pids=("${job_pids[@]}") # Re-index array + fi done < $OUTPUT_DIR/model_configs_3d.csv +# Wait for all remaining 3D jobs to complete +echo "Waiting for remaining 3D jobs to complete..." +wait + +echo "All 3D models processed!" +echo "" + +# Disable trap on successful completion +trap - SIGINT SIGTERM EXIT echo "" echo "Step 3: Converting MIOpen commands to CSV test cases" @@ -311,7 +469,7 @@ if [ $COUNT_3D -gt 0 ]; then fi echo " - Intermediate files in: $OUTPUT_DIR/" echo "" -echo "To use these datasets:" -echo " 1. Build the test: cd ../script && make -j64 test_grouped_convnd_fwd_dataset_xdl" -echo " 2. Run the test: ./bin/test_grouped_convnd_fwd_dataset_xdl" +echo "To use these datasets for direction (bwd_data, bwd_weight, or fwd):" +echo " 1. Build the test: cd ../script && make -j64 test_grouped_convnd__dataset_xdl" +echo " 2. Run the test: ./bin/test_grouped_convnd__dataset_xdl" echo "" diff --git a/test_data/gtest_parallel.py b/test_data/gtest_parallel.py new file mode 100644 index 0000000000..9ea9ee79b0 --- /dev/null +++ b/test_data/gtest_parallel.py @@ -0,0 +1,1187 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# This file has been modified to allow round-robin GPU scheduling. +# Original file can be found at +# https://github.com/google/gtest-parallel/blob/cd488bdedc1d2cffb98201a17afc1b298b0b90f1/gtest_parallel.py +# Changes from the original file are subject to the following license: +# SPDX-License-Identifier: MIT +# +# Copyright 2013 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import errno +from functools import total_ordering +import gzip +import io +import json +import multiprocessing +import optparse +import os +import re +import shutil +import signal +import subprocess +import sys +import tempfile +import threading +import time + +if sys.version_info.major >= 3: + long = int + import _pickle as cPickle + import _thread as thread +else: + import cPickle + import thread + +from pickle import HIGHEST_PROTOCOL as PICKLE_HIGHEST_PROTOCOL + +if sys.platform == "win32": + import msvcrt +else: + import fcntl + + +# An object that catches SIGINT sent to the Python process and notices +# if processes passed to wait() die by SIGINT (we need to look for +# both of those cases, because pressing Ctrl+C can result in either +# the main process or one of the subprocesses getting the signal). +# +# Before a SIGINT is seen, wait(p) will simply call p.wait() and +# return the result. Once a SIGINT has been seen (in the main process +# or a subprocess, including the one the current call is waiting for), +# wait(p) will call p.terminate() and raise ProcessWasInterrupted. +class SigintHandler(object): + class ProcessWasInterrupted(Exception): + pass + + sigint_returncodes = { + -signal.SIGINT, # Unix + -1073741510, # Windows + } + + def __init__(self): + self.__lock = threading.Lock() + self.__processes = set() + self.__got_sigint = False + signal.signal(signal.SIGINT, lambda signal_num, frame: self.interrupt()) + + def __on_sigint(self): + self.__got_sigint = True + while self.__processes: + try: + self.__processes.pop().terminate() + except OSError: + pass + + def interrupt(self): + with self.__lock: + self.__on_sigint() + + def got_sigint(self): + with self.__lock: + return self.__got_sigint + + def wait(self, p, timeout_per_test): + with self.__lock: + if self.__got_sigint: + p.terminate() + self.__processes.add(p) + try: + code = p.wait(timeout_per_test) + except subprocess.TimeoutExpired: + p.terminate() + self.__processes.remove(p) + code = -errno.ETIME + with self.__lock: + self.__processes.discard(p) + if code in self.sigint_returncodes: + self.__on_sigint() + if self.__got_sigint: + raise self.ProcessWasInterrupted + return code + + +sigint_handler = SigintHandler() + + +# Return the width of the terminal, or None if it couldn't be +# determined (e.g. because we're not being run interactively). +def term_width(out): + if not out.isatty(): + return None + try: + p = subprocess.Popen( + ["stty", "size"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + (out, err) = p.communicate() + if p.returncode != 0 or err: + return None + return int(out.split()[1]) + except (IndexError, OSError, ValueError): + return None + + +# Output transient and permanent lines of text. If several transient +# lines are written in sequence, the new will overwrite the old. We +# use this to ensure that lots of unimportant info (tests passing) +# won't drown out important info (tests failing). +class Outputter(object): + def __init__(self, out_file): + self.__out_file = out_file + self.__previous_line_was_transient = False + self.__width = term_width(out_file) # Line width, or None if not a tty. + + def transient_line(self, msg): + if self.__width is None: + self.__out_file.write(msg + "\n") + self.__out_file.flush() + else: + self.__out_file.write("\r" + msg[: self.__width].ljust(self.__width)) + self.__previous_line_was_transient = True + + def flush_transient_output(self): + if self.__previous_line_was_transient: + self.__out_file.write("\n") + self.__previous_line_was_transient = False + + def permanent_line(self, msg): + self.flush_transient_output() + self.__out_file.write(msg + "\n") + if self.__width is None: + self.__out_file.flush() + + +def get_available_gpus(num_gpus): + """Get list of available GPU IDs based on HIP_VISIBLE_DEVICES and num_gpus. + + Returns a list of GPU IDs to use. + If HIP_VISIBLE_DEVICES is set, we return the first min(num_gpus, len(HIP_VISIBLE_DEVICES)) GPU IDs from it. + If not set, we return GPU IDs 0 to num_gpus-1. + """ + hip_visible = os.environ.get("HIP_VISIBLE_DEVICES", None) + + # Treat empty string as not set + if hip_visible is not None and hip_visible.strip(): + # Parse HIP_VISIBLE_DEVICES to get the list of available GPU IDs + try: + available_gpu_ids = [ + gpu_id.strip() for gpu_id in hip_visible.split(",") if gpu_id.strip() + ] + except ValueError: + sys.stderr.write( + "Warning: Invalid HIP_VISIBLE_DEVICES format, using GPU 0\n" + ) + return ["0"] + + # If parsing resulted in empty list, treat as not set + if not available_gpu_ids: + return [str(i) for i in range(num_gpus)] + + # Use the first min(num_gpus, len(available_gpu_ids)) GPUs from the list + num_to_use = min(num_gpus, len(available_gpu_ids)) + return available_gpu_ids[:num_to_use] + else: + # If HIP_VISIBLE_DEVICES is not set or empty, use GPU IDs 0 to num_gpus-1 + return [str(i) for i in range(num_gpus)] + + +def get_save_file_path(): + """Return path to file for saving transient data.""" + if sys.platform == "win32": + default_cache_path = os.path.join(os.path.expanduser("~"), "AppData", "Local") + cache_path = os.environ.get("LOCALAPPDATA", default_cache_path) + else: + # We don't use xdg module since it's not a standard. + default_cache_path = os.path.join(os.path.expanduser("~"), ".cache") + cache_path = os.environ.get("XDG_CACHE_HOME", default_cache_path) + + if os.path.isdir(cache_path): + return os.path.join(cache_path, "gtest-parallel") + else: + sys.stderr.write("Directory {} does not exist".format(cache_path)) + return os.path.join(os.path.expanduser("~"), ".gtest-parallel-times") + + +@total_ordering +class Task(object): + """Stores information about a task (single execution of a test). + + This class stores information about the test to be executed (gtest binary and + test name), and its result (log file, exit code and runtime). + Each task is uniquely identified by the gtest binary, the test name and an + execution number that increases each time the test is executed. + Additionaly we store the last execution time, so that next time the test is + executed, the slowest tests are run first. + """ + + def __init__( + self, + test_binary, + test_name, + test_command, + execution_number, + last_execution_time, + output_dir, + ): + self.test_name = test_name + self.output_dir = output_dir + self.test_binary = test_binary + self.test_command = test_command + self.execution_number = execution_number + self.last_execution_time = last_execution_time + + self.exit_code = None + self.runtime_ms = None + + self.test_id = (test_binary, test_name) + self.task_id = (test_binary, test_name, self.execution_number) + + self.log_file = Task._logname( + self.output_dir, self.test_binary, test_name, self.execution_number + ) + + def __sorting_key(self): + # Unseen or failing tests (both missing execution time) take precedence over + # execution time. Tests are greater (seen as slower) when missing times so + # that they are executed first. + return (1 if self.last_execution_time is None else 0, self.last_execution_time) + + def __eq__(self, other): + return self.__sorting_key() == other.__sorting_key() + + def __ne__(self, other): + return not (self == other) + + def __lt__(self, other): + return self.__sorting_key() < other.__sorting_key() + + @staticmethod + def _normalize(string): + return re.sub("[^A-Za-z0-9]", "_", string) + + @staticmethod + def _logname(output_dir, test_binary, test_name, execution_number): + # Store logs to temporary files if there is no output_dir. + if output_dir is None: + (log_handle, log_name) = tempfile.mkstemp( + prefix="gtest_parallel_", suffix=".log" + ) + os.close(log_handle) + return log_name + + log_name = "%s-%s-%d.log" % ( + Task._normalize(os.path.basename(test_binary)), + Task._normalize(test_name), + execution_number, + ) + + return os.path.join(output_dir, log_name) + + def run(self, timeout_per_test, gpu_id=None): + begin = time.time() + with open(self.log_file, "w") as log: + # Set up environment with GPU assignment if specified + env = os.environ.copy() + if gpu_id is not None: + env["HIP_VISIBLE_DEVICES"] = str(gpu_id) + + # Get the absolute path to the test binary and its directory + # This handles both relative and absolute paths correctly + abs_test_binary = os.path.abspath(self.test_binary) + test_binary_dir = os.path.dirname(abs_test_binary) + + # Update the test command to use the absolute path + abs_test_command = [abs_test_binary] + self.test_command[1:] + + task = subprocess.Popen( + abs_test_command, stdout=log, stderr=log, env=env, cwd=test_binary_dir + ) + try: + self.exit_code = sigint_handler.wait(task, timeout_per_test) + except sigint_handler.ProcessWasInterrupted: + thread.exit() + self.runtime_ms = int(1000 * (time.time() - begin)) + self.last_execution_time = None if self.exit_code else self.runtime_ms + + +class TaskManager(object): + """Executes the tasks and stores the passed, failed and interrupted tasks. + + When a task is run, this class keeps track if it passed, failed or was + interrupted. After a task finishes it calls the relevant functions of the + Logger, TestResults and TestTimes classes, and in case of failure, retries the + test as specified by the --retry_failed flag. + """ + + def __init__( + self, + times, + logger, + test_results, + task_factory, + times_to_retry, + initial_execution_number, + ): + self.times = times + self.logger = logger + self.test_results = test_results + self.task_factory = task_factory + self.times_to_retry = times_to_retry + self.initial_execution_number = initial_execution_number + + self.global_exit_code = 0 + + self.passed = [] + self.failed = [] + self.started = {} + self.timed_out = [] + self.execution_number = {} + + self.lock = threading.Lock() + + def __get_next_execution_number(self, test_id): + with self.lock: + next_execution_number = self.execution_number.setdefault( + test_id, self.initial_execution_number + ) + self.execution_number[test_id] += 1 + return next_execution_number + + def __register_start(self, task): + with self.lock: + self.started[task.task_id] = task + + def register_exit(self, task): + self.logger.log_exit(task) + self.times.record_test_time( + task.test_binary, task.test_name, task.last_execution_time + ) + if self.test_results: + self.test_results.log( + task.test_name, task.runtime_ms / 1000.0, task.exit_code + ) + + with self.lock: + self.started.pop(task.task_id) + if task.exit_code == 0: + self.passed.append(task) + elif task.exit_code == -errno.ETIME: + self.timed_out.append(task) + else: + self.failed.append(task) + + def run_task(self, task, timeout_per_test, gpu_id=None): + for try_number in range(self.times_to_retry + 1): + self.__register_start(task) + task.run(timeout_per_test, gpu_id) + self.register_exit(task) + + if task.exit_code == 0: + break + + if try_number < self.times_to_retry: + execution_number = self.__get_next_execution_number(task.test_id) + # We need create a new Task instance. Each task represents a single test + # execution, with its own runtime, exit code and log file. + task = self.task_factory( + task.test_binary, + task.test_name, + task.test_command, + execution_number, + task.last_execution_time, + task.output_dir, + ) + + with self.lock: + if task.exit_code != 0: + self.global_exit_code = task.exit_code + + +class FilterFormat(object): + def __init__(self, output_dir): + if sys.stdout.isatty(): + # stdout needs to be unbuffered since the output is interactive. + if isinstance(sys.stdout, io.TextIOWrapper): + # workaround for https://bugs.python.org/issue17404 + sys.stdout = io.TextIOWrapper( + sys.stdout.detach(), + line_buffering=True, + write_through=True, + newline="\n", + ) + else: + sys.stdout = os.fdopen(sys.stdout.fileno(), "w", 0) + + self.output_dir = output_dir + + self.total_tasks = 0 + self.finished_tasks = 0 + self.out = Outputter(sys.stdout) + self.stdout_lock = threading.Lock() + + def move_to(self, destination_dir, tasks): + if self.output_dir is None: + return + + destination_dir = os.path.join(self.output_dir, destination_dir) + os.makedirs(destination_dir) + for task in tasks: + shutil.move(task.log_file, destination_dir) + + def print_tests(self, message, tasks, print_try_number, print_test_command): + self.out.permanent_line("%s (%s/%s):" % (message, len(tasks), self.total_tasks)) + for task in sorted(tasks): + runtime_ms = "Interrupted" + if task.runtime_ms is not None: + runtime_ms = "%d ms" % task.runtime_ms + if print_test_command: + try: + cmd_str = " ".join(task.test_command) + except TypeError: + cmd_str = task.test_command + self.out.permanent_line( + "%11s: %s%s" + % ( + runtime_ms, + cmd_str, + (" (try #%d)" % task.execution_number) + if print_try_number + else "", + ) + ) + else: + self.out.permanent_line( + "%11s: %s %s%s" + % ( + runtime_ms, + task.test_binary, + task.test_name, + (" (try #%d)" % task.execution_number) + if print_try_number + else "", + ) + ) + + def log_exit(self, task): + with self.stdout_lock: + self.finished_tasks += 1 + self.out.transient_line( + "[%d/%d] %s (%d ms)" + % ( + self.finished_tasks, + self.total_tasks, + task.test_name, + task.runtime_ms, + ) + ) + if task.exit_code != 0: + signal_name = None + if task.exit_code < 0: + try: + signal_name = signal.Signals(-task.exit_code).name + except ValueError: + pass + + with open(task.log_file) as f: + for line in f.readlines(): + self.out.permanent_line(line.rstrip()) + if task.exit_code is None: + self.out.permanent_line( + "[%d/%d] %s aborted after %d ms" + % ( + self.finished_tasks, + self.total_tasks, + task.test_name, + task.runtime_ms, + ) + ) + elif task.exit_code == -errno.ETIME: + self.out.permanent_line( + "\033[31m[ TIMEOUT ]\033[0m %s timed out after %d s" + % (task.test_name, task.runtime_ms / 1000) + ) + elif signal_name is not None: + self.out.permanent_line( + "[%d/%d] %s killed by signal %s (%d ms)" + % ( + self.finished_tasks, + self.total_tasks, + task.test_name, + signal_name, + task.runtime_ms, + ) + ) + else: + self.out.permanent_line( + "[%d/%d] %s returned with exit code %d (%d ms)" + % ( + self.finished_tasks, + self.total_tasks, + task.test_name, + task.exit_code, + task.runtime_ms, + ) + ) + + if self.output_dir is None: + # Try to remove the file 100 times (sleeping for 0.1 second in between). + # This is a workaround for a process handle seemingly holding on to the + # file for too long inside os.subprocess. This workaround is in place + # until we figure out a minimal repro to report upstream (or a better + # suspect) to prevent os.remove exceptions. + num_tries = 100 + for i in range(num_tries): + try: + os.remove(task.log_file) + except OSError as e: + if e.errno is not errno.ENOENT: + if i is num_tries - 1: + self.out.permanent_line( + "Could not remove temporary log file: " + str(e) + ) + else: + time.sleep(0.1) + continue + break + + def log_tasks(self, total_tasks): + self.total_tasks += total_tasks + self.out.transient_line("[0/%d] Running tests..." % self.total_tasks) + + def summarize(self, passed_tasks, failed_tasks, interrupted_tasks): + stats = {} + + def add_stats(stats, task, idx): + task_key = (task.test_binary, task.test_name) + if task_key not in stats: + # (passed, failed, interrupted) task_key is added as tie breaker to get + # alphabetic sorting on equally-stable tests + stats[task_key] = [0, 0, 0, task_key] + stats[task_key][idx] += 1 + + for task in passed_tasks: + add_stats(stats, task, 0) + for task in failed_tasks: + add_stats(stats, task, 1) + for task in interrupted_tasks: + add_stats(stats, task, 2) + + self.out.permanent_line("SUMMARY:") + for task_key in sorted(stats, key=stats.__getitem__): + (num_passed, num_failed, num_interrupted, _) = stats[task_key] + (test_binary, task_name) = task_key + total_runs = num_passed + num_failed + num_interrupted + if num_passed == total_runs: + continue + self.out.permanent_line( + " %s %s passed %d / %d times%s." + % ( + test_binary, + task_name, + num_passed, + total_runs, + "" + if num_interrupted == 0 + else (" (%d interrupted)" % num_interrupted), + ) + ) + + def flush(self): + self.out.flush_transient_output() + + +class CollectTestResults(object): + def __init__(self, json_dump_filepath): + self.test_results_lock = threading.Lock() + self.json_dump_file = open(json_dump_filepath, "w") + self.test_results = { + "interrupted": False, + "path_delimiter": ".", + # Third version of the file format. See the link in the flag description + # for details. + "version": 3, + "seconds_since_epoch": int(time.time()), + "num_failures_by_type": { + "PASS": 0, + "FAIL": 0, + "TIMEOUT": 0, + }, + "tests": {}, + } + + def log(self, test, runtime_seconds, exit_code): + if exit_code is None: + actual_result = "TIMEOUT" + elif exit_code == 0: + actual_result = "PASS" + else: + actual_result = "FAIL" + with self.test_results_lock: + self.test_results["num_failures_by_type"][actual_result] += 1 + results = self.test_results["tests"] + for name in test.split("."): + results = results.setdefault(name, {}) + + if results: + results["actual"] += " " + actual_result + results["times"].append(runtime_seconds) + else: # This is the first invocation of the test + results["actual"] = actual_result + results["times"] = [runtime_seconds] + results["time"] = runtime_seconds + results["expected"] = "PASS" + + def dump_to_file_and_close(self): + json.dump(self.test_results, self.json_dump_file) + self.json_dump_file.close() + + +# Record of test runtimes. Has built-in locking. +class TestTimes(object): + class LockedFile(object): + def __init__(self, filename, mode): + self._filename = filename + self._mode = mode + self._fo = None + + def __enter__(self): + self._fo = open(self._filename, self._mode) + + # Regardless of opening mode we always seek to the beginning of file. + # This simplifies code working with LockedFile and also ensures that + # we lock (and unlock below) always the same region in file on win32. + self._fo.seek(0) + + try: + if sys.platform == "win32": + # We are locking here fixed location in file to use it as + # an exclusive lock on entire file. + msvcrt.locking(self._fo.fileno(), msvcrt.LK_LOCK, 1) + else: + fcntl.flock(self._fo.fileno(), fcntl.LOCK_EX) + except IOError: + self._fo.close() + raise + + return self._fo + + def __exit__(self, exc_type, exc_value, traceback): + # Flush any buffered data to disk. This is needed to prevent race + # condition which happens from the moment of releasing file lock + # till closing the file. + self._fo.flush() + + try: + if sys.platform == "win32": + self._fo.seek(0) + msvcrt.locking(self._fo.fileno(), msvcrt.LK_UNLCK, 1) + else: + fcntl.flock(self._fo.fileno(), fcntl.LOCK_UN) + finally: + self._fo.close() + + return exc_value is None + + def __init__(self, save_file): + "Create new object seeded with saved test times from the given file." + self.__times = {} # (test binary, test name) -> runtime in ms + + # Protects calls to record_test_time(); other calls are not + # expected to be made concurrently. + self.__lock = threading.Lock() + + try: + with TestTimes.LockedFile(save_file, "rb") as fd: + times = TestTimes.__read_test_times_file(fd) + except IOError: + # We couldn't obtain the lock. + return + + # Discard saved times if the format isn't right. + if type(times) is not dict: + return + for (test_binary, test_name), runtime in times.items(): + if ( + type(test_binary) is not str + or type(test_name) is not str + or type(runtime) not in {int, long, type(None)} + ): + return + + self.__times = times + + def get_test_time(self, binary, testname): + """Return the last duration for the given test as an integer number of + milliseconds, or None if the test failed or if there's no record for it.""" + return self.__times.get((binary, testname), None) + + def record_test_time(self, binary, testname, runtime_ms): + """Record that the given test ran in the specified number of + milliseconds. If the test failed, runtime_ms should be None.""" + with self.__lock: + self.__times[(binary, testname)] = runtime_ms + + def write_to_file(self, save_file): + "Write all the times to file." + try: + with TestTimes.LockedFile(save_file, "a+b") as fd: + times = TestTimes.__read_test_times_file(fd) + + if times is None: + times = self.__times + else: + times.update(self.__times) + + # We erase data from file while still holding a lock to it. This + # way reading old test times and appending new ones are atomic + # for external viewer. + fd.seek(0) + fd.truncate() + with gzip.GzipFile(fileobj=fd, mode="wb") as gzf: + cPickle.dump(times, gzf, PICKLE_HIGHEST_PROTOCOL) + except IOError: + pass # ignore errors---saving the times isn't that important + + @staticmethod + def __read_test_times_file(fd): + try: + with gzip.GzipFile(fileobj=fd, mode="rb") as gzf: + times = cPickle.load(gzf) + except Exception: + # File doesn't exist, isn't readable, is malformed---whatever. + # Just ignore it. + return None + else: + return times + + +def find_tests(binaries, additional_args, options, times): + test_count = 0 + tasks = [] + for test_binary in binaries: + command = [test_binary] + additional_args + if options.gtest_also_run_disabled_tests: + command += ["--gtest_also_run_disabled_tests"] + + list_command = command + ["--gtest_list_tests"] + if options.gtest_filter != "": + list_command += ["--gtest_filter=" + options.gtest_filter] + + # Get absolute path and directory for the test binary + abs_test_binary = os.path.abspath(test_binary) + test_binary_dir = os.path.dirname(abs_test_binary) + + # Create list command with absolute path + abs_list_command = [abs_test_binary] + additional_args + ["--gtest_list_tests"] + if options.gtest_also_run_disabled_tests: + abs_list_command += ["--gtest_also_run_disabled_tests"] + if options.gtest_filter != "": + abs_list_command += ["--gtest_filter=" + options.gtest_filter] + + try: + # Run the list command from the binary's directory so relative paths work + test_list = subprocess.check_output( + abs_list_command, stderr=subprocess.STDOUT, cwd=test_binary_dir + ) + except subprocess.CalledProcessError as e: + sys.exit("%s: %s\n%s" % (test_binary, str(e), e.output)) + + try: + test_list = test_list.split("\n") + except TypeError: + # subprocess.check_output() returns bytes in python3 + test_list = test_list.decode(sys.stdout.encoding).split("\n") + + command += ["--gtest_color=" + options.gtest_color] + + test_group = "" + for line in test_list: + if not line.strip(): + continue + if line[0] != " ": + # Remove comments for typed tests and strip whitespace. + test_group = line.split("#")[0].strip() + continue + # Remove comments for parameterized tests and strip whitespace. + line = line.split("#")[0].strip() + if not line: + continue + + test_name = test_group + line + if not options.gtest_also_run_disabled_tests and "DISABLED_" in test_name: + continue + + # Skip PRE_ tests which are used by Chromium. + if ".PRE_" in test_name: + continue + + last_execution_time = times.get_test_time(test_binary, test_name) + if options.failed and last_execution_time is not None: + continue + + test_command = command + ["--gtest_filter=" + test_name] + if (test_count - options.shard_index) % options.shard_count == 0: + for execution_number in range(options.repeat): + tasks.append( + Task( + test_binary, + test_name, + test_command, + execution_number + 1, + last_execution_time, + options.output_dir, + ) + ) + + test_count += 1 + + # Sort the tasks to run the slowest tests first, so that faster ones can be + # finished in parallel. + return sorted(tasks, reverse=True) + + +def execute_tasks( + tasks, + pool_size, + task_manager, + timeout_seconds, + timeout_per_test, + serialize_test_cases, + available_gpus=None, +): + class WorkerFn(object): + def __init__(self, tasks, running_groups, timeout_per_test, available_gpus): + self.tasks = tasks + self.running_groups = running_groups + self.timeout_per_test = timeout_per_test + self.available_gpus = available_gpus + self.task_lock = threading.Lock() + self.task_counter = 0 + + def __call__(self): + while True: + gpu_id = None + with self.task_lock: + for task_id in range(len(self.tasks)): + task = self.tasks[task_id] + + if self.running_groups is not None: + test_group = task.test_name.split(".")[0] + if test_group in self.running_groups: + # Try to find other non-running test group. + continue + else: + self.running_groups.add(test_group) + + # Assign GPU in round-robin fashion if GPUs are available + if self.available_gpus: + gpu_id = self.available_gpus[ + self.task_counter % len(self.available_gpus) + ] + self.task_counter += 1 + + del self.tasks[task_id] + break + else: + # Either there is no tasks left or number or remaining test + # cases (groups) is less than number or running threads. + return + + task_manager.run_task(task, self.timeout_per_test, gpu_id) + + if self.running_groups is not None: + with self.task_lock: + self.running_groups.remove(test_group) + + def start_daemon(func): + t = threading.Thread(target=func) + t.daemon = True + t.start() + return t + + timeout = None + try: + if timeout_seconds: + timeout = threading.Timer(timeout_seconds, sigint_handler.interrupt) + timeout.start() + running_groups = set() if serialize_test_cases else None + worker_fn = WorkerFn(tasks, running_groups, timeout_per_test, available_gpus) + workers = [start_daemon(worker_fn) for _ in range(pool_size)] + for worker in workers: + worker.join() + finally: + if timeout: + timeout.cancel() + for task in list(task_manager.started.values()): + task.runtime_ms = timeout_seconds * 1000 + task_manager.register_exit(task) + + +def default_options_parser(): + parser = optparse.OptionParser( + usage="usage: %prog [options] binary [binary ...] -- [additional args]" + ) + + parser.add_option( + "-d", + "--output_dir", + type="string", + default=None, + help="Output directory for test logs. Logs will be " + "available under gtest-parallel-logs/, so " + "--output_dir=/tmp will results in all logs being " + "available under /tmp/gtest-parallel-logs/.", + ) + parser.add_option( + "-r", + "--repeat", + type="int", + default=1, + help="Number of times to execute all the tests.", + ) + parser.add_option( + "--retry_failed", + type="int", + default=0, + help="Number of times to repeat failed tests.", + ) + parser.add_option( + "--failed", + action="store_true", + default=False, + help="run only failed and new tests", + ) + parser.add_option( + "-w", + "--workers", + type="int", + default=multiprocessing.cpu_count(), + help="number of workers to spawn", + ) + parser.add_option( + "--gpus", + type="int", + default=1, + help="number of GPUs to use for parallel execution (default: 1)", + ) + parser.add_option( + "--gtest_color", type="string", default="yes", help="color output" + ) + parser.add_option("--gtest_filter", type="string", default="", help="test filter") + parser.add_option( + "--gtest_also_run_disabled_tests", + action="store_true", + default=False, + help="run disabled tests too", + ) + parser.add_option( + "--print_test_times", + action="store_true", + default=False, + help="list the run time of each test at the end of execution", + ) + parser.add_option( + "--print_test_command", + action="store_true", + default=False, + help="Print full test command instead of name", + ) + parser.add_option( + "--shard_count", + type="int", + default=1, + help="total number of shards (for sharding test execution " + "between multiple machines)", + ) + parser.add_option( + "--shard_index", + type="int", + default=0, + help="zero-indexed number identifying this shard (for " + "sharding test execution between multiple machines)", + ) + parser.add_option( + "--dump_json_test_results", + type="string", + default=None, + help="Saves the results of the tests as a JSON machine-" + "readable file. The format of the file is specified at " + "https://www.chromium.org/developers/the-json-test-results-format", + ) + parser.add_option( + "--timeout", + type="int", + default=None, + help="Interrupt all remaining processes after the given time (in seconds).", + ) + parser.add_option( + "--timeout_per_test", + type="int", + default=None, + help="Interrupt single processes after the given time (in seconds).", + ) + parser.add_option( + "--serialize_test_cases", + action="store_true", + default=False, + help="Do not run tests from the same test case in parallel.", + ) + return parser + + +def main(): + # Remove additional arguments (anything after --). + additional_args = [] + + for i in range(len(sys.argv)): + if sys.argv[i] == "--": + additional_args = sys.argv[i + 1 :] + sys.argv = sys.argv[:i] + break + + parser = default_options_parser() + (options, binaries) = parser.parse_args() + + if options.output_dir is not None and not os.path.isdir(options.output_dir): + parser.error( + "--output_dir value must be an existing directory, " + 'current value is "%s"' % options.output_dir + ) + + # Append gtest-parallel-logs to log output, this is to avoid deleting user + # data if an user passes a directory where files are already present. If a + # user specifies --output_dir=Docs/, we'll create Docs/gtest-parallel-logs + # and clean that directory out on startup, instead of nuking Docs/. + if options.output_dir: + options.output_dir = os.path.join(options.output_dir, "gtest-parallel-logs") + + if binaries == []: + parser.print_usage() + sys.exit(1) + + if options.shard_count < 1: + parser.error( + "Invalid number of shards: %d. Must be at least 1." % options.shard_count + ) + if not (0 <= options.shard_index < options.shard_count): + parser.error( + "Invalid shard index: %d. Must be between 0 and %d " + "(less than the number of shards)." + % (options.shard_index, options.shard_count - 1) + ) + + # Check that all test binaries have an unique basename. That way we can ensure + # the logs are saved to unique files even when two different binaries have + # common tests. + unique_binaries = set(os.path.basename(binary) for binary in binaries) + assert len(unique_binaries) == len(binaries), ( + "All test binaries must have an unique basename." + ) + + if options.output_dir: + # Remove files from old test runs. + if os.path.isdir(options.output_dir): + shutil.rmtree(options.output_dir) + # Create directory for test log output. + try: + os.makedirs(options.output_dir) + except OSError as e: + # Ignore errors if this directory already exists. + if e.errno != errno.EEXIST or not os.path.isdir(options.output_dir): + raise e + + test_results = None + if options.dump_json_test_results is not None: + test_results = CollectTestResults(options.dump_json_test_results) + + save_file = get_save_file_path() + + times = TestTimes(save_file) + logger = FilterFormat(options.output_dir) + + task_manager = TaskManager( + times, logger, test_results, Task, options.retry_failed, options.repeat + 1 + ) + + # Get available GPUs based on HIP_VISIBLE_DEVICES and --gpus option + available_gpus = get_available_gpus(options.gpus) if options.gpus > 0 else None + + tasks = find_tests(binaries, additional_args, options, times) + logger.log_tasks(len(tasks)) + execute_tasks( + tasks, + options.workers, + task_manager, + options.timeout, + options.timeout_per_test, + options.serialize_test_cases, + available_gpus, + ) + + print_try_number = options.retry_failed > 0 or options.repeat > 1 + if task_manager.passed: + logger.move_to("passed", task_manager.passed) + if options.print_test_times: + logger.print_tests( + "PASSED TESTS", + task_manager.passed, + print_try_number, + options.print_test_command, + ) + + if task_manager.failed: + logger.print_tests( + "FAILED TESTS", + task_manager.failed, + print_try_number, + options.print_test_command, + ) + logger.move_to("failed", task_manager.failed) + + if task_manager.timed_out: + logger.print_tests( + "TIMED OUT TESTS", + task_manager.timed_out, + print_try_number, + options.print_test_command, + ) + logger.move_to("timed_out", task_manager.timed_out) + + if task_manager.started: + logger.print_tests( + "INTERRUPTED TESTS", + task_manager.started.values(), + print_try_number, + options.print_test_command, + ) + logger.move_to("interrupted", task_manager.started.values()) + + if options.repeat > 1 and (task_manager.failed or task_manager.started): + logger.summarize( + task_manager.passed, task_manager.failed, task_manager.started.values() + ) + + logger.flush() + times.write_to_file(save_file) + if test_results: + test_results.dump_to_file_and_close() + + if sigint_handler.got_sigint(): + return -signal.SIGINT + + return task_manager.global_exit_code + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp index bb0b8090fa..8c0c5f78d4 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp @@ -81,42 +81,3 @@ inline KernelTraits extract_traits_from_name(const std::string& kernel_name) return traits; } - -template -auto shuffle_b(const ck_tile::HostTensor& t, - ck_tile::index_t N_Warp_Tile, - ck_tile::index_t K_Warp_Tile) -{ - assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - int divisor = N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view( - {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); -} - -template -auto shuffle_b_permuteN(const ck_tile::HostTensor& t, - ck_tile::index_t N_Warp_Tile, - ck_tile::index_t K_Warp_Tile, - ck_tile::index_t N_Tile, - ck_tile::index_t N_Warp) -{ - assert(t.get_lengths().size() == 2); - - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - int divisor = N_Warp_Tile == 32 ? 2 : 4; - int NRepeat = N_Tile / N_Warp_Tile / N_Warp; - ck_tile::HostTensor t_view({n_ / N_Tile, - N_Warp, - N_Warp_Tile, - NRepeat, - k_ / K_Warp_Tile, - divisor, - K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); -} diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp index 739bd7e677..cad53b472f 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp @@ -111,21 +111,30 @@ class GemmProfiler c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); + struct GemmConfig + { + ck_tile::index_t N_Warp_Tile; + ck_tile::index_t K_Warp_Tile; + ck_tile::index_t N_Tile; + ck_tile::index_t N_Warp; + }; + for(const auto& callable : callables) { - ck_tile::index_t N_Warp_Tile = std::get<1>(config.warp_tile_dims); - ck_tile::index_t K_Warp_Tile = std::get<2>(config.warp_tile_dims); - ck_tile::index_t N_Tile = std::get<1>(config.tile_dims); - ck_tile::index_t N_Warp = std::get<1>(config.warp_dims); + GemmConfig gemmConfig = {}; + gemmConfig.N_Warp_Tile = std::get<1>(config.warp_tile_dims); + gemmConfig.K_Warp_Tile = std::get<2>(config.warp_tile_dims); + gemmConfig.N_Tile = std::get<1>(config.tile_dims); + gemmConfig.N_Warp = std::get<1>(config.warp_dims); ck_tile::HostTensor b_shuffle_host = [&]() { if(config.permuteN) { - return shuffle_b_permuteN(b_k_n, N_Warp_Tile, K_Warp_Tile, N_Tile, N_Warp); + return ck_tile::shuffle_b_permuteN(b_k_n, gemmConfig); } else { - return shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile); + return ck_tile::shuffle_b(b_k_n, gemmConfig); } }();