From 4517b3a8da34fc37c585d9931ebc01f6ce53cd38 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Fri, 7 Mar 2025 08:29:40 -0800 Subject: [PATCH] refactor ck-tile kernel launch (#1925) [ROCm/composable_kernel commit: 9e132eb77cceef03cf986c7ff9b140b8815f4e11] --- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 6 +-- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 4 +- include/ck_tile/host/kernel_launch.hpp | 52 ++++++++++++------- 3 files changed, 37 insertions(+), 25 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 6326a97f8e..677ccb5ee3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -170,9 +170,9 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) if(s.log_level_ > 0) std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << ", " << fmha_bwd_convert_dq_get_name_() << std::flush; return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); }} + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }} ); }} diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index ba555df88d..75305a1336 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -253,8 +253,8 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a << std::flush; return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); }} + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }}, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }} ); }} diff --git a/include/ck_tile/host/kernel_launch.hpp b/include/ck_tile/host/kernel_launch.hpp index 5c7bf12bfc..376027ec98 100644 --- a/include/ck_tile/host/kernel_launch.hpp +++ b/include/ck_tile/host/kernel_launch.hpp @@ -38,9 +38,20 @@ make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byt return [=](const stream_config& s) { kernel<<>>(args...); + return hipPeekAtLastError() == hipSuccess; }; } +template +CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... callables) +{ + // abort the sequence in case of intermediate error + if(!(callables(sc) && ...)) + { + HIP_CHECK_ERROR(hipGetLastError()); + } +} + // clang-format off /* * launch_kernel() @@ -69,38 +80,39 @@ make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byt **/ // clang-format on template -CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables) +CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callables) { - // clang-format off - if(!s.time_kernel_) { - (callables(s),...); HIP_CHECK_ERROR(hipGetLastError()); + if(!s.time_kernel_) + { + launch_and_check(s, std::forward(callables)...); return 0; } - if(s.is_gpu_timer_) { - gpu_timer timer {}; + auto time_launches = [&](auto timer) { // warmup - for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError()); + for(int i = 0; i < s.cold_niters_; i++) + { + launch_and_check(s, std::forward(callables)...); + } timer.start(s.stream_id_); - for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError()); + for(int i = 0; i < s.nrepeat_; i++) + { + launch_and_check(s, std::forward(callables)...); + } timer.stop(s.stream_id_); return timer.duration() / s.nrepeat_; + }; + + if(s.is_gpu_timer_) + { + return time_launches(gpu_timer{}); } - else { - cpu_timer timer {}; - - // warmup - for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError()); - - timer.start(s.stream_id_); - for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError()); - timer.stop(s.stream_id_); - - return timer.duration() / s.nrepeat_; + else + { + return time_launches(cpu_timer{}); } - // clang-format on } } // namespace ck_tile