[CK_TILE] Tile loop persistent gemm kernel (#2191)

* Implement tile loop persistent gemm kernel

* Enable timing

* Add tests for persistent gemm

* Fix formatting

* Fix gemm_basic

* Rename True/False to Persistent/NonPersistent

* Use only one set of layouts for persistent tests

* Fix gemm example persistent template parameter

* Fix formatting

[ROCm/composable_kernel commit: ffb52783d0]
This commit is contained in:
Sami Remes
2025-06-04 11:46:28 +03:00
committed by GitHub
parent 2d67076cc0
commit 0385ef2437
10 changed files with 232 additions and 18 deletions

View File

@@ -23,3 +23,8 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
else()
message("Skipping ck_tile_gemm tests for current target")
endif()
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx90a")
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()

View File

@@ -2,6 +2,7 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <type_traits>
#include "gtest/gtest.h"
@@ -21,6 +22,9 @@ using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType:
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
using Persistent = std::true_type;
using NonPersistent = std::false_type;
// clang-format off
using KernelTypesMem = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
@@ -59,4 +63,9 @@ using KernelTypesCompV4 = ::testing::Types<
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>
>;
using KernelTypesPersistent = ::testing::Types<
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3, NonPersistent>
>;
// clang-format on

View File

@@ -0,0 +1,16 @@
#include "test_gemm_pipeline_kernel_types.hpp"
#include "test_gemm_pipeline_util.hpp"
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelinePersistent : public TestCkTileGemmPipeline<T>
{
};
#define TEST_SUITE_NAME TestCkTileGemmPipelinePersistent
TYPED_TEST_SUITE(TEST_SUITE_NAME, KernelTypesPersistent);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -89,6 +89,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
using CDataType = std::tuple_element_t<6, Tuple>;
static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value;
static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value;
static constexpr bool Persistent =
ck_tile::tuple_element_or_default_t<Tuple, 9, std::false_type>::value;
// TODO: expose tile size through test t-param ?
template <bool PadM, bool PadN, bool PadK>
@@ -130,14 +132,17 @@ class TestCkTileGemmPipeline : public ::testing::Test
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
static constexpr bool StructuredSparsity = false;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC>;
TransposeC,
StructuredSparsity,
Persistent>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
@@ -190,7 +195,15 @@ class TestCkTileGemmPipeline : public ::testing::Test
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
@@ -442,9 +455,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
EXPECT_TRUE(pass);
}
};