mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
qdo/kv strides split
This commit is contained in:
@@ -495,7 +495,7 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
}}
|
||||
|
||||
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_>
|
||||
float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io_perm)
|
||||
float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << std::flush;
|
||||
@@ -513,35 +513,16 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io
|
||||
args.log2e = ck_tile::log2e_v<float>;
|
||||
args.seq_len = a.seqlen_q;
|
||||
|
||||
int stride_tg = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * a.hdim_q * 2;
|
||||
int stride_head = a.seqlen_q * a.hdim_q * 2;
|
||||
int stride_batch = a.nhead_q * a.seqlen_q * a.hdim_q * 2;
|
||||
int stride_seqlen = a.hdim_q * 2;
|
||||
|
||||
int stride_head_kv = a.seqlen_q * a.hdim_q * 2;
|
||||
int stride_batch_kv = a.nhead_k * a.seqlen_q * a.hdim_q * 2;
|
||||
int stride_seqlen_kv = a.hdim_q * 2;
|
||||
int stride_seqlen_dkv = a.hdim_q * 2;
|
||||
if(io_perm == 0) //BSHD
|
||||
{{
|
||||
stride_seqlen = a.nhead_q * a.hdim_q * 2;
|
||||
stride_head = a.hdim_q * 2;
|
||||
|
||||
stride_seqlen_kv = a.nhead_k * a.hdim_q * 2;
|
||||
stride_seqlen_dkv = a.nhead_q * a.hdim_q * 2;
|
||||
stride_tg = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * stride_seqlen_kv;
|
||||
stride_head_kv = a.hdim_q * 2;
|
||||
}}
|
||||
args.Ts = stride_tg;
|
||||
args.Hs = stride_head;
|
||||
args.BAs = stride_batch;
|
||||
args.Seqs = stride_seqlen;
|
||||
args.Ts = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * a.stride_k * 2;
|
||||
args.Hs = a.nhead_stride_q * 2;
|
||||
args.BAs = a.batch_stride_q * 2;
|
||||
args.Seqs = a.stride_q * 2;
|
||||
|
||||
args.ratio = a.nhead_q / a.nhead_k;
|
||||
args.Hs_kv = stride_head_kv;
|
||||
args.BAs_kv = stride_batch_kv;
|
||||
args.Seqs_kv = stride_seqlen_kv;
|
||||
args.Seqs_dkv = stride_seqlen_dkv;
|
||||
args.Hs_kv = a.nhead_stride_k * 2;
|
||||
args.BAs_kv = a.batch_stride_k * 2;
|
||||
args.Seqs_kv = a.stride_k * 2;
|
||||
args.Seqs_dkv = a.stride_dk * 2;
|
||||
auto traits = fmha_bwd_v3_traits{{a.batch,
|
||||
a.nhead_q,
|
||||
a.seqlen_q,
|
||||
@@ -607,7 +588,7 @@ float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io_per
|
||||
}}
|
||||
|
||||
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
|
||||
float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io_perm)
|
||||
float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
|
||||
@@ -625,35 +606,16 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io
|
||||
args.log2e = ck_tile::log2e_v<float>;
|
||||
args.seq_len = a.seqlen_q;
|
||||
|
||||
int stride_tg = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * a.hdim_q * 2;
|
||||
int stride_head = a.seqlen_q * a.hdim_q * 2;
|
||||
int stride_batch = a.nhead_q * a.seqlen_q * a.hdim_q * 2;
|
||||
int stride_seqlen = a.hdim_q * 2;
|
||||
|
||||
int stride_head_kv = a.seqlen_q * a.hdim_q * 2;
|
||||
int stride_batch_kv = a.nhead_k * a.seqlen_q * a.hdim_q * 2;
|
||||
int stride_seqlen_kv = a.hdim_q * 2;
|
||||
int stride_seqlen_dkv = a.hdim_q * 2;
|
||||
if(io_perm == 0) //BSHD
|
||||
{{
|
||||
stride_seqlen = a.nhead_q * a.hdim_q * 2;
|
||||
stride_head = a.hdim_q * 2;
|
||||
|
||||
stride_seqlen_kv = a.nhead_k * a.hdim_q * 2;
|
||||
stride_seqlen_dkv = a.nhead_q * a.hdim_q * 2;
|
||||
stride_tg = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * stride_seqlen_kv;
|
||||
stride_head_kv = a.hdim_q * 2;
|
||||
}}
|
||||
args.Ts = stride_tg;
|
||||
args.Hs = stride_head;
|
||||
args.BAs = stride_batch;
|
||||
args.Seqs = stride_seqlen;
|
||||
args.Ts = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * a.stride_k * 2;
|
||||
args.Hs = a.nhead_stride_q * 2;
|
||||
args.BAs = a.batch_stride_q * 2;
|
||||
args.Seqs = a.stride_q * 2;
|
||||
|
||||
args.ratio = a.nhead_q / a.nhead_k;
|
||||
args.Hs_kv = stride_head_kv;
|
||||
args.BAs_kv = stride_batch_kv;
|
||||
args.Seqs_kv = stride_seqlen_kv;
|
||||
args.Seqs_dkv = stride_seqlen_dkv;
|
||||
args.Hs_kv = a.nhead_stride_k * 2;
|
||||
args.BAs_kv = a.batch_stride_k * 2;
|
||||
args.Seqs_kv = a.stride_k * 2;
|
||||
args.Seqs_dkv = a.stride_dk * 2;
|
||||
auto traits = fmha_bwd_v3_traits{{a.batch,
|
||||
a.nhead_q,
|
||||
a.seqlen_q,
|
||||
@@ -694,8 +656,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, false, 0>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_fp16_a32";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
@@ -703,8 +664,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, false, false, 0>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_fp16_a16";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
@@ -724,8 +684,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, false, 0>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
@@ -733,8 +692,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, false, false, 0>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
@@ -757,8 +715,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 0>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
@@ -766,8 +723,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 1>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
@@ -775,8 +731,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 2>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
@@ -786,24 +741,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 0>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 1>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 2>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
@@ -825,8 +777,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 0>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
@@ -834,8 +785,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 1>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
@@ -843,8 +793,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 2>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
@@ -854,24 +803,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 0>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 1>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 2>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
@@ -887,8 +833,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, true, false, 0>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
@@ -896,8 +841,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, true, false, 1>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
@@ -905,8 +849,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, true, false, 2>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
@@ -915,24 +858,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, false, false, 0>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, false, false, 1>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, false, false, 2>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
@@ -944,8 +884,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, true, false, 0>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
@@ -953,8 +892,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, true, false, 1>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
@@ -962,8 +900,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, true, false, 2>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
@@ -972,24 +909,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, false, false, 0>;
|
||||
const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 1){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, false, false, 1>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
else if(t.how_v3_bf16_cvt == 2){{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>;
|
||||
using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, false, false, 2>;
|
||||
// const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz";
|
||||
bool io_perm = a.nhead_stride_q > a.stride_q;
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a, io_perm);
|
||||
r = fmha_bwd_v3_xqa_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
}}
|
||||
|
||||
Reference in New Issue
Block a user