From 0c126ffc19fe36a651769515ac862ca21b792d23 Mon Sep 17 00:00:00 2001 From: danyao12 Date: Fri, 3 Jan 2025 16:01:30 +0800 Subject: [PATCH] qdo/kv strides split --- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 158 +++++------------- 1 file changed, 46 insertions(+), 112 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 6cf47e5a8a..b8025612cd 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -495,7 +495,7 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) }} template -float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io_perm) +float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ if(s.log_level_ > 0) std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << FmhaBwdV3Name::bwd_v3_name << std::flush; @@ -513,35 +513,16 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io args.log2e = ck_tile::log2e_v; args.seq_len = a.seqlen_q; - int stride_tg = FmhaBwdV3Ts::ts_kv * a.hdim_q * 2; - int stride_head = a.seqlen_q * a.hdim_q * 2; - int stride_batch = a.nhead_q * a.seqlen_q * a.hdim_q * 2; - int stride_seqlen = a.hdim_q * 2; - - int stride_head_kv = a.seqlen_q * a.hdim_q * 2; - int stride_batch_kv = a.nhead_k * a.seqlen_q * a.hdim_q * 2; - int stride_seqlen_kv = a.hdim_q * 2; - int stride_seqlen_dkv = a.hdim_q * 2; - if(io_perm == 0) //BSHD - {{ - stride_seqlen = a.nhead_q * a.hdim_q * 2; - stride_head = a.hdim_q * 2; - - stride_seqlen_kv = a.nhead_k * a.hdim_q * 2; - stride_seqlen_dkv = a.nhead_q * a.hdim_q * 2; - stride_tg = FmhaBwdV3Ts::ts_kv * stride_seqlen_kv; - stride_head_kv = a.hdim_q * 2; - }} - args.Ts = stride_tg; - args.Hs = stride_head; - args.BAs = stride_batch; - args.Seqs = stride_seqlen; + args.Ts = FmhaBwdV3Ts::ts_kv * a.stride_k * 2; + args.Hs = a.nhead_stride_q * 2; + args.BAs = a.batch_stride_q * 2; + args.Seqs = a.stride_q * 2; args.ratio = a.nhead_q / a.nhead_k; - args.Hs_kv = stride_head_kv; - args.BAs_kv = stride_batch_kv; - args.Seqs_kv = stride_seqlen_kv; - args.Seqs_dkv = stride_seqlen_dkv; + args.Hs_kv = a.nhead_stride_k * 2; + args.BAs_kv = a.batch_stride_k * 2; + args.Seqs_kv = a.stride_k * 2; + args.Seqs_dkv = a.stride_dk * 2; auto traits = fmha_bwd_v3_traits{{a.batch, a.nhead_q, a.seqlen_q, @@ -607,7 +588,7 @@ float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io_per }} template -float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io_perm) +float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ if(s.log_level_ > 0) std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << FmhaBwdV3Name::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_() << std::flush; @@ -625,35 +606,16 @@ float fmha_bwd_v3_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, bool io args.log2e = ck_tile::log2e_v; args.seq_len = a.seqlen_q; - int stride_tg = FmhaBwdV3Ts::ts_kv * a.hdim_q * 2; - int stride_head = a.seqlen_q * a.hdim_q * 2; - int stride_batch = a.nhead_q * a.seqlen_q * a.hdim_q * 2; - int stride_seqlen = a.hdim_q * 2; - - int stride_head_kv = a.seqlen_q * a.hdim_q * 2; - int stride_batch_kv = a.nhead_k * a.seqlen_q * a.hdim_q * 2; - int stride_seqlen_kv = a.hdim_q * 2; - int stride_seqlen_dkv = a.hdim_q * 2; - if(io_perm == 0) //BSHD - {{ - stride_seqlen = a.nhead_q * a.hdim_q * 2; - stride_head = a.hdim_q * 2; - - stride_seqlen_kv = a.nhead_k * a.hdim_q * 2; - stride_seqlen_dkv = a.nhead_q * a.hdim_q * 2; - stride_tg = FmhaBwdV3Ts::ts_kv * stride_seqlen_kv; - stride_head_kv = a.hdim_q * 2; - }} - args.Ts = stride_tg; - args.Hs = stride_head; - args.BAs = stride_batch; - args.Seqs = stride_seqlen; + args.Ts = FmhaBwdV3Ts::ts_kv * a.stride_k * 2; + args.Hs = a.nhead_stride_q * 2; + args.BAs = a.batch_stride_q * 2; + args.Seqs = a.stride_q * 2; args.ratio = a.nhead_q / a.nhead_k; - args.Hs_kv = stride_head_kv; - args.BAs_kv = stride_batch_kv; - args.Seqs_kv = stride_seqlen_kv; - args.Seqs_dkv = stride_seqlen_dkv; + args.Hs_kv = a.nhead_stride_k * 2; + args.BAs_kv = a.batch_stride_k * 2; + args.Seqs_kv = a.stride_k * 2; + args.Seqs_dkv = a.stride_dk * 2; auto traits = fmha_bwd_v3_traits{{a.batch, a.nhead_q, a.seqlen_q, @@ -694,8 +656,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, true, false, 0>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_fp16_a32"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} }} @@ -703,8 +664,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_fp16_a16"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} }} @@ -724,8 +684,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, true, false, 0>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a32"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} }} @@ -733,8 +692,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::fp16_t, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_fp16_causal_a16"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} }} @@ -757,8 +715,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 0>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtne"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 1){{ @@ -766,8 +723,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 1>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtna"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 2){{ @@ -775,8 +731,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, true, false, 2>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_a32_rtz"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} }} @@ -786,24 +741,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtne"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 1){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 1>; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtna"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 2){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, false, false, false, 2>; // const std::string bwd_v3_name = "bwd_v3_bf16_a16_rtz"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} }} @@ -825,8 +777,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 0>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtne"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 1){{ @@ -834,8 +785,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 1>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtna"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 2){{ @@ -843,8 +793,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, true, false, 2>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a32_rtz"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} }} @@ -854,24 +803,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtne"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 1){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 1>; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtna"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 2){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, ck_tile::bf16_t, true, false, false, 2>; // const std::string bwd_v3_name = "bwd_v3_bf16_causal_a16_rtz"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} }} @@ -887,8 +833,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, true, false, 0>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 1){{ @@ -896,8 +841,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, true, false, 1>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 2){{ @@ -905,8 +849,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, true, false, 2>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} }} @@ -915,24 +858,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 1){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, false, false, 1>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 2){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, false, false, false, 2>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} }} @@ -944,8 +884,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, true, false, 0>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 1){{ @@ -953,8 +892,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, true, false, 1>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 2){{ @@ -962,8 +900,7 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, true, false, 2>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, ck_tile::bf16_t, false, false, false, false>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} }} @@ -972,24 +909,21 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, false, false, 0>; const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 1){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, false, false, 1>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} else if(t.how_v3_bf16_cvt == 2){{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, ck_tile::bf16_t, false, false, false>; using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, ck_tile::bf16_t, true, false, false, 2>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz"; - bool io_perm = a.nhead_stride_q > a.stride_q; - r = fmha_bwd_v3_xqa_(s, a, io_perm); + r = fmha_bwd_v3_xqa_(s, a); return r; }} }}