diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 9d87115676..23eb567d97 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") # to be included in "make all/install/check" message("adding example ${EXAMPLE_FMHA_BWD}") -add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_bf16_a16.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32.cpp hsaco/bwd_bf16_causal_a16.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32.cpp hsaco/bwd_bf16_nocoex_a32.cpp hsaco/bwd_bf16_nocoex_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_nocoex_a32.cpp hsaco/bwd_fp16_nocoex_causal_a32.cpp fmha_bwd.cpp) +add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL hsaco/bwd_bf16_a16.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32.cpp hsaco/bwd_bf16_causal_a16.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32.cpp hsaco/bwd_bf16_spec_a32.cpp hsaco/bwd_bf16_spec_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_spec_a32.cpp hsaco/bwd_fp16_spec_causal_a32.cpp fmha_bwd.cpp) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 0cfe797042..04a0e7b9cb 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -188,7 +188,7 @@ struct p2 unsigned int _p0; unsigned int _p1; }}; -struct __attribute__((packed)) fmha_bwd_asm_args +struct __attribute__((packed)) fmha_bwd_v3_args {{ void* ptr_dq; p2 _p0; @@ -224,7 +224,7 @@ struct __attribute__((packed)) fmha_bwd_asm_args p3 _p15; }}; -struct __attribute__((packed)) fmha_bwd_xqa_asm_args +struct __attribute__((packed)) fmha_bwd_xqa_v3_args {{ void* ptr_dq; p2 _p0; @@ -270,7 +270,7 @@ struct __attribute__((packed)) fmha_bwd_xqa_asm_args p3 _p20; }}; -struct fmha_bwd_ext_traits +struct fmha_bwd_v3_traits {{ int b; int h; @@ -283,17 +283,17 @@ struct fmha_bwd_ext_traits int ts_kv; }}; -class fmha_bwd_ext_kernel +class fmha_bwd_v3_kernel {{ public: - fmha_bwd_ext_kernel(const std::string& name, unsigned char buffer[]) + fmha_bwd_v3_kernel(const std::string& name, unsigned char buffer[]) {{ HIP_CALL(hipModuleLoadData(&module, buffer)); HIP_CALL(hipModuleGetFunction(&kernel_func, module, name.c_str())); }} void - launch_kernel(fmha_bwd_ext_traits fmha_ext_traits, fmha_bwd_asm_args args, const ck_tile::stream_config& s) const + launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_v3_args args, const ck_tile::stream_config& s) const {{ size_t arg_size = sizeof(args); void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER, @@ -303,12 +303,12 @@ class fmha_bwd_ext_kernel HIP_LAUNCH_PARAM_END}}; int bdx = 256; - int gdx = fmha_ext_traits.s / fmha_ext_traits.ts_kv; - int gdy = fmha_ext_traits.h; - int gdz = fmha_ext_traits.b; - if(fmha_ext_traits.mask > 0) + int gdx = fmha_v3_traits.s / fmha_v3_traits.ts_kv; + int gdy = fmha_v3_traits.h; + int gdz = fmha_v3_traits.b; + if(fmha_v3_traits.mask > 0) {{ - int num_tg = fmha_ext_traits.s / fmha_ext_traits.ts_kv; + int num_tg = fmha_v3_traits.s / fmha_v3_traits.ts_kv; gdx = (num_tg % 2) ? (num_tg / 2 + 1) : (num_tg / 2); }} HIP_CALL(hipModuleLaunchKernel(kernel_func, @@ -325,7 +325,7 @@ class fmha_bwd_ext_kernel }} void - launch_kernel(fmha_bwd_ext_traits fmha_ext_traits, fmha_bwd_xqa_asm_args args, const ck_tile::stream_config& s) const + launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_xqa_v3_args args, const ck_tile::stream_config& s) const {{ size_t arg_size = sizeof(args); void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER, @@ -335,12 +335,12 @@ class fmha_bwd_ext_kernel HIP_LAUNCH_PARAM_END}}; int bdx = 256; - int gdx = fmha_ext_traits.s / fmha_ext_traits.ts_kv; - int gdy = fmha_ext_traits.h; - int gdz = fmha_ext_traits.b; - if(fmha_ext_traits.mask > 0) + int gdx = fmha_v3_traits.s / fmha_v3_traits.ts_kv; + int gdy = fmha_v3_traits.h; + int gdz = fmha_v3_traits.b; + if(fmha_v3_traits.mask > 0) {{ - int num_tg = fmha_ext_traits.s / fmha_ext_traits.ts_kv; + int num_tg = fmha_v3_traits.s / fmha_v3_traits.ts_kv; gdx = (num_tg % 2) ? (num_tg / 2 + 1) : (num_tg / 2); }} HIP_CALL(hipModuleLaunchKernel(kernel_func, @@ -374,11 +374,11 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) }} template -float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name, bool io_perm) +float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_v3_buf[], const std::string& bwd_v3_name, bool io_perm) {{ if(s.log_level_ > 0) - std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << bwd_ext_name << std::flush; - fmha_bwd_asm_args args; + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << bwd_v3_name << std::flush; + fmha_bwd_v3_args args; args.ptr_dq = a.dq_ptr; args.ptr_dk = a.dk_ptr; args.ptr_dv = a.dv_ptr; @@ -406,15 +406,15 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned args.Hs = stride_head; args.BAs = stride_batch; args.Seqs = stride_seqlen; - auto traits = fmha_bwd_ext_traits{{a.batch, - a.nhead_q, - a.seqlen_q, - a.hdim_q, - 1, - a.mask_type, - 32, - 128}}; - fmha_bwd_ext_kernel impl(HSA_KERNEL, bwd_ext_asm); + auto traits = fmha_bwd_v3_traits{{a.batch, + a.nhead_q, + a.seqlen_q, + a.hdim_q, + 1, + a.mask_type, + 32, + 128}}; + static fmha_bwd_v3_kernel impl(HSA_KERNEL, bwd_v3_buf); // static here is for thread safety. return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, [=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }} @@ -422,11 +422,11 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned }} template -float fmha_ext_bwd_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name, bool io_perm) +float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_v3_buf[], const std::string& bwd_v3_name, bool io_perm) {{ if(s.log_level_ > 0) - std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << bwd_ext_name << std::flush; - fmha_bwd_xqa_asm_args args; + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << bwd_v3_name << std::flush; + fmha_bwd_xqa_v3_args args; args.ptr_dq = a.dq_ptr; args.ptr_dk = a.dk_ptr; args.ptr_dv = a.dv_ptr; @@ -469,15 +469,15 @@ float fmha_ext_bwd_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsig args.BAs_kv = stride_batch_kv; args.Seqs_kv = stride_seqlen_kv; args.Seqs_dkv = stride_seqlen_dkv; - auto traits = fmha_bwd_ext_traits{{a.batch, - a.nhead_q, - a.seqlen_q, - a.hdim_q, - 1, - a.mask_type, - 32, - 128}}; - fmha_bwd_ext_kernel impl(HSA_KERNEL, bwd_ext_asm); + auto traits = fmha_bwd_v3_traits{{a.batch, + a.nhead_q, + a.seqlen_q, + a.hdim_q, + 1, + a.mask_type, + 32, + 128}}; + static fmha_bwd_v3_kernel impl(HSA_KERNEL, bwd_v3_buf); // static here is for thread safety. return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, [=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }} @@ -485,11 +485,11 @@ float fmha_ext_bwd_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsig }} template -float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_ext_asm[], const std::string& bwd_ext_name, bool io_perm) +float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_v3_buf[], const std::string& bwd_v3_name, bool io_perm) {{ if(s.log_level_ > 0) - std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << bwd_ext_name << ", " << fmha_bwd_convert_dq_get_name_() << std::flush; - fmha_bwd_asm_args args; + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_() << std::flush; + fmha_bwd_v3_args args; args.ptr_dq = a.dq_acc_ptr; args.ptr_dk = a.dk_ptr; args.ptr_dv = a.dv_ptr; @@ -517,15 +517,15 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned args.Hs = stride_head; args.BAs = stride_batch; args.Seqs = stride_seqlen; - auto traits = fmha_bwd_ext_traits{{a.batch, - a.nhead_q, - a.seqlen_q, - a.hdim_q, - 1, - a.mask_type, - 32, - 128}}; - fmha_bwd_ext_kernel impl(HSA_KERNEL, bwd_ext_asm); + auto traits = fmha_bwd_v3_traits{{a.batch, + a.nhead_q, + a.seqlen_q, + a.hdim_q, + 1, + a.mask_type, + 32, + 128}}; + static fmha_bwd_v3_kernel impl(HSA_KERNEL, bwd_v3_buf); // static here is for thread safety. return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, [=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}, @@ -536,139 +536,139 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ float r = -1; - if (t.uses_ext_asm == true){{ + if (t.uses_bwd_v3 == true){{ if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false) && (a.stride_q == a.stride_o /*i_perm == o_perm*/)) {{ if(t.data_type.compare("fp16") == 0){{ if(t.mask_type == mask_enum::no_mask){{ - if((t.is_asm_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/) + if((t.is_v3_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ - if(t.is_asm_no_coex == true){{ + if(t.is_v3_spec == true){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; - const std::string bwd_ext_name = "bwd_ext_fp16_nocoex_a32"; + const std::string bwd_v3_name = "bwd_v3_fp16_spec_a32"; bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_ext_bwd_(s, a, bwd_fp16_nocoex_a32, bwd_ext_name, io_perm); + r = fmha_bwd_v3_(s, a, bwd_fp16_spec_a32, bwd_v3_name, io_perm); return r; }} else{{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; - const std::string bwd_ext_name = "bwd_ext_fp16_a32"; + const std::string bwd_v3_name = "bwd_v3_fp16_a32"; bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_ext_bwd_(s, a, bwd_fp16_a32, bwd_ext_name, io_perm); + r = fmha_bwd_v3_(s, a, bwd_fp16_a32, bwd_v3_name, io_perm); return r; }} }} - else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ + else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; - const std::string bwd_ext_name = "bwd_ext_fp16_a16"; + const std::string bwd_v3_name = "bwd_v3_fp16_a16"; bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_ext_bwd_xqa_(s, a, bwd_fp16_a16, bwd_ext_name, io_perm); + r = fmha_bwd_v3_xqa_(s, a, bwd_fp16_a16, bwd_v3_name, io_perm); return r; }} }} else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ - if((t.is_asm_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/) + if((t.is_v3_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ - if(t.is_asm_no_coex == true){{ + if(t.is_v3_spec == true){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; - const std::string bwd_ext_name = "bwd_ext_fp16_nocoex_causal_a32"; + const std::string bwd_v3_name = "bwd_v3_fp16_spec_causal_a32"; bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_ext_bwd_(s, a, bwd_fp16_nocoex_causal_a32, bwd_ext_name, io_perm); + r = fmha_bwd_v3_(s, a, bwd_fp16_spec_causal_a32, bwd_v3_name, io_perm); return r; }} else{{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; - const std::string bwd_ext_name = "bwd_ext_fp16_causal_a32"; + const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32"; bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_ext_bwd_(s, a, bwd_fp16_causal_a32, bwd_ext_name, io_perm); + r = fmha_bwd_v3_(s, a, bwd_fp16_causal_a32, bwd_v3_name, io_perm); return r; }} }} - else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ + else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; - const std::string bwd_ext_name = "bwd_ext_fp16_causal_a16"; + const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16"; bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_ext_bwd_xqa_(s, a, bwd_fp16_causal_a16, bwd_ext_name, io_perm); + r = fmha_bwd_v3_xqa_(s, a, bwd_fp16_causal_a16, bwd_v3_name, io_perm); return r; }} }} }} else if(t.data_type.compare("bf16") == 0){{ if(t.mask_type == mask_enum::no_mask){{ - if((t.is_asm_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/) + if((t.is_v3_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ - if(t.is_asm_no_coex == true){{ + if(t.is_v3_spec == true){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; - const std::string bwd_ext_name = "bwd_ext_bf16_nocoex_a32"; + const std::string bwd_v3_name = "bwd_v3_bf16_spec_a32"; bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_ext_bwd_(s, a, bwd_bf16_nocoex_a32, bwd_ext_name, io_perm); + r = fmha_bwd_v3_(s, a, bwd_bf16_spec_a32, bwd_v3_name, io_perm); return r; }} else{{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; - const std::string bwd_ext_name = "bwd_ext_bf16_a32"; + const std::string bwd_v3_name = "bwd_v3_bf16_a32"; bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_ext_bwd_(s, a, bwd_bf16_a32, bwd_ext_name, io_perm); + r = fmha_bwd_v3_(s, a, bwd_bf16_a32, bwd_v3_name, io_perm); return r; }} }} - else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ - if(t.is_asm_rtz_cvt == true){{ + else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ + if(t.is_v3_rtz_cvt == true){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; - const std::string bwd_ext_name = "bwd_ext_bf16_a16_rtz"; + const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz"; bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_ext_bwd_xqa_(s, a, bwd_bf16_a16_rtz, bwd_ext_name, io_perm); + r = fmha_bwd_v3_xqa_(s, a, bwd_bf16_a16_rtz, bwd_v3_name, io_perm); return r; }} else{{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; - const std::string bwd_ext_name = "bwd_ext_bf16_a16"; + const std::string bwd_v3_name = "bwd_v3_bf16_a16"; bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_ext_bwd_xqa_(s, a, bwd_bf16_a16, bwd_ext_name, io_perm); + r = fmha_bwd_v3_xqa_(s, a, bwd_bf16_a16, bwd_v3_name, io_perm); return r; }} }} }} else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ - if((t.is_asm_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/) + if((t.is_v3_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{ - if(t.is_asm_no_coex == true){{ + if(t.is_v3_spec == true){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; - const std::string bwd_ext_name = "bwd_ext_bf16_nocoex_causal_a32"; + const std::string bwd_v3_name = "bwd_v3_bf16_spec_causal_a32"; bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_ext_bwd_(s, a, bwd_bf16_nocoex_causal_a32, bwd_ext_name, io_perm); + r = fmha_bwd_v3_(s, a, bwd_bf16_spec_causal_a32, bwd_v3_name, io_perm); return r; }} else{{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; - const std::string bwd_ext_name = "bwd_ext_bf16_causal_a32"; + const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32"; bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_ext_bwd_(s, a, bwd_bf16_causal_a32, bwd_ext_name, io_perm); + r = fmha_bwd_v3_(s, a, bwd_bf16_causal_a32, bwd_v3_name, io_perm); return r; }} }} - else if((t.is_asm_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ - if(t.is_asm_rtz_cvt == true){{ + else if((t.is_v3_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{ + if(t.is_v3_rtz_cvt == true){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; - const std::string bwd_ext_name = "bwd_ext_bf16_causal_a16_rtz"; + const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz"; bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_ext_bwd_xqa_(s, a, bwd_bf16_causal_a16_rtz, bwd_ext_name, io_perm); + r = fmha_bwd_v3_xqa_(s, a, bwd_bf16_causal_a16_rtz, bwd_v3_name, io_perm); return r; }} else{{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; - const std::string bwd_ext_name = "bwd_ext_bf16_causal_a16"; + const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16"; bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_ext_bwd_xqa_(s, a, bwd_bf16_causal_a16, bwd_ext_name, io_perm); + r = fmha_bwd_v3_xqa_(s, a, bwd_bf16_causal_a16, bwd_v3_name, io_perm); return r; }} }} diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index 183072fc27..16249fe9a0 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -92,17 +92,17 @@ auto create_args(int argc, char* argv[]) "0", "if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion " "will not be used") - .insert("ext_asm", "0", "if set to 1, some cases will call the ext asm dqdkdv kernel") + .insert("bwd_v3", "0", "if set to 1, some cases will call the bwd v3 dqdkdv kernel") .insert( - "asm_atomic_fp32", + "v3_atomic_fp32", "1", - "if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) when ext_asm is set to 1") - .insert("asm_no_coex", + "if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) when bwd_v3 is set to 1") + .insert("v3_spec", "0", - "if set to 1 will use non-coexectuion kernel when ext_asm is set to 1") - .insert("asm_rtz_cvt", + "if set to 1 will call the specialized v3 kernel when bwd_v3 is set to 1") + .insert("v3_rtz_cvt", "0", - "if set to 1 will use float to bf16 RTZ convert when ext_asm is set to 1"); + "if set to 1 will use float to bf16 RTZ convert when bwd_v3 is set to 1"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -187,14 +187,14 @@ bool run(const ck_tile::ArgParser& arg_parser) seed.reset(); } - int stream_warmup = arg_parser.get_int("warmup"); - int stream_repeat = arg_parser.get_int("repeat"); - bool kname = arg_parser.get_bool("kname"); - bool deterministic = arg_parser.get_bool("deterministic"); - bool ext_asm = arg_parser.get_bool("ext_asm"); - bool asm_atomic_fp32 = arg_parser.get_bool("asm_atomic_fp32"); - bool asm_no_coex = arg_parser.get_bool("asm_no_coex"); - bool asm_rtz_cvt = arg_parser.get_bool("asm_rtz_cvt"); + int stream_warmup = arg_parser.get_int("warmup"); + int stream_repeat = arg_parser.get_int("repeat"); + bool kname = arg_parser.get_bool("kname"); + bool deterministic = arg_parser.get_bool("deterministic"); + bool bwd_v3 = arg_parser.get_bool("bwd_v3"); + bool v3_atomic_fp32 = arg_parser.get_bool("v3_atomic_fp32"); + bool v3_spec = arg_parser.get_bool("v3_spec"); + bool v3_rtz_cvt = arg_parser.get_bool("v3_rtz_cvt"); ck_tile::stream_config stream_config{nullptr, true, @@ -430,10 +430,10 @@ bool run(const ck_tile::ArgParser& arg_parser) p_drop > 0.0f, s_randval, deterministic, - ext_asm, - asm_atomic_fp32, - asm_no_coex, - asm_rtz_cvt}; + bwd_v3, + v3_atomic_fp32, + v3_spec, + v3_rtz_cvt}; auto fmha_args = [&]() { assert(nhead % nhead_k == 0); /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 0cadcb7b93..660ba2d47f 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -438,10 +438,10 @@ struct fmha_bwd_traits bool has_dropout; bool is_store_randval; bool is_deterministic; - bool uses_ext_asm; - bool is_asm_atomic_fp32; - bool is_asm_no_coex; - bool is_asm_rtz_cvt; + bool uses_bwd_v3; + bool is_v3_atomic_fp32; + bool is_v3_spec; + bool is_v3_rtz_cvt; // TODO: padding check is inside this api }; float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/hsaco/bwd_bf16_nocoex_a32.cpp b/example/ck_tile/01_fmha/hsaco/bwd_bf16_spec_a32.cpp similarity index 99% rename from example/ck_tile/01_fmha/hsaco/bwd_bf16_nocoex_a32.cpp rename to example/ck_tile/01_fmha/hsaco/bwd_bf16_spec_a32.cpp index 9864ef6893..ce90017bc1 100644 --- a/example/ck_tile/01_fmha/hsaco/bwd_bf16_nocoex_a32.cpp +++ b/example/ck_tile/01_fmha/hsaco/bwd_bf16_spec_a32.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "fmha_hsaco.hpp" -unsigned char bwd_bf16_nocoex_a32[] = { +unsigned char bwd_bf16_spec_a32[] = { 0x7F, 0x45, 0x4C, 0x46, 0x02, 0x01, 0x01, 0x40, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0xE0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xB0, 0x7D, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, diff --git a/example/ck_tile/01_fmha/hsaco/bwd_bf16_nocoex_causal_a32.cpp b/example/ck_tile/01_fmha/hsaco/bwd_bf16_spec_causal_a32.cpp similarity index 99% rename from example/ck_tile/01_fmha/hsaco/bwd_bf16_nocoex_causal_a32.cpp rename to example/ck_tile/01_fmha/hsaco/bwd_bf16_spec_causal_a32.cpp index ba7108583c..57f450d290 100644 --- a/example/ck_tile/01_fmha/hsaco/bwd_bf16_nocoex_causal_a32.cpp +++ b/example/ck_tile/01_fmha/hsaco/bwd_bf16_spec_causal_a32.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "fmha_hsaco.hpp" -unsigned char bwd_bf16_nocoex_causal_a32[] = { +unsigned char bwd_bf16_spec_causal_a32[] = { 0x7F, 0x45, 0x4C, 0x46, 0x02, 0x01, 0x01, 0x40, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0xE0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x85, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, diff --git a/example/ck_tile/01_fmha/hsaco/bwd_fp16_nocoex_a32.cpp b/example/ck_tile/01_fmha/hsaco/bwd_fp16_spec_a32.cpp similarity index 99% rename from example/ck_tile/01_fmha/hsaco/bwd_fp16_nocoex_a32.cpp rename to example/ck_tile/01_fmha/hsaco/bwd_fp16_spec_a32.cpp index 0b8af82a90..753e64c63e 100644 --- a/example/ck_tile/01_fmha/hsaco/bwd_fp16_nocoex_a32.cpp +++ b/example/ck_tile/01_fmha/hsaco/bwd_fp16_spec_a32.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "fmha_hsaco.hpp" -unsigned char bwd_fp16_nocoex_a32[] = { +unsigned char bwd_fp16_spec_a32[] = { 0x7F, 0x45, 0x4C, 0x46, 0x02, 0x01, 0x01, 0x40, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0xE0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x98, 0x5B, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, diff --git a/example/ck_tile/01_fmha/hsaco/bwd_fp16_nocoex_causal_a32.cpp b/example/ck_tile/01_fmha/hsaco/bwd_fp16_spec_causal_a32.cpp similarity index 99% rename from example/ck_tile/01_fmha/hsaco/bwd_fp16_nocoex_causal_a32.cpp rename to example/ck_tile/01_fmha/hsaco/bwd_fp16_spec_causal_a32.cpp index 55b1d6b533..36915cb1b7 100644 --- a/example/ck_tile/01_fmha/hsaco/bwd_fp16_nocoex_causal_a32.cpp +++ b/example/ck_tile/01_fmha/hsaco/bwd_fp16_spec_causal_a32.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "fmha_hsaco.hpp" -unsigned char bwd_fp16_nocoex_causal_a32[] = { +unsigned char bwd_fp16_spec_causal_a32[] = { 0x7F, 0x45, 0x4C, 0x46, 0x02, 0x01, 0x01, 0x40, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0xE0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xF0, 0x62, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, diff --git a/example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp b/example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp index 9a165169b2..71abf6f117 100644 --- a/example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp +++ b/example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp @@ -9,11 +9,11 @@ extern unsigned char bwd_bf16_a32[]; extern unsigned char bwd_bf16_causal_a16[]; extern unsigned char bwd_bf16_causal_a16_rtz[]; extern unsigned char bwd_bf16_causal_a32[]; -extern unsigned char bwd_bf16_nocoex_a32[]; -extern unsigned char bwd_bf16_nocoex_causal_a32[]; +extern unsigned char bwd_bf16_spec_a32[]; +extern unsigned char bwd_bf16_spec_causal_a32[]; extern unsigned char bwd_fp16_a16[]; extern unsigned char bwd_fp16_a32[]; extern unsigned char bwd_fp16_causal_a16[]; extern unsigned char bwd_fp16_causal_a32[]; -extern unsigned char bwd_fp16_nocoex_a32[]; -extern unsigned char bwd_fp16_nocoex_causal_a32[]; +extern unsigned char bwd_fp16_spec_a32[]; +extern unsigned char bwd_fp16_spec_causal_a32[]; diff --git a/example/ck_tile/01_fmha/script/benchmark_bwd_ext.sh b/example/ck_tile/01_fmha/script/benchmark_bwd_ext.sh index fc5ce8333b..f57db3a788 100644 --- a/example/ck_tile/01_fmha/script/benchmark_bwd_ext.sh +++ b/example/ck_tile/01_fmha/script/benchmark_bwd_ext.sh @@ -9,23 +9,23 @@ for hdim in 128 ; do nhead=$((2048 / $hdim)) # follow fav2 setup $EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -asm_atomic_fp32=0 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v=$VALID ; sleep 3 $EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -asm_atomic_fp32=0 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v=$VALID ; sleep 3 $EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -asm_atomic_fp32=0 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v=$VALID ; sleep 3 $EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -asm_atomic_fp32=0 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v=$VALID ; sleep 3 $EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -asm_atomic_fp32=0 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v=$VALID ; sleep 3 $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -ext_asm=1 -asm_atomic_fp32=0 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -bwd_v3=1 -v3_atomic_fp32=0 -v=$VALID ; sleep 3 done done diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd_ext.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd_ext.sh index 2e380993fd..3dfa19ef2c 100644 --- a/example/ck_tile/01_fmha/script/smoke_test_bwd_ext.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd_ext.sh @@ -11,12 +11,12 @@ set -x for prec in "fp16" "bf16" ; do for perm in 0 1 ; do for hdim in 128 ; do -for asm_atomic_fp32 in 0 1 ; do -for asm_no_coex in 0 1 ; do +for v3_atomic_fp32 in 0 1 ; do +for v3_spec in 0 1 ; do for mask in 0 1 ; do -$EXE -prec=$prec -b=4 -h=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=$asm_atomic_fp32 -asm_no_coex=$asm_no_coex -mode=0 -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=1 -h=3 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=$asm_atomic_fp32 -asm_no_coex=$asm_no_coex -mode=0 -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=4 -h=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=$v3_atomic_fp32 -v3_spec=$v3_spec -mode=0 -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=3 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=$v3_atomic_fp32 -v3_spec=$v3_spec -mode=0 -kname=$KNAME $COMMON_ARGS done done diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd_xqa_ext.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd_xqa_ext.sh index b753d60fc5..0d616864b6 100644 --- a/example/ck_tile/01_fmha/script/smoke_test_bwd_xqa_ext.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd_xqa_ext.sh @@ -13,8 +13,8 @@ for perm in 0 1 ; do for hdim in 128 ; do for mask in 0 1 ; do -$EXE -prec=$prec -b=2 -h=4 -h_k=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=0 -asm_rtz_cvt=1 -mode=0 -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=1 -h=3 -h_k=1 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=0 -asm_rtz_cvt=1 -mode=0 -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=4 -h_k=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=0 -v3_rtz_cvt=1 -mode=0 -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=3 -h_k=1 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=0 -v3_rtz_cvt=1 -mode=0 -kname=$KNAME $COMMON_ARGS done done