diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 3e6b744712..9c3a144312 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -968,37 +968,39 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& r = fmha_bwd_v3_(s, a); return r; }} - else if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_psskddv"; - r = fmha_bwd_v3_genl_(s, a); - return r; - }} - else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>; - // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_psskddv"; - r = fmha_bwd_v3_genl_(s, a); - return r; - }} - else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>; - // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_psskddv"; - r = fmha_bwd_v3_genl_(s, a); - return r; - }} - else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false>; - // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_psskddv"; - r = fmha_bwd_v3_genl_(s, a); - return r; + if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{ + if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>; + // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_psskddv"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} + else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>; + // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_psskddv"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} + else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>; + // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_psskddv"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} + else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false>; + // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_psskddv"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} }} }} else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) && @@ -1223,37 +1225,39 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& r = fmha_bwd_v3_(s, a); return r; }} - else if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_psskddv"; - r = fmha_bwd_v3_genl_(s, a); - return r; - }} - else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_psskddv"; - r = fmha_bwd_v3_genl_(s, a); - return r; - }} - else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_psskddv"; - r = fmha_bwd_v3_genl_(s, a); - return r; - }} - else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_psskddv"; - r = fmha_bwd_v3_genl_(s, a); - return r; + if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{ + if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; + // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_psskddv"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} + else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>; + // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_psskddv"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} + else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; + // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_psskddv"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} + else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>; + // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_psskddv"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} }} }} else if(t.how_v3_bf16_cvt == 1){{ @@ -1267,37 +1271,39 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& r = fmha_bwd_v3_(s, a); return r; }} - else if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_psskddv"; - r = fmha_bwd_v3_genl_(s, a); - return r; - }} - else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_psskddv"; - r = fmha_bwd_v3_genl_(s, a); - return r; - }} - else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_psskddv"; - r = fmha_bwd_v3_genl_(s, a); - return r; - }} - else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_psskddv"; - r = fmha_bwd_v3_genl_(s, a); - return r; + if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{ + if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; + // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_psskddv"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} + else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>; + // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_psskddv"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} + else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; + // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_psskddv"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} + else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>; + // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_psskddv"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} }} }} else if(t.how_v3_bf16_cvt == 2){{ @@ -1311,37 +1317,39 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& r = fmha_bwd_v3_(s, a); return r; }} - else if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_psskddv"; - r = fmha_bwd_v3_genl_(s, a); - return r; - }} - else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_psskddv"; - r = fmha_bwd_v3_genl_(s, a); - return r; - }} - else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_psskddv"; - r = fmha_bwd_v3_genl_(s, a); - return r; - }} - else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_psskddv"; - r = fmha_bwd_v3_genl_(s, a); - return r; + if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{ + if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; + // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_psskddv"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} + else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>; + // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_psskddv"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} + else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; + // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_psskddv"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} + else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>; + // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_psskddv"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} }} }} }} @@ -1433,21 +1441,23 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& }} else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{ - if(a.seqlen_q % 64 == 0){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk"; - r = fmha_bwd_v3_genl_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk"; - r = fmha_bwd_v3_genl_(s, a); - return r; + if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{ + if(a.seqlen_q % 64 == 0){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>; + // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} + else{{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false>; + // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} }} }} else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) && @@ -1547,58 +1557,60 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& }} else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{ - if(t.how_v3_bf16_cvt == 0){{ - if(a.seqlen_q % 64 == 0){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk"; - r = fmha_bwd_v3_genl_(s, a); - return r; + if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{ + if(t.how_v3_bf16_cvt == 0){{ + if(a.seqlen_q % 64 == 0){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; + // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} + else{{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>; + // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk"; - r = fmha_bwd_v3_genl_(s, a); - return r; + else if(t.how_v3_bf16_cvt == 1){{ + if(a.seqlen_q % 64 == 0){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; + // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} + else{{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>; + // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} }} - }} - else if(t.how_v3_bf16_cvt == 1){{ - if(a.seqlen_q % 64 == 0){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk"; - r = fmha_bwd_v3_genl_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk"; - r = fmha_bwd_v3_genl_(s, a); - return r; - }} - }} - else if(t.how_v3_bf16_cvt == 2){{ - if(a.seqlen_q % 64 == 0){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk"; - r = fmha_bwd_v3_genl_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk"; - r = fmha_bwd_v3_genl_(s, a); - return r; + else if(t.how_v3_bf16_cvt == 2){{ + if(a.seqlen_q % 64 == 0){{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; + // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} + else{{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>; + // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk"; + r = fmha_bwd_v3_genl_(s, a); + return r; + }} }} }} }}