Re-org the kernel parameters in HstuAttentionFwdBatchModeBaseKargs and HstuAttentionFwdJaggModeBaseKargs

This commit is contained in:
Qianfeng Zhang
2025-07-17 04:48:55 +00:00
parent fdd9c117d4
commit 0306a1e6ab

View File

@@ -62,34 +62,80 @@ struct HstuAttentionFwdKernel
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct HstuAttentionFwdCommonKargs
struct HstuAttentionFwdBatchModeBaseKargs
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_o;
const int32_t* num_targets_ptr;
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
void* o_ptr;
ck_tile::index_t seqlen;
ck_tile::index_t hdim_qk;
ck_tile::index_t hdim_v;
ck_tile::index_t num_head;
float scale_s;
ck_tile::index_t seq_stride_q;
ck_tile::index_t seq_stride_k;
ck_tile::index_t seq_stride_v;
ck_tile::index_t seq_stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_o;
const int32_t* num_targets_ptr;
ck_tile::index_t seqlen;
ck_tile::index_t hdim_qk;
ck_tile::index_t hdim_v;
ck_tile::index_t seq_stride_q;
ck_tile::index_t seq_stride_k;
ck_tile::index_t seq_stride_v;
ck_tile::index_t seq_stride_o;
ck_tile::index_t num_head;
float scale_s;
ck_tile::index_t max_seqlen;
ck_tile::index_t contextual_seqlen;
};
struct HstuAttentionFwdJaggModeBaseKargs
{
const int32_t* seq_offsets_ptr;
ck_tile::index_t seq_stride_q;
ck_tile::index_t seq_stride_k;
ck_tile::index_t seq_stride_v;
ck_tile::index_t seq_stride_o;
const int32_t* num_targets_ptr;
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
void* o_ptr;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t hdim_qk;
ck_tile::index_t hdim_v;
ck_tile::index_t seqlen;
ck_tile::index_t num_head;
float scale_s;
ck_tile::index_t max_seqlen;
ck_tile::index_t contextual_seqlen;
};
struct HstuAttentionFwdMaskKargs
{
ck_tile::index_t window_size;
ck_tile::index_t min_full_attn_seqlen;
};
struct HstuAttentionFwdCommonBiasKargs
{
const void* bias_ptr = nullptr;
@@ -102,12 +148,6 @@ struct HstuAttentionFwdKernel
ck_tile::index_t batch_stride_bias = 0;
};
struct HstuAttentionFwdMaskKargs
{
ck_tile::index_t window_size;
ck_tile::index_t min_full_attn_seqlen;
};
struct HstuAttentionFwdDropoutSeedOffset
{
uint64_t drop_seed;
@@ -131,38 +171,30 @@ struct HstuAttentionFwdKernel
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
};
struct HstuAttentionFwdBatchModeKargs : HstuAttentionFwdCommonKargs,
std::conditional_t<kHasBias,
HstuAttentionFwdBatchModeBiasKargs,
HstuAttentionFwdEmptyKargs<0>>,
struct HstuAttentionFwdBatchModeKargs : HstuAttentionFwdBatchModeBaseKargs,
std::conditional_t<kHasLocalMask,
HstuAttentionFwdMaskKargs,
HstuAttentionFwdEmptyKargs<0>>,
std::conditional_t<kHasBias,
HstuAttentionFwdBatchModeBiasKargs,
HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,
HstuAttentionFwdEmptyKargs<2>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_o;
ck_tile::index_t max_seqlen;
};
struct HstuAttentionFwdJaggModeKargs : HstuAttentionFwdCommonKargs,
std::conditional_t<kHasBias,
HstuAttentionFwdCommonBiasKargs,
HstuAttentionFwdEmptyKargs<0>>,
struct HstuAttentionFwdJaggModeKargs : HstuAttentionFwdJaggModeBaseKargs,
std::conditional_t<kHasLocalMask,
HstuAttentionFwdMaskKargs,
HstuAttentionFwdEmptyKargs<0>>,
std::conditional_t<kHasBias,
HstuAttentionFwdCommonBiasKargs,
HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,
HstuAttentionFwdEmptyKargs<2>>
{
const int32_t* seq_offsets_ptr;
ck_tile::index_t max_seqlen;
};
using Kargs = std::
@@ -202,34 +234,41 @@ struct HstuAttentionFwdKernel
float p_drop,
const std::pair<uint64_t, uint64_t>& drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
o_ptr,
seqlen,
hdim_qk,
hdim_v,
num_head,
-scale_s,
seq_stride_q,
seq_stride_k,
seq_stride_v,
seq_stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_o,
reinterpret_cast<const int32_t*>(num_targets_ptr),
contextual_seqlen}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for dropout
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_o,
seqlen}; // max_seqlen
Kargs kargs{
{batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_o,
reinterpret_cast<const int32_t*>(num_targets_ptr),
q_ptr,
k_ptr,
v_ptr,
o_ptr,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_o,
seqlen,
hdim_qk,
hdim_v,
seq_stride_q,
seq_stride_k,
seq_stride_v,
seq_stride_o,
num_head,
-scale_s,
seqlen, // max_seqlen
contextual_seqlen}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for bias
{}, // placeholder for dropout
};
if constexpr(kHasLocalMask)
{
kargs.window_size = window_size;
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
}
if constexpr(kHasBias)
{
kargs.bias_ptr = bias_ptr;
@@ -237,11 +276,6 @@ struct HstuAttentionFwdKernel
kargs.nhead_stride_bias = nhead_stride_bias;
kargs.batch_stride_bias = batch_stride_bias;
}
if constexpr(kHasLocalMask)
{
kargs.window_size = window_size;
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
}
if constexpr(kHasDropout)
{
auto seed = std::get<0>(drop_seed_offset);
@@ -350,42 +384,44 @@ struct HstuAttentionFwdKernel
float p_drop,
const std::pair<uint64_t, uint64_t>& drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
o_ptr,
-1, // seqlen will be updated by another pointer
hdim_qk,
hdim_v,
num_head,
-scale_s,
seq_stride_q,
seq_stride_k,
seq_stride_v,
seq_stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_o,
reinterpret_cast<const int32_t*>(num_targets_ptr),
contextual_seqlen}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for dropout
reinterpret_cast<const int32_t*>(seq_offsets_ptr),
max_seqlen};
Kargs kargs{
{reinterpret_cast<const int32_t*>(seq_offsets_ptr),
seq_stride_q,
seq_stride_k,
seq_stride_v,
seq_stride_o,
reinterpret_cast<const int32_t*>(num_targets_ptr),
q_ptr,
k_ptr,
v_ptr,
o_ptr,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_o,
hdim_qk,
hdim_v,
-1, // seqlen will be updated by another pointer
num_head,
-scale_s,
max_seqlen,
contextual_seqlen}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for bias
{}, // placeholder for dropout
};
if constexpr(kHasLocalMask)
{
kargs.window_size = window_size;
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
}
if constexpr(kHasBias)
{
kargs.bias_ptr = bias_ptr;
kargs.seq_stride_bias = seq_stride_bias;
kargs.nhead_stride_bias = nhead_stride_bias;
}
if constexpr(kHasLocalMask)
{
kargs.window_size = window_size;
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
}
if constexpr(kHasDropout)
{
auto seed = std::get<0>(drop_seed_offset);