Merge commit 'b0ee317d83b77741022997265d4125697e7f7804' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-12 20:11:58 +00:00
parent facbc883fa
commit 302aa809ea
65 changed files with 2301 additions and 232 deletions

View File

@@ -1,4 +1,3 @@
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_ck_tile_batched_gemm test_batched_gemm.cpp)
endif()

View File

@@ -27,21 +27,41 @@ class TestCkTileBatchedGemm : public ::testing::Test
using DsLayout = ck_tile::tuple<>;
using DsDataType = ck_tile::tuple<>;
template <typename ALayout, typename BLayout, typename CLayout>
struct GemmWarpConfig_Mfma
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
struct GemmWarpConfig_Wmma
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename GemmWarpConfig, typename ALayout, typename BLayout, typename CLayout>
void invoke_batched_gemm(const ck_tile::BatchedGemmHostArgs& args,
const ck_tile::stream_config& s)
{
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Tile = GemmWarpConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmWarpConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmWarpConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
constexpr bool DoubleSmemBuffer = false;
@@ -255,9 +275,13 @@ class TestCkTileBatchedGemm : public ::testing::Test
BatchStrideB,
BatchStrideC,
BatchCount};
invoke_batched_gemm<ALayout, BLayout, CLayout>(args,
ck_tile::stream_config{nullptr, false});
#if CK_TILE_USE_WMMA
invoke_batched_gemm<GemmWarpConfig_Wmma, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, false});
#else
invoke_batched_gemm<GemmWarpConfig_Mfma, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, false});
#endif
std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << StrideA << " StrideB =" << StrideB << " StrideC =" << StrideC