diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 6326a97f8e..677ccb5ee3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -170,9 +170,9 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) if(s.log_level_ > 0) std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << ", " << fmha_bwd_convert_dq_get_name_() << std::flush; return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); }} + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }} ); }} diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index ba555df88d..75305a1336 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -253,8 +253,8 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a << std::flush; return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); }} + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }}, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }} ); }} diff --git a/include/ck_tile/host/kernel_launch.hpp b/include/ck_tile/host/kernel_launch.hpp index 5c7bf12bfc..376027ec98 100644 --- a/include/ck_tile/host/kernel_launch.hpp +++ b/include/ck_tile/host/kernel_launch.hpp @@ -38,9 +38,20 @@ make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byt return [=](const stream_config& s) { kernel<<>>(args...); + return hipPeekAtLastError() == hipSuccess; }; } +template +CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... callables) +{ + // abort the sequence in case of intermediate error + if(!(callables(sc) && ...)) + { + HIP_CHECK_ERROR(hipGetLastError()); + } +} + // clang-format off /* * launch_kernel() @@ -69,38 +80,39 @@ make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byt **/ // clang-format on template -CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables) +CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callables) { - // clang-format off - if(!s.time_kernel_) { - (callables(s),...); HIP_CHECK_ERROR(hipGetLastError()); + if(!s.time_kernel_) + { + launch_and_check(s, std::forward(callables)...); return 0; } - if(s.is_gpu_timer_) { - gpu_timer timer {}; + auto time_launches = [&](auto timer) { // warmup - for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError()); + for(int i = 0; i < s.cold_niters_; i++) + { + launch_and_check(s, std::forward(callables)...); + } timer.start(s.stream_id_); - for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError()); + for(int i = 0; i < s.nrepeat_; i++) + { + launch_and_check(s, std::forward(callables)...); + } timer.stop(s.stream_id_); return timer.duration() / s.nrepeat_; + }; + + if(s.is_gpu_timer_) + { + return time_launches(gpu_timer{}); } - else { - cpu_timer timer {}; - - // warmup - for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError()); - - timer.start(s.stream_id_); - for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError()); - timer.stop(s.stream_id_); - - return timer.duration() / s.nrepeat_; + else + { + return time_launches(cpu_timer{}); } - // clang-format on } } // namespace ck_tile