mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 13:29:20 +00:00
Add an option to change the number of warm-up cycles and iterations. (#1124)
* allow setting the number of warmup cycles and iterations for profiler
* fix the gemm_splitk and grouped_gemm examples
[ROCm/composable_kernel commit: 886d9eeb99]
This commit is contained in:
@@ -42,7 +42,9 @@ int profile_gemm_impl(int do_verification,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideC)
|
||||
int StrideC,
|
||||
int n_warmup,
|
||||
int n_iter)
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
@@ -165,8 +167,8 @@ int profile_gemm_impl(int do_verification,
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
float avg_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, 10, 50});
|
||||
float avg_time = invoker_ptr->Run(
|
||||
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
|
||||
|
||||
@@ -42,7 +42,9 @@ bool profile_gemm_splitk_impl(int do_verification,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideC,
|
||||
int KBatch)
|
||||
int KBatch,
|
||||
int n_warmup,
|
||||
int n_iter)
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
@@ -177,7 +179,8 @@ bool profile_gemm_splitk_impl(int do_verification,
|
||||
// re-init C to zero before profiling next kernel
|
||||
c_device_buf.SetZero();
|
||||
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
invoker_ptr->Run(argument_ptr.get(),
|
||||
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
@@ -200,8 +203,8 @@ bool profile_gemm_splitk_impl(int do_verification,
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
float ave_time = invoker_ptr->Run(
|
||||
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
|
||||
|
||||
@@ -42,7 +42,9 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs,
|
||||
int kbatch = 1)
|
||||
int kbatch = 1,
|
||||
int n_warmup = 1,
|
||||
int n_iter = 10)
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
@@ -261,7 +263,8 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
c_device_buf[i]->SetZero();
|
||||
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
invoker_ptr->Run(argument_ptr.get(),
|
||||
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
@@ -307,8 +310,8 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
pass = pass && instance_pass;
|
||||
}
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
float ave_time = invoker_ptr->Run(
|
||||
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user