enable fa v3 bwd codegen

This commit is contained in:
Jim
2025-02-27 19:32:10 +08:00
parent 32760acdbe
commit c4dfe25bcc
2 changed files with 215 additions and 609 deletions

View File

@@ -126,3 +126,42 @@ BOOL_MAP = {
"t" : "true",
"f" : "false"
}
BWD_V3_HDIM_MAP = {
"64": "64",
"128": "128"
}
BF16_CVT_MAP = {
0 : "rtne",
1 : "rtna",
2 : "rtz",
}
BWD_V3_MASK_MAP = {
"t": "((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0)))",
"f": "(t.mask_type == mask_enum::no_mask)"
}
BWD_V3_ATOMIC32_MAP = {
"t": "((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/))",
"f": "(t.is_v3_atomic_fp32 == false)"
}
BWD_V3_HDIM_CASE_MAP = {
0: "(a.hdim_q == 128)",
1: "(a.hdim_q == 64)",
2: "((a.hdim_q > 64) && (a.hdim_q < 128))"
}
BWD_V3_HDIM_CASE_CHECK_MAP = {
0: 128,
1: 64,
2: 128
}
BWD_V3_PADDING_CHECK_MAP = {
0: "false",
1: "false",
2: "true"
}

View File

@@ -161,6 +161,12 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
"""
FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
FMHA_BWD_V3_TEMPLATE="""template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr const char * bwd_v3_name = "bwd_v3{F_hdim_name}{F_dtype_name}{F_causal_name}{F_atomic_name}{F_bf16_cvt_name}{F_hdpad_name}"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr unsigned char * bwd_v3_buf = bwd{F_hdim_name}{F_dtype_name}{F_causal_name}{F_atomic_name}{F_bf16_cvt_name}{F_hdpad_name}; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic}, {F_bf16_cvt}, {F_hdpad}>> {{ static constexpr int ts_qo = {F_Ts_qo}; static constexpr int ts_kv = 192; }};
"""
FMHA_BWD_API="""
#include <iostream>
#include "hsaco/fmha_hsaco.hpp"
@@ -312,157 +318,9 @@ struct fmha_bwd_dq_dk_dv_v3_traits_
}};
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Name;
// ########################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsHDPad|
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_a16"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_causal_a16"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_causal_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtne_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtna_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtz_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtne_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtna_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtz_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_a16_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_a32_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_causal_a16_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_causal_a32_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 1, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 2, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 1, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 2, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 1, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 2, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 1, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 2, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, false, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_a16"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, false, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_causal_a16"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32"; }};
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Buf;
// #######################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsHDPad|
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_a16; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_a32; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_causal_a16; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_causal_a32; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtne_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtna_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtz_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtne_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtna_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtz_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtne_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtna_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtz_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtne_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtna_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtz_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_a16_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_a32_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_causal_a16_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_causal_a32_pddv; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 1, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 2, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 1, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 2, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 1, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 2, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtne; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 1, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtna; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 2, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtz; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, false, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_a16; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_a32; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, false, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_causal_a16; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_causal_a32; }};
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Ts;
// ######################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsHDPad|
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 0, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 1, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 2, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 0, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 1, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 2, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 0, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 1, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 2, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 0, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 1, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 2, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, false, 0, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, false, 0, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
{F_template}
class fmha_bwd_v3_kernel
{{
@@ -738,445 +596,11 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
(a.seqlen_q == a.seqlen_k) && (a.nhead_q % a.nhead_k == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
(a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))) {{
if((a.hdim_q > 64) && (a.hdim_q <= 128) && (a.hdim_q % 8 == 0) && (a.seqlen_k % 64 == 0)){{
if(t.data_type.compare("fp16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(a.hdim_q == 128){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a32";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a32_pddv";
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
else if(t.is_v3_atomic_fp32 == false){{
if(a.hdim_q == 128){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a16";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true>;
// const std::string bwd_v3_name = "bwd_v3_fp16_a16_pddv";
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(a.hdim_q == 128){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_pddv";
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
else if(t.is_v3_atomic_fp32 == false){{
if(a.hdim_q == 128){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true>;
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16_pddv";
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
}}
}}
else if(t.data_type.compare("bf16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.how_v3_bf16_cvt == 0){{
if(a.hdim_q == 128){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne_pddv";
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 1){{
if(a.hdim_q == 128){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna_pddv";
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 2){{
if(a.hdim_q == 128){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz_pddv";
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
}}
else if(t.is_v3_atomic_fp32 == false){{
if(t.how_v3_bf16_cvt == 0){{
if(a.hdim_q == 128){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne_pddv";
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 1){{
if(a.hdim_q == 128){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna_pddv";
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 2){{
if(a.hdim_q == 128){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz_pddv";
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
}}
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.how_v3_bf16_cvt == 0){{
if(a.hdim_q == 128){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_pddv";
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 1){{
if(a.hdim_q == 128){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_pddv";
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 2){{
if(a.hdim_q == 128){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_pddv";
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
}}
else if(t.is_v3_atomic_fp32 == false){{
if(t.how_v3_bf16_cvt == 0){{
if(a.hdim_q == 128){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne_pddv";
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 1){{
if(a.hdim_q == 128){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna_pddv";
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 2){{
if(a.hdim_q == 128){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, true>;
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz_pddv";
r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
}}
}}
}}
}}
else if((a.hdim_q == 64) && (a.seqlen_k % 64 == 0)){{
if(t.data_type.compare("fp16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else if(t.is_v3_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a16";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else if(t.is_v3_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a16";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
}}
else if(t.data_type.compare("bf16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
else if(t.is_v3_atomic_fp32 == false){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 0, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 1, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 2, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
else if(t.is_v3_atomic_fp32 == false){{
if(t.how_v3_bf16_cvt == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 0, false>;
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else if(t.how_v3_bf16_cvt == 1){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 1, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
else if(t.how_v3_bf16_cvt == 2){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 2, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz";
r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;
}}
}}
}}
}}
if (((a.hdim_q >= 64) && (a.hdim_q <= 128) && (a.hdim_q % 8 == 0) && (a.seqlen_k % 64 == 0))) {{
{F_v3_dispatch}
}}
}}
}}
{F_dispatch}
return r;
}}
@@ -1201,26 +625,62 @@ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode})
}}
"""
FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH=""" using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}, false>;
r = fmha_bwd_v3{F_padding_suffix}_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;"""
FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH=""" using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>;
r = fmha_bwd_v3{F_padding_suffix}_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
return r;"""
FMHA_BWD_V3_PER_DTYPE_CASE=""" {F_if} (t.data_type.compare(\"{F_dtype}\") == 0) {{
{per_bf16_cvt_dispatch}
}}
"""
FMHA_BWD_V3_PER_BF16_CVT_CASE=""" {F_if} (t.how_v3_bf16_cvt == {F_bf16_cvt}) {{
{per_mask_dispatch}
}}
"""
FMHA_BWD_V3_PER_MASK_CASE=""" {F_if} {F_mask_expression}{{
{per_atomic_dispatch}
}}
"""
FMHA_BWD_V3_PER_ATOMIC_CASE=""" {F_if} {F_atomic_expression}{{
{per_hdim_dispatch}
}}
"""
FMHA_BWD_V3_PER_HDIM_CASE=""" {F_if} {F_hdim_expression}{{
{inner_dispatch}
}}
"""
@dataclass
class FmhaBwdDQDKDVApiTrait:
pipeline : str
pipeline : str
# sync with fmha_bwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along k seqlen
bhdq : int # q head_dim
bhdv : int # v head_dim
mask : str
bias : str
dbias : str
dropout : str
spad : str
skpad : str
dpad : str
dvpad : str
deterministic : str
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along k seqlen
bhdq : int # q head_dim
bhdv : int # v head_dim
mask : str
bias : str
dbias : str
dropout : str
spad : str
skpad : str
dpad : str
dvpad : str
deterministic : str
def scheck(self, spad1 : str) -> str:
if self.mode == 'group':
@@ -1251,9 +711,25 @@ class FmhaBwdDQDKDVApiTrait:
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
else : return f'a.hdim_v % {self.bhdv} == 0'
@dataclass
class FmhaBwdV3DQDKDVApiTrait:
hdim : str
dtype : str # data type
is_causal : str
is_atomic : str
bf16_cvt : int
is_hdpad : str
def remap_hdim(self):
hdim_int = int(self.hdim)
if hdim_int > 64:
self.hdim = 128
hdim_int = (hdim_int + 64 - 1) / 64 * 64
class FmhaBwdApiPool:
def __init__(self, mask_impl):
self.dq_dk_dv_pool = dict()
self.dq_dk_dv_v3_pool = dict()
self.mask_impl = mask_impl
def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None:
@@ -1265,6 +741,15 @@ class FmhaBwdApiPool:
self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait))
def register_dq_dk_dv_v3_traits(self, trait : FmhaBwdV3DQDKDVApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.dq_dk_dv_v3_pool.keys():
self.dq_dk_dv_v3_pool[trait.dtype] = dict()
if trait.hdim not in self.dq_dk_dv_v3_pool[trait.dtype].keys():
self.dq_dk_dv_v3_pool[trait.dtype][trait.hdim] = list()
self.dq_dk_dv_v3_pool[trait.dtype][trait.hdim].append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes=str()
@@ -1285,7 +770,6 @@ class FmhaBwdApiPool:
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype],
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
@@ -1293,7 +777,57 @@ class FmhaBwdApiPool:
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
gen_template = str()
for i, dtype in enumerate(self.dq_dk_dv_v3_pool.keys()):
for j, hdim in enumerate(BWD_V3_HDIM_MAP.keys()):
traits = self.dq_dk_dv_v3_pool[dtype][hdim]
hdim = int(hdim)
Ts_qo = 32 if hdim == 64 else 16
for k, trait in enumerate(traits):
if hdim == 64 and trait.is_hdpad == "t":
continue
hdim_name = "_hd64" if hdim == 64 else ""
dtype_name = "_{}".format(dtype)
causal_name = "_causal" if trait.is_causal == "t" else ""
atomic_name = "_a32" if trait.is_atomic == "t" else "_a16"
bf16_cvt_name = "_{}".format(BF16_CVT_MAP[trait.bf16_cvt])
bf16_cvt_name = bf16_cvt_name if dtype == "bf16" else ""
hdpad_name = "_pddv" if trait.is_hdpad == "t" else ""
gen_template = gen_template + FMHA_BWD_V3_TEMPLATE.format(F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_is_atomic=BOOL_MAP[trait.is_atomic],
F_is_causal=BOOL_MAP[trait.is_causal], F_bf16_cvt=trait.bf16_cvt, F_hdpad=BOOL_MAP[trait.is_hdpad], F_Ts_qo = Ts_qo, F_hdim_name=hdim_name,
F_dtype_name=dtype_name, F_causal_name=causal_name, F_atomic_name=atomic_name, F_bf16_cvt_name=bf16_cvt_name, F_hdpad_name=hdpad_name)
v3_code = str()
for i, dtype in enumerate(self.dq_dk_dv_v3_pool.keys()):
per_bf16_cvt = str()
for j, bf16_cvt in enumerate([0, 1, 2]):
per_mask = str()
for k, is_causal in enumerate(["t", "f"]):
per_atomic = str()
for l, is_atomic in enumerate(["t", "f"]):
per_hdim = str()
for m, hdim in enumerate(BWD_V3_HDIM_CASE_CHECK_MAP.values()):
if_m = 'if' if m == 0 else 'else if'
inners = str()
bf16_cvt_tmp = 0 if dtype == "fp16" else bf16_cvt
padding_suffix = "_gen" if BWD_V3_PADDING_CHECK_MAP[m] == "true" else ""
if is_atomic == "t":
inners = FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH.format(F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_is_causal=BOOL_MAP[is_causal], F_is_atomic32=BOOL_MAP[is_atomic], F_how_v3_bf16_cvt=bf16_cvt_tmp, F_padding=BWD_V3_PADDING_CHECK_MAP[m], F_padding_suffix=padding_suffix)
else:
inners = FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH.format(F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_is_causal=BOOL_MAP[is_causal], F_is_atomic32=BOOL_MAP[is_atomic], F_how_v3_bf16_cvt=bf16_cvt_tmp, F_padding=BWD_V3_PADDING_CHECK_MAP[m], F_padding_suffix=padding_suffix)
per_hdim = per_hdim + FMHA_BWD_V3_PER_HDIM_CASE.format(F_if=if_m, F_hdim_expression=BWD_V3_HDIM_CASE_MAP[m], inner_dispatch=inners)
if_l = 'if' if l == 0 else 'else if'
per_atomic = per_atomic + FMHA_BWD_V3_PER_ATOMIC_CASE.format(F_if=if_l, F_atomic_expression=BWD_V3_ATOMIC32_MAP[is_atomic], per_hdim_dispatch=per_hdim)
if_k = 'if' if k == 0 else 'else if'
per_mask = per_mask + FMHA_BWD_V3_PER_MASK_CASE.format(F_if=if_k, F_mask_expression=BWD_V3_MASK_MAP[is_causal], per_atomic_dispatch=per_atomic)
if_j = 'if' if j == 0 else 'else if'
per_bf16_cvt = per_bf16_cvt + FMHA_BWD_V3_PER_BF16_CVT_CASE.format(F_if=if_j, F_bf16_cvt=bf16_cvt, per_mask_dispatch=per_mask)
if_i = 'if' if i == 0 else 'else if'
v3_code = v3_code + FMHA_BWD_V3_PER_DTYPE_CASE.format(F_if=if_i, F_dtype=dtype, per_bf16_cvt_dispatch=per_bf16_cvt)
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes, F_template = gen_template, F_v3_dispatch = v3_code)
# GEMM0: Q@K=S^T
# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v)
@@ -1444,20 +978,39 @@ class FmhaBwdDQDKDVKernel:
dvpad=self.F_dvpad,
deterministic=self.F_deterministic
)
@dataclass
class FmhaBwdV3DQDKDVKernel:
F_hdim : int # hdim
F_dtype : str # data type
F_is_causal : str
F_is_atomic : str
F_bf16_cvt : int
F_is_hdpad : str
def v3_api_trait(self) -> FmhaBwdV3DQDKDVApiTrait:
return FmhaBwdV3DQDKDVApiTrait(hdim=str(self.F_hdim),
dtype=self.F_dtype,
is_causal=self.F_is_causal,
is_atomic=self.F_is_atomic,
bf16_cvt=self.F_bf16_cvt,
is_hdpad=self.F_is_hdpad
)
# TODO: design a more practical way to do it
# this is current supported tile size & pipeline.
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
# '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
# "kr_ktr_vr_iglp", "kr_ktr_vr"],
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"]
# '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# "kr_ktr_vr_iglp", "kr_ktr_vr"]
}
else:
return None
@@ -1501,8 +1054,11 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue
elif receipt == 3:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
cond &= bias in ['no']
cond &= dropout in ['no']
cond &= dpad == dvpad
cond &= spad == skpad
cond &= spad == "f"
cond &= deterministic == "f"
if not cond:
continue
@@ -1538,6 +1094,17 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
api_pool.register_dq_dk_dv_traits(k.api_trait())
gen.append(k)
for hdim_str, is_causal, is_atomic, bf16_cvt, is_hdpad in itertools.product(d.keys(), ["t", "f"], ["t", "f"], [0, 1, 2], ["t", "f"]):
hdim = int(hdim_str)
k = FmhaBwdV3DQDKDVKernel(F_hdim=hdim, F_dtype=dtype, F_is_causal=is_causal, F_is_atomic=is_atomic, F_bf16_cvt=bf16_cvt, F_is_hdpad=is_hdpad)
if receipt == 3:
cond = (dtype == 'fp16') and (bf16_cvt in [1, 2])
if cond:
# print(dtype, bf16_cvt)
continue
api_pool.register_dq_dk_dv_v3_traits(k.v3_api_trait())
# gen.append(k)
return (api_pool, gen)
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""