mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
@@ -401,6 +401,73 @@ struct __attribute__((packed)) fmha_bwd_v3_group_args
|
||||
p1 _p17;
|
||||
}};
|
||||
|
||||
struct __attribute__((packed)) fmha_bwd_v3_swa_genl_args
|
||||
{{
|
||||
void* ptr_dq;
|
||||
void* ptr_dk;
|
||||
void* ptr_dv;
|
||||
const void* ptr_q;
|
||||
const void* ptr_k;
|
||||
const void* ptr_v;
|
||||
const void* ptr_do;
|
||||
const void* ptr_lse;
|
||||
const void* ptr_d;
|
||||
float scalar;
|
||||
p1 _p0;
|
||||
float log2e;
|
||||
p1 _p1;
|
||||
unsigned int ratio;
|
||||
p1 _p2;
|
||||
unsigned int seqlen_q;
|
||||
p1 _p3;
|
||||
unsigned int seqlen_k;
|
||||
p1 _p4;
|
||||
unsigned int head_dim;
|
||||
p1 _p5;
|
||||
unsigned int nhead_q;
|
||||
p1 _p6;
|
||||
unsigned int Hs_q;
|
||||
p1 _p7;
|
||||
unsigned int BAs_q;
|
||||
p1 _p8;
|
||||
unsigned int Seqs_q;
|
||||
p1 _p9;
|
||||
unsigned int Hs_k;
|
||||
p1 _p10;
|
||||
unsigned int BAs_k;
|
||||
p1 _p11;
|
||||
unsigned int Seqs_k;
|
||||
p1 _p12;
|
||||
unsigned int Hs_v;
|
||||
p1 _p13;
|
||||
unsigned int BAs_v;
|
||||
p1 _p14;
|
||||
unsigned int Seqs_v;
|
||||
p1 _p15;
|
||||
unsigned int Hs_do;
|
||||
p1 _p16;
|
||||
unsigned int BAs_do;
|
||||
p1 _p17;
|
||||
unsigned int Seqs_do;
|
||||
p1 _p18;
|
||||
unsigned int Hs_dk;
|
||||
p1 _p19;
|
||||
unsigned int BAs_dk;
|
||||
p1 _p20;
|
||||
unsigned int Seqs_dk;
|
||||
p1 _p21;
|
||||
unsigned int Hs_dv;
|
||||
p1 _p22;
|
||||
unsigned int BAs_dv;
|
||||
p1 _p23;
|
||||
unsigned int Seqs_dv;
|
||||
p1 _p24;
|
||||
int mask_x;
|
||||
p1 _p25;
|
||||
int mask_y;
|
||||
p1 _p26;
|
||||
}};
|
||||
|
||||
struct fmha_bwd_v3_traits
|
||||
{{
|
||||
int b;
|
||||
@@ -415,7 +482,7 @@ struct fmha_bwd_v3_traits
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsCausal_,
|
||||
int mask_type_,
|
||||
bool kIsAtomic32_,
|
||||
ck_tile::index_t BF16Cvt_,
|
||||
bool kIsSEQPad_,
|
||||
@@ -425,7 +492,7 @@ struct fmha_bwd_dq_dk_dv_v3_traits_
|
||||
{{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsCausal = kIsCausal_;
|
||||
static constexpr int mask_type = mask_type_;
|
||||
static constexpr bool kIsAtomic32 = kIsAtomic32_;
|
||||
static constexpr ck_tile::index_t BF16Cvt = BF16Cvt_;
|
||||
static constexpr bool kIsSEQPad = kIsSEQPad_;
|
||||
@@ -434,256 +501,268 @@ struct fmha_bwd_dq_dk_dv_v3_traits_
|
||||
}};
|
||||
|
||||
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Name;
|
||||
// ########################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtne"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtna"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtz"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtna"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtz"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtne"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtna"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtz"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtna"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtz"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_a16"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_a32"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_causal_a16"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtne_pddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtna_pddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtz_pddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtna_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtz_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtne_pddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtna_pddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtz_pddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtna_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtz_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_a16_pddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_a32_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_causal_a16_pddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 1, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 2, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 0, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne_pssk"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 1, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna_pssk"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 2, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz_pssk"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 1, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 2, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 0, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 1, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 2, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, false, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_a16"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, false, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_causal_a16"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 0, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd192_bf16_a32_rtne_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 1, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd192_bf16_a32_rtna_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 2, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd192_bf16_a32_rtz_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 0, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd192_bf16_causal_a32_rtne_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 1, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd192_bf16_causal_a32_rtna_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 2, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd192_bf16_causal_a32_rtz_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, false, true, 0, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd192_fp16_a32_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, true, true, 0, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd192_fp16_causal_a32_psskddv"; }};
|
||||
// ########################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode|
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 1, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 2, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 1, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 2, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_psskddv_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_a32_psskddv_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_a32_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne_psskddv_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtna_psskddv_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtz_psskddv_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne_psskddv_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtna_psskddv_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtz_psskddv_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtna_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtz_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtna_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtz_pssk_group"; }};
|
||||
// ########################################################|HDim| DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, false, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtne"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, false, 1, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtna"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, false, 2, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtz"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 1, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtna"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 2, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtz"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, false, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtne"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, false, 1, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtna"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, false, 2, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtz"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 1, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtna"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 2, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtz"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 0, false, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_a16"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 0, true, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_a32"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 1, false, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_causal_a16"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 1, true, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, false, 0, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtne_pddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, false, 1, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtna_pddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, false, 2, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtz_pddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 0, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 1, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtna_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 2, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtz_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, false, 0, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtne_pddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, false, 1, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtna_pddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, false, 2, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtz_pddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 0, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 1, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtna_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 2, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtz_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 0, false, 0, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_a16_pddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 0, true, 0, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_a32_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 1, false, 0, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_causal_a16_pddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 1, true, 0, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, false, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, false, 1, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, false, 2, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, true, 0, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne_pssk"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, true, 1, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna_pssk"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, true, 2, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz_pssk"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, false, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, false, 1, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, false, 2, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, true, 0, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, true, 1, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, true, 2, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, 0, false, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_a16"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, 0, true, 0, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, 1, false, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_causal_a16"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, 1, true, 0, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 0, true, 0, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd192_bf16_a32_rtne_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 0, true, 1, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd192_bf16_a32_rtna_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 0, true, 2, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd192_bf16_a32_rtz_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 1, true, 0, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd192_bf16_causal_a32_rtne_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 1, true, 1, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd192_bf16_causal_a32_rtna_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 1, true, 2, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd192_bf16_causal_a32_rtz_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, 0, true, 0, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd192_fp16_a32_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, 1, true, 0, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd192_fp16_causal_a32_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_swa_a32_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtne_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtna_psskddv"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtz_psskddv"; }};
|
||||
// ########################################################|HDim| DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode|
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, true, 1, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, true, 2, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, true, 1, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, true, 2, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, 0, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, 1, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 1, true, 0, true, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_psskddv_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 1, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 0, true, 0, true, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_a32_psskddv_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 0, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_fp16_a32_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 0, true, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne_psskddv_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 1, true, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtna_psskddv_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 2, true, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtz_psskddv_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 0, true, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne_psskddv_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 1, true, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtna_psskddv_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 2, true, true, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtz_psskddv_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 1, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtna_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 2, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtz_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 1, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtna_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 2, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtz_pssk_group"; }};
|
||||
|
||||
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Buf;
|
||||
// #######################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a16_rtne; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a16_rtna; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a16_rtz; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtne; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtna; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtz; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a16_rtne; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a16_rtna; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a16_rtz; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtne; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtna; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtz; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_a16; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_a32; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_causal_a16; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_causal_a32; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a16_rtne_pddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a16_rtna_pddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a16_rtz_pddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtne_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtna_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtz_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a16_rtne_pddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a16_rtna_pddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a16_rtz_pddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtne_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtna_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtz_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_a16_pddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_a32_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_causal_a16_pddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_causal_a32_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtne; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 1, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtna; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 2, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtz; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 0, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtne_pssk; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 1, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtna_pssk; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 2, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtz_pssk; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtne; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 1, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtna; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 2, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtz; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 0, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtne_pssk; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 1, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtna_pssk; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 2, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtz_pssk; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, false, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_a16; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_a32_pssk; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, false, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_causal_a16; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_causal_a32_pssk; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 0, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd192_bf16_a32_rtne_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 1, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd192_bf16_a32_rtna_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 2, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd192_bf16_a32_rtz_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 0, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd192_bf16_causal_a32_rtne_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 1, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd192_bf16_causal_a32_rtna_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 2, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd192_bf16_causal_a32_rtz_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, false, true, 0, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd192_fp16_a32_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, true, true, 0, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd192_fp16_causal_a32_psskddv; }};
|
||||
// #######################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode|
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtne_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 1, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtna_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 2, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtz_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtne_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 1, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtna_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 2, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtz_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_a32_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_causal_a32_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_causal_a32_psskddv_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_causal_a32_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_a32_psskddv_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_a32_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtne_psskddv_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtna_psskddv_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtz_psskddv_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtne_psskddv_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtna_psskddv_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtz_psskddv_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtne_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtna_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtz_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtne_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtna_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtz_pssk_group; }};
|
||||
// #######################################################|HDim| DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, false, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a16_rtne; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, false, 1, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a16_rtna; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, false, 2, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a16_rtz; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtne; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 1, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtna; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 2, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtz; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, false, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a16_rtne; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, false, 1, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a16_rtna; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, false, 2, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a16_rtz; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtne; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 1, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtna; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 2, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtz; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 0, false, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_a16; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 0, true, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_a32; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 1, false, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_causal_a16; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 1, true, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_causal_a32; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, false, 0, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a16_rtne_pddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, false, 1, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a16_rtna_pddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, false, 2, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a16_rtz_pddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 0, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtne_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 1, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtna_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 2, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtz_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, false, 0, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a16_rtne_pddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, false, 1, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a16_rtna_pddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, false, 2, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a16_rtz_pddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 0, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtne_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 1, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtna_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 2, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtz_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 0, false, 0, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_a16_pddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 0, true, 0, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_a32_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 1, false, 0, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_causal_a16_pddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 1, true, 0, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_causal_a32_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, false, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtne; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, false, 1, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtna; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, false, 2, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a16_rtz; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, true, 0, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtne_pssk; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, true, 1, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtna_pssk; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, true, 2, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtz_pssk; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, false, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtne; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, false, 1, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtna; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, false, 2, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a16_rtz; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, true, 0, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtne_pssk; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, true, 1, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtna_pssk; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, true, 2, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtz_pssk; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, 0, false, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_a16; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, 0, true, 0, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_a32_pssk; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, 1, false, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_causal_a16; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, 1, true, 0, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_causal_a32_pssk; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 0, true, 0, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd192_bf16_a32_rtne_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 0, true, 1, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd192_bf16_a32_rtna_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 0, true, 2, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd192_bf16_a32_rtz_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 1, true, 0, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd192_bf16_causal_a32_rtne_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 1, true, 1, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd192_bf16_causal_a32_rtna_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 1, true, 2, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd192_bf16_causal_a32_rtz_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, 0, true, 0, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd192_fp16_a32_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, 1, true, 0, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd192_fp16_causal_a32_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_swa_a32_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_swa_a32_rtne_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_swa_a32_rtna_psskddv; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_swa_a32_rtz_psskddv; }};
|
||||
// #######################################################|HDim| DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode|
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtne_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, true, 1, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtna_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, true, 2, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtz_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtne_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, true, 1, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtna_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, true, 2, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtz_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, 0, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_a32_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, 1, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_causal_a32_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 1, true, 0, true, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_causal_a32_psskddv_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 1, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_causal_a32_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 0, true, 0, true, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_a32_psskddv_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 0, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_fp16_a32_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 0, true, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtne_psskddv_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 1, true, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtna_psskddv_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 2, true, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtz_psskddv_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 0, true, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtne_psskddv_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 1, true, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtna_psskddv_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 2, true, true, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtz_psskddv_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtne_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 1, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtna_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 2, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_a32_rtz_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtne_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 1, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtna_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 2, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd128_bf16_causal_a32_rtz_pssk_group; }};
|
||||
|
||||
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Ts;
|
||||
// ######################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 0, false, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 1, false, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, false, 2, false, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 0, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 1, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 2, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 0, false, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 1, false, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, false, 2, false, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 0, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 1, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 2, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, false, 0, false, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, false, 0, false, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 0, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 1, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 2, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 0, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 1, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 2, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, false, true, 0, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, true, true, 0, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
|
||||
// ######################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode|
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 0, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 1, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 2, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 0, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 1, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 2, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
// ######################################################|HDim| DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, false, 0, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, false, 1, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, false, 2, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 0, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 1, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 2, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, false, 0, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, false, 1, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, false, 2, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 0, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 1, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 2, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 0, false, 0, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 0, true, 0, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 1, false, 0, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 1, true, 0, false, false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, false, 0, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, false, 1, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, false, 2, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 0, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 1, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 2, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, false, 0, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, false, 1, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, false, 2, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 0, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 1, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 2, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 0, false, 0, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 0, true, 0, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 1, false, 0, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 1, true, 0, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, false, 0, false, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, false, 1, false, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, false, 2, false, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, true, 0, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, true, 1, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, true, 2, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, false, 0, false, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, false, 1, false, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, false, 2, false, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, true, 0, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, true, 1, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, true, 2, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, 0, false, 0, false, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, 0, true, 0, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, 1, false, 0, false, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, 1, true, 0, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 0, true, 0, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 0, true, 1, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 0, true, 2, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 1, true, 0, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 1, true, 1, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 1, true, 2, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, 0, true, 0, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, 1, true, 0, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
// ######################################################|HDim| DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode|
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, true, 0, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, true, 1, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 0, true, 2, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, true, 0, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, true, 1, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, 1, true, 2, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, 0, true, 0, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, 1, true, 0, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 1, true, 0, true, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 1, true, 0, true, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 0, true, 0, true, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 0, true, 0, true, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 0, true, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 0, true, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 0, true, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 0, true, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 1, true, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 1, true, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 1, true, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 1, true, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 2, true, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 1, true, 2, true, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 2, true, true, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 0, true, 2, true, false, true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
|
||||
|
||||
class fmha_bwd_v3_kernel
|
||||
{{
|
||||
@@ -817,7 +896,34 @@ class fmha_bwd_v3_kernel
|
||||
NULL,
|
||||
reinterpret_cast<void**>(&config)));
|
||||
}}
|
||||
|
||||
|
||||
void
|
||||
launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_v3_swa_genl_args args, const ck_tile::stream_config& s) const
|
||||
{{
|
||||
size_t arg_size = sizeof(args);
|
||||
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
|
||||
&args,
|
||||
HIP_LAUNCH_PARAM_BUFFER_SIZE,
|
||||
&arg_size,
|
||||
HIP_LAUNCH_PARAM_END}};
|
||||
|
||||
int bdx = 256;
|
||||
int gdx = (fmha_v3_traits.s + fmha_v3_traits.ts_kv - 1) / fmha_v3_traits.ts_kv;
|
||||
int gdy = fmha_v3_traits.h;
|
||||
int gdz = fmha_v3_traits.b;
|
||||
|
||||
HIP_CALL(hipModuleLaunchKernel(kernel_func,
|
||||
gdx,
|
||||
gdy,
|
||||
gdz,
|
||||
bdx,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
s.stream_id_,
|
||||
NULL,
|
||||
reinterpret_cast<void**>(&config)));
|
||||
}}
|
||||
private:
|
||||
hipModule_t module;
|
||||
hipFunction_t kernel_func;
|
||||
@@ -1122,6 +1228,74 @@ float fmha_bwd_v3_group_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
);
|
||||
}}
|
||||
|
||||
// SWA supposes to include following circumstances:
|
||||
// 1. FA style SWA: t/b: mask_left > 0 or mask_right > 0
|
||||
// 2. xformer style SWA: xt / xb: window_size > 0
|
||||
// 3. generic style SWA: g: x, y
|
||||
// after preprocessing, 1 & 2 can be unioned into:
|
||||
// mask_type == mask_top_left or mask_bottom_right
|
||||
// left > 0 or right > 0
|
||||
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
|
||||
float fmha_bwd_v3_swa_genl_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
|
||||
fmha_bwd_v3_swa_genl_args args;
|
||||
args.ptr_dq = a.dq_acc_ptr;
|
||||
args.ptr_dk = a.dk_ptr;
|
||||
args.ptr_dv = a.dv_ptr;
|
||||
args.ptr_q = a.q_ptr;
|
||||
args.ptr_k = a.k_ptr;
|
||||
args.ptr_v = a.v_ptr;
|
||||
args.ptr_do = a.do_ptr;
|
||||
args.ptr_lse = a.lse_ptr;
|
||||
args.ptr_d = a.d_ptr;
|
||||
args.scalar = a.scale;
|
||||
args.log2e = ck_tile::log2e_v<float>;
|
||||
args.ratio = a.nhead_q / a.nhead_k;
|
||||
args.seqlen_q = a.seqlen_q;
|
||||
args.seqlen_k = a.seqlen_k;
|
||||
args.head_dim = a.hdim_q;
|
||||
args.nhead_q = a.nhead_q;
|
||||
args.Hs_q = a.nhead_stride_q * 2;
|
||||
args.BAs_q = a.batch_stride_q * 2;
|
||||
args.Seqs_q = a.stride_q * 2;
|
||||
args.Hs_k = a.nhead_stride_k * 2;
|
||||
args.BAs_k = a.batch_stride_k * 2;
|
||||
args.Seqs_k = a.stride_k * 2;
|
||||
args.Hs_v = a.nhead_stride_v * 2;
|
||||
args.BAs_v = a.batch_stride_v * 2;
|
||||
args.Seqs_v = a.stride_v * 2;
|
||||
args.Hs_do = a.nhead_stride_do * 2;
|
||||
args.BAs_do = a.batch_stride_do * 2;
|
||||
args.Seqs_do = a.stride_do * 2;
|
||||
args.Hs_dk = a.nhead_stride_dk * 2;
|
||||
args.BAs_dk = a.batch_stride_dk * 2;
|
||||
args.Seqs_dk = a.stride_dk * 2;
|
||||
args.Hs_dv = a.nhead_stride_dv * 2;
|
||||
args.BAs_dv = a.batch_stride_dv * 2;
|
||||
args.Seqs_dv = a.stride_dv * 2;
|
||||
|
||||
// convert l/r to x/y HERE
|
||||
auto generic_mask = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(a.window_size_left, a.window_size_right, a.seqlen_q, a.seqlen_k, (a.mask_type == static_cast<ck_tile::index_t>(mask_enum::mask_top_left) || a.mask_type == static_cast<ck_tile::index_t>(mask_enum::window_generic)));
|
||||
args.mask_y = generic_mask.at(ck_tile::number<0>{{}});
|
||||
args.mask_x = generic_mask.at(ck_tile::number<1>{{}});
|
||||
|
||||
auto traits = fmha_bwd_v3_traits{{a.batch,
|
||||
a.nhead_q,
|
||||
a.seqlen_k,
|
||||
a.hdim_q,
|
||||
a.mask_type,
|
||||
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
|
||||
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv}};
|
||||
static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
|
||||
template <>
|
||||
float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
@@ -1361,6 +1535,42 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
else if(((t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left > 0) || (a.window_size_right > 0))) || (t.mask_type == mask_enum::window_generic)){{
|
||||
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
|
||||
if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv";
|
||||
r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv";
|
||||
r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv;
|
||||
r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv";
|
||||
r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{//group mode
|
||||
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/) && (t.mask_type == mask_enum::mask_top_left)){{
|
||||
if(a.hdim_q == 128){{
|
||||
@@ -1823,6 +2033,112 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
else if(((t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left > 0) || (a.window_size_right > 0))) || (t.mask_type == mask_enum::window_generic)){{
|
||||
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
|
||||
if(t.how_v3_bf16_cvt == 0){{
|
||||
if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtne_psskddv";
|
||||
r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtne_psskddv";
|
||||
r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtne_psskddv;
|
||||
r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtne_psskddv";
|
||||
r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtna_psskddv";
|
||||
r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtna_psskddv";
|
||||
r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtna_psskddv;
|
||||
r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtna_psskddv";
|
||||
r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtz_psskddv";
|
||||
r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtz_psskddv";
|
||||
r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtz_psskddv;
|
||||
r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtz_psskddv";
|
||||
r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{//group mode
|
||||
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/) && (t.mask_type == mask_enum::mask_top_left)){{
|
||||
if(t.how_v3_bf16_cvt == 0){{
|
||||
@@ -2459,8 +2775,8 @@ class FmhaBwdDQDKDVKernel:
|
||||
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
# '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# "kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
@@ -2511,7 +2827,9 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
# CK tile example
|
||||
elif receipt == 3:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= bias in ['no']
|
||||
cond &= dbias in ['no']
|
||||
cond &= dropout in ['no']
|
||||
cond &= dpad == dvpad
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
2772
example/ck_tile/01_fmha/hsaco/bwd_hd128_bf16_swa_a32_rtz_psskddv.cpp
Normal file
2772
example/ck_tile/01_fmha/hsaco/bwd_hd128_bf16_swa_a32_rtz_psskddv.cpp
Normal file
File diff suppressed because it is too large
Load Diff
2362
example/ck_tile/01_fmha/hsaco/bwd_hd128_fp16_swa_a32_psskddv.cpp
Normal file
2362
example/ck_tile/01_fmha/hsaco/bwd_hd128_fp16_swa_a32_psskddv.cpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -83,3 +83,7 @@ extern unsigned char bwd_hd192_bf16_causal_a32_rtne_psskddv[];
|
||||
extern unsigned char bwd_hd192_bf16_causal_a32_rtz_psskddv[];
|
||||
extern unsigned char bwd_hd192_fp16_a32_psskddv[];
|
||||
extern unsigned char bwd_hd192_fp16_causal_a32_psskddv[];
|
||||
extern unsigned char bwd_hd128_fp16_swa_a32_psskddv[];
|
||||
extern unsigned char bwd_hd128_bf16_swa_a32_rtna_psskddv[];
|
||||
extern unsigned char bwd_hd128_bf16_swa_a32_rtne_psskddv[];
|
||||
extern unsigned char bwd_hd128_bf16_swa_a32_rtz_psskddv[];
|
||||
|
||||
Reference in New Issue
Block a user