mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
Add parameters used by storing lse in the fwd and fwd_splitkv_combine kernel to prepare for supporting training
This commit is contained in:
@@ -139,6 +139,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
param.v_ptr,
|
||||
param.bias_ptr,
|
||||
param.o_ptr,
|
||||
nullptr, // lse_ptr
|
||||
param.seqlen_q,
|
||||
param.is_cross_attention ? param.seqlen_kv
|
||||
: param.seqlen_q,
|
||||
@@ -152,16 +153,19 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
param.seq_stride_v,
|
||||
param.seq_stride_bias,
|
||||
param.seq_stride_o,
|
||||
0, // seq_stride_lse
|
||||
param.nhead_stride_q,
|
||||
param.nhead_stride_k,
|
||||
param.nhead_stride_v,
|
||||
param.nhead_stride_bias,
|
||||
param.nhead_stride_o,
|
||||
0, // nhead_stride_lse
|
||||
param.batch_stride_q,
|
||||
param.batch_stride_k,
|
||||
param.batch_stride_v,
|
||||
param.batch_stride_bias,
|
||||
param.batch_stride_o,
|
||||
0, // batch_stride_lse
|
||||
param.num_targets_ptr,
|
||||
param.contextual_seqlen,
|
||||
param.window_size,
|
||||
|
||||
@@ -334,9 +334,13 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
return HstuKernel::MakeKargs(ws.o_acc_ptr,
|
||||
ws.lse_acc_ptr,
|
||||
param.o_ptr,
|
||||
nullptr, // lse_ptr
|
||||
param.batch_stride_o,
|
||||
0, // batch_stride_lse
|
||||
param.seq_stride_o,
|
||||
0, // seq_stride_o
|
||||
param.nhead_stride_o,
|
||||
0, // nhead_stride_o
|
||||
param.seqlen_q,
|
||||
param.num_head,
|
||||
ws.num_splits,
|
||||
|
||||
@@ -47,6 +47,7 @@ struct HstuAttentionFwdKernel
|
||||
static constexpr bool kHasDropout = HstuAttentionPipeline::Problem::kHasDropout;
|
||||
static constexpr bool kHasCausalMask = HstuAttentionPipeline::Problem::kHasCausal;
|
||||
static constexpr bool kUseSoftmax = HstuAttentionPipeline::Problem::kUseSoftmax;
|
||||
static constexpr bool kStoreLSE = HstuAttentionPipeline::Problem::kStoreLSE;
|
||||
|
||||
static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = HstuAttentionPipeline::kPadSeqLenK;
|
||||
@@ -202,6 +203,21 @@ struct HstuAttentionFwdKernel
|
||||
uint64_t drop_offset;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdBatchedLSEKargs
|
||||
{
|
||||
void* lse_ptr;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t seq_stride_lse;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdJaggedLSEKargs
|
||||
{
|
||||
void* lse_ptr;
|
||||
ck_tile::index_t seq_stride_lse;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdCommonDropoutKargs : HstuAttentionFwdDropoutSeedOffset
|
||||
{
|
||||
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
|
||||
@@ -226,7 +242,11 @@ struct HstuAttentionFwdKernel
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
HstuAttentionFwdEmptyKargs<2>>
|
||||
HstuAttentionFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kStoreLSE,
|
||||
HstuAttentionFwdBatchedLSEKargs,
|
||||
HstuAttentionFwdEmptyKargs<3>>
|
||||
|
||||
{
|
||||
};
|
||||
|
||||
@@ -237,7 +257,10 @@ struct HstuAttentionFwdKernel
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
HstuAttentionFwdEmptyKargs<2>>
|
||||
HstuAttentionFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kStoreLSE,
|
||||
HstuAttentionFwdJaggedLSEKargs,
|
||||
HstuAttentionFwdEmptyKargs<3>>
|
||||
{
|
||||
};
|
||||
|
||||
@@ -247,7 +270,10 @@ struct HstuAttentionFwdKernel
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
HstuAttentionFwdEmptyKargs<2>>
|
||||
HstuAttentionFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kStoreLSE,
|
||||
HstuAttentionFwdJaggedLSEKargs,
|
||||
HstuAttentionFwdEmptyKargs<3>>
|
||||
{
|
||||
};
|
||||
|
||||
@@ -267,6 +293,7 @@ struct HstuAttentionFwdKernel
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* o_ptr,
|
||||
void* lse_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_kv,
|
||||
ck_tile::index_t hdim_qk,
|
||||
@@ -279,16 +306,19 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t seq_stride_v,
|
||||
ck_tile::index_t seq_stride_bias,
|
||||
ck_tile::index_t seq_stride_o,
|
||||
ck_tile::index_t seq_stride_lse,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
const void* num_targets_ptr,
|
||||
ck_tile::index_t contextual_seqlen,
|
||||
ck_tile::index_t window_size,
|
||||
@@ -327,6 +357,7 @@ struct HstuAttentionFwdKernel
|
||||
min_full_attn_seqlen}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for dropout
|
||||
{}, // placeholder for LSE
|
||||
};
|
||||
|
||||
if constexpr(kHasBias)
|
||||
@@ -340,6 +371,13 @@ struct HstuAttentionFwdKernel
|
||||
{
|
||||
kargs.init_dropout(p_drop, philox_seed, philox_offset);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.batch_stride_lse = batch_stride_lse;
|
||||
kargs.seq_stride_lse = seq_stride_lse;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -351,6 +389,7 @@ struct HstuAttentionFwdKernel
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* o_ptr,
|
||||
void* lse_ptr,
|
||||
const void* seq_q_offsets_ptr,
|
||||
const void* seq_kv_offsets_ptr,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
@@ -364,11 +403,13 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t seq_stride_v,
|
||||
ck_tile::index_t seq_stride_bias,
|
||||
ck_tile::index_t seq_stride_o,
|
||||
ck_tile::index_t seq_stride_lse,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
const void* num_targets_ptr,
|
||||
ck_tile::index_t contextual_seqlen,
|
||||
ck_tile::index_t window_size,
|
||||
@@ -405,6 +446,7 @@ struct HstuAttentionFwdKernel
|
||||
min_full_attn_seqlen}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for dropout
|
||||
{}, // placeholder for LSE
|
||||
};
|
||||
|
||||
if constexpr(kHasBias)
|
||||
@@ -417,6 +459,12 @@ struct HstuAttentionFwdKernel
|
||||
{
|
||||
kargs.init_dropout(p_drop, philox_seed, philox_offset);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.seq_stride_lse = seq_stride_lse;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -428,6 +476,7 @@ struct HstuAttentionFwdKernel
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* o_ptr,
|
||||
void* lse_ptr,
|
||||
ck_tile::index_t num_batch_per_group,
|
||||
const void* seq_q_offsets_ptr,
|
||||
const void* seq_kv_offsets_ptr,
|
||||
@@ -445,11 +494,13 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t seq_stride_v,
|
||||
ck_tile::index_t seq_stride_bias,
|
||||
ck_tile::index_t seq_stride_o,
|
||||
ck_tile::index_t seq_stride_lse,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
const void* num_targets_ptr,
|
||||
float p_drop,
|
||||
uint64_t philox_seed,
|
||||
@@ -489,6 +540,7 @@ struct HstuAttentionFwdKernel
|
||||
reinterpret_cast<const float*>(group_attn_scale_ptr)}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for dropout
|
||||
{}, // placeholder for LSE
|
||||
};
|
||||
|
||||
if constexpr(kHasBias)
|
||||
@@ -501,6 +553,12 @@ struct HstuAttentionFwdKernel
|
||||
{
|
||||
kargs.init_dropout(p_drop, philox_seed, philox_offset);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.seq_stride_lse = seq_stride_lse;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ struct HstuAttentionFwdSplitKVCombineKernel
|
||||
|
||||
static constexpr bool kIsJagged = HstuAttentionPipeline::Problem::kIsJagged;
|
||||
static constexpr bool kUseSoftmax = HstuAttentionPipeline::Problem::kUseSoftmax;
|
||||
static constexpr bool kStoreLSE = HstuAttentionPipeline::Problem::kStoreLSE;
|
||||
|
||||
static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDimO = HstuAttentionPipeline::kPadHeadDimO;
|
||||
@@ -93,17 +94,40 @@ struct HstuAttentionFwdSplitKVCombineKernel
|
||||
const void* lse_acc_ptr = nullptr;
|
||||
};
|
||||
|
||||
struct HstuAttentionBatchedCombineKargs : HstuAttentionBatchedCombineBaseKargs,
|
||||
std::conditional_t<kUseSoftmax,
|
||||
HstuAttentionCombineSoftmaxKargs,
|
||||
HstuAttentionCombineEmptyKargs<1>>
|
||||
struct HstuAttentionBatchedCombineLSEKargs
|
||||
{
|
||||
void* lse_ptr;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t seq_stride_lse;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
};
|
||||
|
||||
struct HstuAttentionJaggedCombineLSEKargs
|
||||
{
|
||||
void* lse_ptr;
|
||||
ck_tile::index_t seq_stride_lse;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
};
|
||||
|
||||
struct HstuAttentionBatchedCombineKargs
|
||||
: HstuAttentionBatchedCombineBaseKargs,
|
||||
std::conditional_t<kUseSoftmax,
|
||||
HstuAttentionCombineSoftmaxKargs,
|
||||
HstuAttentionCombineEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE,
|
||||
HstuAttentionBatchedCombineLSEKargs,
|
||||
HstuAttentionCombineEmptyKargs<2>>
|
||||
|
||||
{
|
||||
};
|
||||
|
||||
struct HstuAttentionJaggedCombineKargs : HstuAttentionJaggedCombineBaseKargs,
|
||||
std::conditional_t<kUseSoftmax,
|
||||
HstuAttentionCombineSoftmaxKargs,
|
||||
HstuAttentionCombineEmptyKargs<1>>
|
||||
HstuAttentionCombineEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE,
|
||||
HstuAttentionJaggedCombineLSEKargs,
|
||||
HstuAttentionCombineEmptyKargs<2>>
|
||||
{
|
||||
};
|
||||
|
||||
@@ -115,29 +139,43 @@ struct HstuAttentionFwdSplitKVCombineKernel
|
||||
MakeKargs(const void* o_acc_ptr, // workspace for accumulation of o
|
||||
const void* lse_acc_ptr, // workspace for accummulation of lse
|
||||
void* o_ptr,
|
||||
void* lse_ptr,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t seq_stride_o,
|
||||
ck_tile::index_t seq_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t num_head,
|
||||
ck_tile::index_t num_splits, // number of splitted seqlen_kv
|
||||
ck_tile::index_t hdim_v)
|
||||
{
|
||||
Kargs kargs{{o_acc_ptr,
|
||||
o_ptr,
|
||||
batch_stride_o,
|
||||
seq_stride_o,
|
||||
nhead_stride_o,
|
||||
seqlen_q,
|
||||
num_head,
|
||||
num_splits,
|
||||
hdim_v},
|
||||
{} /* place holder for softmax */};
|
||||
Kargs kargs{
|
||||
{o_acc_ptr,
|
||||
o_ptr,
|
||||
batch_stride_o,
|
||||
seq_stride_o,
|
||||
nhead_stride_o,
|
||||
seqlen_q,
|
||||
num_head,
|
||||
num_splits,
|
||||
hdim_v},
|
||||
{}, // place holder for softmax
|
||||
{}, // place holder for LSE
|
||||
};
|
||||
|
||||
if constexpr(kUseSoftmax)
|
||||
{
|
||||
kargs.lse_acc_ptr = lse_acc_ptr;
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.batch_stride_lse = batch_stride_lse;
|
||||
kargs.seq_stride_lse = seq_stride_lse;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -147,8 +185,11 @@ struct HstuAttentionFwdSplitKVCombineKernel
|
||||
MakeKargs(const void* o_acc_ptr, // workspace for accumulation of o
|
||||
const void* lse_acc_ptr, // workspace for accummulation of lse
|
||||
void* o_ptr,
|
||||
void* lse_ptr,
|
||||
ck_tile::index_t seq_stride_o,
|
||||
ck_tile::index_t seq_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
const void* seq_q_offsets_ptr,
|
||||
ck_tile::index_t num_head,
|
||||
ck_tile::index_t num_splits, // number of splitted seqlen_kv
|
||||
@@ -164,13 +205,20 @@ struct HstuAttentionFwdSplitKVCombineKernel
|
||||
num_splits,
|
||||
hdim_v,
|
||||
0 /* seqlen_q will be updated later*/},
|
||||
{} /* place holder for softmax */
|
||||
{}, // place holder for softmax
|
||||
{}, // place holder for LSE
|
||||
};
|
||||
|
||||
if constexpr(kUseSoftmax)
|
||||
{
|
||||
kargs.lse_acc_ptr = lse_acc_ptr;
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
kargs.lse_ptr = lse_ptr;
|
||||
kargs.seq_stride_lse = seq_stride_lse;
|
||||
kargs.nhead_stride_lse = nhead_stride_lse;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -129,6 +129,7 @@ struct group_forward_causal_softmax_bias_dropout_dispatch
|
||||
param.v_ptr,
|
||||
param.bias_ptr,
|
||||
param.o_ptr,
|
||||
nullptr, // lse_ptr
|
||||
param.num_batch / param.num_group,
|
||||
param.seq_q_offsets_ptr,
|
||||
param.is_cross_attention ? param.seq_kv_offsets_ptr
|
||||
@@ -147,11 +148,13 @@ struct group_forward_causal_softmax_bias_dropout_dispatch
|
||||
param.seq_stride_v,
|
||||
param.seq_stride_bias,
|
||||
param.seq_stride_o,
|
||||
0, // seq_stride_lse
|
||||
param.nhead_stride_q,
|
||||
param.nhead_stride_k,
|
||||
param.nhead_stride_v,
|
||||
param.nhead_stride_bias,
|
||||
param.nhead_stride_o,
|
||||
0, // nhead_stride_lse
|
||||
param.num_targets_ptr,
|
||||
param.p_drop,
|
||||
param.philox_seed,
|
||||
|
||||
@@ -320,8 +320,11 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
return HstuKernel::MakeKargs(ws.o_acc_ptr,
|
||||
ws.lse_acc_ptr,
|
||||
param.o_ptr,
|
||||
nullptr, // lse_ptr
|
||||
param.seq_stride_o,
|
||||
0, // seq_stride_lse
|
||||
param.nhead_stride_o,
|
||||
0, // nhead_stride_lse
|
||||
param.seq_q_offsets_ptr,
|
||||
param.num_head,
|
||||
ws.num_splits,
|
||||
|
||||
@@ -129,6 +129,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
param.v_ptr,
|
||||
param.bias_ptr,
|
||||
param.o_ptr,
|
||||
nullptr, // lse_ptr
|
||||
param.seq_q_offsets_ptr,
|
||||
param.is_cross_attention ? param.seq_kv_offsets_ptr
|
||||
: param.seq_q_offsets_ptr,
|
||||
@@ -143,11 +144,13 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
param.seq_stride_v,
|
||||
param.seq_stride_bias,
|
||||
param.seq_stride_o,
|
||||
0, // seq_stride_o
|
||||
param.nhead_stride_q,
|
||||
param.nhead_stride_k,
|
||||
param.nhead_stride_v,
|
||||
param.nhead_stride_bias,
|
||||
param.nhead_stride_o,
|
||||
0, // nhead_stride_o
|
||||
param.num_targets_ptr,
|
||||
param.contextual_seqlen,
|
||||
param.window_size,
|
||||
|
||||
@@ -323,8 +323,11 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
return HstuKernel::MakeKargs(ws.o_acc_ptr,
|
||||
ws.lse_acc_ptr,
|
||||
param.o_ptr,
|
||||
nullptr, // lse_ptr
|
||||
param.seq_stride_o,
|
||||
0, // seq_stride_lse
|
||||
param.nhead_stride_o,
|
||||
0, // nhead_stride_lse
|
||||
param.seq_q_offsets_ptr,
|
||||
param.num_head,
|
||||
ws.num_splits,
|
||||
|
||||
@@ -14,6 +14,10 @@ struct HstuAttentionNoGroupFwdParams
|
||||
|
||||
bool is_jagged;
|
||||
|
||||
bool use_softmax;
|
||||
|
||||
bool is_training;
|
||||
|
||||
ck_tile::index_t num_batch;
|
||||
ck_tile::index_t seqlen_q; // batched mode only
|
||||
ck_tile::index_t seqlen_kv; // batched mode only
|
||||
@@ -26,6 +30,7 @@ struct HstuAttentionNoGroupFwdParams
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr;
|
||||
void* o_ptr;
|
||||
void* lse_ptr; // only used when both is_training and use_softmax be true
|
||||
|
||||
ck_tile::index_t hdim_qk;
|
||||
ck_tile::index_t hdim_v;
|
||||
@@ -38,12 +43,14 @@ struct HstuAttentionNoGroupFwdParams
|
||||
ck_tile::index_t seq_stride_v;
|
||||
ck_tile::index_t seq_stride_bias;
|
||||
ck_tile::index_t seq_stride_o;
|
||||
ck_tile::index_t seq_stride_lse;
|
||||
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
|
||||
// batched mode only parameters
|
||||
ck_tile::index_t batch_stride_q;
|
||||
@@ -51,6 +58,7 @@ struct HstuAttentionNoGroupFwdParams
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
|
||||
const void* num_targets_ptr;
|
||||
|
||||
@@ -60,8 +68,6 @@ struct HstuAttentionNoGroupFwdParams
|
||||
ck_tile::index_t contextual_seqlen;
|
||||
ck_tile::index_t min_full_attn_seqlen;
|
||||
|
||||
bool use_softmax;
|
||||
|
||||
float p_drop;
|
||||
uint64_t philox_seed;
|
||||
uint64_t philox_offset;
|
||||
@@ -73,6 +79,10 @@ struct HstuAttentionGroupFwdParams
|
||||
// 1) either seq_kv_offsets_ptr == nullptr, or seq_kv_offsets_ptr == seq_q_offsets_ptr
|
||||
bool is_cross_attention;
|
||||
|
||||
bool use_softmax;
|
||||
|
||||
bool is_training;
|
||||
|
||||
ck_tile::index_t num_group;
|
||||
ck_tile::index_t num_batch;
|
||||
const void* seq_q_offsets_ptr;
|
||||
@@ -84,6 +94,7 @@ struct HstuAttentionGroupFwdParams
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr;
|
||||
void* o_ptr;
|
||||
void* lse_ptr; // only used when both is_training and use_softmax be true
|
||||
|
||||
ck_tile::index_t hdim_qk;
|
||||
ck_tile::index_t hdim_v;
|
||||
@@ -95,19 +106,14 @@ struct HstuAttentionGroupFwdParams
|
||||
ck_tile::index_t seq_stride_v;
|
||||
ck_tile::index_t seq_stride_bias;
|
||||
ck_tile::index_t seq_stride_o;
|
||||
ck_tile::index_t seq_stride_lse;
|
||||
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
// batched mode only parameters
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
|
||||
const void* num_targets_ptr;
|
||||
|
||||
@@ -120,8 +126,6 @@ struct HstuAttentionGroupFwdParams
|
||||
const void* group_contextual_seqlen_ptr;
|
||||
const void* group_min_full_attn_seqlen_ptr;
|
||||
|
||||
bool use_softmax;
|
||||
|
||||
float p_drop;
|
||||
uint64_t philox_seed;
|
||||
uint64_t philox_offset;
|
||||
|
||||
@@ -96,6 +96,8 @@ struct HstuAttentionFwdPipelineProblem
|
||||
|
||||
static_assert(!kUseGroup || (kUseGroup && kIsJagged),
|
||||
"Group HSTU is only used with jagged mode!");
|
||||
static_assert(!kStoreLSE || (kStoreLSE && kUseSoftmax),
|
||||
"Storing Lse is only necessary when softmax is used!");
|
||||
|
||||
using HstuAttentionTileSetting = remove_cvref_t<AttentionTileSetting_>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user