addressing review comments

This commit is contained in:
khuagarw
2025-07-31 05:10:15 +00:00
parent 6437788c6c
commit 8931ac18bd
4 changed files with 31 additions and 43 deletions

35
include/ck_tile/host/kernel_launch.hpp Executable file → Normal file
View File

@@ -66,20 +66,21 @@ CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... calla
}
// Measure the preprocess time during the cold iterations
template <typename TimerType>
template <typename TimerType, typename PreprocessFunc>
CK_TILE_HOST double
preprocess_profiling_impl(TimerType timer, const stream_config& s, std::function<void()> preprocess)
preprocess_profiling_impl(TimerType timer, const stream_config& s, PreprocessFunc preprocess)
{
timer.start(s.stream_id_);
auto iter = max(s.cold_niters_, s.nrepeat_);
for(int i = 0; i < iter; i++)
for(int i = 0; i < s.nrepeat_; i++)
{
if(preprocess)
if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
{
preprocess();
}
}
timer.stop(s.stream_id_);
return timer.duration() / iter;
return timer.duration() / s.nrepeat_;
}
template <typename PreprocessFunc>
@@ -88,27 +89,31 @@ CK_TILE_HOST float preprocess_profiling(const stream_config& s, PreprocessFunc p
return preprocess_profiling_impl(gpu_timer{}, s, preprocess);
}
template <typename TimerType, typename CallablesFunc>
template <typename TimerType, typename CallablesFunc, typename PreprocessFunc = std::nullptr_t>
CK_TILE_HOST double timing_loop_impl(TimerType timer,
const stream_config& s,
CallablesFunc&& callables_func,
std::function<void()> preprocess = nullptr)
PreprocessFunc preprocess = nullptr)
{
for(int i = 0; i < s.cold_niters_; i++)
{
callables_func();
}
timer.start(s.stream_id_);
// Only profile preprocess if it's provided
auto preprocess_time = 0.0;
if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
{
preprocess_time = preprocess_profiling_impl(gpu_timer{}, s, preprocess);
}
auto profile_time = preprocess_profiling(s, preprocess);
std::vector<float> times;
int i = 0;
timer.start(s.stream_id_);
while(i < s.nrepeat_)
{
if(preprocess)
if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
{
preprocess();
}
callables_func();
i++;
@@ -117,7 +122,7 @@ CK_TILE_HOST double timing_loop_impl(TimerType timer,
if(!i)
return 0.;
return (timer.duration() / s.nrepeat_) - profile_time;
return (timer.duration() / s.nrepeat_) - preprocess_time;
}
// clang-format off
@@ -187,7 +192,7 @@ launch_kernel_time_mask(const stream_config& s, PreprocessFunc preprocess, Calla
if(s.is_gpu_timer_)
{
return timing_loop_impl(gpu_timer{}, s, callables_func);
return timing_loop_impl(gpu_timer{}, s, callables_func, preprocess);
}
else
{