[CK_TILE] Fix fMHA fwd MakeKargs() compilation errors (#1689)

* Fix mis-matched tuple<> elem types

* Rename MakeKargs() as MakeKargsImpl()

---------

Co-authored-by: Qianfeng <qianfeng.zhang@amd.com>
This commit is contained in:
Po Yen Chen
2024-11-25 15:30:35 +08:00
committed by GitHub
parent c2bcbb1379
commit 645fe812f6
4 changed files with 484 additions and 482 deletions

View File

@@ -304,64 +304,64 @@ struct FmhaBwdDQDKDVKernel
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -470,7 +470,7 @@ struct FmhaBwdDQDKDVKernel
return kargs;
}
// std::variant can't take in a list initializer, overload for backward compatibility
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
@@ -531,7 +531,7 @@ struct FmhaBwdDQDKDVKernel
float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargs(
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
@@ -591,7 +591,7 @@ struct FmhaBwdDQDKDVKernel
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant can't take in a list initializer, overload for backward compatibility
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
@@ -650,9 +650,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<void*, void*>& drop_seed_offset)
const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargs(
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
@@ -714,54 +714,54 @@ struct FmhaBwdDQDKDVKernel
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -858,7 +858,7 @@ struct FmhaBwdDQDKDVKernel
return kargs;
}
// std::variant can't take in a list initializer, overload for backward compatibility
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
@@ -909,7 +909,7 @@ struct FmhaBwdDQDKDVKernel
float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargs(
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
@@ -959,7 +959,7 @@ struct FmhaBwdDQDKDVKernel
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant can't take in a list initializer, overload for backward compatibility
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
@@ -1008,9 +1008,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<void*, void*>& drop_seed_offset)
const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargs(
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,

View File

@@ -64,7 +64,7 @@ struct FmhaFwdKernel
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
__host__ static std::string GetName()
CK_TILE_HOST static std::string GetName()
{
// sync with generate.py
// clang-format off
@@ -267,50 +267,50 @@ struct FmhaFwdKernel
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -399,9 +399,9 @@ struct FmhaFwdKernel
return kargs;
}
// std::variant can't take in a list initializer, overload for backward compatibility
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
@@ -445,53 +445,54 @@ struct FmhaFwdKernel
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
MakeKargs(q_ptr,
k_ptr,
v_ptr,
bias_ptr,
rand_val_ptr,
lse_ptr,
o_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale_s,
scale_p,
scale_o,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
window_size_left,
window_size_right,
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
rand_val_ptr,
lse_ptr,
o_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale_s,
scale_p,
scale_o,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
window_size_left,
window_size_right,
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant can't take in a list initializer, overload for backward compatibility
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
@@ -533,91 +534,92 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<void*, void*>& drop_seed_offset)
const std::tuple<const void*, const void*>& drop_seed_offset)
{
MakeKargs(q_ptr,
k_ptr,
v_ptr,
bias_ptr,
rand_val_ptr,
lse_ptr,
o_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale_s,
scale_p,
scale_o,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
window_size_left,
window_size_right,
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
rand_val_ptr,
lse_ptr,
o_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale_s,
scale_p,
scale_o,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
window_size_left,
window_size_right,
mask_type,
p_drop,
s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -702,9 +704,9 @@ struct FmhaFwdKernel
return kargs;
}
// std::variant can't take in a list initializer, overload for backward compatibility
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
@@ -742,7 +744,7 @@ struct FmhaFwdKernel
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargs(
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
@@ -781,9 +783,9 @@ struct FmhaFwdKernel
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant can't take in a list initializer, overload for backward compatibility
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
@@ -819,9 +821,9 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<void*, void*>& drop_seed_offset)
const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargs(
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
@@ -860,15 +862,15 @@ struct FmhaFwdKernel
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{