From f81addbe42aca88bacb4830da1fc9ba550ca4c7d Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Mon, 25 Nov 2024 15:30:35 +0800 Subject: [PATCH] [CK_TILE] Fix fMHA fwd MakeKargs() compilation errors (#1689) * Fix mis-matched tuple<> elem types * Rename MakeKargs() as MakeKargsImpl() --------- Co-authored-by: Qianfeng [ROCm/composable_kernel commit: 645fe812f65db86a9eaca7ae00e0004c1634bc0a] --- example/ck_tile/01_fmha/fmha_bwd.hpp | 208 +++++----- example/ck_tile/01_fmha/fmha_fwd.hpp | 156 ++++---- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 232 +++++------ .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 370 +++++++++--------- 4 files changed, 484 insertions(+), 482 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 3b21a3257f..722ef15a2f 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -150,113 +150,113 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) // create group mode kernel arguments if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) { - return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.lse_ptr, - args.do_ptr, - args.d_ptr, - args.rand_val_ptr, - args.dk_ptr, - args.dv_ptr, - args.dbias_ptr, - args.dq_acc_ptr, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_do, - args.stride_dq_acc, - args.stride_dk, - args.stride_dv, - args.stride_dbias, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_do, - args.nhead_stride_lsed, - args.nhead_stride_dq_acc, - args.nhead_stride_dk, - args.nhead_stride_dv, - args.nhead_stride_dbias, - args.split_stride_dq_acc, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.drop_seed_offset); + return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.dq_acc_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dq_acc, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dq_acc, + args.nhead_stride_dk, + args.nhead_stride_dv, + args.nhead_stride_dbias, + args.split_stride_dq_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.drop_seed_offset); } else { // create batch mode kernel arguments - return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.lse_ptr, - args.do_ptr, - args.d_ptr, - args.rand_val_ptr, - args.dk_ptr, - args.dv_ptr, - args.dbias_ptr, - args.dq_acc_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_do, - args.stride_dq_acc, - args.stride_dk, - args.stride_dv, - args.stride_dbias, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_do, - args.nhead_stride_lsed, - args.nhead_stride_dq_acc, - args.nhead_stride_dk, - args.nhead_stride_dv, - args.nhead_stride_dbias, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_bias, - args.batch_stride_randval, - args.batch_stride_do, - args.batch_stride_lsed, - args.batch_stride_dq_acc, - args.batch_stride_dk, - args.batch_stride_dv, - args.batch_stride_dbias, - args.split_stride_dq_acc, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.drop_seed_offset); + return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.dq_acc_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dq_acc, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dq_acc, + args.nhead_stride_dk, + args.nhead_stride_dv, + args.nhead_stride_dbias, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_do, + args.batch_stride_lsed, + args.batch_stride_dq_acc, + args.batch_stride_dk, + args.batch_stride_dv, + args.batch_stride_dbias, + args.split_stride_dq_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.drop_seed_offset); } }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 41edac67ba..704453baa4 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -281,87 +281,87 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) // create group mode kernel arguments if constexpr(FmhaKernel::kIsGroupMode) { - return FmhaKernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_ptr, - args.o_ptr, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.scale_p, - args.scale_o, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_lse, - args.nhead_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } else { // create batch mode kernel arguments - return FmhaKernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_ptr, - args.o_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.scale_p, - args.scale_o, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_lse, - args.nhead_stride_o, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_bias, - args.batch_stride_randval, - args.batch_stride_lse, - args.batch_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index ccf15ee600..23174528e7 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -304,64 +304,64 @@ struct FmhaBwdDQDKDVKernel template CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - const void* lse_ptr, - const void* do_ptr, - const void* d_ptr, - void* rand_val_ptr, - void* dk_ptr, - void* dv_ptr, - void* dbias_ptr, - void* dq_acc_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_do, - ck_tile::index_t stride_dq_acc, - ck_tile::index_t stride_dk, - ck_tile::index_t stride_dv, - ck_tile::index_t stride_dbias, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_do, - ck_tile::index_t nhead_stride_lsed, - ck_tile::index_t nhead_stride_dq_acc, - ck_tile::index_t nhead_stride_dk, - ck_tile::index_t nhead_stride_dv, - ck_tile::index_t nhead_stride_dbias, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_bias, - ck_tile::index_t batch_stride_randval, - ck_tile::index_t batch_stride_do, - ck_tile::index_t batch_stride_lsed, - ck_tile::index_t batch_stride_dq_acc, - ck_tile::index_t batch_stride_dk, - ck_tile::index_t batch_stride_dv, - ck_tile::index_t batch_stride_dbias, - ck_tile::index_t split_stride_dq_acc, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - std::variant, std::pair> - drop_seed_offset) + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + const void* lse_ptr, + const void* do_ptr, + const void* d_ptr, + void* rand_val_ptr, + void* dk_ptr, + void* dv_ptr, + void* dbias_ptr, + void* dq_acc_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_do, + ck_tile::index_t stride_dq_acc, + ck_tile::index_t stride_dk, + ck_tile::index_t stride_dv, + ck_tile::index_t stride_dbias, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_do, + ck_tile::index_t nhead_stride_lsed, + ck_tile::index_t nhead_stride_dq_acc, + ck_tile::index_t nhead_stride_dk, + ck_tile::index_t nhead_stride_dv, + ck_tile::index_t nhead_stride_dbias, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_do, + ck_tile::index_t batch_stride_lsed, + ck_tile::index_t batch_stride_dq_acc, + ck_tile::index_t batch_stride_dk, + ck_tile::index_t batch_stride_dv, + ck_tile::index_t batch_stride_dbias, + ck_tile::index_t split_stride_dq_acc, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + std::variant, std::pair> + drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -470,7 +470,7 @@ struct FmhaBwdDQDKDVKernel return kargs; } - // std::variant can't take in a list initializer, overload for backward compatibility + // std::variant<> can't take in a list initializer, overload for backward compatibility template CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, @@ -531,7 +531,7 @@ struct FmhaBwdDQDKDVKernel float p_drop, const std::tuple& drop_seed_offset) { - return MakeKargs( + return MakeKargsImpl( q_ptr, k_ptr, v_ptr, @@ -591,7 +591,7 @@ struct FmhaBwdDQDKDVKernel std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); } - // std::variant can't take in a list initializer, overload for backward compatibility + // std::variant<> can't take in a list initializer, overload for backward compatibility template CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, @@ -650,9 +650,9 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, - const std::tuple& drop_seed_offset) + const std::tuple& drop_seed_offset) { - return MakeKargs( + return MakeKargsImpl( q_ptr, k_ptr, v_ptr, @@ -714,54 +714,54 @@ struct FmhaBwdDQDKDVKernel template CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - const void* lse_ptr, - const void* do_ptr, - const void* d_ptr, - void* rand_val_ptr, - void* dk_ptr, - void* dv_ptr, - void* dbias_ptr, - void* dq_acc_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_do, - ck_tile::index_t stride_dq_acc, - ck_tile::index_t stride_dk, - ck_tile::index_t stride_dv, - ck_tile::index_t stride_dbias, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_do, - ck_tile::index_t nhead_stride_lsed, - ck_tile::index_t nhead_stride_dq_acc, - ck_tile::index_t nhead_stride_dk, - ck_tile::index_t nhead_stride_dv, - ck_tile::index_t nhead_stride_dbias, - ck_tile::index_t split_stride_dq_acc, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - std::variant, std::pair> - drop_seed_offset) + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + const void* lse_ptr, + const void* do_ptr, + const void* d_ptr, + void* rand_val_ptr, + void* dk_ptr, + void* dv_ptr, + void* dbias_ptr, + void* dq_acc_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_do, + ck_tile::index_t stride_dq_acc, + ck_tile::index_t stride_dk, + ck_tile::index_t stride_dv, + ck_tile::index_t stride_dbias, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_do, + ck_tile::index_t nhead_stride_lsed, + ck_tile::index_t nhead_stride_dq_acc, + ck_tile::index_t nhead_stride_dk, + ck_tile::index_t nhead_stride_dv, + ck_tile::index_t nhead_stride_dbias, + ck_tile::index_t split_stride_dq_acc, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + std::variant, std::pair> + drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -858,7 +858,7 @@ struct FmhaBwdDQDKDVKernel return kargs; } - // std::variant can't take in a list initializer, overload for backward compatibility + // std::variant<> can't take in a list initializer, overload for backward compatibility template CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, @@ -909,7 +909,7 @@ struct FmhaBwdDQDKDVKernel float p_drop, const std::tuple& drop_seed_offset) { - return MakeKargs( + return MakeKargsImpl( q_ptr, k_ptr, v_ptr, @@ -959,7 +959,7 @@ struct FmhaBwdDQDKDVKernel std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); } - // std::variant can't take in a list initializer, overload for backward compatibility + // std::variant<> can't take in a list initializer, overload for backward compatibility template CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, @@ -1008,9 +1008,9 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, - const std::tuple& drop_seed_offset) + const std::tuple& drop_seed_offset) { - return MakeKargs( + return MakeKargsImpl( q_ptr, k_ptr, v_ptr, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 4443a45038..3de433d6a7 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -64,7 +64,7 @@ struct FmhaFwdKernel template <> struct t2s { static constexpr const char * name = "bf8"; }; // clang-format on - __host__ static std::string GetName() + CK_TILE_HOST static std::string GetName() { // sync with generate.py // clang-format off @@ -267,50 +267,50 @@ struct FmhaFwdKernel using Kargs = std::conditional_t; template - __host__ static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - void* rand_val_ptr, - void* lse_ptr, - void* o_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - float scale_p, - float scale_o, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_bias, - ck_tile::index_t batch_stride_randval, - ck_tile::index_t batch_stride_lse, - ck_tile::index_t batch_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - std::variant, std::pair> - drop_seed_offset) + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + std::variant, std::pair> + drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -399,9 +399,9 @@ struct FmhaFwdKernel return kargs; } - // std::variant can't take in a list initializer, overload for backward compatibility + // std::variant<> can't take in a list initializer, overload for backward compatibility template - __host__ static constexpr std::enable_if_t + CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, @@ -445,53 +445,54 @@ struct FmhaFwdKernel bool s_randval, const std::tuple& drop_seed_offset) { - MakeKargs(q_ptr, - k_ptr, - v_ptr, - bias_ptr, - rand_val_ptr, - lse_ptr, - o_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - scale_p, - scale_o, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_randval, - batch_stride_lse, - batch_stride_o, - window_size_left, - window_size_right, - mask_type, - p_drop, - s_randval, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + return MakeKargsImpl( + q_ptr, + k_ptr, + v_ptr, + bias_ptr, + rand_val_ptr, + lse_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + scale_p, + scale_o, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_randval, + batch_stride_lse, + batch_stride_o, + window_size_left, + window_size_right, + mask_type, + p_drop, + s_randval, + std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); } - // std::variant can't take in a list initializer, overload for backward compatibility + // std::variant<> can't take in a list initializer, overload for backward compatibility template - __host__ static constexpr std::enable_if_t + CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, @@ -533,91 +534,92 @@ struct FmhaFwdKernel ck_tile::index_t mask_type, float p_drop, bool s_randval, - const std::tuple& drop_seed_offset) + const std::tuple& drop_seed_offset) { - MakeKargs(q_ptr, - k_ptr, - v_ptr, - bias_ptr, - rand_val_ptr, - lse_ptr, - o_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - scale_p, - scale_o, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_randval, - batch_stride_lse, - batch_stride_o, - window_size_left, - window_size_right, - mask_type, - p_drop, - s_randval, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + return MakeKargsImpl( + q_ptr, + k_ptr, + v_ptr, + bias_ptr, + rand_val_ptr, + lse_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + scale_p, + scale_o, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_randval, + batch_stride_lse, + batch_stride_o, + window_size_left, + window_size_right, + mask_type, + p_drop, + s_randval, + std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); } template - __host__ static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - void* rand_val_ptr, - void* lse_ptr, - void* o_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - float scale_p, - float scale_o, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - std::variant, std::pair> - drop_seed_offset) + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + std::variant, std::pair> + drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -702,9 +704,9 @@ struct FmhaFwdKernel return kargs; } - // std::variant can't take in a list initializer, overload for backward compatibility + // std::variant<> can't take in a list initializer, overload for backward compatibility template - __host__ static constexpr std::enable_if_t + CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, @@ -742,7 +744,7 @@ struct FmhaFwdKernel bool s_randval, const std::tuple& drop_seed_offset) { - return MakeKargs( + return MakeKargsImpl( q_ptr, k_ptr, v_ptr, @@ -781,9 +783,9 @@ struct FmhaFwdKernel std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); } - // std::variant can't take in a list initializer, overload for backward compatibility + // std::variant<> can't take in a list initializer, overload for backward compatibility template - __host__ static constexpr std::enable_if_t + CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, @@ -819,9 +821,9 @@ struct FmhaFwdKernel ck_tile::index_t mask_type, float p_drop, bool s_randval, - const std::tuple& drop_seed_offset) + const std::tuple& drop_seed_offset) { - return MakeKargs( + return MakeKargsImpl( q_ptr, k_ptr, v_ptr, @@ -860,15 +862,15 @@ struct FmhaFwdKernel std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); } - __host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, - ck_tile::index_t nhead_, - ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_) + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) { return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_); } - __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() {