refactor ck-tile kernel launch (#1925)

[ROCm/composable_kernel commit: 9e132eb77c]
This commit is contained in:
Max Podkorytov
2025-03-07 08:29:40 -08:00
committed by GitHub
parent 7771db8ecf
commit 4517b3a8da
3 changed files with 37 additions and 25 deletions

View File

@@ -170,9 +170,9 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); return hipPeekAtLastError() == hipSuccess; }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); return hipPeekAtLastError() == hipSuccess; }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); return hipPeekAtLastError() == hipSuccess; }}
);
}}

View File

@@ -253,8 +253,8 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a
<< std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); return hipPeekAtLastError() == hipSuccess; }},
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); return hipPeekAtLastError() == hipSuccess; }}
);
}}

View File

@@ -38,9 +38,20 @@ make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byt
return [=](const stream_config& s) {
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
return hipPeekAtLastError() == hipSuccess;
};
}
template <typename... Callables>
CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... callables)
{
// abort the sequence in case of intermediate error
if(!(callables(sc) && ...))
{
HIP_CHECK_ERROR(hipGetLastError());
}
}
// clang-format off
/*
* launch_kernel()
@@ -69,38 +80,39 @@ make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byt
**/
// clang-format on
template <typename... Callables>
CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callables)
{
// clang-format off
if(!s.time_kernel_) {
(callables(s),...); HIP_CHECK_ERROR(hipGetLastError());
if(!s.time_kernel_)
{
launch_and_check(s, std::forward<Callables>(callables)...);
return 0;
}
if(s.is_gpu_timer_) {
gpu_timer timer {};
auto time_launches = [&](auto timer) {
// warmup
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
for(int i = 0; i < s.cold_niters_; i++)
{
launch_and_check(s, std::forward<Callables>(callables)...);
}
timer.start(s.stream_id_);
for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
for(int i = 0; i < s.nrepeat_; i++)
{
launch_and_check(s, std::forward<Callables>(callables)...);
}
timer.stop(s.stream_id_);
return timer.duration() / s.nrepeat_;
};
if(s.is_gpu_timer_)
{
return time_launches(gpu_timer{});
}
else {
cpu_timer timer {};
// warmup
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
timer.start(s.stream_id_);
for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
timer.stop(s.stream_id_);
return timer.duration() / s.nrepeat_;
else
{
return time_launches(cpu_timer{});
}
// clang-format on
}
} // namespace ck_tile