From 216a6c2518fe23e45157dc6e0be6f44154b6f0d0 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Tue, 18 Feb 2025 10:45:09 -0600 Subject: [PATCH] make fmha_bwd_v3_kernel thread_local In Jax TE, multiple threads in the same process are spawn to train for each GPU. Therefore hipModueLoadData, hipGetFunction need to be run for each GPU in each corresponding threads. --- example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 36c49d6ad8..77eccb374d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -590,7 +590,7 @@ float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a) a.mask_type, FmhaBwdV3Ts::ts_qo, FmhaBwdV3Ts::ts_kv}}; - static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf::bwd_v3_buf); // static here is for thread safety. + static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf::bwd_v3_buf); // static here is for thread safety. return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, [=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }} @@ -634,7 +634,7 @@ float fmha_bwd_v3_gen_(const ck_tile::stream_config& s, fmha_bwd_args a) a.mask_type, FmhaBwdV3Ts::ts_qo, FmhaBwdV3Ts::ts_kv}}; - static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf::bwd_v3_buf); // static here is for thread safety. + static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf::bwd_v3_buf); // static here is for thread safety. return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, [=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }} @@ -677,7 +677,7 @@ float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a) a.mask_type, FmhaBwdV3Ts::ts_qo, FmhaBwdV3Ts::ts_kv}}; - static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf::bwd_v3_buf); // static here is for thread safety. + static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf::bwd_v3_buf); // static here is for thread safety. return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, [=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}, @@ -722,7 +722,7 @@ float fmha_bwd_v3_gen_(const ck_tile::stream_config& s, fmha_bwd_args a) a.mask_type, FmhaBwdV3Ts::ts_qo, FmhaBwdV3Ts::ts_kv}}; - static fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf::bwd_v3_buf); // static here is for thread safety. + static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf::bwd_v3_buf); // static here is for thread safety. return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, [=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}, @@ -1800,4 +1800,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") \ No newline at end of file + f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")