[ck_tile] enable test grouped_gemm_quant and gemm_streamk on gfx12 (#3196)

1. Enable grouped_gemm_quant and gemm_streamk on gfx12
- test_ck_tile_streamk_smoke is kept on gfx9, since it looks someone is still working on it.
2. Update warp tile size in grouped_gemm_quant and gemm_streamk unit test
3. Reduce gemm tile size to pass the build on gfx12 in test_gemm_streamk_reboot_types.hpp
This commit is contained in:
linqunAMD
2025-11-20 08:40:27 +08:00
committed by GitHub
parent cd8af997e6
commit d2e32b4305
6 changed files with 55 additions and 24 deletions

View File

@@ -12,7 +12,7 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
)
set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
# Currently test_ck_tile_streamk is only built on gfx9
# Currently test_ck_tile_streamk_smoke is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR})
@@ -140,6 +140,13 @@ if(GPU_TARGETS MATCHES "gfx9")
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
# )
target_compile_options(test_ck_tile_streamk_smoke PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping test_ck_tile_streamk_smoke for current target")
endif()
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR})
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
add_gtest_executable(test_ck_tile_streamk_reboot_smoke
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_fp16_persistent.cpp
@@ -153,7 +160,6 @@ if(GPU_TARGETS MATCHES "gfx9")
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp
test_gemm_streamk_reboot_util.cpp)
target_compile_options(test_ck_tile_streamk_smoke PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping test_ck_tile_streamk tests for current target")
message(DEBUG "Skipping test_ck_tile_streamk unit tests for current target")
endif()

View File

@@ -19,38 +19,38 @@ using Persistent = std::true_type;
using NonPersistent = std::false_type;
using I32 = ck_tile::number<32>;
using I256 = ck_tile::number<256>;
using I128 = ck_tile::number<128>;
// clang-format off
using KernelTypesStreamKFp16Persistent = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>
std::tuple< Row, Row, Row, F16, F16, F32, F16, I128, I128, I32, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I128, I128, I32, Persistent>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I128, I128, I32, Persistent>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I128, I128, I32, Persistent>
>;
using KernelTypesStreamKBf16Persistent = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I128, I128, I32, Persistent>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I128, I128, I32, Persistent>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I128, I128, I32, Persistent>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I128, I128, I32, Persistent>
>;
using KernelTypesStreamKFp16NonPersistent = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>
std::tuple< Row, Row, Row, F16, F16, F32, F16, I128, I128, I32, NonPersistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I128, I128, I32, NonPersistent>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I128, I128, I32, NonPersistent>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I128, I128, I32, NonPersistent>
>;
using KernelTypesStreamKBf16NonPersistent = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I128, I128, I32, NonPersistent>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I128, I128, I32, NonPersistent>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I128, I128, I32, NonPersistent>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I128, I128, I32, NonPersistent>
>;
// clang-format on

View File

@@ -69,11 +69,15 @@ class TestCkTileStreamKReboot : public ::testing::Test
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
#if CK_TILE_USE_WMMA
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#else
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#endif
constexpr bool kPadM = PadM;
constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK;