mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
Re-org the kernel parameters in HstuAttentionFwdBatchModeBaseKargs and HstuAttentionFwdJaggModeBaseKargs
This commit is contained in:
@@ -62,34 +62,80 @@ struct HstuAttentionFwdKernel
|
||||
// kargs use aggregate initializer, so no constructor will provided
|
||||
// use inheritance to minimize karg size
|
||||
// user need to use MakeKargs() function to create kargs.
|
||||
struct HstuAttentionFwdCommonKargs
|
||||
struct HstuAttentionFwdBatchModeBaseKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
const int32_t* num_targets_ptr;
|
||||
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t seqlen;
|
||||
ck_tile::index_t hdim_qk;
|
||||
ck_tile::index_t hdim_v;
|
||||
|
||||
ck_tile::index_t num_head;
|
||||
float scale_s;
|
||||
|
||||
ck_tile::index_t seq_stride_q;
|
||||
ck_tile::index_t seq_stride_k;
|
||||
ck_tile::index_t seq_stride_v;
|
||||
ck_tile::index_t seq_stride_o;
|
||||
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
const int32_t* num_targets_ptr;
|
||||
ck_tile::index_t seqlen;
|
||||
ck_tile::index_t hdim_qk;
|
||||
ck_tile::index_t hdim_v;
|
||||
|
||||
ck_tile::index_t seq_stride_q;
|
||||
ck_tile::index_t seq_stride_k;
|
||||
ck_tile::index_t seq_stride_v;
|
||||
ck_tile::index_t seq_stride_o;
|
||||
|
||||
ck_tile::index_t num_head;
|
||||
float scale_s;
|
||||
ck_tile::index_t max_seqlen;
|
||||
|
||||
ck_tile::index_t contextual_seqlen;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdJaggModeBaseKargs
|
||||
{
|
||||
const int32_t* seq_offsets_ptr;
|
||||
|
||||
ck_tile::index_t seq_stride_q;
|
||||
ck_tile::index_t seq_stride_k;
|
||||
ck_tile::index_t seq_stride_v;
|
||||
ck_tile::index_t seq_stride_o;
|
||||
|
||||
const int32_t* num_targets_ptr;
|
||||
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
ck_tile::index_t hdim_qk;
|
||||
ck_tile::index_t hdim_v;
|
||||
|
||||
ck_tile::index_t seqlen;
|
||||
|
||||
ck_tile::index_t num_head;
|
||||
float scale_s;
|
||||
ck_tile::index_t max_seqlen;
|
||||
|
||||
ck_tile::index_t contextual_seqlen;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdMaskKargs
|
||||
{
|
||||
ck_tile::index_t window_size;
|
||||
ck_tile::index_t min_full_attn_seqlen;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdCommonBiasKargs
|
||||
{
|
||||
const void* bias_ptr = nullptr;
|
||||
@@ -102,12 +148,6 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t batch_stride_bias = 0;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdMaskKargs
|
||||
{
|
||||
ck_tile::index_t window_size;
|
||||
ck_tile::index_t min_full_attn_seqlen;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdDropoutSeedOffset
|
||||
{
|
||||
uint64_t drop_seed;
|
||||
@@ -131,38 +171,30 @@ struct HstuAttentionFwdKernel
|
||||
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdBatchModeKargs : HstuAttentionFwdCommonKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdBatchModeBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<0>>,
|
||||
struct HstuAttentionFwdBatchModeKargs : HstuAttentionFwdBatchModeBaseKargs,
|
||||
std::conditional_t<kHasLocalMask,
|
||||
HstuAttentionFwdMaskKargs,
|
||||
HstuAttentionFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdBatchModeBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
HstuAttentionFwdEmptyKargs<2>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
ck_tile::index_t max_seqlen;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdJaggModeKargs : HstuAttentionFwdCommonKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdCommonBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<0>>,
|
||||
struct HstuAttentionFwdJaggModeKargs : HstuAttentionFwdJaggModeBaseKargs,
|
||||
std::conditional_t<kHasLocalMask,
|
||||
HstuAttentionFwdMaskKargs,
|
||||
HstuAttentionFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdCommonBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
HstuAttentionFwdEmptyKargs<2>>
|
||||
{
|
||||
const int32_t* seq_offsets_ptr;
|
||||
ck_tile::index_t max_seqlen;
|
||||
};
|
||||
|
||||
using Kargs = std::
|
||||
@@ -202,34 +234,41 @@ struct HstuAttentionFwdKernel
|
||||
float p_drop,
|
||||
const std::pair<uint64_t, uint64_t>& drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
seqlen,
|
||||
hdim_qk,
|
||||
hdim_v,
|
||||
num_head,
|
||||
-scale_s,
|
||||
seq_stride_q,
|
||||
seq_stride_k,
|
||||
seq_stride_v,
|
||||
seq_stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_o,
|
||||
reinterpret_cast<const int32_t*>(num_targets_ptr),
|
||||
contextual_seqlen}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for dropout
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_o,
|
||||
seqlen}; // max_seqlen
|
||||
Kargs kargs{
|
||||
{batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_o,
|
||||
reinterpret_cast<const int32_t*>(num_targets_ptr),
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_o,
|
||||
seqlen,
|
||||
hdim_qk,
|
||||
hdim_v,
|
||||
seq_stride_q,
|
||||
seq_stride_k,
|
||||
seq_stride_v,
|
||||
seq_stride_o,
|
||||
num_head,
|
||||
-scale_s,
|
||||
seqlen, // max_seqlen
|
||||
contextual_seqlen}, // args for common karg
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for dropout
|
||||
};
|
||||
|
||||
if constexpr(kHasLocalMask)
|
||||
{
|
||||
kargs.window_size = window_size;
|
||||
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
|
||||
}
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
kargs.bias_ptr = bias_ptr;
|
||||
@@ -237,11 +276,6 @@ struct HstuAttentionFwdKernel
|
||||
kargs.nhead_stride_bias = nhead_stride_bias;
|
||||
kargs.batch_stride_bias = batch_stride_bias;
|
||||
}
|
||||
if constexpr(kHasLocalMask)
|
||||
{
|
||||
kargs.window_size = window_size;
|
||||
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto seed = std::get<0>(drop_seed_offset);
|
||||
@@ -350,42 +384,44 @@ struct HstuAttentionFwdKernel
|
||||
float p_drop,
|
||||
const std::pair<uint64_t, uint64_t>& drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
hdim_qk,
|
||||
hdim_v,
|
||||
num_head,
|
||||
-scale_s,
|
||||
seq_stride_q,
|
||||
seq_stride_k,
|
||||
seq_stride_v,
|
||||
seq_stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_o,
|
||||
reinterpret_cast<const int32_t*>(num_targets_ptr),
|
||||
contextual_seqlen}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for dropout
|
||||
reinterpret_cast<const int32_t*>(seq_offsets_ptr),
|
||||
max_seqlen};
|
||||
Kargs kargs{
|
||||
{reinterpret_cast<const int32_t*>(seq_offsets_ptr),
|
||||
seq_stride_q,
|
||||
seq_stride_k,
|
||||
seq_stride_v,
|
||||
seq_stride_o,
|
||||
reinterpret_cast<const int32_t*>(num_targets_ptr),
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_o,
|
||||
hdim_qk,
|
||||
hdim_v,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
num_head,
|
||||
-scale_s,
|
||||
max_seqlen,
|
||||
contextual_seqlen}, // args for common karg
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for dropout
|
||||
};
|
||||
|
||||
if constexpr(kHasLocalMask)
|
||||
{
|
||||
kargs.window_size = window_size;
|
||||
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
|
||||
}
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
kargs.bias_ptr = bias_ptr;
|
||||
kargs.seq_stride_bias = seq_stride_bias;
|
||||
kargs.nhead_stride_bias = nhead_stride_bias;
|
||||
}
|
||||
if constexpr(kHasLocalMask)
|
||||
{
|
||||
kargs.window_size = window_size;
|
||||
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto seed = std::get<0>(drop_seed_offset);
|
||||
|
||||
Reference in New Issue
Block a user