diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index bafc573a7b..3bac90d1a4 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -449,6 +449,7 @@ bool run(const ck_tile::ArgParser& arg_parser) o_host_ref, num_batch, 1.0f / std::sqrt(params.hdim_qk), + is_jagged ? max_seqlen : seqlen, seq_offsets, num_targets, window_size, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index f0780520ca..5043c5ae9b 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -141,6 +141,8 @@ struct HstuAttentionFwdKernel ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_o; + + ck_tile::index_t max_seqlen; }; struct HstuAttentionFwdJaggModeKargs : HstuAttentionFwdCommonKargs, @@ -155,6 +157,7 @@ struct HstuAttentionFwdKernel HstuAttentionFwdEmptyKargs<2>> { const int32_t* seq_offsets_ptr; + ck_tile::index_t max_seqlen; }; using Kargs = std:: @@ -219,7 +222,8 @@ struct HstuAttentionFwdKernel batch_stride_q, batch_stride_k, batch_stride_v, - batch_stride_o}; + batch_stride_o, + seqlen}; // max_seqlen if constexpr(kHasBias) { @@ -319,6 +323,7 @@ struct HstuAttentionFwdKernel const void* bias_ptr, void* o_ptr, const void* seq_offsets_ptr, + ck_tile::index_t max_seqlen, ck_tile::index_t hdim_qk, ck_tile::index_t hdim_v, ck_tile::index_t num_head, @@ -362,7 +367,8 @@ struct HstuAttentionFwdKernel {}, // placeholder for bias {}, // placeholder for mask {}, // placeholder for dropout - reinterpret_cast(seq_offsets_ptr)}; + reinterpret_cast(seq_offsets_ptr), + max_seqlen}; if constexpr(kHasBias) { @@ -393,6 +399,7 @@ struct HstuAttentionFwdKernel const void* bias_ptr, void* o_ptr, const void* seq_offsets_ptr, + ck_tile::index_t max_seqlen, ck_tile::index_t hdim_qk, ck_tile::index_t hdim_v, ck_tile::index_t num_head, @@ -421,6 +428,7 @@ struct HstuAttentionFwdKernel bias_ptr, o_ptr, seq_offsets_ptr, + max_seqlen, hdim_qk, hdim_v, num_head, @@ -732,6 +740,7 @@ struct HstuAttentionFwdKernel bias_dram_window, mask, kargs.scale_s, + kargs.max_seqlen, smem_ptr, dropout); }(); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 283c287341..7e6a74564a 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -133,6 +133,7 @@ struct HstuAttentionFwdPipelineQRKSVS const OAccElementFunction& o_acc_element_func, HstuMask mask, float scale_s, + index_t max_seqlen, // used by silu void* smem_ptr, DropoutType& dropout) const { @@ -232,16 +233,17 @@ struct HstuAttentionFwdPipelineQRKSVS statically_indexed_array pcomp_tiles; // reduction function for softmax - const auto f_silu = [](CompDataType& x) { + const auto f_silu = [&](CompDataType& x) { const auto neg_one = ck_tile::type_convert(-1.0f); if constexpr(std::is_same_v) { - x = x * __builtin_amdgcn_rcpf(neg_one - __expf(x)); + x = x * __builtin_amdgcn_rcpf(neg_one - __expf(x)) * + __builtin_amdgcn_rcpf(static_cast(max_seqlen)); } else { - x = x / (neg_one - exp(x)); + x = x / (neg_one - exp(x)) / static_cast(max_seqlen); } }; @@ -477,6 +479,7 @@ struct HstuAttentionFwdPipelineQRKSVS const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile HstuMask mask, float scale_s, + index_t max_seqlen, void* smem_ptr, DropoutType& dropout) const { @@ -493,6 +496,7 @@ struct HstuAttentionFwdPipelineQRKSVS identity{}, mask, scale_s, + max_seqlen, smem_ptr, dropout); } diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index b90778e939..b97d1b0977 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -92,6 +92,7 @@ struct jagged_forward_causal_local_bias_dropout_dispatch param.bias_ptr, param.o_ptr, param.seq_offsets_ptr, + param.max_seqlen, param.hdim_qk, param.hdim_v, param.num_head, diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index b5c9769fab..ce02f57e73 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -42,6 +42,7 @@ struct reference_hstu_attention HostTensor& o_batch_seq_nhead_hdim, int num_batch, float alpha, + int max_seqlen, std::vector seq_offsets, std::vector num_targets, // define masking length at the end of token // sequence to be excluded for attention @@ -89,10 +90,10 @@ struct reference_hstu_attention // check num_tagets assert(num_tagets.empty() || num_targets.size() == num_batch); - auto silu = [](CompDataType x) { + auto silu = [&](CompDataType x) { const auto one = ck_tile::type_convert(1.0f); - return x / (one + std::exp(-x)); + return x / (one + std::exp(-x)) / ck_tile::type_convert(max_seqlen); }; auto f = [&](auto i_batch, auto i_head) {