mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Add is_cross_attention as both host API and kernel parameter so that separate masking rules are used for self or cross attention
This commit is contained in:
@@ -263,6 +263,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
int max_target = 0;
|
||||
|
||||
bool is_cross_attention = false;
|
||||
|
||||
if(!num_targets.empty())
|
||||
{
|
||||
// supplement num_targets using the last input value if user-provided lengths not enough
|
||||
@@ -275,9 +277,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
HSTU_CHECK(!seq_lengths_q.empty(), "sequence lengths of q shoud be defined!");
|
||||
|
||||
// assume seq_lengths_kv is same as seq_lengths_q if not defined
|
||||
// assume seq_lengths_kv is same as seq_lengths_q if not defined, or else when
|
||||
// seq_lengths_kv is explicitly defined, we think the input case is a cross_attention case
|
||||
if(seq_lengths_kv.empty())
|
||||
seq_lengths_kv = seq_lengths_q;
|
||||
else
|
||||
is_cross_attention = true;
|
||||
|
||||
// assume input_max_uih_seqlen_kv is same as input_max_uih_seqlen_q if not strictly defined
|
||||
if(input_max_uih_seqlen_kv <= 0)
|
||||
@@ -285,10 +290,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(is_jagged)
|
||||
{
|
||||
// supplement seq_lengths_q using the last input value if user-provided lengths not enough
|
||||
// supplement seq_lengths_q using the last input value if user-provided lengths not
|
||||
// enough
|
||||
supplement_array_by_last_element(seq_lengths_q, num_batch);
|
||||
|
||||
// supplement seq_lengths_kv using the last input value if user-provided lengths not enough
|
||||
// supplement seq_lengths_kv using the last input value if user-provided lengths not
|
||||
// enough
|
||||
supplement_array_by_last_element(seq_lengths_kv, num_batch);
|
||||
|
||||
// only consider num_batch values even if more values are provided by the user
|
||||
@@ -452,6 +459,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(is_jagged)
|
||||
{
|
||||
params.is_cross_attention = is_cross_attention;
|
||||
params.is_jagged = true;
|
||||
params.num_batch = num_batch;
|
||||
params.seq_q_offsets_ptr = seq_offsets_q_dev.GetDeviceBuffer();
|
||||
@@ -489,35 +497,36 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
else
|
||||
{
|
||||
params.is_jagged = false;
|
||||
params.num_batch = num_batch;
|
||||
params.seqlen_q = phy_seqlen_q;
|
||||
params.seqlen_kv = phy_seqlen_kv;
|
||||
params.q_ptr = q_dev.GetDeviceBuffer();
|
||||
params.k_ptr = k_dev.GetDeviceBuffer();
|
||||
params.v_ptr = v_dev.GetDeviceBuffer();
|
||||
params.bias_ptr = nullptr; // bias is not supported at present
|
||||
params.o_ptr = o_dev.GetDeviceBuffer();
|
||||
params.hdim_qk = hdim_qk;
|
||||
params.hdim_v = hdim_v;
|
||||
params.num_head = num_head;
|
||||
params.scale_s = scale_s;
|
||||
params.attn_scale = attn_scale;
|
||||
params.seq_stride_q = q_host.get_strides()[1];
|
||||
params.seq_stride_k = k_host.get_strides()[1];
|
||||
params.seq_stride_v = v_host.get_strides()[1];
|
||||
params.seq_stride_bias = 0;
|
||||
params.seq_stride_o = o_host_ref.get_strides()[1];
|
||||
params.nhead_stride_q = q_host.get_strides()[2];
|
||||
params.nhead_stride_k = k_host.get_strides()[2];
|
||||
params.nhead_stride_v = v_host.get_strides()[2];
|
||||
params.nhead_stride_bias = 0;
|
||||
params.nhead_stride_o = o_host_ref.get_strides()[2];
|
||||
params.batch_stride_q = q_host.get_strides()[0];
|
||||
params.batch_stride_k = k_host.get_strides()[0];
|
||||
params.batch_stride_v = v_host.get_strides()[0];
|
||||
params.batch_stride_bias = 0;
|
||||
params.batch_stride_o = o_host_ref.get_strides()[0];
|
||||
params.is_cross_attention = is_cross_attention;
|
||||
params.is_jagged = false;
|
||||
params.num_batch = num_batch;
|
||||
params.seqlen_q = phy_seqlen_q;
|
||||
params.seqlen_kv = phy_seqlen_kv;
|
||||
params.q_ptr = q_dev.GetDeviceBuffer();
|
||||
params.k_ptr = k_dev.GetDeviceBuffer();
|
||||
params.v_ptr = v_dev.GetDeviceBuffer();
|
||||
params.bias_ptr = nullptr; // bias is not supported at present
|
||||
params.o_ptr = o_dev.GetDeviceBuffer();
|
||||
params.hdim_qk = hdim_qk;
|
||||
params.hdim_v = hdim_v;
|
||||
params.num_head = num_head;
|
||||
params.scale_s = scale_s;
|
||||
params.attn_scale = attn_scale;
|
||||
params.seq_stride_q = q_host.get_strides()[1];
|
||||
params.seq_stride_k = k_host.get_strides()[1];
|
||||
params.seq_stride_v = v_host.get_strides()[1];
|
||||
params.seq_stride_bias = 0;
|
||||
params.seq_stride_o = o_host_ref.get_strides()[1];
|
||||
params.nhead_stride_q = q_host.get_strides()[2];
|
||||
params.nhead_stride_k = k_host.get_strides()[2];
|
||||
params.nhead_stride_v = v_host.get_strides()[2];
|
||||
params.nhead_stride_bias = 0;
|
||||
params.nhead_stride_o = o_host_ref.get_strides()[2];
|
||||
params.batch_stride_q = q_host.get_strides()[0];
|
||||
params.batch_stride_k = k_host.get_strides()[0];
|
||||
params.batch_stride_v = v_host.get_strides()[0];
|
||||
params.batch_stride_bias = 0;
|
||||
params.batch_stride_o = o_host_ref.get_strides()[0];
|
||||
params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer();
|
||||
params.use_softmax = use_softmax;
|
||||
params.use_causal = use_causal;
|
||||
@@ -566,7 +575,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
CompDataType,
|
||||
kIsJagged,
|
||||
kUseSoftmax,
|
||||
kUseCausal>::Run(q_host,
|
||||
kUseCausal>::Run(is_cross_attention,
|
||||
q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
o_host_ref,
|
||||
|
||||
@@ -121,13 +121,15 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
static void RunWithKernel(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
const auto kargs = [&] {
|
||||
return HstuKernel::MakeKargs(param.q_ptr,
|
||||
return HstuKernel::MakeKargs(param.is_cross_attention,
|
||||
param.q_ptr,
|
||||
param.k_ptr,
|
||||
param.v_ptr,
|
||||
param.bias_ptr,
|
||||
param.o_ptr,
|
||||
param.seqlen_q,
|
||||
param.seqlen_kv,
|
||||
param.is_cross_attention ? param.seqlen_kv
|
||||
: param.seqlen_q,
|
||||
param.hdim_qk,
|
||||
param.hdim_v,
|
||||
param.num_head,
|
||||
|
||||
@@ -61,6 +61,7 @@ struct HstuAttentionFwdKernel
|
||||
// user need to use MakeKargs() function to create kargs.
|
||||
struct HstuAttentionFwdBatchModeBaseKargs
|
||||
{
|
||||
bool is_cross_attention;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
@@ -99,6 +100,7 @@ struct HstuAttentionFwdKernel
|
||||
|
||||
struct HstuAttentionFwdJaggModeBaseKargs
|
||||
{
|
||||
bool is_cross_attention;
|
||||
const int32_t* seq_q_offsets_ptr;
|
||||
const int32_t* seq_kv_offsets_ptr;
|
||||
|
||||
@@ -194,7 +196,8 @@ struct HstuAttentionFwdKernel
|
||||
|
||||
template <bool Cond = !kIsJagged>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
MakeKargs(bool is_cross_attention,
|
||||
const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
@@ -230,7 +233,8 @@ struct HstuAttentionFwdKernel
|
||||
uint64_t philox_offset)
|
||||
{
|
||||
Kargs kargs{
|
||||
{batch_stride_q,
|
||||
{is_cross_attention,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_o,
|
||||
@@ -279,7 +283,8 @@ struct HstuAttentionFwdKernel
|
||||
|
||||
template <bool Cond = kIsJagged>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
MakeKargs(bool is_cross_attention,
|
||||
const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
@@ -311,7 +316,8 @@ struct HstuAttentionFwdKernel
|
||||
uint64_t philox_offset)
|
||||
{
|
||||
Kargs kargs{
|
||||
{reinterpret_cast<const int32_t*>(seq_q_offsets_ptr),
|
||||
{is_cross_attention,
|
||||
reinterpret_cast<const int32_t*>(seq_q_offsets_ptr),
|
||||
reinterpret_cast<const int32_t*>(seq_kv_offsets_ptr),
|
||||
seq_stride_q,
|
||||
seq_stride_k,
|
||||
|
||||
@@ -113,13 +113,15 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
static void RunWithKernel(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
const auto kargs = [&] {
|
||||
return HstuKernel::MakeKargs(param.q_ptr,
|
||||
return HstuKernel::MakeKargs(param.is_cross_attention,
|
||||
param.q_ptr,
|
||||
param.k_ptr,
|
||||
param.v_ptr,
|
||||
param.bias_ptr,
|
||||
param.o_ptr,
|
||||
param.seq_q_offsets_ptr,
|
||||
param.seq_kv_offsets_ptr,
|
||||
param.is_cross_attention ? param.seq_kv_offsets_ptr
|
||||
: param.seq_q_offsets_ptr,
|
||||
param.max_seqlen,
|
||||
param.hdim_qk,
|
||||
param.hdim_v,
|
||||
|
||||
@@ -7,6 +7,11 @@
|
||||
|
||||
struct HstuAttentionFwdParams
|
||||
{
|
||||
// for self-attention (is_cross_attention = false), we requires
|
||||
// 1) either seqlen_kv == 0 or seqlen_kv == seqlen_q
|
||||
// 2) either seq_kv_offsets_ptr == nullptr, or seq_kv_offsets_ptr == seq_q_offsets_ptr
|
||||
bool is_cross_attention;
|
||||
|
||||
bool is_jagged;
|
||||
|
||||
ck_tile::index_t num_batch;
|
||||
|
||||
@@ -33,7 +33,8 @@ template <typename InOutDataType,
|
||||
bool kUseCausal>
|
||||
struct reference_hstu_attention
|
||||
{
|
||||
static void Run(const HostTensor<InOutDataType>& q_batch_seq_nhead_hdim,
|
||||
static void Run(bool is_cross_attention,
|
||||
const HostTensor<InOutDataType>& q_batch_seq_nhead_hdim,
|
||||
const HostTensor<InOutDataType>& k_batch_seq_nhead_hdim,
|
||||
const HostTensor<InOutDataType>& v_batch_seq_nhead_hdim,
|
||||
HostTensor<InOutDataType>& o_batch_seq_nhead_hdim,
|
||||
@@ -53,6 +54,8 @@ struct reference_hstu_attention
|
||||
int min_full_attn_seqlen) // define masking length at the end of query token
|
||||
// sequence which is included for full attention
|
||||
{
|
||||
ignore = is_cross_attention;
|
||||
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
// check the number of batches
|
||||
|
||||
Reference in New Issue
Block a user