mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
support group mode for hd=64 of fa bwd v3 (#1990)
* support group mode for hd=64 of fa bwd v3 * bugfixed for causal mask kernels when using kernel balence * tiny align --------- Co-authored-by: Wen.Yang <Wen.Yang@example.com> Co-authored-by: danyao12 <danyao12@amd.com>
This commit is contained in:
460
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
Normal file → Executable file
460
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
Normal file → Executable file
@@ -348,6 +348,55 @@ struct __attribute__((packed)) fmha_bwd_v3_genl_args
|
||||
p1 _p23;
|
||||
}};
|
||||
|
||||
struct __attribute__((packed)) fmha_bwd_v3_group_args
|
||||
{{
|
||||
void* ptr_dq;
|
||||
void* ptr_dk;
|
||||
void* ptr_dv;
|
||||
const void* ptr_q;
|
||||
const void* ptr_k;
|
||||
const void* ptr_v;
|
||||
const void* ptr_do;
|
||||
const void* ptr_lse;
|
||||
const void* ptr_d;
|
||||
const void* ptr_qseq;
|
||||
const void* ptr_kseq;
|
||||
float scalar;
|
||||
p1 _p0;
|
||||
float log2e;
|
||||
p1 _p1;
|
||||
unsigned int ratio;
|
||||
p1 _p2;
|
||||
unsigned int seqlen_q; //total length of q sequences
|
||||
p1 _p3;
|
||||
unsigned int seqlen_k; //total length of k sequences
|
||||
p1 _p4;
|
||||
unsigned int Hs_q;
|
||||
p1 _p5;
|
||||
unsigned int Seqs_q;
|
||||
p1 _p6;
|
||||
unsigned int Hs_k;
|
||||
p1 _p7;
|
||||
unsigned int Seqs_k;
|
||||
p1 _p8;
|
||||
unsigned int Hs_v;
|
||||
p1 _p9;
|
||||
unsigned int Seqs_v;
|
||||
p1 _p10;
|
||||
unsigned int Hs_do;
|
||||
p1 _p11;
|
||||
unsigned int Seqs_do;
|
||||
p1 _p12;
|
||||
unsigned int Hs_dk;
|
||||
p1 _p13;
|
||||
unsigned int Seqs_dk;
|
||||
p1 _p14;
|
||||
unsigned int Hs_dv;
|
||||
p1 _p15;
|
||||
unsigned int Seqs_dv;
|
||||
p1 _p16;
|
||||
}};
|
||||
|
||||
struct fmha_bwd_v3_traits
|
||||
{{
|
||||
int b;
|
||||
@@ -366,7 +415,8 @@ template <ck_tile::index_t HDim_,
|
||||
bool kIsAtomic32_,
|
||||
ck_tile::index_t BF16Cvt_,
|
||||
bool kIsSEQPad_,
|
||||
bool kIsHDPad_>
|
||||
bool kIsHDPad_,
|
||||
bool kIsGroupMode_ = false>
|
||||
struct fmha_bwd_dq_dk_dv_v3_traits_
|
||||
{{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
@@ -376,6 +426,7 @@ struct fmha_bwd_dq_dk_dv_v3_traits_
|
||||
static constexpr ck_tile::index_t BF16Cvt = BF16Cvt_;
|
||||
static constexpr bool kIsSEQPad = kIsSEQPad_;
|
||||
static constexpr bool kIsHDPad = kIsHDPad_;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
}};
|
||||
|
||||
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Name;
|
||||
@@ -428,6 +479,15 @@ template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, false, 0, false, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_causal_a16"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, true, false>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk"; }};
|
||||
// ########################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode|
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 1, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 2, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 1, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 2, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk_group"; }};
|
||||
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, true, false, true>> {{ static constexpr const char * bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk_group"; }};
|
||||
|
||||
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Buf;
|
||||
// #######################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|
|
||||
@@ -479,6 +539,15 @@ template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_a32_pssk; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, false, 0, false, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_causal_a16; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, true, false>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_causal_a32_pssk; }};
|
||||
// #######################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode|
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtne_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 1, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtna_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 2, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_a32_rtz_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtne_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 1, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtna_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 2, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_bf16_causal_a32_rtz_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_a32_pssk_group; }};
|
||||
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, true, false, true>> {{ static constexpr unsigned char * bwd_v3_buf = bwd_hd64_fp16_causal_a32_pssk_group; }};
|
||||
|
||||
template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Ts;
|
||||
// ######################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|
|
||||
@@ -530,6 +599,15 @@ template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, false, 0, false, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, true, false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
// ######################################################|HDim| DataType|kIsCausal|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode|
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 0, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 1, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, false, true, 2, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 0, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 1, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16, true, true, 2, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, false, true, 0, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16, true, true, 0, true, false, true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
|
||||
|
||||
class fmha_bwd_v3_kernel
|
||||
{{
|
||||
@@ -636,6 +714,34 @@ class fmha_bwd_v3_kernel
|
||||
reinterpret_cast<void**>(&config)));
|
||||
}}
|
||||
|
||||
void
|
||||
launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_v3_group_args args, const ck_tile::stream_config& s) const
|
||||
{{
|
||||
size_t arg_size = sizeof(args);
|
||||
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
|
||||
&args,
|
||||
HIP_LAUNCH_PARAM_BUFFER_SIZE,
|
||||
&arg_size,
|
||||
HIP_LAUNCH_PARAM_END}};
|
||||
|
||||
int gdx = (fmha_v3_traits.s + fmha_v3_traits.ts_kv - 1) / fmha_v3_traits.ts_kv;
|
||||
if(fmha_v3_traits.mask > 0)
|
||||
{{
|
||||
gdx = (gdx % 2) ? (gdx / 2 + 1) : (gdx / 2);
|
||||
}}
|
||||
HIP_CALL(hipModuleLaunchKernel(kernel_func,
|
||||
gdx,
|
||||
fmha_v3_traits.h, /*gdy*/
|
||||
fmha_v3_traits.b, /*gdz*/
|
||||
256, /*bdx*/
|
||||
1, /*bdy*/
|
||||
1, /*bdz*/
|
||||
0,
|
||||
s.stream_id_,
|
||||
NULL,
|
||||
reinterpret_cast<void**>(&config)));
|
||||
}}
|
||||
|
||||
private:
|
||||
hipModule_t module;
|
||||
hipFunction_t kernel_func;
|
||||
@@ -884,14 +990,68 @@ float fmha_bwd_v3_genl_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
);
|
||||
}}
|
||||
|
||||
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
|
||||
float fmha_bwd_v3_group_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
|
||||
|
||||
fmha_bwd_v3_group_args args;
|
||||
auto seqstart_q = reinterpret_cast<const int32_t*>(a.seqstart_q_ptr);
|
||||
auto seqstart_k = reinterpret_cast<const int32_t*>(a.seqstart_k_ptr);
|
||||
args.ptr_dq = a.dq_acc_ptr;
|
||||
args.ptr_dk = a.dk_ptr;
|
||||
args.ptr_dv = a.dv_ptr;
|
||||
args.ptr_q = a.q_ptr;
|
||||
args.ptr_k = a.k_ptr;
|
||||
args.ptr_v = a.v_ptr;
|
||||
args.ptr_do = a.do_ptr;
|
||||
args.ptr_lse = a.lse_ptr;
|
||||
args.ptr_d = a.d_ptr;
|
||||
|
||||
args.scalar = a.scale;
|
||||
args.log2e = ck_tile::log2e_v<float>;
|
||||
args.ratio = a.nhead_q / a.nhead_k;
|
||||
args.seqlen_q = seqstart_q[a.batch];
|
||||
args.seqlen_k = seqstart_k[a.batch];
|
||||
args.Hs_q = a.nhead_stride_q * 2;
|
||||
args.Seqs_q = a.stride_q * 2;
|
||||
args.Hs_k = a.nhead_stride_k * 2;
|
||||
args.Seqs_k = a.stride_k * 2;
|
||||
args.Hs_v = a.nhead_stride_v * 2;
|
||||
args.Seqs_v = a.stride_v * 2;
|
||||
args.Hs_do = a.nhead_stride_do * 2;
|
||||
args.Seqs_do = a.stride_do * 2;
|
||||
args.Hs_dk = a.nhead_stride_dk * 2;
|
||||
args.Seqs_dk = a.stride_dk * 2;
|
||||
args.Hs_dv = a.nhead_stride_dv * 2;
|
||||
args.Seqs_dv = a.stride_dv * 2;
|
||||
args.ptr_qseq = a.seqstart_q_ptr;
|
||||
args.ptr_kseq = a.seqstart_k_ptr;
|
||||
|
||||
auto traits = fmha_bwd_v3_traits{{ a.batch,
|
||||
a.nhead_q,
|
||||
a.max_seqlen_k,
|
||||
a.hdim_q,
|
||||
a.mask_type,
|
||||
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
|
||||
FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv }};
|
||||
static thread_local fmha_bwd_v3_kernel impl(HSA_KERNEL, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
|
||||
template <>
|
||||
float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
|
||||
if (t.uses_bwd_v3 == true){{
|
||||
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
|
||||
if ((t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
|
||||
(t.is_deterministic == false) && (a.hdim_q == a.hdim_v) && (a.nhead_q % a.nhead_k == 0)) {{
|
||||
if((a.hdim_q > 64) && (a.hdim_q <= 128) && (a.hdim_q % 8 == 0)){{
|
||||
if((t.is_group_mode == false) && (a.hdim_q > 64) && (a.hdim_q <= 128) && (a.hdim_q % 8 == 0)){{
|
||||
if(t.data_type.compare("fp16") == 0){{
|
||||
if(t.mask_type == mask_enum::no_mask){{
|
||||
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
|
||||
@@ -1413,20 +1573,30 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf
|
||||
if(t.data_type.compare("fp16") == 0){{
|
||||
if(t.mask_type == mask_enum::no_mask){{
|
||||
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
if(t.is_group_mode == false){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, true, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, true, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk_group";
|
||||
r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
@@ -1443,20 +1613,30 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf
|
||||
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
|
||||
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
|
||||
if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
if(t.is_group_mode == false){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, true, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false, true>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, true, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk_group";
|
||||
r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
@@ -1475,59 +1655,78 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf
|
||||
else if(t.data_type.compare("bf16") == 0){{
|
||||
if(t.mask_type == mask_enum::no_mask){{
|
||||
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
|
||||
if(t.how_v3_bf16_cvt == 0){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
if(t.is_group_mode == false){{
|
||||
if(t.how_v3_bf16_cvt == 0){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, true, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, true, true, false, false>;
|
||||
if(t.how_v3_bf16_cvt == 0){{
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false, true>;
|
||||
r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false, true>;
|
||||
r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false, true>;
|
||||
r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
}}
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
|
||||
@@ -1559,59 +1758,78 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf
|
||||
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
|
||||
if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
|
||||
if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{
|
||||
if(t.how_v3_bf16_cvt == 0){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
if(t.is_group_mode == false){{
|
||||
if(t.how_v3_bf16_cvt == 0){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, true, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, true, true, false, false>;
|
||||
if(t.how_v3_bf16_cvt == 0){{
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false, true>;
|
||||
r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false, true>;
|
||||
r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
if(a.seqlen_q % 64 == 0){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else{{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk";
|
||||
r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false, true>;
|
||||
r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
}}
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
|
||||
2140
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_a32_rtna_pssk_group.cpp
Executable file
2140
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_a32_rtna_pssk_group.cpp
Executable file
File diff suppressed because it is too large
Load Diff
2260
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_a32_rtne_pssk_group.cpp
Executable file
2260
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_a32_rtne_pssk_group.cpp
Executable file
File diff suppressed because it is too large
Load Diff
1852
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_a32_rtz_pssk_group.cpp
Executable file
1852
example/ck_tile/01_fmha/hsaco/bwd_hd64_bf16_a32_rtz_pssk_group.cpp
Executable file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
1647
example/ck_tile/01_fmha/hsaco/bwd_hd64_fp16_a32_pssk_group.cpp
Executable file
1647
example/ck_tile/01_fmha/hsaco/bwd_hd64_fp16_a32_pssk_group.cpp
Executable file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
8
example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp
Normal file → Executable file
8
example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp
Normal file → Executable file
@@ -51,3 +51,11 @@ extern unsigned char bwd_hd64_fp16_a16[];
|
||||
extern unsigned char bwd_hd64_fp16_a32_pssk[];
|
||||
extern unsigned char bwd_hd64_fp16_causal_a16[];
|
||||
extern unsigned char bwd_hd64_fp16_causal_a32_pssk[];
|
||||
extern unsigned char bwd_hd64_bf16_a32_rtna_pssk_group[];
|
||||
extern unsigned char bwd_hd64_bf16_a32_rtne_pssk_group[];
|
||||
extern unsigned char bwd_hd64_bf16_a32_rtz_pssk_group[];
|
||||
extern unsigned char bwd_hd64_bf16_causal_a32_rtna_pssk_group[];
|
||||
extern unsigned char bwd_hd64_bf16_causal_a32_rtne_pssk_group[];
|
||||
extern unsigned char bwd_hd64_bf16_causal_a32_rtz_pssk_group[];
|
||||
extern unsigned char bwd_hd64_fp16_a32_pssk_group[];
|
||||
extern unsigned char bwd_hd64_fp16_causal_a32_pssk_group[];
|
||||
|
||||
Reference in New Issue
Block a user