From 0711f4f90a61b471f3668dcb9f033461e343408a Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 6 Feb 2026 15:40:07 +0000 Subject: [PATCH] Add is_cross_attention as both host API and kernel parameter so that separate masking rules are used for self or cross attention --- .../example_hstu_attention.cpp | 76 +++++++++++-------- ...stu_attention_batched_forward_dispatch.hpp | 6 +- .../hstu_attention_fwd_kernel.hpp | 14 +++- ...hstu_attention_jagged_forward_dispatch.hpp | 6 +- .../hstu_attention_params.hpp | 5 ++ .../reference_hstu_attention.hpp | 5 +- 6 files changed, 70 insertions(+), 42 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index 638bc468b2..5c9f0ad3e9 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -263,6 +263,8 @@ bool run(const ck_tile::ArgParser& arg_parser) int max_target = 0; + bool is_cross_attention = false; + if(!num_targets.empty()) { // supplement num_targets using the last input value if user-provided lengths not enough @@ -275,9 +277,12 @@ bool run(const ck_tile::ArgParser& arg_parser) HSTU_CHECK(!seq_lengths_q.empty(), "sequence lengths of q shoud be defined!"); - // assume seq_lengths_kv is same as seq_lengths_q if not defined + // assume seq_lengths_kv is same as seq_lengths_q if not defined, or else when + // seq_lengths_kv is explicitly defined, we think the input case is a cross_attention case if(seq_lengths_kv.empty()) seq_lengths_kv = seq_lengths_q; + else + is_cross_attention = true; // assume input_max_uih_seqlen_kv is same as input_max_uih_seqlen_q if not strictly defined if(input_max_uih_seqlen_kv <= 0) @@ -285,10 +290,12 @@ bool run(const ck_tile::ArgParser& arg_parser) if(is_jagged) { - // supplement seq_lengths_q using the last input value if user-provided lengths not enough + // supplement seq_lengths_q using the last input value if user-provided lengths not + // enough supplement_array_by_last_element(seq_lengths_q, num_batch); - // supplement seq_lengths_kv using the last input value if user-provided lengths not enough + // supplement seq_lengths_kv using the last input value if user-provided lengths not + // enough supplement_array_by_last_element(seq_lengths_kv, num_batch); // only consider num_batch values even if more values are provided by the user @@ -452,6 +459,7 @@ bool run(const ck_tile::ArgParser& arg_parser) if(is_jagged) { + params.is_cross_attention = is_cross_attention; params.is_jagged = true; params.num_batch = num_batch; params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer(); @@ -489,35 +497,36 @@ bool run(const ck_tile::ArgParser& arg_parser) } else { - params.is_jagged = false; - params.num_batch = num_batch; - params.seqlen_q = phy_seqlen_q; - params.seqlen_kv = phy_seqlen_kv; - params.q_ptr = q_dev.GetDeviceBuffer(); - params.k_ptr = k_dev.GetDeviceBuffer(); - params.v_ptr = v_dev.GetDeviceBuffer(); - params.bias_ptr = nullptr; // bias is not supported at present - params.o_ptr = o_dev.GetDeviceBuffer(); - params.hdim_qk = hdim_qk; - params.hdim_v = hdim_v; - params.num_head = num_head; - params.scale_s = scale_s; - params.attn_scale = attn_scale; - params.seq_stride_q = q_host.get_strides()[1]; - params.seq_stride_k = k_host.get_strides()[1]; - params.seq_stride_v = v_host.get_strides()[1]; - params.seq_stride_bias = 0; - params.seq_stride_o = o_host_ref.get_strides()[1]; - params.nhead_stride_q = q_host.get_strides()[2]; - params.nhead_stride_k = k_host.get_strides()[2]; - params.nhead_stride_v = v_host.get_strides()[2]; - params.nhead_stride_bias = 0; - params.nhead_stride_o = o_host_ref.get_strides()[2]; - params.batch_stride_q = q_host.get_strides()[0]; - params.batch_stride_k = k_host.get_strides()[0]; - params.batch_stride_v = v_host.get_strides()[0]; - params.batch_stride_bias = 0; - params.batch_stride_o = o_host_ref.get_strides()[0]; + params.is_cross_attention = is_cross_attention; + params.is_jagged = false; + params.num_batch = num_batch; + params.seqlen_q = phy_seqlen_q; + params.seqlen_kv = phy_seqlen_kv; + params.q_ptr = q_dev.GetDeviceBuffer(); + params.k_ptr = k_dev.GetDeviceBuffer(); + params.v_ptr = v_dev.GetDeviceBuffer(); + params.bias_ptr = nullptr; // bias is not supported at present + params.o_ptr = o_dev.GetDeviceBuffer(); + params.hdim_qk = hdim_qk; + params.hdim_v = hdim_v; + params.num_head = num_head; + params.scale_s = scale_s; + params.attn_scale = attn_scale; + params.seq_stride_q = q_host.get_strides()[1]; + params.seq_stride_k = k_host.get_strides()[1]; + params.seq_stride_v = v_host.get_strides()[1]; + params.seq_stride_bias = 0; + params.seq_stride_o = o_host_ref.get_strides()[1]; + params.nhead_stride_q = q_host.get_strides()[2]; + params.nhead_stride_k = k_host.get_strides()[2]; + params.nhead_stride_v = v_host.get_strides()[2]; + params.nhead_stride_bias = 0; + params.nhead_stride_o = o_host_ref.get_strides()[2]; + params.batch_stride_q = q_host.get_strides()[0]; + params.batch_stride_k = k_host.get_strides()[0]; + params.batch_stride_v = v_host.get_strides()[0]; + params.batch_stride_bias = 0; + params.batch_stride_o = o_host_ref.get_strides()[0]; params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer(); params.use_softmax = use_softmax; params.use_causal = use_causal; @@ -566,7 +575,8 @@ bool run(const ck_tile::ArgParser& arg_parser) CompDataType, kIsJagged, kUseSoftmax, - kUseCausal>::Run(q_host, + kUseCausal>::Run(is_cross_attention, + q_host, k_host, v_host, o_host_ref, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp index 771adccb44..ebe5c600a7 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp @@ -121,13 +121,15 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch static void RunWithKernel(HstuAttentionFwdParams& param, hipStream_t stream) { const auto kargs = [&] { - return HstuKernel::MakeKargs(param.q_ptr, + return HstuKernel::MakeKargs(param.is_cross_attention, + param.q_ptr, param.k_ptr, param.v_ptr, param.bias_ptr, param.o_ptr, param.seqlen_q, - param.seqlen_kv, + param.is_cross_attention ? param.seqlen_kv + : param.seqlen_q, param.hdim_qk, param.hdim_v, param.num_head, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index c03627ab67..2da36c6482 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -61,6 +61,7 @@ struct HstuAttentionFwdKernel // user need to use MakeKargs() function to create kargs. struct HstuAttentionFwdBatchModeBaseKargs { + bool is_cross_attention; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; @@ -99,6 +100,7 @@ struct HstuAttentionFwdKernel struct HstuAttentionFwdJaggModeBaseKargs { + bool is_cross_attention; const int32_t* seq_q_offsets_ptr; const int32_t* seq_kv_offsets_ptr; @@ -194,7 +196,8 @@ struct HstuAttentionFwdKernel template CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, + MakeKargs(bool is_cross_attention, + const void* q_ptr, const void* k_ptr, const void* v_ptr, const void* bias_ptr, @@ -230,7 +233,8 @@ struct HstuAttentionFwdKernel uint64_t philox_offset) { Kargs kargs{ - {batch_stride_q, + {is_cross_attention, + batch_stride_q, batch_stride_k, batch_stride_v, batch_stride_o, @@ -279,7 +283,8 @@ struct HstuAttentionFwdKernel template CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, + MakeKargs(bool is_cross_attention, + const void* q_ptr, const void* k_ptr, const void* v_ptr, const void* bias_ptr, @@ -311,7 +316,8 @@ struct HstuAttentionFwdKernel uint64_t philox_offset) { Kargs kargs{ - {reinterpret_cast(seq_q_offsets_ptr), + {is_cross_attention, + reinterpret_cast(seq_q_offsets_ptr), reinterpret_cast(seq_kv_offsets_ptr), seq_stride_q, seq_stride_k, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index 856fbec32c..897bdc9cec 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -113,13 +113,15 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch static void RunWithKernel(HstuAttentionFwdParams& param, hipStream_t stream) { const auto kargs = [&] { - return HstuKernel::MakeKargs(param.q_ptr, + return HstuKernel::MakeKargs(param.is_cross_attention, + param.q_ptr, param.k_ptr, param.v_ptr, param.bias_ptr, param.o_ptr, param.seq_q_offsets_ptr, - param.seq_kv_offsets_ptr, + param.is_cross_attention ? param.seq_kv_offsets_ptr + : param.seq_q_offsets_ptr, param.max_seqlen, param.hdim_qk, param.hdim_v, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp index 68404422f5..9ce98839fd 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp @@ -7,6 +7,11 @@ struct HstuAttentionFwdParams { + // for self-attention (is_cross_attention = false), we requires + // 1) either seqlen_kv == 0 or seqlen_kv == seqlen_q + // 2) either seq_kv_offsets_ptr == nullptr, or seq_kv_offsets_ptr == seq_q_offsets_ptr + bool is_cross_attention; + bool is_jagged; ck_tile::index_t num_batch; diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index ee844f2bcb..91c456aec6 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -33,7 +33,8 @@ template struct reference_hstu_attention { - static void Run(const HostTensor& q_batch_seq_nhead_hdim, + static void Run(bool is_cross_attention, + const HostTensor& q_batch_seq_nhead_hdim, const HostTensor& k_batch_seq_nhead_hdim, const HostTensor& v_batch_seq_nhead_hdim, HostTensor& o_batch_seq_nhead_hdim, @@ -53,6 +54,8 @@ struct reference_hstu_attention int min_full_attn_seqlen) // define masking length at the end of query token // sequence which is included for full attention { + ignore = is_cross_attention; + if constexpr(kIsJagged) { // check the number of batches