Merge pull request #1897 from ROCm/yewang12/ck_fav3_thread_local

make fmha_bwd_v3_kernel thread_local
This commit is contained in:
Dan Yao
2025-02-22 12:00:42 +08:00
committed by GitHub

View File

@@ -590,7 +590,7 @@ float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a)
a.mask_type,
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv}};
static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}
@@ -634,7 +634,7 @@ float fmha_bwd_v3_gen_(const ck_tile::stream_config& s, fmha_bwd_args a)
a.mask_type,
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv}};
static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}
@@ -677,7 +677,7 @@ float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a)
a.mask_type,
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv}};
static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
@@ -722,7 +722,7 @@ float fmha_bwd_v3_gen_(const ck_tile::stream_config& s, fmha_bwd_args a)
a.mask_type,
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv}};
static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},