Tiny simplification with defining the Bias related Kargs

This commit is contained in:
Qianfeng Zhang
2026-06-03 09:43:44 +00:00
parent f41b0176d3
commit ee5bd0ebba
2 changed files with 24 additions and 18 deletions

View File

@@ -185,16 +185,19 @@ struct HstuAttentionFwdKernel
const float* group_attn_scale_ptr;
};
struct HstuAttentionFwdCommonBiasKargs
struct HstuAttentionFwdBatchedBiasKargs
{
const void* bias_ptr = nullptr;
ck_tile::index_t seq_stride_bias = 0;
ck_tile::index_t nhead_stride_bias = 0;
const void* bias_ptr;
ck_tile::index_t seq_stride_bias;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t batch_stride_bias;
};
struct HstuAttentionFwdBatchModeBiasKargs : HstuAttentionFwdCommonBiasKargs
struct HstuAttentionFwdJaggedBiasKargs
{
ck_tile::index_t batch_stride_bias = 0;
const void* bias_ptr;
ck_tile::index_t seq_stride_bias;
ck_tile::index_t nhead_stride_bias;
};
struct HstuAttentionFwdDropoutSeedOffset
@@ -238,7 +241,7 @@ struct HstuAttentionFwdKernel
struct HstuAttentionNoGroupBatchedFwdKargs
: HstuAttentionNoGroupBatchedFwdBaseKargs,
std::conditional_t<kHasBias,
HstuAttentionFwdBatchModeBiasKargs,
HstuAttentionFwdBatchedBiasKargs,
HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,
@@ -253,7 +256,7 @@ struct HstuAttentionFwdKernel
struct HstuAttentionNoGroupJaggedFwdKargs
: HstuAttentionNoGroupJaggedFwdBaseKargs,
std::conditional_t<kHasBias,
HstuAttentionFwdCommonBiasKargs,
HstuAttentionFwdJaggedBiasKargs,
HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,
@@ -266,7 +269,7 @@ struct HstuAttentionFwdKernel
struct HstuAttentionGroupFwdKargs : HstuAttentionGroupFwdBaseKargs,
std::conditional_t<kHasBias,
HstuAttentionFwdCommonBiasKargs,
HstuAttentionFwdJaggedBiasKargs,
HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,

View File

@@ -185,16 +185,19 @@ struct HstuAttentionFwdSplitKVKernel
const float* group_attn_scale_ptr;
};
struct HstuAttentionFwdCommonBiasKargs
struct HstuAttentionFwdBatchedBiasKargs
{
const void* bias_ptr = nullptr;
ck_tile::index_t seq_stride_bias = 0;
ck_tile::index_t nhead_stride_bias = 0;
const void* bias_ptr;
ck_tile::index_t seq_stride_bias;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t batch_stride_bias;
};
struct HstuAttentionFwdBatchModeBiasKargs : HstuAttentionFwdCommonBiasKargs
struct HstuAttentionFwdJaggedBiasKargs
{
ck_tile::index_t batch_stride_bias = 0;
const void* bias_ptr;
ck_tile::index_t seq_stride_bias;
ck_tile::index_t nhead_stride_bias;
};
struct HstuAttentionFwdDropoutSeedOffset
@@ -228,7 +231,7 @@ struct HstuAttentionFwdSplitKVKernel
struct HstuAttentionNoGroupBatchedFwdKargs
: HstuAttentionNoGroupBatchedFwdBaseKargs,
std::conditional_t<kHasBias,
HstuAttentionFwdBatchModeBiasKargs,
HstuAttentionFwdBatchedBiasKargs,
HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,
@@ -242,7 +245,7 @@ struct HstuAttentionFwdSplitKVKernel
struct HstuAttentionNoGroupJaggedFwdKargs
: HstuAttentionNoGroupJaggedFwdBaseKargs,
std::conditional_t<kHasBias,
HstuAttentionFwdCommonBiasKargs,
HstuAttentionFwdJaggedBiasKargs,
HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,
@@ -255,7 +258,7 @@ struct HstuAttentionFwdSplitKVKernel
struct HstuAttentionGroupFwdKargs : HstuAttentionGroupFwdBaseKargs,
std::conditional_t<kHasBias,
HstuAttentionFwdCommonBiasKargs,
HstuAttentionFwdJaggedBiasKargs,
HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,