diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index ed70446cad..36c49d6ad8 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -189,7 +189,7 @@ struct p2 unsigned int _p1; }}; -struct __attribute__((packed)) fmha_bwd_xqa_v3_args +struct __attribute__((packed)) fmha_bwd_v3_args {{ void* ptr_dq; p2 _p0; @@ -235,7 +235,7 @@ struct __attribute__((packed)) fmha_bwd_xqa_v3_args p3 _p20; }}; -struct __attribute__((packed)) fmha_bwd_xqa_v3_dp_args +struct __attribute__((packed)) fmha_bwd_v3_gen_args {{ void* ptr_dq; p2 _p0; @@ -474,7 +474,7 @@ class fmha_bwd_v3_kernel }} void - launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_xqa_v3_args args, const ck_tile::stream_config& s) const + launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_v3_args args, const ck_tile::stream_config& s) const {{ size_t arg_size = sizeof(args); void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER, @@ -506,7 +506,7 @@ class fmha_bwd_v3_kernel }} void - launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_xqa_v3_dp_args args, const ck_tile::stream_config& s) const + launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_v3_gen_args args, const ck_tile::stream_config& s) const {{ size_t arg_size = sizeof(args); void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER, @@ -555,11 +555,11 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) }} template -float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a) +float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ if(s.log_level_ > 0) std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << FmhaBwdV3Name::bwd_v3_name << std::flush; - fmha_bwd_xqa_v3_args args; + fmha_bwd_v3_args args; args.ptr_dq = a.dq_ptr; args.ptr_dk = a.dk_ptr; args.ptr_dv = a.dv_ptr; @@ -598,11 +598,11 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a) }} template -float fmha_bwd_v3_hdp_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a) +float fmha_bwd_v3_gen_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ if(s.log_level_ > 0) std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << FmhaBwdV3Name::bwd_v3_name << std::flush; - fmha_bwd_xqa_v3_dp_args args; + fmha_bwd_v3_gen_args args; args.ptr_dq = a.dq_ptr; args.ptr_dk = a.dk_ptr; args.ptr_dv = a.dv_ptr; @@ -642,11 +642,11 @@ float fmha_bwd_v3_hdp_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a) }} template -float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a) +float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ if(s.log_level_ > 0) std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << FmhaBwdV3Name::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_() << std::flush; - fmha_bwd_xqa_v3_args args; + fmha_bwd_v3_args args; args.ptr_dq = a.dq_acc_ptr; args.ptr_dk = a.dk_ptr; args.ptr_dv = a.dv_ptr; @@ -686,11 +686,11 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a) }} template -float fmha_bwd_v3_hdp_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a) +float fmha_bwd_v3_gen_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ if(s.log_level_ > 0) std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << FmhaBwdV3Name::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_() << std::flush; - fmha_bwd_xqa_v3_dp_args args; + fmha_bwd_v3_gen_args args; args.ptr_dq = a.dq_acc_ptr; args.ptr_dk = a.dk_ptr; args.ptr_dv = a.dv_ptr; @@ -747,7 +747,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_fp16_a32"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else{{ @@ -755,7 +755,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>; // const std::string bwd_v3_name = "bwd_v3_fp16_a32_pddv"; - r = fmha_bwd_v3_hdp_xqa_(s, a); + r = fmha_bwd_v3_gen_(s, a); return r; }} }} @@ -764,14 +764,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false>; // const std::string bwd_v3_name = "bwd_v3_fp16_a16"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else{{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true>; // const std::string bwd_v3_name = "bwd_v3_fp16_a16_pddv"; - r = fmha_bwd_v3_hdp_xqa_(s, a); + r = fmha_bwd_v3_gen_(s, a); return r; }} }} @@ -783,7 +783,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else{{ @@ -791,7 +791,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>; // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_pddv"; - r = fmha_bwd_v3_hdp_xqa_(s, a); + r = fmha_bwd_v3_gen_(s, a); return r; }} }} @@ -800,14 +800,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false>; // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else{{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true>; // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16_pddv"; - r = fmha_bwd_v3_hdp_xqa_(s, a); + r = fmha_bwd_v3_gen_(s, a); return r; }} }} @@ -822,7 +822,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else{{ @@ -830,7 +830,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne_pddv"; - r = fmha_bwd_v3_hdp_xqa_(s, a); + r = fmha_bwd_v3_gen_(s, a); return r; }} }} @@ -840,7 +840,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else{{ @@ -848,7 +848,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna_pddv"; - r = fmha_bwd_v3_hdp_xqa_(s, a); + r = fmha_bwd_v3_gen_(s, a); return r; }} }} @@ -858,7 +858,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else{{ @@ -866,7 +866,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz_pddv"; - r = fmha_bwd_v3_hdp_xqa_(s, a); + r = fmha_bwd_v3_gen_(s, a); return r; }} }} @@ -877,14 +877,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else{{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true>; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne_pddv"; - r = fmha_bwd_v3_hdp_xqa_(s, a); + r = fmha_bwd_v3_gen_(s, a); return r; }} }} @@ -893,14 +893,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else{{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, true>; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna_pddv"; - r = fmha_bwd_v3_hdp_xqa_(s, a); + r = fmha_bwd_v3_gen_(s, a); return r; }} }} @@ -909,14 +909,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else{{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, true>; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz_pddv"; - r = fmha_bwd_v3_hdp_xqa_(s, a); + r = fmha_bwd_v3_gen_(s, a); return r; }} }} @@ -930,7 +930,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else{{ @@ -938,7 +938,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_pddv"; - r = fmha_bwd_v3_hdp_xqa_(s, a); + r = fmha_bwd_v3_gen_(s, a); return r; }} }} @@ -948,7 +948,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else{{ @@ -956,7 +956,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_pddv"; - r = fmha_bwd_v3_hdp_xqa_(s, a); + r = fmha_bwd_v3_gen_(s, a); return r; }} }} @@ -966,7 +966,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else{{ @@ -974,7 +974,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_pddv"; - r = fmha_bwd_v3_hdp_xqa_(s, a); + r = fmha_bwd_v3_gen_(s, a); return r; }} }} @@ -985,14 +985,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else{{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true>; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne_pddv"; - r = fmha_bwd_v3_hdp_xqa_(s, a); + r = fmha_bwd_v3_gen_(s, a); return r; }} }} @@ -1001,14 +1001,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else{{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, true>; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna_pddv"; - r = fmha_bwd_v3_hdp_xqa_(s, a); + r = fmha_bwd_v3_gen_(s, a); return r; }} }} @@ -1017,14 +1017,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else{{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, true>; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz_pddv"; - r = fmha_bwd_v3_hdp_xqa_(s, a); + r = fmha_bwd_v3_gen_(s, a); return r; }} }} @@ -1040,14 +1040,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else if(t.is_v3_atomic_fp32 == false){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, false, 0, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a16"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} }} @@ -1057,14 +1057,14 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else if(t.is_v3_atomic_fp32 == false){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, false, 0, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a16"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} }} @@ -1077,7 +1077,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 1){{ @@ -1085,7 +1085,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 2){{ @@ -1093,7 +1093,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} }} @@ -1102,21 +1102,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 0, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 1){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 1, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 2){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 2, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} }} @@ -1128,7 +1128,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 1){{ @@ -1136,7 +1136,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 2){{ @@ -1144,7 +1144,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} }} @@ -1153,21 +1153,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 0, false>; const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 1){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 1, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 2){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 2, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz"; - r = fmha_bwd_v3_xqa_(s, a); + r = fmha_bwd_v3_(s, a); return r; }} }}