Remove using MakeKargsImpl() to simplify the hstu kernel

This commit is contained in:
Qianfeng Zhang
2025-09-10 15:24:20 +00:00
parent a8c62920bf
commit 2668bb3aee

View File

@@ -200,38 +200,39 @@ struct HstuAttentionFwdKernel
template <bool Cond = !kIsJagged>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* o_ptr,
ck_tile::index_t seqlen,
ck_tile::index_t hdim_qk,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head,
float scale_s,
float attn_scale,
ck_tile::index_t seq_stride_q,
ck_tile::index_t seq_stride_k,
ck_tile::index_t seq_stride_v,
ck_tile::index_t seq_stride_bias,
ck_tile::index_t seq_stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_o,
const void* num_targets_ptr,
ck_tile::index_t contextual_seqlen,
ck_tile::index_t window_size,
ck_tile::index_t min_full_attn_seqlen,
float p_drop,
const std::pair<uint64_t, uint64_t>& drop_seed_offset)
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* o_ptr,
ck_tile::index_t seqlen,
ck_tile::index_t hdim_qk,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head,
float scale_s,
float attn_scale,
ck_tile::index_t seq_stride_q,
ck_tile::index_t seq_stride_k,
ck_tile::index_t seq_stride_v,
ck_tile::index_t seq_stride_bias,
ck_tile::index_t seq_stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_o,
const void* num_targets_ptr,
ck_tile::index_t contextual_seqlen,
ck_tile::index_t window_size,
ck_tile::index_t min_full_attn_seqlen,
float p_drop,
uint64_t philox_seed,
uint64_t philox_offset)
{
Kargs kargs{
{batch_stride_q,
@@ -277,22 +278,21 @@ struct HstuAttentionFwdKernel
}
if constexpr(kHasDropout)
{
auto seed = std::get<0>(drop_seed_offset);
auto offset = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset);
kargs.init_dropout(p_drop, philox_seed, philox_offset);
}
return kargs;
}
template <bool Cond = !kIsJagged>
template <bool Cond = kIsJagged>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* o_ptr,
ck_tile::index_t seqlen,
const void* seq_offsets_ptr,
ck_tile::index_t max_seqlen,
ck_tile::index_t hdim_qk,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head,
@@ -308,11 +308,6 @@ struct HstuAttentionFwdKernel
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_o,
const void* num_targets_ptr,
ck_tile::index_t contextual_seqlen,
ck_tile::index_t window_size,
@@ -320,71 +315,6 @@ struct HstuAttentionFwdKernel
float p_drop,
uint64_t philox_seed,
uint64_t philox_offset)
{
return MakeKargsImpl(q_ptr,
k_ptr,
v_ptr,
bias_ptr,
o_ptr,
seqlen,
hdim_qk,
hdim_v,
num_head,
scale_s,
attn_scale,
seq_stride_q,
seq_stride_k,
seq_stride_v,
seq_stride_bias,
seq_stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_o,
num_targets_ptr,
contextual_seqlen,
window_size,
min_full_attn_seqlen,
p_drop,
std::make_pair(philox_seed, philox_offset));
}
template <bool Cond = kIsJagged>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* o_ptr,
const void* seq_offsets_ptr,
ck_tile::index_t max_seqlen,
ck_tile::index_t hdim_qk,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head,
float scale_s,
float attn_scale,
ck_tile::index_t seq_stride_q,
ck_tile::index_t seq_stride_k,
ck_tile::index_t seq_stride_v,
ck_tile::index_t seq_stride_bias,
ck_tile::index_t seq_stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_o,
const void* num_targets_ptr,
ck_tile::index_t contextual_seqlen,
ck_tile::index_t window_size,
ck_tile::index_t min_full_attn_seqlen,
float p_drop,
const std::pair<uint64_t, uint64_t>& drop_seed_offset)
{
Kargs kargs{
{reinterpret_cast<const int32_t*>(seq_offsets_ptr),
@@ -426,76 +356,12 @@ struct HstuAttentionFwdKernel
}
if constexpr(kHasDropout)
{
auto seed = std::get<0>(drop_seed_offset);
auto offset = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset);
kargs.init_dropout(p_drop, philox_seed, philox_offset);
}
return kargs;
}
template <bool Cond = kIsJagged>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* o_ptr,
const void* seq_offsets_ptr,
ck_tile::index_t max_seqlen,
ck_tile::index_t hdim_qk,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head,
float scale_s,
float attn_scale,
ck_tile::index_t seq_stride_q,
ck_tile::index_t seq_stride_k,
ck_tile::index_t seq_stride_v,
ck_tile::index_t seq_stride_bias,
ck_tile::index_t seq_stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_o,
const void* num_targets_ptr,
ck_tile::index_t contextual_seqlen,
ck_tile::index_t window_size,
ck_tile::index_t min_full_attn_seqlen,
float p_drop,
uint64_t philox_seed,
uint64_t philox_offset)
{
return MakeKargsImpl(q_ptr,
k_ptr,
v_ptr,
bias_ptr,
o_ptr,
seq_offsets_ptr,
max_seqlen,
hdim_qk,
hdim_v,
num_head,
scale_s,
attn_scale,
seq_stride_q,
seq_stride_k,
seq_stride_v,
seq_stride_bias,
seq_stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_o,
num_targets_ptr,
contextual_seqlen,
window_size,
min_full_attn_seqlen,
p_drop,
std::make_pair(philox_seed, philox_offset));
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_,