diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 76013bf2e3..bb9977d189 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -185,16 +185,19 @@ struct HstuAttentionFwdKernel const float* group_attn_scale_ptr; }; - struct HstuAttentionFwdCommonBiasKargs + struct HstuAttentionFwdBatchedBiasKargs { - const void* bias_ptr = nullptr; - ck_tile::index_t seq_stride_bias = 0; - ck_tile::index_t nhead_stride_bias = 0; + const void* bias_ptr; + ck_tile::index_t seq_stride_bias; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t batch_stride_bias; }; - struct HstuAttentionFwdBatchModeBiasKargs : HstuAttentionFwdCommonBiasKargs + struct HstuAttentionFwdJaggedBiasKargs { - ck_tile::index_t batch_stride_bias = 0; + const void* bias_ptr; + ck_tile::index_t seq_stride_bias; + ck_tile::index_t nhead_stride_bias; }; struct HstuAttentionFwdDropoutSeedOffset @@ -238,7 +241,7 @@ struct HstuAttentionFwdKernel struct HstuAttentionNoGroupBatchedFwdKargs : HstuAttentionNoGroupBatchedFwdBaseKargs, std::conditional_t>, std::conditional_t>, std::conditional_t>, std::conditional_t>, std::conditional_t>, std::conditional_t>, std::conditional_t