mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-28 18:56:59 +00:00
Tiny simplification with defining the Bias related Kargs
This commit is contained in:
@@ -185,16 +185,19 @@ struct HstuAttentionFwdKernel
|
||||
const float* group_attn_scale_ptr;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdCommonBiasKargs
|
||||
struct HstuAttentionFwdBatchedBiasKargs
|
||||
{
|
||||
const void* bias_ptr = nullptr;
|
||||
ck_tile::index_t seq_stride_bias = 0;
|
||||
ck_tile::index_t nhead_stride_bias = 0;
|
||||
const void* bias_ptr;
|
||||
ck_tile::index_t seq_stride_bias;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdBatchModeBiasKargs : HstuAttentionFwdCommonBiasKargs
|
||||
struct HstuAttentionFwdJaggedBiasKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_bias = 0;
|
||||
const void* bias_ptr;
|
||||
ck_tile::index_t seq_stride_bias;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdDropoutSeedOffset
|
||||
@@ -238,7 +241,7 @@ struct HstuAttentionFwdKernel
|
||||
struct HstuAttentionNoGroupBatchedFwdKargs
|
||||
: HstuAttentionNoGroupBatchedFwdBaseKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdBatchModeBiasKargs,
|
||||
HstuAttentionFwdBatchedBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
@@ -253,7 +256,7 @@ struct HstuAttentionFwdKernel
|
||||
struct HstuAttentionNoGroupJaggedFwdKargs
|
||||
: HstuAttentionNoGroupJaggedFwdBaseKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdCommonBiasKargs,
|
||||
HstuAttentionFwdJaggedBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
@@ -266,7 +269,7 @@ struct HstuAttentionFwdKernel
|
||||
|
||||
struct HstuAttentionGroupFwdKargs : HstuAttentionGroupFwdBaseKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdCommonBiasKargs,
|
||||
HstuAttentionFwdJaggedBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
|
||||
@@ -185,16 +185,19 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
const float* group_attn_scale_ptr;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdCommonBiasKargs
|
||||
struct HstuAttentionFwdBatchedBiasKargs
|
||||
{
|
||||
const void* bias_ptr = nullptr;
|
||||
ck_tile::index_t seq_stride_bias = 0;
|
||||
ck_tile::index_t nhead_stride_bias = 0;
|
||||
const void* bias_ptr;
|
||||
ck_tile::index_t seq_stride_bias;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdBatchModeBiasKargs : HstuAttentionFwdCommonBiasKargs
|
||||
struct HstuAttentionFwdJaggedBiasKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_bias = 0;
|
||||
const void* bias_ptr;
|
||||
ck_tile::index_t seq_stride_bias;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdDropoutSeedOffset
|
||||
@@ -228,7 +231,7 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
struct HstuAttentionNoGroupBatchedFwdKargs
|
||||
: HstuAttentionNoGroupBatchedFwdBaseKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdBatchModeBiasKargs,
|
||||
HstuAttentionFwdBatchedBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
@@ -242,7 +245,7 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
struct HstuAttentionNoGroupJaggedFwdKargs
|
||||
: HstuAttentionNoGroupJaggedFwdBaseKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdCommonBiasKargs,
|
||||
HstuAttentionFwdJaggedBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
@@ -255,7 +258,7 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
|
||||
struct HstuAttentionGroupFwdKargs : HstuAttentionGroupFwdBaseKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdCommonBiasKargs,
|
||||
HstuAttentionFwdJaggedBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
|
||||
Reference in New Issue
Block a user