mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
refactor ck-tile kernel launch (#1925)
This commit is contained in:
@@ -38,9 +38,20 @@ make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byt
|
||||
|
||||
return [=](const stream_config& s) {
|
||||
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
|
||||
return hipPeekAtLastError() == hipSuccess;
|
||||
};
|
||||
}
|
||||
|
||||
template <typename... Callables>
|
||||
CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... callables)
|
||||
{
|
||||
// abort the sequence in case of intermediate error
|
||||
if(!(callables(sc) && ...))
|
||||
{
|
||||
HIP_CHECK_ERROR(hipGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
/*
|
||||
* launch_kernel()
|
||||
@@ -69,38 +80,39 @@ make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byt
|
||||
**/
|
||||
// clang-format on
|
||||
template <typename... Callables>
|
||||
CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
|
||||
CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callables)
|
||||
{
|
||||
// clang-format off
|
||||
if(!s.time_kernel_) {
|
||||
(callables(s),...); HIP_CHECK_ERROR(hipGetLastError());
|
||||
if(!s.time_kernel_)
|
||||
{
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
return 0;
|
||||
}
|
||||
if(s.is_gpu_timer_) {
|
||||
gpu_timer timer {};
|
||||
|
||||
auto time_launches = [&](auto timer) {
|
||||
// warmup
|
||||
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
|
||||
for(int i = 0; i < s.cold_niters_; i++)
|
||||
{
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
}
|
||||
|
||||
timer.start(s.stream_id_);
|
||||
for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
|
||||
for(int i = 0; i < s.nrepeat_; i++)
|
||||
{
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
}
|
||||
timer.stop(s.stream_id_);
|
||||
|
||||
return timer.duration() / s.nrepeat_;
|
||||
};
|
||||
|
||||
if(s.is_gpu_timer_)
|
||||
{
|
||||
return time_launches(gpu_timer{});
|
||||
}
|
||||
else {
|
||||
cpu_timer timer {};
|
||||
|
||||
// warmup
|
||||
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
|
||||
|
||||
timer.start(s.stream_id_);
|
||||
for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
|
||||
timer.stop(s.stream_id_);
|
||||
|
||||
return timer.duration() / s.nrepeat_;
|
||||
else
|
||||
{
|
||||
return time_launches(cpu_timer{});
|
||||
}
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user