mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
Remove using MakeKargsImpl() to simplify the hstu kernel
This commit is contained in:
@@ -200,38 +200,39 @@ struct HstuAttentionFwdKernel
|
||||
|
||||
template <bool Cond = !kIsJagged>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargsImpl(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen,
|
||||
ck_tile::index_t hdim_qk,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
float scale_s,
|
||||
float attn_scale,
|
||||
ck_tile::index_t seq_stride_q,
|
||||
ck_tile::index_t seq_stride_k,
|
||||
ck_tile::index_t seq_stride_v,
|
||||
ck_tile::index_t seq_stride_bias,
|
||||
ck_tile::index_t seq_stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
const void* num_targets_ptr,
|
||||
ck_tile::index_t contextual_seqlen,
|
||||
ck_tile::index_t window_size,
|
||||
ck_tile::index_t min_full_attn_seqlen,
|
||||
float p_drop,
|
||||
const std::pair<uint64_t, uint64_t>& drop_seed_offset)
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen,
|
||||
ck_tile::index_t hdim_qk,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
float scale_s,
|
||||
float attn_scale,
|
||||
ck_tile::index_t seq_stride_q,
|
||||
ck_tile::index_t seq_stride_k,
|
||||
ck_tile::index_t seq_stride_v,
|
||||
ck_tile::index_t seq_stride_bias,
|
||||
ck_tile::index_t seq_stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
const void* num_targets_ptr,
|
||||
ck_tile::index_t contextual_seqlen,
|
||||
ck_tile::index_t window_size,
|
||||
ck_tile::index_t min_full_attn_seqlen,
|
||||
float p_drop,
|
||||
uint64_t philox_seed,
|
||||
uint64_t philox_offset)
|
||||
{
|
||||
Kargs kargs{
|
||||
{batch_stride_q,
|
||||
@@ -277,22 +278,21 @@ struct HstuAttentionFwdKernel
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto seed = std::get<0>(drop_seed_offset);
|
||||
auto offset = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset);
|
||||
kargs.init_dropout(p_drop, philox_seed, philox_offset);
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
template <bool Cond = !kIsJagged>
|
||||
template <bool Cond = kIsJagged>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen,
|
||||
const void* seq_offsets_ptr,
|
||||
ck_tile::index_t max_seqlen,
|
||||
ck_tile::index_t hdim_qk,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
@@ -308,11 +308,6 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
const void* num_targets_ptr,
|
||||
ck_tile::index_t contextual_seqlen,
|
||||
ck_tile::index_t window_size,
|
||||
@@ -320,71 +315,6 @@ struct HstuAttentionFwdKernel
|
||||
float p_drop,
|
||||
uint64_t philox_seed,
|
||||
uint64_t philox_offset)
|
||||
{
|
||||
return MakeKargsImpl(q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
o_ptr,
|
||||
seqlen,
|
||||
hdim_qk,
|
||||
hdim_v,
|
||||
num_head,
|
||||
scale_s,
|
||||
attn_scale,
|
||||
seq_stride_q,
|
||||
seq_stride_k,
|
||||
seq_stride_v,
|
||||
seq_stride_bias,
|
||||
seq_stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_o,
|
||||
num_targets_ptr,
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
min_full_attn_seqlen,
|
||||
p_drop,
|
||||
std::make_pair(philox_seed, philox_offset));
|
||||
}
|
||||
|
||||
template <bool Cond = kIsJagged>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargsImpl(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* o_ptr,
|
||||
const void* seq_offsets_ptr,
|
||||
ck_tile::index_t max_seqlen,
|
||||
ck_tile::index_t hdim_qk,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
float scale_s,
|
||||
float attn_scale,
|
||||
ck_tile::index_t seq_stride_q,
|
||||
ck_tile::index_t seq_stride_k,
|
||||
ck_tile::index_t seq_stride_v,
|
||||
ck_tile::index_t seq_stride_bias,
|
||||
ck_tile::index_t seq_stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
const void* num_targets_ptr,
|
||||
ck_tile::index_t contextual_seqlen,
|
||||
ck_tile::index_t window_size,
|
||||
ck_tile::index_t min_full_attn_seqlen,
|
||||
float p_drop,
|
||||
const std::pair<uint64_t, uint64_t>& drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{
|
||||
{reinterpret_cast<const int32_t*>(seq_offsets_ptr),
|
||||
@@ -426,76 +356,12 @@ struct HstuAttentionFwdKernel
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto seed = std::get<0>(drop_seed_offset);
|
||||
auto offset = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset);
|
||||
kargs.init_dropout(p_drop, philox_seed, philox_offset);
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
template <bool Cond = kIsJagged>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* o_ptr,
|
||||
const void* seq_offsets_ptr,
|
||||
ck_tile::index_t max_seqlen,
|
||||
ck_tile::index_t hdim_qk,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
float scale_s,
|
||||
float attn_scale,
|
||||
ck_tile::index_t seq_stride_q,
|
||||
ck_tile::index_t seq_stride_k,
|
||||
ck_tile::index_t seq_stride_v,
|
||||
ck_tile::index_t seq_stride_bias,
|
||||
ck_tile::index_t seq_stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
const void* num_targets_ptr,
|
||||
ck_tile::index_t contextual_seqlen,
|
||||
ck_tile::index_t window_size,
|
||||
ck_tile::index_t min_full_attn_seqlen,
|
||||
float p_drop,
|
||||
uint64_t philox_seed,
|
||||
uint64_t philox_offset)
|
||||
{
|
||||
return MakeKargsImpl(q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
o_ptr,
|
||||
seq_offsets_ptr,
|
||||
max_seqlen,
|
||||
hdim_qk,
|
||||
hdim_v,
|
||||
num_head,
|
||||
scale_s,
|
||||
attn_scale,
|
||||
seq_stride_q,
|
||||
seq_stride_k,
|
||||
seq_stride_v,
|
||||
seq_stride_bias,
|
||||
seq_stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_o,
|
||||
num_targets_ptr,
|
||||
contextual_seqlen,
|
||||
window_size,
|
||||
min_full_attn_seqlen,
|
||||
p_drop,
|
||||
std::make_pair(philox_seed, philox_offset));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_,
|
||||
|
||||
Reference in New Issue
Block a user