Merge commit 'b0ee317d83b77741022997265d4125697e7f7804' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-12 20:11:58 +00:00
parent facbc883fa
commit 302aa809ea
65 changed files with 2301 additions and 232 deletions

View File

@@ -1,4 +1,4 @@
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_ck_tile_grouped_gemm test_grouped_gemm.cpp)
endif()

View File

@@ -31,7 +31,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
using PersistentType = std::tuple_element_t<7, Tuple>;
static constexpr bool Persistent = PersistentType::value;
struct GroupedGemKernelParam
struct GroupedGemKernelParam_Mfma
{
static const bool kPadM = false;
static const bool kPadN = false;
@@ -51,13 +51,24 @@ class TestCkTileGroupedGemm : public ::testing::Test
static const ck_tile::index_t K_Warp_Tile = 16;
};
struct GroupedGemKernelParam_Wmma : public GroupedGemKernelParam_Mfma
{
static const ck_tile::index_t M_Tile = 128;
static const ck_tile::index_t N_Tile = 128;
static const ck_tile::index_t K_Tile = 64;
static const ck_tile::index_t M_Warp_Tile = 16;
static const ck_tile::index_t N_Warp_Tile = 16;
static const ck_tile::index_t K_Warp_Tile = 16;
};
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
}
template <typename ALayout, typename BLayout, typename CLayout>
template <typename GroupedGemKernelParam, typename ALayout, typename BLayout, typename CLayout>
void invoke_grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
@@ -200,7 +211,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
template <typename ALayout, typename BLayout, typename CLayout>
template <typename GroupedGemKernelParam, typename ALayout, typename BLayout, typename CLayout>
void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
@@ -460,15 +471,27 @@ class TestCkTileGroupedGemm : public ::testing::Test
kargs.size() * sizeof(ck_tile::GemmTransKernelArg),
hipMemcpyHostToDevice,
stream.stream_id_));
invoke_grouped_gemm_persistent<ALayout, BLayout, CLayout>(
#if CK_TILE_USE_WMMA
invoke_grouped_gemm_persistent<GroupedGemKernelParam_Wmma, ALayout, BLayout, CLayout>(
stream, group_count, kargs_ptr, splitk);
#else
invoke_grouped_gemm_persistent<GroupedGemKernelParam_Mfma, ALayout, BLayout, CLayout>(
stream, group_count, kargs_ptr, splitk);
#endif
}
else
{
invoke_grouped_gemm<ALayout, BLayout, CLayout>(
#if CK_TILE_USE_WMMA
invoke_grouped_gemm<GroupedGemKernelParam_Wmma, ALayout, BLayout, CLayout>(
gemm_descs,
ck_tile::stream_config{nullptr, false, 1},
gemm_workspace.GetDeviceBuffer());
#else
invoke_grouped_gemm<GroupedGemKernelParam_Mfma, ALayout, BLayout, CLayout>(
gemm_descs,
ck_tile::stream_config{nullptr, false, 1},
gemm_workspace.GetDeviceBuffer());
#endif
}
// Copy results back to host for validation