mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
add api limit for top-left causal mask
This commit is contained in:
@@ -968,37 +968,39 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{
|
||||
if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
|
||||
@@ -1223,37 +1225,39 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{
|
||||
if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
@@ -1267,37 +1271,39 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{
|
||||
if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
@@ -1311,37 +1317,39 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{
|
||||
if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_psskddv";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
@@ -1433,21 +1441,23 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
}}
|
||||
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
|
||||
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
|
||||
@@ -1547,58 +1557,60 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
}}
|
||||
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
|
||||
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
|
||||
if(t.how_v3_bf16_cvt == 0){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{
|
||||
if(t.how_v3_bf16_cvt == 0){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
|
||||
Reference in New Issue
Block a user