Add is_cross_attention as both host API and kernel parameter so that separate masking rules are used for self or cross attention

This commit is contained in:
Qianfeng Zhang
2026-02-06 15:40:07 +00:00
parent d169ed2194
commit 0711f4f90a
6 changed files with 70 additions and 42 deletions

View File

@@ -263,6 +263,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
int max_target = 0;
bool is_cross_attention = false;
if(!num_targets.empty())
{
// supplement num_targets using the last input value if user-provided lengths not enough
@@ -275,9 +277,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
HSTU_CHECK(!seq_lengths_q.empty(), "sequence lengths of q shoud be defined!");
// assume seq_lengths_kv is same as seq_lengths_q if not defined
// assume seq_lengths_kv is same as seq_lengths_q if not defined, or else when
// seq_lengths_kv is explicitly defined, we think the input case is a cross_attention case
if(seq_lengths_kv.empty())
seq_lengths_kv = seq_lengths_q;
else
is_cross_attention = true;
// assume input_max_uih_seqlen_kv is same as input_max_uih_seqlen_q if not strictly defined
if(input_max_uih_seqlen_kv <= 0)
@@ -285,10 +290,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(is_jagged)
{
// supplement seq_lengths_q using the last input value if user-provided lengths not enough
// supplement seq_lengths_q using the last input value if user-provided lengths not
// enough
supplement_array_by_last_element(seq_lengths_q, num_batch);
// supplement seq_lengths_kv using the last input value if user-provided lengths not enough
// supplement seq_lengths_kv using the last input value if user-provided lengths not
// enough
supplement_array_by_last_element(seq_lengths_kv, num_batch);
// only consider num_batch values even if more values are provided by the user
@@ -452,6 +459,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(is_jagged)
{
params.is_cross_attention = is_cross_attention;
params.is_jagged = true;
params.num_batch = num_batch;
params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer();
@@ -489,35 +497,36 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
else
{
params.is_jagged = false;
params.num_batch = num_batch;
params.seqlen_q = phy_seqlen_q;
params.seqlen_kv = phy_seqlen_kv;
params.q_ptr = q_dev.GetDeviceBuffer();
params.k_ptr = k_dev.GetDeviceBuffer();
params.v_ptr = v_dev.GetDeviceBuffer();
params.bias_ptr = nullptr; // bias is not supported at present
params.o_ptr = o_dev.GetDeviceBuffer();
params.hdim_qk = hdim_qk;
params.hdim_v = hdim_v;
params.num_head = num_head;
params.scale_s = scale_s;
params.attn_scale = attn_scale;
params.seq_stride_q = q_host.get_strides()[1];
params.seq_stride_k = k_host.get_strides()[1];
params.seq_stride_v = v_host.get_strides()[1];
params.seq_stride_bias = 0;
params.seq_stride_o = o_host_ref.get_strides()[1];
params.nhead_stride_q = q_host.get_strides()[2];
params.nhead_stride_k = k_host.get_strides()[2];
params.nhead_stride_v = v_host.get_strides()[2];
params.nhead_stride_bias = 0;
params.nhead_stride_o = o_host_ref.get_strides()[2];
params.batch_stride_q = q_host.get_strides()[0];
params.batch_stride_k = k_host.get_strides()[0];
params.batch_stride_v = v_host.get_strides()[0];
params.batch_stride_bias = 0;
params.batch_stride_o = o_host_ref.get_strides()[0];
params.is_cross_attention = is_cross_attention;
params.is_jagged = false;
params.num_batch = num_batch;
params.seqlen_q = phy_seqlen_q;
params.seqlen_kv = phy_seqlen_kv;
params.q_ptr = q_dev.GetDeviceBuffer();
params.k_ptr = k_dev.GetDeviceBuffer();
params.v_ptr = v_dev.GetDeviceBuffer();
params.bias_ptr = nullptr; // bias is not supported at present
params.o_ptr = o_dev.GetDeviceBuffer();
params.hdim_qk = hdim_qk;
params.hdim_v = hdim_v;
params.num_head = num_head;
params.scale_s = scale_s;
params.attn_scale = attn_scale;
params.seq_stride_q = q_host.get_strides()[1];
params.seq_stride_k = k_host.get_strides()[1];
params.seq_stride_v = v_host.get_strides()[1];
params.seq_stride_bias = 0;
params.seq_stride_o = o_host_ref.get_strides()[1];
params.nhead_stride_q = q_host.get_strides()[2];
params.nhead_stride_k = k_host.get_strides()[2];
params.nhead_stride_v = v_host.get_strides()[2];
params.nhead_stride_bias = 0;
params.nhead_stride_o = o_host_ref.get_strides()[2];
params.batch_stride_q = q_host.get_strides()[0];
params.batch_stride_k = k_host.get_strides()[0];
params.batch_stride_v = v_host.get_strides()[0];
params.batch_stride_bias = 0;
params.batch_stride_o = o_host_ref.get_strides()[0];
params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer();
params.use_softmax = use_softmax;
params.use_causal = use_causal;
@@ -566,7 +575,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
CompDataType,
kIsJagged,
kUseSoftmax,
kUseCausal>::Run(q_host,
kUseCausal>::Run(is_cross_attention,
q_host,
k_host,
v_host,
o_host_ref,

View File

@@ -121,13 +121,15 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
static void RunWithKernel(HstuAttentionFwdParams& param, hipStream_t stream)
{
const auto kargs = [&] {
return HstuKernel::MakeKargs(param.q_ptr,
return HstuKernel::MakeKargs(param.is_cross_attention,
param.q_ptr,
param.k_ptr,
param.v_ptr,
param.bias_ptr,
param.o_ptr,
param.seqlen_q,
param.seqlen_kv,
param.is_cross_attention ? param.seqlen_kv
: param.seqlen_q,
param.hdim_qk,
param.hdim_v,
param.num_head,

View File

@@ -61,6 +61,7 @@ struct HstuAttentionFwdKernel
// user need to use MakeKargs() function to create kargs.
struct HstuAttentionFwdBatchModeBaseKargs
{
bool is_cross_attention;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
@@ -99,6 +100,7 @@ struct HstuAttentionFwdKernel
struct HstuAttentionFwdJaggModeBaseKargs
{
bool is_cross_attention;
const int32_t* seq_q_offsets_ptr;
const int32_t* seq_kv_offsets_ptr;
@@ -194,7 +196,8 @@ struct HstuAttentionFwdKernel
template <bool Cond = !kIsJagged>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
MakeKargs(bool is_cross_attention,
const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
@@ -230,7 +233,8 @@ struct HstuAttentionFwdKernel
uint64_t philox_offset)
{
Kargs kargs{
{batch_stride_q,
{is_cross_attention,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_o,
@@ -279,7 +283,8 @@ struct HstuAttentionFwdKernel
template <bool Cond = kIsJagged>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
MakeKargs(bool is_cross_attention,
const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
@@ -311,7 +316,8 @@ struct HstuAttentionFwdKernel
uint64_t philox_offset)
{
Kargs kargs{
{reinterpret_cast<const int32_t*>(seq_q_offsets_ptr),
{is_cross_attention,
reinterpret_cast<const int32_t*>(seq_q_offsets_ptr),
reinterpret_cast<const int32_t*>(seq_kv_offsets_ptr),
seq_stride_q,
seq_stride_k,

View File

@@ -113,13 +113,15 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
static void RunWithKernel(HstuAttentionFwdParams& param, hipStream_t stream)
{
const auto kargs = [&] {
return HstuKernel::MakeKargs(param.q_ptr,
return HstuKernel::MakeKargs(param.is_cross_attention,
param.q_ptr,
param.k_ptr,
param.v_ptr,
param.bias_ptr,
param.o_ptr,
param.seq_q_offsets_ptr,
param.seq_kv_offsets_ptr,
param.is_cross_attention ? param.seq_kv_offsets_ptr
: param.seq_q_offsets_ptr,
param.max_seqlen,
param.hdim_qk,
param.hdim_v,

View File

@@ -7,6 +7,11 @@
struct HstuAttentionFwdParams
{
// for self-attention (is_cross_attention = false), we requires
// 1) either seqlen_kv == 0 or seqlen_kv == seqlen_q
// 2) either seq_kv_offsets_ptr == nullptr, or seq_kv_offsets_ptr == seq_q_offsets_ptr
bool is_cross_attention;
bool is_jagged;
ck_tile::index_t num_batch;

View File

@@ -33,7 +33,8 @@ template <typename InOutDataType,
bool kUseCausal>
struct reference_hstu_attention
{
static void Run(const HostTensor<InOutDataType>& q_batch_seq_nhead_hdim,
static void Run(bool is_cross_attention,
const HostTensor<InOutDataType>& q_batch_seq_nhead_hdim,
const HostTensor<InOutDataType>& k_batch_seq_nhead_hdim,
const HostTensor<InOutDataType>& v_batch_seq_nhead_hdim,
HostTensor<InOutDataType>& o_batch_seq_nhead_hdim,
@@ -53,6 +54,8 @@ struct reference_hstu_attention
int min_full_attn_seqlen) // define masking length at the end of query token
// sequence which is included for full attention
{
ignore = is_cross_attention;
if constexpr(kIsJagged)
{
// check the number of batches