mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[CK_TILE] MakeKargs overloads for backward compatibility (#1681)
* Add overloads for MakeKargs Overload MakeKargs to accept std::tuple<uint64_t, uint64_t> and std::tuple<void*, void*> to preserve functionality of code currently passing in list initializers or tuples. * Add overloads for MakeKargs Overload MakeKargs to accept std::tuple<uint64_t, uint64_t> and std::tuple<void*, void*> to preserve functionality of code currently passing in list initializers or tuples. * Re-format files using ck_tile remod.py --------- Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -470,6 +470,248 @@ struct FmhaBwdDQDKDVKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
// std::variant can't take in a list initializer, overload for backward compatibility
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* lse_ptr,
|
||||
const void* do_ptr,
|
||||
const void* d_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* dk_ptr,
|
||||
void* dv_ptr,
|
||||
void* dbias_ptr,
|
||||
void* dq_acc_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_dq_acc,
|
||||
ck_tile::index_t stride_dk,
|
||||
ck_tile::index_t stride_dv,
|
||||
ck_tile::index_t stride_dbias,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t nhead_stride_dk,
|
||||
ck_tile::index_t nhead_stride_dv,
|
||||
ck_tile::index_t nhead_stride_dbias,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_do,
|
||||
ck_tile::index_t batch_stride_lsed,
|
||||
ck_tile::index_t batch_stride_dq_acc,
|
||||
ck_tile::index_t batch_stride_dk,
|
||||
ck_tile::index_t batch_stride_dv,
|
||||
ck_tile::index_t batch_stride_dbias,
|
||||
ck_tile::index_t split_stride_dq_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
{
|
||||
return MakeKargs(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
lse_ptr,
|
||||
do_ptr,
|
||||
d_ptr,
|
||||
rand_val_ptr,
|
||||
dk_ptr,
|
||||
dv_ptr,
|
||||
dbias_ptr,
|
||||
dq_acc_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_dq_acc,
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_dk,
|
||||
nhead_stride_dv,
|
||||
nhead_stride_dbias,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_randval,
|
||||
batch_stride_do,
|
||||
batch_stride_lsed,
|
||||
batch_stride_dq_acc,
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
batch_stride_dbias,
|
||||
split_stride_dq_acc,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
// std::variant can't take in a list initializer, overload for backward compatibility
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* lse_ptr,
|
||||
const void* do_ptr,
|
||||
const void* d_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* dk_ptr,
|
||||
void* dv_ptr,
|
||||
void* dbias_ptr,
|
||||
void* dq_acc_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_dq_acc,
|
||||
ck_tile::index_t stride_dk,
|
||||
ck_tile::index_t stride_dv,
|
||||
ck_tile::index_t stride_dbias,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t nhead_stride_dk,
|
||||
ck_tile::index_t nhead_stride_dv,
|
||||
ck_tile::index_t nhead_stride_dbias,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_do,
|
||||
ck_tile::index_t batch_stride_lsed,
|
||||
ck_tile::index_t batch_stride_dq_acc,
|
||||
ck_tile::index_t batch_stride_dk,
|
||||
ck_tile::index_t batch_stride_dv,
|
||||
ck_tile::index_t batch_stride_dbias,
|
||||
ck_tile::index_t split_stride_dq_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<void*, void*>& drop_seed_offset)
|
||||
{
|
||||
return MakeKargs(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
lse_ptr,
|
||||
do_ptr,
|
||||
d_ptr,
|
||||
rand_val_ptr,
|
||||
dk_ptr,
|
||||
dv_ptr,
|
||||
dbias_ptr,
|
||||
dq_acc_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_dq_acc,
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_dk,
|
||||
nhead_stride_dv,
|
||||
nhead_stride_dbias,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_randval,
|
||||
batch_stride_do,
|
||||
batch_stride_lsed,
|
||||
batch_stride_dq_acc,
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
batch_stride_dbias,
|
||||
split_stride_dq_acc,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
@@ -616,6 +858,208 @@ struct FmhaBwdDQDKDVKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
// std::variant can't take in a list initializer, overload for backward compatibility
|
||||
template <bool Cond = kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* lse_ptr,
|
||||
const void* do_ptr,
|
||||
const void* d_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* dk_ptr,
|
||||
void* dv_ptr,
|
||||
void* dbias_ptr,
|
||||
void* dq_acc_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_dq_acc,
|
||||
ck_tile::index_t stride_dk,
|
||||
ck_tile::index_t stride_dv,
|
||||
ck_tile::index_t stride_dbias,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t nhead_stride_dk,
|
||||
ck_tile::index_t nhead_stride_dv,
|
||||
ck_tile::index_t nhead_stride_dbias,
|
||||
ck_tile::index_t split_stride_dq_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
{
|
||||
return MakeKargs(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
lse_ptr,
|
||||
do_ptr,
|
||||
d_ptr,
|
||||
rand_val_ptr,
|
||||
dk_ptr,
|
||||
dv_ptr,
|
||||
dbias_ptr,
|
||||
dq_acc_ptr,
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_dq_acc,
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_dk,
|
||||
nhead_stride_dv,
|
||||
nhead_stride_dbias,
|
||||
split_stride_dq_acc,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
// std::variant can't take in a list initializer, overload for backward compatibility
|
||||
template <bool Cond = kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* lse_ptr,
|
||||
const void* do_ptr,
|
||||
const void* d_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* dk_ptr,
|
||||
void* dv_ptr,
|
||||
void* dbias_ptr,
|
||||
void* dq_acc_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_dq_acc,
|
||||
ck_tile::index_t stride_dk,
|
||||
ck_tile::index_t stride_dv,
|
||||
ck_tile::index_t stride_dbias,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t nhead_stride_dk,
|
||||
ck_tile::index_t nhead_stride_dv,
|
||||
ck_tile::index_t nhead_stride_dbias,
|
||||
ck_tile::index_t split_stride_dq_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<void*, void*>& drop_seed_offset)
|
||||
{
|
||||
return MakeKargs(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
lse_ptr,
|
||||
do_ptr,
|
||||
d_ptr,
|
||||
rand_val_ptr,
|
||||
dk_ptr,
|
||||
dv_ptr,
|
||||
dbias_ptr,
|
||||
dq_acc_ptr,
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_dq_acc,
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_dk,
|
||||
nhead_stride_dv,
|
||||
nhead_stride_dbias,
|
||||
split_stride_dq_acc,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
|
||||
{
|
||||
|
||||
@@ -399,6 +399,186 @@ struct FmhaFwdKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
// std::variant can't take in a list initializer, overload for backward compatibility
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
{
|
||||
MakeKargs(q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
rand_val_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
// std::variant can't take in a list initializer, overload for backward compatibility
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<void*, void*>& drop_seed_offset)
|
||||
{
|
||||
MakeKargs(q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
rand_val_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
@@ -522,6 +702,164 @@ struct FmhaFwdKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
// std::variant can't take in a list initializer, overload for backward compatibility
|
||||
template <bool Cond = kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
{
|
||||
return MakeKargs(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
rand_val_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
// std::variant can't take in a list initializer, overload for backward compatibility
|
||||
template <bool Cond = kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<void*, void*>& drop_seed_offset)
|
||||
{
|
||||
return MakeKargs(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
rand_val_ptr,
|
||||
lse_ptr,
|
||||
o_ptr,
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
|
||||
Reference in New Issue
Block a user