diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 6a1c0e938d..11bcc0f570 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -590,7 +590,7 @@ float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a) a.mask_type, FmhaBwdV3Ts::ts_qo, FmhaBwdV3Ts::ts_kv}}; - static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf::bwd_v3_buf); // static here is for thread safety. + static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf::bwd_v3_buf); // static here is for thread safety. return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, [=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }} @@ -634,7 +634,7 @@ float fmha_bwd_v3_gen_(const ck_tile::stream_config& s, fmha_bwd_args a) a.mask_type, FmhaBwdV3Ts::ts_qo, FmhaBwdV3Ts::ts_kv}}; - static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf::bwd_v3_buf); // static here is for thread safety. + static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf::bwd_v3_buf); // static here is for thread safety. return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, [=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }} @@ -677,7 +677,7 @@ float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a) a.mask_type, FmhaBwdV3Ts::ts_qo, FmhaBwdV3Ts::ts_kv}}; - static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf::bwd_v3_buf); // static here is for thread safety. + static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf::bwd_v3_buf); // static here is for thread safety. return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, [=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}, @@ -722,7 +722,7 @@ float fmha_bwd_v3_gen_(const ck_tile::stream_config& s, fmha_bwd_args a) a.mask_type, FmhaBwdV3Ts::ts_qo, FmhaBwdV3Ts::ts_kv}}; - static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf::bwd_v3_buf); // static here is for thread safety. + static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf::bwd_v3_buf); // static here is for thread safety. return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, [=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},