From ac0fb4fec58e7c8097b134f48a91d92f5cc1f338 Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Thu, 20 Nov 2025 08:40:27 +0800 Subject: [PATCH] [ck_tile] enable test grouped_gemm_quant and gemm_streamk on gfx12 (#3196) 1. Enable grouped_gemm_quant and gemm_streamk on gfx12 - test_ck_tile_streamk_smoke is kept on gfx9, since it looks someone is still working on it. 2. Update warp tile size in grouped_gemm_quant and gemm_streamk unit test 3. Reduce gemm tile size to pass the build on gfx12 in test_gemm_streamk_reboot_types.hpp [ROCm/composable_kernel commit: d2e32b43052c914186403954580af32e2d2c4dc0] --- .../ops/gemm/kernel/streamk_gemm_kernel.hpp | 2 +- test/ck_tile/gemm_streamk/CMakeLists.txt | 12 +++++-- .../test_gemm_streamk_reboot_types.hpp | 34 +++++++++---------- .../test_gemm_streamk_reboot_util.hpp | 8 +++-- .../ck_tile/grouped_gemm_quant/CMakeLists.txt | 2 +- .../test_grouped_gemm_util_quant.hpp | 21 ++++++++++++ 6 files changed, 55 insertions(+), 24 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp index d8850749f1..a32e2faf5d 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp @@ -662,7 +662,7 @@ struct StreamKKernel hip_check_error( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0)); - return occupancy; + return max(occupancy, 1); } }; } // namespace reboot diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt index 150181d0d7..3e3345dd0e 100644 --- a/test/ck_tile/gemm_streamk/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -12,7 +12,7 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS ) set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) -# Currently test_ck_tile_streamk is only built on gfx9 +# Currently test_ck_tile_streamk_smoke is only built on gfx9 if(GPU_TARGETS MATCHES "gfx9") include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR}) @@ -140,6 +140,13 @@ if(GPU_TARGETS MATCHES "gfx9") # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp # ) + target_compile_options(test_ck_tile_streamk_smoke PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +else() + message(DEBUG "Skipping test_ck_tile_streamk_smoke for current target") +endif() + +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") + include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR}) add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp) add_gtest_executable(test_ck_tile_streamk_reboot_smoke ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_fp16_persistent.cpp @@ -153,7 +160,6 @@ if(GPU_TARGETS MATCHES "gfx9") ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp test_gemm_streamk_reboot_util.cpp) - target_compile_options(test_ck_tile_streamk_smoke PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() - message(DEBUG "Skipping test_ck_tile_streamk tests for current target") + message(DEBUG "Skipping test_ck_tile_streamk unit tests for current target") endif() diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_types.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_types.hpp index 1db53ddd64..f01f7e142f 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_types.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_types.hpp @@ -19,38 +19,38 @@ using Persistent = std::true_type; using NonPersistent = std::false_type; using I32 = ck_tile::number<32>; -using I256 = ck_tile::number<256>; +using I128 = ck_tile::number<128>; // clang-format off using KernelTypesStreamKFp16Persistent = ::testing::Types< // ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent - std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent> + std::tuple< Row, Row, Row, F16, F16, F32, F16, I128, I128, I32, Persistent>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I128, I128, I32, Persistent>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I128, I128, I32, Persistent>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I128, I128, I32, Persistent> >; using KernelTypesStreamKBf16Persistent = ::testing::Types< - std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>, - std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>, - std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent> + std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I128, I128, I32, Persistent>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I128, I128, I32, Persistent>, + std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I128, I128, I32, Persistent>, + std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I128, I128, I32, Persistent> >; using KernelTypesStreamKFp16NonPersistent = ::testing::Types< // ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent - std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>, - std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>, - std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>, - std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent> + std::tuple< Row, Row, Row, F16, F16, F32, F16, I128, I128, I32, NonPersistent>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I128, I128, I32, NonPersistent>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I128, I128, I32, NonPersistent>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I128, I128, I32, NonPersistent> >; using KernelTypesStreamKBf16NonPersistent = ::testing::Types< - std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>, - std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>, - std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent> + std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I128, I128, I32, NonPersistent>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I128, I128, I32, NonPersistent>, + std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I128, I128, I32, NonPersistent>, + std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I128, I128, I32, NonPersistent> >; // clang-format on diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_util.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_util.hpp index 85863989b0..c3605cbcda 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_util.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_util.hpp @@ -69,11 +69,15 @@ class TestCkTileStreamKReboot : public ::testing::Test constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; - +#if CK_TILE_USE_WMMA + constexpr ck_tile::index_t M_Warp_Tile = 16; + constexpr ck_tile::index_t N_Warp_Tile = 16; + constexpr ck_tile::index_t K_Warp_Tile = 16; +#else constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 16; - +#endif constexpr bool kPadM = PadM; constexpr bool kPadN = PadN; constexpr bool kPadK = PadK; diff --git a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt index 3f32413f59..c9399e54dc 100644 --- a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt @@ -3,7 +3,7 @@ if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -if(GPU_TARGETS MATCHES "gfx94|gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") # Split into three separate test executables for faster parallel compilation add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp) target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp index d82deb7305..4e48907317 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp @@ -74,6 +74,17 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test M_Warp_Tile>(); }; + struct GroupedGemKernelParam_Wmma : public GroupedGemKernelParam_Mfma + { + static const ck_tile::index_t M_Tile = 128; + static const ck_tile::index_t N_Tile = 128; + static const ck_tile::index_t K_Tile = 128; + + static const ck_tile::index_t M_Warp_Tile = 16; + static const ck_tile::index_t N_Warp_Tile = 16; + static const ck_tile::index_t K_Warp_Tile = 16; + }; + using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; std::size_t get_workspace_size(const std::vector& gemm_descs) { @@ -373,8 +384,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test if constexpr(PreshuffleB && QuantType == ck_tile::QuantType::BQuantGrouped) { +#if CK_TILE_USE_WMMA + auto b_shuffle_host = + ck_tile::shuffle_b(b_k_n_tensors[i]); +#else auto b_shuffle_host = ck_tile::shuffle_b(b_k_n_tensors[i]); +#endif b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data()); } else @@ -446,8 +462,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), hipMemcpyHostToDevice, stream.stream_id_)); +#if CK_TILE_USE_WMMA + invoke_grouped_gemm_persistent( + stream, group_count, kargs_ptr); +#else invoke_grouped_gemm_persistent( stream, group_count, kargs_ptr); +#endif } else {