[CK_TILE] Tile loop persistent gemm kernel (#2191)

* Implement tile loop persistent gemm kernel

* Enable timing

* Add tests for persistent gemm

* Fix formatting

* Fix gemm_basic

* Rename True/False to Persistent/NonPersistent

* Use only one set of layouts for persistent tests

* Fix gemm example persistent template parameter

* Fix formatting

[ROCm/composable_kernel commit: ffb52783d0]
This commit is contained in:
Sami Remes
2025-06-04 11:46:28 +03:00
committed by GitHub
parent 2d67076cc0
commit 0385ef2437
10 changed files with 232 additions and 18 deletions

View File

@@ -18,9 +18,12 @@ template <typename ADataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
typename CLayout,
bool Persistent>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
if constexpr(Persistent)
std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl;
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;

View File

@@ -213,7 +213,8 @@ auto create_args(int argc, char* argv[])
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)");
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("persistent", "0", "0:non-persistent, 1:persistent");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -226,5 +227,6 @@ template <typename ADataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
typename CLayout,
bool Persistent = false>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -162,7 +162,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
int n_repeat,
bool persistent)
{
ck_tile::GemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
@@ -176,9 +177,31 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.stride_B = stride_B;
args.stride_C = stride_C;
float ave_time =
gemm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
float ave_time;
if(persistent)
{
ave_time = gemm_calc<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
true>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
}
else
{
ave_time = gemm_calc<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
false>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
}
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
@@ -193,8 +216,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
<< " B_Type=" << DataTypeTraits<BDataType>::name
<< " C_Type=" << DataTypeTraits<CDataType>::name
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
@@ -229,6 +252,7 @@ int run_gemm_example_with_layouts(int argc,
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
bool persistent = arg_parser.get_int("persistent");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
@@ -316,7 +340,8 @@ int run_gemm_example_with_layouts(int argc,
stride_C,
kbatch,
n_warmup,
n_repeat);
n_repeat,
persistent);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;

View File

@@ -32,7 +32,8 @@ template <typename ADataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
typename CLayout,
bool Persistent>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
@@ -61,7 +62,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
BLayout,
CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity>;
GemmConfig::UseStructuredSparsity,
Persistent>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
@@ -111,7 +113,15 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))