diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 382aa75979..fe6fb7a4ed 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -200,38 +200,39 @@ struct HstuAttentionFwdKernel template CK_TILE_HOST static constexpr std::enable_if_t - MakeKargsImpl(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - void* o_ptr, - ck_tile::index_t seqlen, - ck_tile::index_t hdim_qk, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head, - float scale_s, - float attn_scale, - ck_tile::index_t seq_stride_q, - ck_tile::index_t seq_stride_k, - ck_tile::index_t seq_stride_v, - ck_tile::index_t seq_stride_bias, - ck_tile::index_t seq_stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_bias, - ck_tile::index_t batch_stride_o, - const void* num_targets_ptr, - ck_tile::index_t contextual_seqlen, - ck_tile::index_t window_size, - ck_tile::index_t min_full_attn_seqlen, - float p_drop, - const std::pair& drop_seed_offset) + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* o_ptr, + ck_tile::index_t seqlen, + ck_tile::index_t hdim_qk, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head, + float scale_s, + float attn_scale, + ck_tile::index_t seq_stride_q, + ck_tile::index_t seq_stride_k, + ck_tile::index_t seq_stride_v, + ck_tile::index_t seq_stride_bias, + ck_tile::index_t seq_stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_o, + const void* num_targets_ptr, + ck_tile::index_t contextual_seqlen, + ck_tile::index_t window_size, + ck_tile::index_t min_full_attn_seqlen, + float p_drop, + uint64_t philox_seed, + uint64_t philox_offset) { Kargs kargs{ {batch_stride_q, @@ -277,22 +278,21 @@ struct HstuAttentionFwdKernel } if constexpr(kHasDropout) { - auto seed = std::get<0>(drop_seed_offset); - auto offset = std::get<1>(drop_seed_offset); - kargs.init_dropout(p_drop, seed, offset); + kargs.init_dropout(p_drop, philox_seed, philox_offset); } return kargs; } - template + template CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, const void* bias_ptr, void* o_ptr, - ck_tile::index_t seqlen, + const void* seq_offsets_ptr, + ck_tile::index_t max_seqlen, ck_tile::index_t hdim_qk, ck_tile::index_t hdim_v, ck_tile::index_t num_head, @@ -308,11 +308,6 @@ struct HstuAttentionFwdKernel ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_bias, - ck_tile::index_t batch_stride_o, const void* num_targets_ptr, ck_tile::index_t contextual_seqlen, ck_tile::index_t window_size, @@ -320,71 +315,6 @@ struct HstuAttentionFwdKernel float p_drop, uint64_t philox_seed, uint64_t philox_offset) - { - return MakeKargsImpl(q_ptr, - k_ptr, - v_ptr, - bias_ptr, - o_ptr, - seqlen, - hdim_qk, - hdim_v, - num_head, - scale_s, - attn_scale, - seq_stride_q, - seq_stride_k, - seq_stride_v, - seq_stride_bias, - seq_stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_o, - num_targets_ptr, - contextual_seqlen, - window_size, - min_full_attn_seqlen, - p_drop, - std::make_pair(philox_seed, philox_offset)); - } - - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargsImpl(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - void* o_ptr, - const void* seq_offsets_ptr, - ck_tile::index_t max_seqlen, - ck_tile::index_t hdim_qk, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head, - float scale_s, - float attn_scale, - ck_tile::index_t seq_stride_q, - ck_tile::index_t seq_stride_k, - ck_tile::index_t seq_stride_v, - ck_tile::index_t seq_stride_bias, - ck_tile::index_t seq_stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_o, - const void* num_targets_ptr, - ck_tile::index_t contextual_seqlen, - ck_tile::index_t window_size, - ck_tile::index_t min_full_attn_seqlen, - float p_drop, - const std::pair& drop_seed_offset) { Kargs kargs{ {reinterpret_cast(seq_offsets_ptr), @@ -426,76 +356,12 @@ struct HstuAttentionFwdKernel } if constexpr(kHasDropout) { - auto seed = std::get<0>(drop_seed_offset); - auto offset = std::get<1>(drop_seed_offset); - kargs.init_dropout(p_drop, seed, offset); + kargs.init_dropout(p_drop, philox_seed, philox_offset); } return kargs; } - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - void* o_ptr, - const void* seq_offsets_ptr, - ck_tile::index_t max_seqlen, - ck_tile::index_t hdim_qk, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head, - float scale_s, - float attn_scale, - ck_tile::index_t seq_stride_q, - ck_tile::index_t seq_stride_k, - ck_tile::index_t seq_stride_v, - ck_tile::index_t seq_stride_bias, - ck_tile::index_t seq_stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_o, - const void* num_targets_ptr, - ck_tile::index_t contextual_seqlen, - ck_tile::index_t window_size, - ck_tile::index_t min_full_attn_seqlen, - float p_drop, - uint64_t philox_seed, - uint64_t philox_offset) - { - return MakeKargsImpl(q_ptr, - k_ptr, - v_ptr, - bias_ptr, - o_ptr, - seq_offsets_ptr, - max_seqlen, - hdim_qk, - hdim_v, - num_head, - scale_s, - attn_scale, - seq_stride_q, - seq_stride_k, - seq_stride_v, - seq_stride_bias, - seq_stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_o, - num_targets_ptr, - contextual_seqlen, - window_size, - min_full_attn_seqlen, - p_drop, - std::make_pair(philox_seed, philox_offset)); - } - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_,