diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 7717f64fcd..881ec297ab 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -518,7 +518,7 @@ struct HstuAttentionFwdKernel } batch_offset_o = query_start * kargs.seq_stride_o; - kargs.seqlen = kargs.seq_offsets_ptr[1] - kargs.seq_offsets_ptr[0]; + kargs.seqlen = kargs.seq_offsets_ptr[i_batch + 1] - kargs.seq_offsets_ptr[i_batch]; // # of required blocks is different in each groups, terminate unnecessary blocks // earlier diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index 7a689a93a6..df6e3dddae 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -116,7 +116,7 @@ struct jagged_forward_causal_local_bias_dropout_dispatch }(); dim3 kGridSize = - HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen, param.hdim_v); + HstuKernel::GridSize(param.num_batch, param.num_head, param.max_seqlen, param.hdim_v); constexpr dim3 kBlockSize = HstuKernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;