mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
refactor ck-tile kernel launch (#1925)
[ROCm/composable_kernel commit: 9e132eb77c]
This commit is contained in:
@@ -170,9 +170,9 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); return hipPeekAtLastError() == hipSuccess; }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); return hipPeekAtLastError() == hipSuccess; }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); return hipPeekAtLastError() == hipSuccess; }}
|
||||
);
|
||||
}}
|
||||
|
||||
|
||||
@@ -253,8 +253,8 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a
|
||||
<< std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); return hipPeekAtLastError() == hipSuccess; }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); return hipPeekAtLastError() == hipSuccess; }}
|
||||
);
|
||||
}}
|
||||
|
||||
|
||||
@@ -38,9 +38,20 @@ make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byt
|
||||
|
||||
return [=](const stream_config& s) {
|
||||
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
|
||||
return hipPeekAtLastError() == hipSuccess;
|
||||
};
|
||||
}
|
||||
|
||||
template <typename... Callables>
|
||||
CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... callables)
|
||||
{
|
||||
// abort the sequence in case of intermediate error
|
||||
if(!(callables(sc) && ...))
|
||||
{
|
||||
HIP_CHECK_ERROR(hipGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
/*
|
||||
* launch_kernel()
|
||||
@@ -69,38 +80,39 @@ make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byt
|
||||
**/
|
||||
// clang-format on
|
||||
template <typename... Callables>
|
||||
CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
|
||||
CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callables)
|
||||
{
|
||||
// clang-format off
|
||||
if(!s.time_kernel_) {
|
||||
(callables(s),...); HIP_CHECK_ERROR(hipGetLastError());
|
||||
if(!s.time_kernel_)
|
||||
{
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
return 0;
|
||||
}
|
||||
if(s.is_gpu_timer_) {
|
||||
gpu_timer timer {};
|
||||
|
||||
auto time_launches = [&](auto timer) {
|
||||
// warmup
|
||||
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
|
||||
for(int i = 0; i < s.cold_niters_; i++)
|
||||
{
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
}
|
||||
|
||||
timer.start(s.stream_id_);
|
||||
for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
|
||||
for(int i = 0; i < s.nrepeat_; i++)
|
||||
{
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
}
|
||||
timer.stop(s.stream_id_);
|
||||
|
||||
return timer.duration() / s.nrepeat_;
|
||||
};
|
||||
|
||||
if(s.is_gpu_timer_)
|
||||
{
|
||||
return time_launches(gpu_timer{});
|
||||
}
|
||||
else {
|
||||
cpu_timer timer {};
|
||||
|
||||
// warmup
|
||||
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
|
||||
|
||||
timer.start(s.stream_id_);
|
||||
for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
|
||||
timer.stop(s.stream_id_);
|
||||
|
||||
return timer.duration() / s.nrepeat_;
|
||||
else
|
||||
{
|
||||
return time_launches(cpu_timer{});
|
||||
}
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user