From 930f95d4a634240bc3e22b51eaf027d68cc20c88 Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Sat, 13 Sep 2025 03:45:14 +0800 Subject: [PATCH] [CK_TILE] Enable ck_tile tests on gfx11 and gfx12 (#2821) * [CK_TILE] Enable ck_tile test on gfx11 & gfx12 * revert an unnecessary change * enable pk_int4 on gfx11 & gfx12 * revert .pre-commit-config.yaml [ROCm/composable_kernel commit: b0ee317d83b77741022997265d4125697e7f7804] --- example/ck_tile/03_gemm/run_gemm_example.inc | 8 +- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 8 +- .../ck_tile/18_flatmm/run_flatmm_example.inc | 8 +- ...p_gemm_attribute_wmma_impl_base_traits.hpp | 4 +- .../add_rmsnorm2d_rdquant/CMakeLists.txt | 2 +- test/ck_tile/batched_gemm/CMakeLists.txt | 3 +- .../batched_gemm/test_batched_gemm_util.hpp | 44 +++++++--- test/ck_tile/batched_transpose/CMakeLists.txt | 3 +- test/ck_tile/container/CMakeLists.txt | 2 +- test/ck_tile/data_type/CMakeLists.txt | 2 +- test/ck_tile/elementwise/CMakeLists.txt | 2 +- .../elementwise/test_elementwise_1d.cpp | 2 +- test/ck_tile/gemm/CMakeLists.txt | 70 ++++++++------- .../test_gemm_pipeline_basic_run_test.inc | 65 ++++++++++---- .../gemm/test_gemm_pipeline_smoke_util.hpp | 21 +++++ .../test_gemm_pipeline_universal_run_test.inc | 8 ++ test/ck_tile/gemm_multi_d/CMakeLists.txt | 3 +- .../gemm_multi_d/test_gemm_multi_d_util.hpp | 53 ++++++++++-- .../gemm_weight_preshuffle/CMakeLists.txt | 2 +- .../test_gemm_pipeline_kernel_types.hpp | 6 +- .../test_gemm_pipeline_util.hpp | 86 +++++++++++++++---- test/ck_tile/grouped_gemm/CMakeLists.txt | 2 +- .../grouped_gemm/test_grouped_gemm_util.hpp | 33 +++++-- test/ck_tile/image_to_column/CMakeLists.txt | 3 +- test/ck_tile/layernorm2d/CMakeLists.txt | 2 +- test/ck_tile/moe_smoothquant/CMakeLists.txt | 3 +- test/ck_tile/moe_sorting/CMakeLists.txt | 4 +- test/ck_tile/permute/CMakeLists.txt | 3 +- test/ck_tile/permute/test_permute_util.hpp | 4 + test/ck_tile/reduce/CMakeLists.txt | 2 +- test/ck_tile/reduce/test_reduce2d.cpp | 2 +- test/ck_tile/rmsnorm2d/CMakeLists.txt | 2 +- test/ck_tile/smoothquant/CMakeLists.txt | 3 +- test/ck_tile/topk_softmax/CMakeLists.txt | 3 +- 34 files changed, 338 insertions(+), 130 deletions(-) diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index cc980a75f7..e6875f97d5 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -181,15 +181,15 @@ auto shuffle_b(const ck_tile::HostTensor& t) if(ck_tile::is_gfx12_supported()) { - // TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase constexpr int divisor = 2; - constexpr int kABK0PerLane = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, GemmConfig::N_Warp_Tile, k_ / GemmConfig::K_Warp_Tile, - divisor, kABK0PerLane, - GemmConfig::K_Warp_Tile / divisor / kABK0PerLane}); + divisor, + kABK1PerLane}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index a8abcee41e..1ae0844032 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -314,15 +314,15 @@ auto shuffle_b(const ck_tile::HostTensor& t) if(ck_tile::is_gfx12_supported()) { - // TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase constexpr int divisor = 2; - constexpr int kABK0PerLane = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, GemmConfig::N_Warp_Tile, k_ / GemmConfig::K_Warp_Tile, - divisor, kABK0PerLane, - GemmConfig::K_Warp_Tile / divisor / kABK0PerLane}); + divisor, + kABK1PerLane}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); } diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index 63d0a80555..c187f72594 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -45,15 +45,15 @@ auto shuffle_b(const ck_tile::HostTensor& t) if(ck_tile::is_gfx12_supported()) { - // TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase constexpr int divisor = 2; - constexpr int kABK0PerLane = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = FlatmmConfig::K_Warp_Tile / divisor / kABK1PerLane; ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Warp_Tile, FlatmmConfig::N_Warp_Tile, k_ / FlatmmConfig::K_Warp_Tile, - divisor, kABK0PerLane, - FlatmmConfig::K_Warp_Tile / divisor / kABK0PerLane}); + divisor, + kABK1PerLane}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); } diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp index 7a3190e6f4..86bae7655b 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp @@ -70,9 +70,9 @@ struct WmmaTraitsBase static constexpr index_t kRepeat = 1; static constexpr index_t kAMLane = 16; static constexpr index_t kBNLane = 16; - static constexpr index_t kABK0PerLane = 2; + static constexpr index_t kABK0PerLane = 1; static constexpr index_t kABKLane = 2; - static constexpr index_t kABK1PerLane = 4; + static constexpr index_t kABK1PerLane = 8; static constexpr index_t kCMLane = 2; static constexpr index_t kCNLane = 16; diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt b/test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt index 37774f7643..64672e200b 100644 --- a/test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt +++ b/test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt @@ -18,7 +18,7 @@ function(create_tile_add_rmsnorm2d_rdquant_fwd SUFFIX) set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) endfunction() -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") create_tile_add_rmsnorm2d_rdquant_fwd("fp16") create_tile_add_rmsnorm2d_rdquant_fwd("bf16") else() diff --git a/test/ck_tile/batched_gemm/CMakeLists.txt b/test/ck_tile/batched_gemm/CMakeLists.txt index 532ead1124..9bcbc7352e 100644 --- a/test/ck_tile/batched_gemm/CMakeLists.txt +++ b/test/ck_tile/batched_gemm/CMakeLists.txt @@ -1,4 +1,3 @@ -# Currently ck_tile is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_batched_gemm test_batched_gemm.cpp) endif() diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index f634e508e3..1e2ea45b9e 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -27,21 +27,41 @@ class TestCkTileBatchedGemm : public ::testing::Test using DsLayout = ck_tile::tuple<>; using DsDataType = ck_tile::tuple<>; - template + struct GemmWarpConfig_Mfma + { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + }; + + struct GemmWarpConfig_Wmma + { + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64; + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + }; + + template void invoke_batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s) { - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 64; + constexpr ck_tile::index_t M_Tile = GemmWarpConfig::M_Tile; + constexpr ck_tile::index_t N_Tile = GemmWarpConfig::N_Tile; + constexpr ck_tile::index_t K_Tile = GemmWarpConfig::K_Tile; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; + constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile; constexpr bool DoubleSmemBuffer = false; @@ -255,9 +275,13 @@ class TestCkTileBatchedGemm : public ::testing::Test BatchStrideB, BatchStrideC, BatchCount}; - - invoke_batched_gemm(args, - ck_tile::stream_config{nullptr, false}); +#if CK_TILE_USE_WMMA + invoke_batched_gemm( + args, ck_tile::stream_config{nullptr, false}); +#else + invoke_batched_gemm( + args, ck_tile::stream_config{nullptr, false}); +#endif std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K << " StrideA =" << StrideA << " StrideB =" << StrideB << " StrideC =" << StrideC diff --git a/test/ck_tile/batched_transpose/CMakeLists.txt b/test/ck_tile/batched_transpose/CMakeLists.txt index 111b7c2bed..fb45caf044 100644 --- a/test/ck_tile/batched_transpose/CMakeLists.txt +++ b/test/ck_tile/batched_transpose/CMakeLists.txt @@ -1,5 +1,4 @@ -# Currently ck_tile is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx950") add_gtest_executable(test_ck_tile_batched_transpose test_batched_transpose.cpp) set_property(TARGET test_ck_tile_batched_transpose PROPERTY CXX_STANDARD 20) else() diff --git a/test/ck_tile/container/CMakeLists.txt b/test/ck_tile/container/CMakeLists.txt index 50670c83e4..f13f0dbedf 100644 --- a/test/ck_tile/container/CMakeLists.txt +++ b/test/ck_tile/container/CMakeLists.txt @@ -1,4 +1,4 @@ -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_tuple_apply test_tuple_apply.cpp) if(result EQUAL 0) target_link_libraries(test_ck_tile_tuple_apply PRIVATE utility) diff --git a/test/ck_tile/data_type/CMakeLists.txt b/test/ck_tile/data_type/CMakeLists.txt index 384fd3c1c4..a5713ac55c 100644 --- a/test/ck_tile/data_type/CMakeLists.txt +++ b/test/ck_tile/data_type/CMakeLists.txt @@ -1,4 +1,4 @@ -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_pk_int4 test_pk_int4.cpp) endif() if(GPU_TARGETS MATCHES "gfx95") diff --git a/test/ck_tile/elementwise/CMakeLists.txt b/test/ck_tile/elementwise/CMakeLists.txt index d22a30ff56..5fca0eb801 100644 --- a/test/ck_tile/elementwise/CMakeLists.txt +++ b/test/ck_tile/elementwise/CMakeLists.txt @@ -1,4 +1,4 @@ -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_elementwise_1d test_elementwise_1d.cpp) if(result EQUAL 0) target_link_libraries(test_ck_tile_elementwise_1d PRIVATE utility) diff --git a/test/ck_tile/elementwise/test_elementwise_1d.cpp b/test/ck_tile/elementwise/test_elementwise_1d.cpp index 3ce6e78d1d..2eb2b506e8 100644 --- a/test/ck_tile/elementwise/test_elementwise_1d.cpp +++ b/test/ck_tile/elementwise/test_elementwise_1d.cpp @@ -106,7 +106,7 @@ class TestCkTileElementwise : public ::testing::Test ck_tile::index_t grid_size = (total_m_elements + TestElementWiseShape::kBlockM - 1) / TestElementWiseShape::kBlockM; dim3 grid(grid_size, 1, 1); - dim3 block(TestElementWiseShape::kBlockSize, 1, 1); + dim3 block = dim3(ew_kernel.BlockSize()); constexpr ck_tile::index_t kBlockPerCu = 1; ck_tile::stream_config s{nullptr, false, 0}; // Default stream, no timing, no log diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 5d34943e0d..44e2433060 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -12,16 +12,16 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -enable-noalias-to-md-conversion=0 ) -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") - add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp) - add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp) - add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp) - - target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) - +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12") + add_test_executable(test_ck_tile_gemm_pipeline_universal_int8 test_gemm_pipeline_universal_int8.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_universal_int8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_test_executable(test_ck_tile_gemm_pipeline_universal_pk_int4 test_gemm_pipeline_universal_pk_int4.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_universal_pk_int4 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +else() + message(DEBUG "Skipping ck_tile_gemm tests for current target") +endif() +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_test_executable(test_ck_tile_gemm_pipeline_universal_fp8 test_gemm_pipeline_universal_fp8.cpp) target_compile_options(test_ck_tile_gemm_pipeline_universal_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_universal_bf8 test_gemm_pipeline_universal_bf8.cpp) @@ -30,37 +30,47 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_basic_bf8 test_gemm_pipeline_basic_bf8.cpp) target_compile_options(test_ck_tile_gemm_pipeline_basic_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - - add_test_executable(test_ck_tile_gemm_pipeline_universal_int8 test_gemm_pipeline_universal_int8.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_universal_int8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - add_test_executable(test_ck_tile_gemm_pipeline_universal_pk_int4 test_gemm_pipeline_universal_pk_int4.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_universal_pk_int4 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - -elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") - # On Radeon devices, build the WMMA version instead - add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp) - add_gtest_executable(test_ck_tile_gemm_pipeline_compv3_wmma test_gemm_pipeline_compv3_wmma.cpp) - add_gtest_executable(test_ck_tile_gemm_pipeline_compv4_wmma test_gemm_pipeline_compv4_wmma.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_mem_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - target_compile_options(test_ck_tile_gemm_pipeline_compv3_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - target_compile_options(test_ck_tile_gemm_pipeline_compv4_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) else() message(DEBUG "Skipping ck_tile_gemm tests for current target") endif() -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx90a") - add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12") add_test_executable(test_ck_tile_gemm_pipeline_universal_fp16 test_gemm_pipeline_universal_fp16.cpp) target_compile_options(test_ck_tile_gemm_pipeline_universal_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_universal_fp16 PRIVATE --save-temps -Wno-gnu-line-marker) add_test_executable(test_ck_tile_gemm_pipeline_universal_bf16 test_gemm_pipeline_universal_bf16.cpp) target_compile_options(test_ck_tile_gemm_pipeline_universal_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_basic_fp16 test_gemm_pipeline_basic_fp16.cpp) target_compile_options(test_ck_tile_gemm_pipeline_basic_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_basic_bf16 test_gemm_pipeline_basic_bf16.cpp) target_compile_options(test_ck_tile_gemm_pipeline_basic_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") - add_gtest_executable(test_ck_tile_gemm_pipeline_persistent_wmma test_gemm_pipeline_persistent_wmma.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_persistent_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +else() + message(DEBUG "Skipping ck_tile_gemm tests for current target ") +endif() + +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12") + if(GPU_TARGETS MATCHES "gfx94|gfx95") + add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp) + add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp) + add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp) + add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + endif() + + if(GPU_TARGETS MATCHES "gfx11|gfx12") + # On Radeon devices, build the WMMA version instead + add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp) + add_gtest_executable(test_ck_tile_gemm_pipeline_compv3_wmma test_gemm_pipeline_compv3_wmma.cpp) + add_gtest_executable(test_ck_tile_gemm_pipeline_compv4_wmma test_gemm_pipeline_compv4_wmma.cpp) + add_gtest_executable(test_ck_tile_gemm_pipeline_persistent_wmma test_gemm_pipeline_persistent_wmma.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_mem_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv3_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv4_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_persistent_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + endif() +else() + message(DEBUG "Skipping ck_tile_gemm tests for current target test_ck_tile_gemm_pipeline") endif() diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc index 1fdf26f01c..706035cabc 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc @@ -13,6 +13,28 @@ #include "test_gemm_pipeline_smoke_util.hpp" #include "test_gemm_pipeline_smoke_run_test.inc" +struct GemmConfig_Mfma : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; +}; + +struct GemmConfig_Wmma : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; +}; + template , @@ -130,7 +152,10 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) } } -template +template bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, ck_tile::ArgParser& arg_parser) @@ -142,12 +167,12 @@ bool run_gemm_test_prec_type(std::string a_layout, { if(a_layout == "R" && b_layout == "C") { - return run_gemm_test_with_layouts( + return run_gemm_test_with_layouts( arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_test_with_layouts( + return run_gemm_test_with_layouts( arg_parser, Col{}, Col{}, Row{}); } else @@ -160,22 +185,22 @@ bool run_gemm_test_prec_type(std::string a_layout, { if(a_layout == "R" && b_layout == "C") { - return run_gemm_test_with_layouts( + return run_gemm_test_with_layouts( arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "R" && b_layout == "R") { - return run_gemm_test_with_layouts( + return run_gemm_test_with_layouts( arg_parser, Row{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "R") { - return run_gemm_test_with_layouts( + return run_gemm_test_with_layouts( arg_parser, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_test_with_layouts( + return run_gemm_test_with_layouts( arg_parser, Col{}, Col{}, Row{}); } else @@ -185,7 +210,7 @@ bool run_gemm_test_prec_type(std::string a_layout, } } -template +template bool run_gemm_test(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -195,7 +220,8 @@ bool run_gemm_test(int argc, char* argv[]) std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); - return run_gemm_test_prec_type(a_layout, b_layout, arg_parser); + return run_gemm_test_prec_type( + a_layout, b_layout, arg_parser); } template @@ -255,8 +281,15 @@ int run_gemm_combinations() // Call the function with the current configuration try { - is_success = run_gemm_test(ARG_COUNT, argv) && +#if CK_TILE_USE_WMMA + is_success = run_gemm_test( + ARG_COUNT, argv) && is_success; +#else + is_success = run_gemm_test( + ARG_COUNT, argv) && + is_success; +#endif } catch(const ArgumentsNotSupportedException& e) { diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp index f64d3e092b..52f6ea7026 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp @@ -220,6 +220,27 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; }; +template +struct GemmConfigComputeV3_WMMA : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; +}; + template struct GemmTypeConfig; diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc index fd50596f2f..dfee45cdfd 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc @@ -325,6 +325,13 @@ int run_gemm_combinations() // Call the function with the current configuration try { +#if CK_TILE_USE_WMMA + is_success = run_gemm_test, + APrecType, + BPrecType, + CPrecType>(ARG_COUNT, argv) && + is_success; +#else is_success = run_gemm_test, APrecType, BPrecType, @@ -335,6 +342,7 @@ int run_gemm_combinations() BPrecType, CPrecType>(ARG_COUNT, argv) && is_success; +#endif } catch(const ArgumentsNotSupportedException& e) { diff --git a/test/ck_tile/gemm_multi_d/CMakeLists.txt b/test/ck_tile/gemm_multi_d/CMakeLists.txt index c9d53e53e2..143fb9dc40 100644 --- a/test/ck_tile/gemm_multi_d/CMakeLists.txt +++ b/test/ck_tile/gemm_multi_d/CMakeLists.txt @@ -1,10 +1,9 @@ -# Currently ck_tile is only built on gfx9 set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_gtest_executable(test_gemm_multi_d_cshuffle test_gemm_multi_d_cshuffle.cpp) add_gtest_executable(test_gemm_multi_d_default2d test_gemm_multi_d_default2d.cpp) target_compile_definitions(test_gemm_multi_d_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp index 8399bc7ee3..f0050c15d5 100644 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp @@ -86,7 +86,28 @@ class TestCkTileGemmMultiD : public ::testing::Test using DsLayout = ck_tile::tuple; using DsDataType = ck_tile::tuple; - template & args, const ck_tile::stream_config& s) { - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 64; + constexpr ck_tile::index_t M_Tile = GemmWarpConfig::M_Tile; + constexpr ck_tile::index_t N_Tile = GemmWarpConfig::N_Tile; + constexpr ck_tile::index_t K_Tile = GemmWarpConfig::K_Tile; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; + constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile; constexpr bool DoubleSmemBuffer = false; @@ -359,8 +380,9 @@ class TestCkTileGemmMultiD : public ::testing::Test StrideB, stridesDs, StrideE}); - - invoke_gemm_multi_d(args, ck_tile::stream_config{nullptr, false}); +#else + invoke_gemm_multi_d(args, ck_tile::stream_config{nullptr, false}); +#endif std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K << " StrideA =" << StrideA << " StrideB =" << StrideB << " StrideE =" << StrideE diff --git a/test/ck_tile/gemm_weight_preshuffle/CMakeLists.txt b/test/ck_tile/gemm_weight_preshuffle/CMakeLists.txt index 4b9e6049e3..90803bd9d5 100644 --- a/test/ck_tile/gemm_weight_preshuffle/CMakeLists.txt +++ b/test/ck_tile/gemm_weight_preshuffle/CMakeLists.txt @@ -12,7 +12,7 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -enable-noalias-to-md-conversion=0 ) -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12") add_gtest_executable(test_ck_tile_gemm_pipeline_wp test_gemm_pipeline_wp.cpp) target_compile_options(test_ck_tile_gemm_pipeline_wp PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_kernel_types.hpp index f66f3cb0aa..b1521fc35a 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_kernel_types.hpp @@ -31,8 +31,10 @@ using F8Types = std::tuple, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffle>, - F8Types + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffle> +#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8 + , F8Types +#endif >; // clang-format on diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index 5d52f15696..42d0149498 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -63,6 +63,23 @@ struct config static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32; }; + +template +struct config_wmma +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(Datatype); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; +}; + template class TestCkTileGemmPipeline : public ::testing::Test { @@ -79,13 +96,12 @@ class TestCkTileGemmPipeline : public ::testing::Test using DsLayout = ck_tile::tuple<>; using DsDataType = ck_tile::tuple<>; - using GemmConfig = config; static constexpr bool Persistent = ck_tile::tuple_element_or_default_t::value; // TODO: expose tile size through test t-param ? - template + template void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // TODO: This should be parameterized in tests @@ -253,6 +269,48 @@ class TestCkTileGemmPipeline : public ::testing::Test k_batches_ = {1}; } + template + auto shuffle_b(const ck_tile::HostTensor& t) + { + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + if(ck_tile::is_gfx12_supported()) + { + constexpr int divisor = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + kABK0PerLane, + divisor, + kABK1PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + } template void Run(const int M, const int N, @@ -263,11 +321,17 @@ class TestCkTileGemmPipeline : public ::testing::Test { for(auto kb : k_batches_) { - RunSingle(M, N, K, StrideA, StrideB, StrideC, kb); +#if CK_TILE_USE_WMMA + RunSingle, PadM, PadN, PadK, Preshuffle>( + M, N, K, StrideA, StrideB, StrideC, kb); +#else + RunSingle, PadM, PadN, PadK, Preshuffle>( + M, N, K, StrideA, StrideB, StrideC, kb); +#endif } } - template + template void RunSingle(const int M, const int N, const int K, @@ -327,16 +391,7 @@ class TestCkTileGemmPipeline : public ::testing::Test ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); - constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view({N / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - K / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - - std::copy(b_k_n.begin(), b_k_n.end(), t_view.begin()); - ck_tile::HostTensor b_shuffle_host = - ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_k_n); a_m_k_dev_buf.ToDevice(a_m_k.data()); b_k_n_dev_buf.ToDevice(b_shuffle_host.data()); @@ -354,7 +409,8 @@ class TestCkTileGemmPipeline : public ::testing::Test stride_B, stride_C}; - invoke_gemm(args, ck_tile::stream_config{nullptr, false}); + invoke_gemm( + args, ck_tile::stream_config{nullptr, false}); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; diff --git a/test/ck_tile/grouped_gemm/CMakeLists.txt b/test/ck_tile/grouped_gemm/CMakeLists.txt index f4845847f1..4fd5c82ae9 100644 --- a/test/ck_tile/grouped_gemm/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm/CMakeLists.txt @@ -1,4 +1,4 @@ # Currently ck_tile is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_grouped_gemm test_grouped_gemm.cpp) endif() diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index 5aca02a433..6893318ea2 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -31,7 +31,7 @@ class TestCkTileGroupedGemm : public ::testing::Test using PersistentType = std::tuple_element_t<7, Tuple>; static constexpr bool Persistent = PersistentType::value; - struct GroupedGemKernelParam + struct GroupedGemKernelParam_Mfma { static const bool kPadM = false; static const bool kPadN = false; @@ -51,13 +51,24 @@ class TestCkTileGroupedGemm : public ::testing::Test static const ck_tile::index_t K_Warp_Tile = 16; }; + struct GroupedGemKernelParam_Wmma : public GroupedGemKernelParam_Mfma + { + static const ck_tile::index_t M_Tile = 128; + static const ck_tile::index_t N_Tile = 128; + static const ck_tile::index_t K_Tile = 64; + + static const ck_tile::index_t M_Warp_Tile = 16; + static const ck_tile::index_t N_Warp_Tile = 16; + static const ck_tile::index_t K_Warp_Tile = 16; + }; + using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; std::size_t get_workspace_size(const std::vector& gemm_descs) { return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); } - template + template void invoke_grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* kargs_ptr) @@ -200,7 +211,7 @@ class TestCkTileGroupedGemm : public ::testing::Test BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } - template + template void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr, @@ -460,15 +471,27 @@ class TestCkTileGroupedGemm : public ::testing::Test kargs.size() * sizeof(ck_tile::GemmTransKernelArg), hipMemcpyHostToDevice, stream.stream_id_)); - invoke_grouped_gemm_persistent( +#if CK_TILE_USE_WMMA + invoke_grouped_gemm_persistent( stream, group_count, kargs_ptr, splitk); +#else + invoke_grouped_gemm_persistent( + stream, group_count, kargs_ptr, splitk); +#endif } else { - invoke_grouped_gemm( +#if CK_TILE_USE_WMMA + invoke_grouped_gemm( gemm_descs, ck_tile::stream_config{nullptr, false, 1}, gemm_workspace.GetDeviceBuffer()); +#else + invoke_grouped_gemm( + gemm_descs, + ck_tile::stream_config{nullptr, false, 1}, + gemm_workspace.GetDeviceBuffer()); +#endif } // Copy results back to host for validation diff --git a/test/ck_tile/image_to_column/CMakeLists.txt b/test/ck_tile/image_to_column/CMakeLists.txt index 247358dd4d..8873a846fc 100644 --- a/test/ck_tile/image_to_column/CMakeLists.txt +++ b/test/ck_tile/image_to_column/CMakeLists.txt @@ -1,4 +1,3 @@ -# Currently ck_tile is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_tile_image_to_column test_tile_image_to_column.cpp) endif() diff --git a/test/ck_tile/layernorm2d/CMakeLists.txt b/test/ck_tile/layernorm2d/CMakeLists.txt index c909d6cf40..e924f39e7a 100644 --- a/test/ck_tile/layernorm2d/CMakeLists.txt +++ b/test/ck_tile/layernorm2d/CMakeLists.txt @@ -14,7 +14,7 @@ function(create_tile_layernorm2d_fwd SUFFIX) target_compile_options(${TEST_CK_TILE_LAYERNORM2D_FWD} PRIVATE ${TEST_CK_TILE_LAYERNORM2D_FWD_COMPILE_OPTIONS}) endfunction() -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") set(LAYERNORM2D_FWD_KNOWN_APIS "fwd;bwd") set(LAYERNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING "semicolon-separated list of APIs to generate (${LAYERNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".") diff --git a/test/ck_tile/moe_smoothquant/CMakeLists.txt b/test/ck_tile/moe_smoothquant/CMakeLists.txt index b6c8a395b6..019e87323f 100644 --- a/test/ck_tile/moe_smoothquant/CMakeLists.txt +++ b/test/ck_tile/moe_smoothquant/CMakeLists.txt @@ -1,5 +1,4 @@ -# Currently ck_tile is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") function (add_moe_smoothquant_test TARGET_NAME MAIN_SRC) message(DEBUG "adding ${TARGET_NAME}") add_gtest_executable(${TARGET_NAME} ${MAIN_SRC}) diff --git a/test/ck_tile/moe_sorting/CMakeLists.txt b/test/ck_tile/moe_sorting/CMakeLists.txt index 5abc7df5a9..48d8e1392f 100644 --- a/test/ck_tile/moe_sorting/CMakeLists.txt +++ b/test/ck_tile/moe_sorting/CMakeLists.txt @@ -1,5 +1,5 @@ -# Currently ck_tile is only built on gfx90a, gfx942 and gfx950 -if(GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx950" OR GPU_TARGETS MATCHES "gfx90a") +# Currently ck_tile is only built on gfx90a, gfx942, gfx950, gfx11 and gfx12 +if(GPU_TARGETS MATCHES "gfx942|gfx950|gfx90a|gfx11|gfx12") function(add_moe_sorting_test EXECUTABLE USE_2D_BUF) add_gtest_executable(${EXECUTABLE} test_moe_sorting.cpp moe_sorting_api.cpp) diff --git a/test/ck_tile/permute/CMakeLists.txt b/test/ck_tile/permute/CMakeLists.txt index 4256ad8de1..8574813be3 100644 --- a/test/ck_tile/permute/CMakeLists.txt +++ b/test/ck_tile/permute/CMakeLists.txt @@ -1,5 +1,4 @@ -# Currently ck_tile is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") function(add_permute_test TARGET_NAME MAIN_SRC) add_gtest_executable(${TARGET_NAME} ${MAIN_SRC}) diff --git a/test/ck_tile/permute/test_permute_util.hpp b/test/ck_tile/permute/test_permute_util.hpp index 5494749541..2028f56bb8 100644 --- a/test/ck_tile/permute/test_permute_util.hpp +++ b/test/ck_tile/permute/test_permute_util.hpp @@ -17,9 +17,11 @@ #include #include +#if !CK_TILE_USE_WMMA #ifdef PERMUTE_USE_ALTERNATIVE_IMPL #include "alternative_impl/matrix_core_swizzle.hpp" #endif +#endif namespace detail { template @@ -193,6 +195,7 @@ class TestCkTilePermute : public ::testing::Test return permute(a, stream_config); }; +#if !CK_TILE_USE_WMMA #ifdef PERMUTE_USE_ALTERNATIVE_IMPL // batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 if((perm == std::string("0,1,4,2,5,3,6") || perm == std::string("0,1,2,4,5,3,6") || @@ -278,6 +281,7 @@ class TestCkTilePermute : public ::testing::Test } } else +#endif #endif { run_permute(); diff --git a/test/ck_tile/reduce/CMakeLists.txt b/test/ck_tile/reduce/CMakeLists.txt index 052669e20a..0ba5974f6c 100644 --- a/test/ck_tile/reduce/CMakeLists.txt +++ b/test/ck_tile/reduce/CMakeLists.txt @@ -1,4 +1,4 @@ -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_reduce2d test_reduce2d.cpp) if(result EQUAL 0) target_link_libraries(test_ck_tile_reduce2d PRIVATE utility) diff --git a/test/ck_tile/reduce/test_reduce2d.cpp b/test/ck_tile/reduce/test_reduce2d.cpp index ff807e52c9..ded0406797 100644 --- a/test/ck_tile/reduce/test_reduce2d.cpp +++ b/test/ck_tile/reduce/test_reduce2d.cpp @@ -59,7 +59,7 @@ class TestCkTileReduce : public ::testing::Test using Kernel = ck_tile::Reduce; // Launch configuration - constexpr ck_tile::index_t kBlockSize = 256; + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; ck_tile::index_t kGridSize = diff --git a/test/ck_tile/rmsnorm2d/CMakeLists.txt b/test/ck_tile/rmsnorm2d/CMakeLists.txt index 5a73b0914c..c60d73aafd 100644 --- a/test/ck_tile/rmsnorm2d/CMakeLists.txt +++ b/test/ck_tile/rmsnorm2d/CMakeLists.txt @@ -14,7 +14,7 @@ function(create_tile_rmsnorm2d_fwd SUFFIX) target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS}) endfunction() -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") set(RMSNORM2D_FWD_KNOWN_APIS "fwd;bwd") set(RMSNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING "semicolon-separated list of APIs to generate (${RMSNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".") diff --git a/test/ck_tile/smoothquant/CMakeLists.txt b/test/ck_tile/smoothquant/CMakeLists.txt index 548fc03a41..381923803f 100644 --- a/test/ck_tile/smoothquant/CMakeLists.txt +++ b/test/ck_tile/smoothquant/CMakeLists.txt @@ -1,5 +1,4 @@ -# Currently ck_tile is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") function (add_smoothquant_test TARGET_NAME MAIN_SRC) message(DEBUG "adding ${TARGET_NAME}") diff --git a/test/ck_tile/topk_softmax/CMakeLists.txt b/test/ck_tile/topk_softmax/CMakeLists.txt index 046eaf6649..cd524eca01 100644 --- a/test/ck_tile/topk_softmax/CMakeLists.txt +++ b/test/ck_tile/topk_softmax/CMakeLists.txt @@ -10,8 +10,7 @@ function(add_tile_topk_softmax_test SUFFIX) target_compile_options(${TEST_NAME} PRIVATE ${TEST_TOPK_SOFTMAX_COMPILE_OPTIONS}) endfunction() -# Currently ck_tile is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_tile_topk_softmax_test(fp16) add_tile_topk_softmax_test(bf16) else()