mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
[ck_tile] enable test grouped_gemm_quant and gemm_streamk on gfx12 (#3196)
1. Enable grouped_gemm_quant and gemm_streamk on gfx12 - test_ck_tile_streamk_smoke is kept on gfx9, since it looks someone is still working on it. 2. Update warp tile size in grouped_gemm_quant and gemm_streamk unit test 3. Reduce gemm tile size to pass the build on gfx12 in test_gemm_streamk_reboot_types.hpp
This commit is contained in:
@@ -12,7 +12,7 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
|
||||
)
|
||||
set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
|
||||
|
||||
# Currently test_ck_tile_streamk is only built on gfx9
|
||||
# Currently test_ck_tile_streamk_smoke is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
|
||||
include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
@@ -140,6 +140,13 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# )
|
||||
target_compile_options(test_ck_tile_streamk_smoke PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping test_ck_tile_streamk_smoke for current target")
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
|
||||
add_gtest_executable(test_ck_tile_streamk_reboot_smoke
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_fp16_persistent.cpp
|
||||
@@ -153,7 +160,6 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp
|
||||
test_gemm_streamk_reboot_util.cpp)
|
||||
target_compile_options(test_ck_tile_streamk_smoke PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping test_ck_tile_streamk tests for current target")
|
||||
message(DEBUG "Skipping test_ck_tile_streamk unit tests for current target")
|
||||
endif()
|
||||
|
||||
@@ -19,38 +19,38 @@ using Persistent = std::true_type;
|
||||
using NonPersistent = std::false_type;
|
||||
|
||||
using I32 = ck_tile::number<32>;
|
||||
using I256 = ck_tile::number<256>;
|
||||
using I128 = ck_tile::number<128>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypesStreamKFp16Persistent = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent
|
||||
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I128, I128, I32, Persistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I128, I128, I32, Persistent>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I128, I128, I32, Persistent>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I128, I128, I32, Persistent>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf16Persistent = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I128, I128, I32, Persistent>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I128, I128, I32, Persistent>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I128, I128, I32, Persistent>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I128, I128, I32, Persistent>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp16NonPersistent = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent
|
||||
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I128, I128, I32, NonPersistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I128, I128, I32, NonPersistent>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I128, I128, I32, NonPersistent>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I128, I128, I32, NonPersistent>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf16NonPersistent = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I128, I128, I32, NonPersistent>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I128, I128, I32, NonPersistent>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I128, I128, I32, NonPersistent>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I128, I128, I32, NonPersistent>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
@@ -69,11 +69,15 @@ class TestCkTileStreamKReboot : public ::testing::Test
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
#else
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
#endif
|
||||
constexpr bool kPadM = PadM;
|
||||
constexpr bool kPadN = PadN;
|
||||
constexpr bool kPadK = PadK;
|
||||
|
||||
@@ -3,7 +3,7 @@ if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
# Split into three separate test executables for faster parallel compilation
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
@@ -74,6 +74,17 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
M_Warp_Tile>();
|
||||
};
|
||||
|
||||
struct GroupedGemKernelParam_Wmma : public GroupedGemKernelParam_Mfma
|
||||
{
|
||||
static const ck_tile::index_t M_Tile = 128;
|
||||
static const ck_tile::index_t N_Tile = 128;
|
||||
static const ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static const ck_tile::index_t M_Warp_Tile = 16;
|
||||
static const ck_tile::index_t N_Warp_Tile = 16;
|
||||
static const ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
|
||||
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
@@ -373,8 +384,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
|
||||
if constexpr(PreshuffleB && QuantType == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
auto b_shuffle_host =
|
||||
ck_tile::shuffle_b<GroupedGemKernelParam_Wmma>(b_k_n_tensors[i]);
|
||||
#else
|
||||
auto b_shuffle_host =
|
||||
ck_tile::shuffle_b<GroupedGemKernelParam_Mfma>(b_k_n_tensors[i]);
|
||||
#endif
|
||||
b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
@@ -446,8 +462,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
#if CK_TILE_USE_WMMA
|
||||
invoke_grouped_gemm_persistent<GroupedGemKernelParam_Wmma, ALayout, BLayout, CLayout>(
|
||||
stream, group_count, kargs_ptr);
|
||||
#else
|
||||
invoke_grouped_gemm_persistent<GroupedGemKernelParam_Mfma, ALayout, BLayout, CLayout>(
|
||||
stream, group_count, kargs_ptr);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user