mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Merge pull request #1897 from ROCm/yewang12/ck_fav3_thread_local
make fmha_bwd_v3_kernel thread_local
This commit is contained in:
@@ -590,7 +590,7 @@ float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
a.mask_type,
|
||||
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
|
||||
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv}};
|
||||
static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
|
||||
static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}
|
||||
@@ -634,7 +634,7 @@ float fmha_bwd_v3_gen_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
a.mask_type,
|
||||
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
|
||||
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv}};
|
||||
static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
|
||||
static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}
|
||||
@@ -677,7 +677,7 @@ float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
a.mask_type,
|
||||
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
|
||||
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv}};
|
||||
static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
|
||||
static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
|
||||
@@ -722,7 +722,7 @@ float fmha_bwd_v3_gen_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
a.mask_type,
|
||||
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
|
||||
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv}};
|
||||
static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
|
||||
static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
|
||||
|
||||
Reference in New Issue
Block a user