[CK_TILE] Tile loop persistent gemm kernel (#2191)

* Implement tile loop persistent gemm kernel

* Enable timing

* Add tests for persistent gemm

* Fix formatting

* Fix gemm_basic

* Rename True/False to Persistent/NonPersistent

* Use only one set of layouts for persistent tests

* Fix gemm example persistent template parameter

* Fix formatting
This commit is contained in:
Sami Remes
2025-06-04 11:46:28 +03:00
committed by GitHub
parent 52b4860a30
commit ffb52783d0
10 changed files with 232 additions and 18 deletions

View File

@@ -32,7 +32,8 @@ template <typename ADataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
typename CLayout,
bool Persistent>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
@@ -61,7 +62,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
BLayout,
CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity>;
GemmConfig::UseStructuredSparsity,
Persistent>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
@@ -111,7 +113,15 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))