[CK_Tile] Updating gpu timer when doing flush cache (#2593)

* Missed updating function names in example

* updating timer

* code cleanup

* addressing review comments

* updating tile_engine code

* addressing review comments
This commit is contained in:
Khushbu Agarwal
2025-07-31 16:43:33 -07:00
committed by GitHub
parent 546ef78d1d
commit 88d72178d6
11 changed files with 54 additions and 139 deletions

View File

@@ -458,7 +458,8 @@ auto create_args(int argc, char* argv[])
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("persistent", "0", "0:non-persistent, 1:persistent")
.insert("bench_time_ms", "0", "benchmark time in ms, defaults to 0 ms");
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
.insert("rotating_count", "1", "rotating count, defaults to 1");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -184,7 +184,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup,
int n_repeat,
bool persistent,
int bench_time_ms)
bool flush_cache,
int rotating_count)
{
ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
@@ -214,7 +215,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
CDEElementWise>(
args,
ck_tile::stream_config{
nullptr, true, 1, n_warmup, n_repeat, true, true, 50, bench_time_ms});
nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
}
else
{
@@ -232,7 +233,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
CDEElementWise>(
args,
ck_tile::stream_config{
nullptr, true, 1, n_warmup, n_repeat, true, true, 50, bench_time_ms});
nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
}
std::size_t flop = std::size_t(2) * M * N * K;
@@ -303,7 +304,8 @@ int run_gemm_example_with_layouts(int argc,
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
bool persistent = arg_parser.get_int("persistent");
int bench_time_ms = arg_parser.get_int("bench_time_ms");
bool flush_cache = arg_parser.get_bool("flush_cache");
int rotating_count = arg_parser.get_int("rotating_count");
const bool preshuffle = GemmConfig::Preshuffle;
@@ -422,7 +424,8 @@ int run_gemm_example_with_layouts(int argc,
n_warmup,
n_repeat,
persistent,
bench_time_ms);
flush_cache,
rotating_count);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;