support group mode for hd=64 of fa bwd v3 (#1990)

* support group mode for hd=64 of fa bwd v3

* bugfixed for causal mask kernels when using kernel balence

* tiny align

---------

Co-authored-by: Wen.Yang <Wen.Yang@example.com>
Co-authored-by: danyao12 <danyao12@amd.com>
This commit is contained in:
wen-des
2025-03-21 11:42:58 +08:00
committed by GitHub
parent e80ff1acbb
commit 8a25aa2669
10 changed files with 17636 additions and 121 deletions

460
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py Normal file → Executable file
View File

@@ -348,6 +348,55 @@ struct __attribute__((packed)) fmha_bwd_v3_genl_args
p1 _p23;
}};
struct __attribute__((packed)) fmha_bwd_v3_group_args
{{
void* ptr_dq;
void* ptr_dk;
void* ptr_dv;
const void* ptr_q;
const void* ptr_k;
const void* ptr_v;
const void* ptr_do;
const void* ptr_lse;
const void* ptr_d;
const void* ptr_qseq;
const void* ptr_kseq;
float scalar;
p1 _p0;
float log2e;
p1 _p1;
unsigned int ratio;
p1 _p2;
unsigned int seqlen_q; //total length of q sequences
p1 _p3;
unsigned int seqlen_k; //total length of k sequences
p1 _p4;
unsigned int Hs_q;
p1 _p5;
unsigned int Seqs_q;
p1 _p6;
unsigned int Hs_k;
p1 _p7;
unsigned int Seqs_k;
p1 _p8;
unsigned int Hs_v;
p1 _p9;
unsigned int Seqs_v;
p1 _p10;
unsigned int Hs_do;
p1 _p11;
unsigned int Seqs_do;
p1 _p12;
unsigned int Hs_dk;
p1 _p13;
unsigned int Seqs_dk;
p1 _p14;
unsigned int Hs_dv;
p1 _p15;
unsigned int Seqs_dv;
p1 _p16;
}};
struct fmha_bwd_v3_traits
{{
int b;
@@ -366,7 +415,8 @@ template <ck_tile::index_t HDim_,
bool kIsAtomic32_,
ck_tile::index_t BF16Cvt_,
bool kIsSEQPad_,
bool kIsHDPad_>
bool kIsHDPad_,
bool kIsGroupMode_ = false>
struct fmha_bwd_dq_dk_dv_v3_traits_
{{
static constexpr ck_tile::index_t HDim = HDim_;
@@ -376,6 +426,7 @@ struct fmha_bwd_dq_dk_dv_v3_traits_
static constexpr ck_tile::index_t BF16Cvt = BF16Cvt_;
static constexpr bool kIsSEQPad = kIsSEQPad_;
static constexpr bool kIsHDPad = kIsHDPad_;
static constexpr bool kIsGroupMode = kIsGroupMode_;
}};
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Name;
@@ -428,6 +479,15 @@ template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, false, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_causal_a16"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk"; }};
// ########################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode|
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 1, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 2, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 1, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 2, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk_group"; }};
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Buf;
// #######################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|
@@ -479,6 +539,15 @@ template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_a32_pssk; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, false, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_causal_a16; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_causal_a32_pssk; }};
// #######################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode|
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtne_pssk_group; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 1, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtna_pssk_group; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 2, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtz_pssk_group; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtne_pssk_group; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 1, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtna_pssk_group; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 2, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtz_pssk_group; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_a32_pssk_group; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_causal_a32_pssk_group; }};
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Ts;
// ######################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|
@@ -530,6 +599,15 @@ template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, false, 0, false, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
// ######################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode|
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 0, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 1, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 2, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 0, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 1, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 2, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
class fmha_bwd_v3_kernel
{{
@@ -636,6 +714,34 @@ class fmha_bwd_v3_kernel
reinterpret_cast<void**>(&config)));
}}
void
launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_v3_group_args args, const ck_tile::stream_config& s) const
{{
size_t arg_size = sizeof(args);
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
&args,
HIP_LAUNCH_PARAM_BUFFER_SIZE,
&arg_size,
HIP_LAUNCH_PARAM_END}};
int gdx = (fmha_v3_traits.s + fmha_v3_traits.ts_kv - 1) / fmha_v3_traits.ts_kv;
if(fmha_v3_traits.mask > 0)
{{
gdx = (gdx % 2) ? (gdx / 2 + 1) : (gdx / 2);
}}
HIP_CALL(hipModuleLaunchKernel(kernel_func,
gdx,
fmha_v3_traits.h, /*gdy*/
fmha_v3_traits.b, /*gdz*/
256, /*bdx*/
1, /*bdy*/
1, /*bdz*/
0,
s.stream_id_,
NULL,
reinterpret_cast<void**>(&config)));
}}
private:
hipModule_t module;
hipFunction_t kernel_func;
@@ -884,14 +990,68 @@ float fmha_bwd_v3_genl_(const ck_tile::stream_config& s, fmha_bwd_args a)
);
}}
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
float fmha_bwd_v3_group_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
fmha_bwd_v3_group_args args;
auto seqstart_q = reinterpret_cast<const int32_t*>(a.seqstart_q_ptr);
auto seqstart_k = reinterpret_cast<const int32_t*>(a.seqstart_k_ptr);
args.ptr_dq = a.dq_acc_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
args.ptr_q = a.q_ptr;
args.ptr_k = a.k_ptr;
args.ptr_v = a.v_ptr;
args.ptr_do = a.do_ptr;
args.ptr_lse = a.lse_ptr;
args.ptr_d = a.d_ptr;
args.scalar = a.scale;
args.log2e = ck_tile::log2e_v<float>;
args.ratio = a.nhead_q / a.nhead_k;
args.seqlen_q = seqstart_q[a.batch];
args.seqlen_k = seqstart_k[a.batch];
args.Hs_q = a.nhead_stride_q * 2;
args.Seqs_q = a.stride_q * 2;
args.Hs_k = a.nhead_stride_k * 2;
args.Seqs_k = a.stride_k * 2;
args.Hs_v = a.nhead_stride_v * 2;
args.Seqs_v = a.stride_v * 2;
args.Hs_do = a.nhead_stride_do * 2;
args.Seqs_do = a.stride_do * 2;
args.Hs_dk = a.nhead_stride_dk * 2;
args.Seqs_dk = a.stride_dk * 2;
args.Hs_dv = a.nhead_stride_dv * 2;
args.Seqs_dv = a.stride_dv * 2;
args.ptr_qseq = a.seqstart_q_ptr;
args.ptr_kseq = a.seqstart_k_ptr;
auto traits = fmha_bwd_v3_traits{{ a.batch,
a.nhead_q,
a.max_seqlen_k,
a.hdim_q,
a.mask_type,
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv }};
static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
);
}}
template <>
float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
float r = -1;
if (t.uses_bwd_v3 == true){{
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
if ((t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
(t.is_deterministic == false) && (a.hdim_q == a.hdim_v) && (a.nhead_q % a.nhead_k == 0)) {{
if((a.hdim_q > 64) && (a.hdim_q <= 128) && (a.hdim_q % 8 == 0)){{
if((t.is_group_mode == false) && (a.hdim_q > 64) && (a.hdim_q <= 128) && (a.hdim_q % 8 == 0)){{
if(t.data_type.compare("fp16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
@@ -1413,20 +1573,30 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf
if(t.data_type.compare("fp16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(a.seqlen_q % 64 == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
if(t.is_group_mode == false){{
if(a.seqlen_q % 64 == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, true, true, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, true, true, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk_group";
r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
@@ -1443,20 +1613,30 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{
if(a.seqlen_q % 64 == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
if(t.is_group_mode == false){{
if(a.seqlen_q % 64 == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, true, true, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false, true>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, true, true, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk_group";
r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
@@ -1475,59 +1655,78 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf
else if(t.data_type.compare("bf16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.how_v3_bf16_cvt == 0){{
if(a.seqlen_q % 64 == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
if(t.is_group_mode == false){{
if(t.how_v3_bf16_cvt == 0){{
if(a.seqlen_q % 64 == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
else if(t.how_v3_bf16_cvt == 1){{
if(a.seqlen_q % 64 == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 2){{
if(a.seqlen_q % 64 == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
}}
else if(t.how_v3_bf16_cvt == 1){{
if(a.seqlen_q % 64 == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, true, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, true, true, false, false>;
if(t.how_v3_bf16_cvt == 0){{
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false, true>;
r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
}}
else if(t.how_v3_bf16_cvt == 1){{
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false, true>;
r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 2){{
if(a.seqlen_q % 64 == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false, true>;
r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
}}
return r;
}}
}}
else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
@@ -1559,59 +1758,78 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{
if(t.how_v3_bf16_cvt == 0){{
if(a.seqlen_q % 64 == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
if(t.is_group_mode == false){{
if(t.how_v3_bf16_cvt == 0){{
if(a.seqlen_q % 64 == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
else if(t.how_v3_bf16_cvt == 1){{
if(a.seqlen_q % 64 == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 2){{
if(a.seqlen_q % 64 == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
}}
else if(t.how_v3_bf16_cvt == 1){{
if(a.seqlen_q % 64 == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, true, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, true, true, false, false>;
if(t.how_v3_bf16_cvt == 0){{
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false, true>;
r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
}}
else if(t.how_v3_bf16_cvt == 1){{
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false, true>;
r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
}}
else if(t.how_v3_bf16_cvt == 2){{
if(a.seqlen_q % 64 == 0){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk";
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
return r;
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false, true>;
r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
}}
return r;
}}
}}
}}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

8
example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp Normal file → Executable file
View File

@@ -51,3 +51,11 @@ extern unsigned char bwd_hd64_fp16_a16[];
extern unsigned char bwd_hd64_fp16_a32_pssk[];
extern unsigned char bwd_hd64_fp16_causal_a16[];
extern unsigned char bwd_hd64_fp16_causal_a32_pssk[];
extern unsigned char bwd_hd64_bf16_a32_rtna_pssk_group[];
extern unsigned char bwd_hd64_bf16_a32_rtne_pssk_group[];
extern unsigned char bwd_hd64_bf16_a32_rtz_pssk_group[];
extern unsigned char bwd_hd64_bf16_causal_a32_rtna_pssk_group[];
extern unsigned char bwd_hd64_bf16_causal_a32_rtne_pssk_group[];
extern unsigned char bwd_hd64_bf16_causal_a32_rtz_pssk_group[];
extern unsigned char bwd_hd64_fp16_a32_pssk_group[];
extern unsigned char bwd_hd64_fp16_causal_a32_pssk_group[];