diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 332707eafd..d5556bcb83 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -126,3 +126,42 @@ BOOL_MAP = { "t" : "true", "f" : "false" } + +BWD_V3_HDIM_MAP = { + "64": "64", + "128": "128" +} + +BF16_CVT_MAP = { + 0 : "rtne", + 1 : "rtna", + 2 : "rtz", +} + +BWD_V3_MASK_MAP = { + "t": "((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0)))", + "f": "(t.mask_type == mask_enum::no_mask)" +} + +BWD_V3_ATOMIC32_MAP = { + "t": "((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/))", + "f": "(t.is_v3_atomic_fp32 == false)" +} + +BWD_V3_HDIM_CASE_MAP = { + 0: "(a.hdim_q == 128)", + 1: "(a.hdim_q == 64)", + 2: "((a.hdim_q > 64) && (a.hdim_q < 128))" +} + +BWD_V3_HDIM_CASE_CHECK_MAP = { + 0: 128, + 1: 64, + 2: 128 +} + +BWD_V3_PADDING_CHECK_MAP = { + 0: "false", + 1: "false", + 2: "true" +} \ No newline at end of file diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 11bcc0f570..047cf67f55 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -161,6 +161,12 @@ std::string fmha_bwd_dq_dk_dv_get_name_() """ FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" + +FMHA_BWD_V3_TEMPLATE="""template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3{F_hdim_name}{F_dtype_name}{F_causal_name}{F_atomic_name}{F_bf16_cvt_name}{F_hdpad_name}"; }}; +template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd{F_hdim_name}{F_dtype_name}{F_causal_name}{F_atomic_name}{F_bf16_cvt_name}{F_hdpad_name}; }}; +template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = {F_Ts_qo}; static constexpr int ts_kv = 192; }}; +""" + FMHA_BWD_API=""" #include #include "hsaco/fmha_hsaco.hpp" @@ -312,157 +318,9 @@ struct fmha_bwd_dq_dk_dv_v3_traits_ }}; template struct FmhaBwdV3Name; -// ########################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsHDPad| -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtne"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtna"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtz"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtne"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtna"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtz"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_a16"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_a32"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_causal_a16"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_causal_a32"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtne_pddv"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtna_pddv"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a16_rtz_pddv"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtne_pddv"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtna_pddv"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_a32_rtz_pddv"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne_pddv"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna_pddv"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz_pddv"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_pddv"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_pddv"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_pddv"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_a16_pddv"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_a32_pddv"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_causal_a16_pddv"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_fp16_causal_a32_pddv"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_a16"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_a32"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_causal_a16"; }}; -template<> struct FmhaBwdV3Name> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32"; }}; - template struct FmhaBwdV3Buf; -// #######################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsHDPad| -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtne; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtna; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtz; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtne; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtna; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtz; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtne; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtna; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtz; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtne; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtna; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtz; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_a16; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_a32; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_causal_a16; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_causal_a32; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtne_pddv; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtna_pddv; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a16_rtz_pddv; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtne_pddv; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtna_pddv; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_a32_rtz_pddv; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtne_pddv; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtna_pddv; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a16_rtz_pddv; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtne_pddv; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtna_pddv; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_bf16_causal_a32_rtz_pddv; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_a16_pddv; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_a32_pddv; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_causal_a16_pddv; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_fp16_causal_a32_pddv; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtne; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtna; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtz; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtne; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtna; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtz; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtne; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtna; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtz; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtne; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtna; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtz; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_a16; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_a32; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_causal_a16; }}; -template<> struct FmhaBwdV3Buf> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_causal_a32; }}; - template struct FmhaBwdV3Ts; -// ######################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsHDPad| -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; -template<> struct FmhaBwdV3Ts> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }}; +{F_template} class fmha_bwd_v3_kernel {{ @@ -738,445 +596,11 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& (a.seqlen_q == a.seqlen_k) && (a.nhead_q % a.nhead_k == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) && (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))) {{ - if((a.hdim_q > 64) && (a.hdim_q <= 128) && (a.hdim_q % 8 == 0) && (a.seqlen_k % 64 == 0)){{ - if(t.data_type.compare("fp16") == 0){{ - if(t.mask_type == mask_enum::no_mask){{ - if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{ - if(a.hdim_q == 128){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_fp16_a32"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>; - // const std::string bwd_v3_name = "bwd_v3_fp16_a32_pddv"; - r = fmha_bwd_v3_gen_(s, a); - return r; - }} - }} - else if(t.is_v3_atomic_fp32 == false){{ - if(a.hdim_q == 128){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false>; - // const std::string bwd_v3_name = "bwd_v3_fp16_a16"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true>; - // const std::string bwd_v3_name = "bwd_v3_fp16_a16_pddv"; - r = fmha_bwd_v3_gen_(s, a); - return r; - }} - }} - }} - else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ - if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{ - if(a.hdim_q == 128){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>; - // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32_pddv"; - r = fmha_bwd_v3_gen_(s, a); - return r; - }} - }} - else if(t.is_v3_atomic_fp32 == false){{ - if(a.hdim_q == 128){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false>; - // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true>; - // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16_pddv"; - r = fmha_bwd_v3_gen_(s, a); - return r; - }} - }} - }} - }} - else if(t.data_type.compare("bf16") == 0){{ - if(t.mask_type == mask_enum::no_mask){{ - if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{ - if(t.how_v3_bf16_cvt == 0){{ - if(a.hdim_q == 128){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne_pddv"; - r = fmha_bwd_v3_gen_(s, a); - return r; - }} - }} - else if(t.how_v3_bf16_cvt == 1){{ - if(a.hdim_q == 128){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna_pddv"; - r = fmha_bwd_v3_gen_(s, a); - return r; - }} - }} - else if(t.how_v3_bf16_cvt == 2){{ - if(a.hdim_q == 128){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz_pddv"; - r = fmha_bwd_v3_gen_(s, a); - return r; - }} - }} - }} - else if(t.is_v3_atomic_fp32 == false){{ - if(t.how_v3_bf16_cvt == 0){{ - if(a.hdim_q == 128){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true>; - // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne_pddv"; - r = fmha_bwd_v3_gen_(s, a); - return r; - }} - }} - else if(t.how_v3_bf16_cvt == 1){{ - if(a.hdim_q == 128){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, true>; - // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna_pddv"; - r = fmha_bwd_v3_gen_(s, a); - return r; - }} - }} - else if(t.how_v3_bf16_cvt == 2){{ - if(a.hdim_q == 128){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, true>; - // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz_pddv"; - r = fmha_bwd_v3_gen_(s, a); - return r; - }} - }} - }} - }} - else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ - if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{ - if(t.how_v3_bf16_cvt == 0){{ - if(a.hdim_q == 128){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne_pddv"; - r = fmha_bwd_v3_gen_(s, a); - return r; - }} - }} - else if(t.how_v3_bf16_cvt == 1){{ - if(a.hdim_q == 128){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna_pddv"; - r = fmha_bwd_v3_gen_(s, a); - return r; - }} - }} - else if(t.how_v3_bf16_cvt == 2){{ - if(a.hdim_q == 128){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz_pddv"; - r = fmha_bwd_v3_gen_(s, a); - return r; - }} - }} - }} - else if(t.is_v3_atomic_fp32 == false){{ - if(t.how_v3_bf16_cvt == 0){{ - if(a.hdim_q == 128){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne_pddv"; - r = fmha_bwd_v3_gen_(s, a); - return r; - }} - }} - else if(t.how_v3_bf16_cvt == 1){{ - if(a.hdim_q == 128){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, true>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna_pddv"; - r = fmha_bwd_v3_gen_(s, a); - return r; - }} - }} - else if(t.how_v3_bf16_cvt == 2){{ - if(a.hdim_q == 128){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else{{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, true>; - // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz_pddv"; - r = fmha_bwd_v3_gen_(s, a); - return r; - }} - }} - }} - }} - }} - }} - else if((a.hdim_q == 64) && (a.seqlen_k % 64 == 0)){{ - if(t.data_type.compare("fp16") == 0){{ - if(t.mask_type == mask_enum::no_mask){{ - if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else if(t.is_v3_atomic_fp32 == false){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, false, 0, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a16"; - r = fmha_bwd_v3_(s, a); - return r; - }} - }} - else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ - if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else if(t.is_v3_atomic_fp32 == false){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, false, 0, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a16"; - r = fmha_bwd_v3_(s, a); - return r; - }} - }} - }} - else if(t.data_type.compare("bf16") == 0){{ - if(t.mask_type == mask_enum::no_mask){{ - if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{ - if(t.how_v3_bf16_cvt == 0){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else if(t.how_v3_bf16_cvt == 1){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else if(t.how_v3_bf16_cvt == 2){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz"; - r = fmha_bwd_v3_(s, a); - return r; - }} - }} - else if(t.is_v3_atomic_fp32 == false){{ - if(t.how_v3_bf16_cvt == 0){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 0, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else if(t.how_v3_bf16_cvt == 1){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 1, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else if(t.how_v3_bf16_cvt == 2){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 2, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz"; - r = fmha_bwd_v3_(s, a); - return r; - }} - }} - }} - else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{ - if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{ - if(t.how_v3_bf16_cvt == 0){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else if(t.how_v3_bf16_cvt == 1){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else if(t.how_v3_bf16_cvt == 2){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz"; - r = fmha_bwd_v3_(s, a); - return r; - }} - }} - else if(t.is_v3_atomic_fp32 == false){{ - if(t.how_v3_bf16_cvt == 0){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 0, false>; - const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else if(t.how_v3_bf16_cvt == 1){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 1, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna"; - r = fmha_bwd_v3_(s, a); - return r; - }} - else if(t.how_v3_bf16_cvt == 2){{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 2, false>; - // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz"; - r = fmha_bwd_v3_(s, a); - return r; - }} - }} - }} - }} + if (((a.hdim_q >= 64) && (a.hdim_q <= 128) && (a.hdim_q % 8 == 0) && (a.seqlen_k % 64 == 0))) {{ +{F_v3_dispatch} }} }} }} - {F_dispatch} return r; }} @@ -1201,26 +625,62 @@ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) }} """ +FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH=""" using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}, false>; + r = fmha_bwd_v3{F_padding_suffix}_(s, a); + return r;""" + +FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH=""" using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, false, false, {F_padding}>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<{F_hdim}, {F_dtype}, {F_is_causal}, {F_is_atomic32}, {F_how_v3_bf16_cvt}, {F_padding}>; + r = fmha_bwd_v3{F_padding_suffix}_(s, a); + return r;""" + +FMHA_BWD_V3_PER_DTYPE_CASE=""" {F_if} (t.data_type.compare(\"{F_dtype}\") == 0) {{ +{per_bf16_cvt_dispatch} + }} +""" + +FMHA_BWD_V3_PER_BF16_CVT_CASE=""" {F_if} (t.how_v3_bf16_cvt == {F_bf16_cvt}) {{ +{per_mask_dispatch} + }} +""" + +FMHA_BWD_V3_PER_MASK_CASE=""" {F_if} {F_mask_expression}{{ +{per_atomic_dispatch} + }} +""" + +FMHA_BWD_V3_PER_ATOMIC_CASE=""" {F_if} {F_atomic_expression}{{ +{per_hdim_dispatch} + }} +""" + +FMHA_BWD_V3_PER_HDIM_CASE=""" {F_if} {F_hdim_expression}{{ +{inner_dispatch} + }} +""" + @dataclass class FmhaBwdDQDKDVApiTrait: - pipeline : str + pipeline : str # sync with fmha_bwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along k seqlen - bhdq : int # q head_dim - bhdv : int # v head_dim - mask : str - bias : str - dbias : str - dropout : str - spad : str - skpad : str - dpad : str - dvpad : str - deterministic : str + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along k seqlen + bhdq : int # q head_dim + bhdv : int # v head_dim + mask : str + bias : str + dbias : str + dropout : str + spad : str + skpad : str + dpad : str + dvpad : str + deterministic : str def scheck(self, spad1 : str) -> str: if self.mode == 'group': @@ -1251,9 +711,25 @@ class FmhaBwdDQDKDVApiTrait: if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' else : return f'a.hdim_v % {self.bhdv} == 0' +@dataclass +class FmhaBwdV3DQDKDVApiTrait: + hdim : str + dtype : str # data type + is_causal : str + is_atomic : str + bf16_cvt : int + is_hdpad : str + + def remap_hdim(self): + hdim_int = int(self.hdim) + if hdim_int > 64: + self.hdim = 128 + hdim_int = (hdim_int + 64 - 1) / 64 * 64 + class FmhaBwdApiPool: def __init__(self, mask_impl): self.dq_dk_dv_pool = dict() + self.dq_dk_dv_v3_pool = dict() self.mask_impl = mask_impl def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None: @@ -1265,6 +741,15 @@ class FmhaBwdApiPool: self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + def register_dq_dk_dv_v3_traits(self, trait : FmhaBwdV3DQDKDVApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.dq_dk_dv_v3_pool.keys(): + self.dq_dk_dv_v3_pool[trait.dtype] = dict() + if trait.hdim not in self.dq_dk_dv_v3_pool[trait.dtype].keys(): + self.dq_dk_dv_v3_pool[trait.dtype][trait.hdim] = list() + + self.dq_dk_dv_v3_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + @property def api(self) -> str: per_dtypes=str() @@ -1285,7 +770,6 @@ class FmhaBwdApiPool: F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_deterministic=BOOL_MAP[trait.deterministic]) - if_j = 'if' if j == 0 else 'else if' per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' @@ -1293,7 +777,57 @@ class FmhaBwdApiPool: if not per_dtypes: # empty string we add some ignore to suppress warning in api per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) + + gen_template = str() + for i, dtype in enumerate(self.dq_dk_dv_v3_pool.keys()): + for j, hdim in enumerate(BWD_V3_HDIM_MAP.keys()): + traits = self.dq_dk_dv_v3_pool[dtype][hdim] + hdim = int(hdim) + Ts_qo = 32 if hdim == 64 else 16 + for k, trait in enumerate(traits): + if hdim == 64 and trait.is_hdpad == "t": + continue + hdim_name = "_hd64" if hdim == 64 else "" + dtype_name = "_{}".format(dtype) + causal_name = "_causal" if trait.is_causal == "t" else "" + atomic_name = "_a32" if trait.is_atomic == "t" else "_a16" + bf16_cvt_name = "_{}".format(BF16_CVT_MAP[trait.bf16_cvt]) + bf16_cvt_name = bf16_cvt_name if dtype == "bf16" else "" + hdpad_name = "_pddv" if trait.is_hdpad == "t" else "" + gen_template = gen_template + FMHA_BWD_V3_TEMPLATE.format(F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_is_atomic=BOOL_MAP[trait.is_atomic], + F_is_causal=BOOL_MAP[trait.is_causal], F_bf16_cvt=trait.bf16_cvt, F_hdpad=BOOL_MAP[trait.is_hdpad], F_Ts_qo = Ts_qo, F_hdim_name=hdim_name, + F_dtype_name=dtype_name, F_causal_name=causal_name, F_atomic_name=atomic_name, F_bf16_cvt_name=bf16_cvt_name, F_hdpad_name=hdpad_name) + + v3_code = str() + for i, dtype in enumerate(self.dq_dk_dv_v3_pool.keys()): + per_bf16_cvt = str() + for j, bf16_cvt in enumerate([0, 1, 2]): + per_mask = str() + for k, is_causal in enumerate(["t", "f"]): + per_atomic = str() + for l, is_atomic in enumerate(["t", "f"]): + per_hdim = str() + for m, hdim in enumerate(BWD_V3_HDIM_CASE_CHECK_MAP.values()): + if_m = 'if' if m == 0 else 'else if' + inners = str() + bf16_cvt_tmp = 0 if dtype == "fp16" else bf16_cvt + padding_suffix = "_gen" if BWD_V3_PADDING_CHECK_MAP[m] == "true" else "" + if is_atomic == "t": + inners = FMHA_BWD_V3_ATOMIC32_INNER_DISPATCH.format(F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_is_causal=BOOL_MAP[is_causal], F_is_atomic32=BOOL_MAP[is_atomic], F_how_v3_bf16_cvt=bf16_cvt_tmp, F_padding=BWD_V3_PADDING_CHECK_MAP[m], F_padding_suffix=padding_suffix) + else: + inners = FMHA_BWD_V3_ATOMIC16_INNER_DISPATCH.format(F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_is_causal=BOOL_MAP[is_causal], F_is_atomic32=BOOL_MAP[is_atomic], F_how_v3_bf16_cvt=bf16_cvt_tmp, F_padding=BWD_V3_PADDING_CHECK_MAP[m], F_padding_suffix=padding_suffix) + per_hdim = per_hdim + FMHA_BWD_V3_PER_HDIM_CASE.format(F_if=if_m, F_hdim_expression=BWD_V3_HDIM_CASE_MAP[m], inner_dispatch=inners) + + if_l = 'if' if l == 0 else 'else if' + per_atomic = per_atomic + FMHA_BWD_V3_PER_ATOMIC_CASE.format(F_if=if_l, F_atomic_expression=BWD_V3_ATOMIC32_MAP[is_atomic], per_hdim_dispatch=per_hdim) + if_k = 'if' if k == 0 else 'else if' + per_mask = per_mask + FMHA_BWD_V3_PER_MASK_CASE.format(F_if=if_k, F_mask_expression=BWD_V3_MASK_MAP[is_causal], per_atomic_dispatch=per_atomic) + if_j = 'if' if j == 0 else 'else if' + per_bf16_cvt = per_bf16_cvt + FMHA_BWD_V3_PER_BF16_CVT_CASE.format(F_if=if_j, F_bf16_cvt=bf16_cvt, per_mask_dispatch=per_mask) + if_i = 'if' if i == 0 else 'else if' + v3_code = v3_code + FMHA_BWD_V3_PER_DTYPE_CASE.format(F_if=if_i, F_dtype=dtype, per_bf16_cvt_dispatch=per_bf16_cvt) + + return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes, F_template = gen_template, F_v3_dispatch = v3_code) # GEMM0: Q@K=S^T # GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) @@ -1444,20 +978,39 @@ class FmhaBwdDQDKDVKernel: dvpad=self.F_dvpad, deterministic=self.F_deterministic ) + + +@dataclass +class FmhaBwdV3DQDKDVKernel: + F_hdim : int # hdim + F_dtype : str # data type + F_is_causal : str + F_is_atomic : str + F_bf16_cvt : int + F_is_hdpad : str + + def v3_api_trait(self) -> FmhaBwdV3DQDKDVApiTrait: + return FmhaBwdV3DQDKDVApiTrait(hdim=str(self.F_hdim), + dtype=self.F_dtype, + is_causal=self.F_is_causal, + is_atomic=self.F_is_atomic, + bf16_cvt=self.F_bf16_cvt, + is_hdpad=self.F_is_hdpad + ) # TODO: design a more practical way to do it # this is current supported tile size & pipeline. def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - "kr_ktr_vr_iglp", "kr_ktr_vr"], + # '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + # "kr_ktr_vr_iglp", "kr_ktr_vr"], '64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"], '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"], - '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - "kr_ktr_vr_iglp", "kr_ktr_vr"] + # '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + # "kr_ktr_vr_iglp", "kr_ktr_vr"] } else: return None @@ -1501,8 +1054,11 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> continue elif receipt == 3: cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'alibi'] + cond &= bias in ['no'] + cond &= dropout in ['no'] cond &= dpad == dvpad + cond &= spad == skpad + cond &= spad == "f" cond &= deterministic == "f" if not cond: continue @@ -1538,6 +1094,17 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> api_pool.register_dq_dk_dv_traits(k.api_trait()) gen.append(k) + for hdim_str, is_causal, is_atomic, bf16_cvt, is_hdpad in itertools.product(d.keys(), ["t", "f"], ["t", "f"], [0, 1, 2], ["t", "f"]): + hdim = int(hdim_str) + k = FmhaBwdV3DQDKDVKernel(F_hdim=hdim, F_dtype=dtype, F_is_causal=is_causal, F_is_atomic=is_atomic, F_bf16_cvt=bf16_cvt, F_is_hdpad=is_hdpad) + if receipt == 3: + cond = (dtype == 'fp16') and (bf16_cvt in [1, 2]) + if cond: + # print(dtype, bf16_cvt) + continue + api_pool.register_dq_dk_dv_v3_traits(k.v3_api_trait()) + # gen.append(k) + return (api_pool, gen) FMHA_BWD_DOT_DO_O_KERNEL_BODY="""