Update to gpu_timer for rotating_buffer (#2524)

* update gpu_timer for rotating buffer as hipblasLt's implementation

* timing fix

* Updating gpu timer for old ck as well

* Revert "Updating gpu timer for old ck as well"

This reverts commit 958cd1bc99.

* code clean up with runtime argument; function rename

* code cleanup

* general timer fixes

* bug fix

* clang formatted

* addressing reveiew comments

* clang formatted

* Addressing review comments

* CI fix

---------

Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
Khushbu Agarwal
2025-07-29 15:21:05 -07:00
committed by GitHub
parent b80099cc5f
commit 61e21f5567
13 changed files with 182 additions and 78 deletions

View File

@@ -457,7 +457,8 @@ auto create_args(int argc, char* argv[])
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("persistent", "0", "0:non-persistent, 1:persistent");
.insert("persistent", "0", "0:non-persistent, 1:persistent")
.insert("bench_time_ms", "0", "benchmark time in ms, defaults to 0 ms");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -146,18 +146,14 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
@@ -173,7 +169,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(

View File

@@ -183,7 +183,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat,
bool persistent)
bool persistent,
int bench_time_ms)
{
ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
@@ -211,7 +212,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
CLayout,
true,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
args,
ck_tile::stream_config{
nullptr, true, 1, n_warmup, n_repeat, true, true, 50, bench_time_ms});
}
else
{
@@ -227,7 +230,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
CLayout,
false,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
args,
ck_tile::stream_config{
nullptr, true, 1, n_warmup, n_repeat, true, true, 50, bench_time_ms});
}
std::size_t flop = std::size_t(2) * M * N * K;
@@ -236,15 +241,16 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K
std::cout << "Run Gemm kernel with \n M=" << M << " N=" << N << " K=" << K
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
<< " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name
<< " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits<ADataType>::name
<< " B_Type=" << DataTypeTraits<BDataType>::name
<< " C_Type=" << DataTypeTraits<CDataType>::name
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
<< " Persistent=" << (persistent ? "on" : "off") << " : \n"
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
@@ -297,6 +303,7 @@ int run_gemm_example_with_layouts(int argc,
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
bool persistent = arg_parser.get_int("persistent");
int bench_time_ms = arg_parser.get_int("bench_time_ms");
const bool preshuffle = GemmConfig::Preshuffle;
@@ -414,7 +421,8 @@ int run_gemm_example_with_layouts(int argc,
kbatch,
n_warmup,
n_repeat,
persistent);
persistent,
bench_time_ms);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;

View File

@@ -147,18 +147,14 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
@@ -174,7 +170,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(