Add parameters used by storing lse in the fwd and fwd_splitkv_combine kernel to prepare for supporting training

This commit is contained in:
Qianfeng Zhang
2026-06-03 09:19:32 +00:00
parent 5ee8a37cd3
commit eba3c2f635
10 changed files with 162 additions and 30 deletions

View File

@@ -139,6 +139,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
param.v_ptr,
param.bias_ptr,
param.o_ptr,
nullptr, // lse_ptr
param.seqlen_q,
param.is_cross_attention ? param.seqlen_kv
: param.seqlen_q,
@@ -152,16 +153,19 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
param.seq_stride_v,
param.seq_stride_bias,
param.seq_stride_o,
0, // seq_stride_lse
param.nhead_stride_q,
param.nhead_stride_k,
param.nhead_stride_v,
param.nhead_stride_bias,
param.nhead_stride_o,
0, // nhead_stride_lse
param.batch_stride_q,
param.batch_stride_k,
param.batch_stride_v,
param.batch_stride_bias,
param.batch_stride_o,
0, // batch_stride_lse
param.num_targets_ptr,
param.contextual_seqlen,
param.window_size,

View File

@@ -334,9 +334,13 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
return HstuKernel::MakeKargs(ws.o_acc_ptr,
ws.lse_acc_ptr,
param.o_ptr,
nullptr, // lse_ptr
param.batch_stride_o,
0, // batch_stride_lse
param.seq_stride_o,
0, // seq_stride_o
param.nhead_stride_o,
0, // nhead_stride_o
param.seqlen_q,
param.num_head,
ws.num_splits,

View File

@@ -47,6 +47,7 @@ struct HstuAttentionFwdKernel
static constexpr bool kHasDropout = HstuAttentionPipeline::Problem::kHasDropout;
static constexpr bool kHasCausalMask = HstuAttentionPipeline::Problem::kHasCausal;
static constexpr bool kUseSoftmax = HstuAttentionPipeline::Problem::kUseSoftmax;
static constexpr bool kStoreLSE = HstuAttentionPipeline::Problem::kStoreLSE;
static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = HstuAttentionPipeline::kPadSeqLenK;
@@ -202,6 +203,21 @@ struct HstuAttentionFwdKernel
uint64_t drop_offset;
};
struct HstuAttentionFwdBatchedLSEKargs
{
void* lse_ptr;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t seq_stride_lse;
ck_tile::index_t nhead_stride_lse;
};
struct HstuAttentionFwdJaggedLSEKargs
{
void* lse_ptr;
ck_tile::index_t seq_stride_lse;
ck_tile::index_t nhead_stride_lse;
};
struct HstuAttentionFwdCommonDropoutKargs : HstuAttentionFwdDropoutSeedOffset
{
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
@@ -226,7 +242,11 @@ struct HstuAttentionFwdKernel
HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,
HstuAttentionFwdEmptyKargs<2>>
HstuAttentionFwdEmptyKargs<2>>,
std::conditional_t<kStoreLSE,
HstuAttentionFwdBatchedLSEKargs,
HstuAttentionFwdEmptyKargs<3>>
{
};
@@ -237,7 +257,10 @@ struct HstuAttentionFwdKernel
HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,
HstuAttentionFwdEmptyKargs<2>>
HstuAttentionFwdEmptyKargs<2>>,
std::conditional_t<kStoreLSE,
HstuAttentionFwdJaggedLSEKargs,
HstuAttentionFwdEmptyKargs<3>>
{
};
@@ -247,7 +270,10 @@ struct HstuAttentionFwdKernel
HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,
HstuAttentionFwdEmptyKargs<2>>
HstuAttentionFwdEmptyKargs<2>>,
std::conditional_t<kStoreLSE,
HstuAttentionFwdJaggedLSEKargs,
HstuAttentionFwdEmptyKargs<3>>
{
};
@@ -267,6 +293,7 @@ struct HstuAttentionFwdKernel
const void* v_ptr,
const void* bias_ptr,
void* o_ptr,
void* lse_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_kv,
ck_tile::index_t hdim_qk,
@@ -279,16 +306,19 @@ struct HstuAttentionFwdKernel
ck_tile::index_t seq_stride_v,
ck_tile::index_t seq_stride_bias,
ck_tile::index_t seq_stride_o,
ck_tile::index_t seq_stride_lse,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_o,
ck_tile::index_t batch_stride_lse,
const void* num_targets_ptr,
ck_tile::index_t contextual_seqlen,
ck_tile::index_t window_size,
@@ -327,6 +357,7 @@ struct HstuAttentionFwdKernel
min_full_attn_seqlen}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for dropout
{}, // placeholder for LSE
};
if constexpr(kHasBias)
@@ -340,6 +371,13 @@ struct HstuAttentionFwdKernel
{
kargs.init_dropout(p_drop, philox_seed, philox_offset);
}
if constexpr(kStoreLSE)
{
kargs.lse_ptr = lse_ptr;
kargs.batch_stride_lse = batch_stride_lse;
kargs.seq_stride_lse = seq_stride_lse;
kargs.nhead_stride_lse = nhead_stride_lse;
}
return kargs;
}
@@ -351,6 +389,7 @@ struct HstuAttentionFwdKernel
const void* v_ptr,
const void* bias_ptr,
void* o_ptr,
void* lse_ptr,
const void* seq_q_offsets_ptr,
const void* seq_kv_offsets_ptr,
ck_tile::index_t max_seqlen_q,
@@ -364,11 +403,13 @@ struct HstuAttentionFwdKernel
ck_tile::index_t seq_stride_v,
ck_tile::index_t seq_stride_bias,
ck_tile::index_t seq_stride_o,
ck_tile::index_t seq_stride_lse,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_lse,
const void* num_targets_ptr,
ck_tile::index_t contextual_seqlen,
ck_tile::index_t window_size,
@@ -405,6 +446,7 @@ struct HstuAttentionFwdKernel
min_full_attn_seqlen}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for dropout
{}, // placeholder for LSE
};
if constexpr(kHasBias)
@@ -417,6 +459,12 @@ struct HstuAttentionFwdKernel
{
kargs.init_dropout(p_drop, philox_seed, philox_offset);
}
if constexpr(kStoreLSE)
{
kargs.lse_ptr = lse_ptr;
kargs.seq_stride_lse = seq_stride_lse;
kargs.nhead_stride_lse = nhead_stride_lse;
}
return kargs;
}
@@ -428,6 +476,7 @@ struct HstuAttentionFwdKernel
const void* v_ptr,
const void* bias_ptr,
void* o_ptr,
void* lse_ptr,
ck_tile::index_t num_batch_per_group,
const void* seq_q_offsets_ptr,
const void* seq_kv_offsets_ptr,
@@ -445,11 +494,13 @@ struct HstuAttentionFwdKernel
ck_tile::index_t seq_stride_v,
ck_tile::index_t seq_stride_bias,
ck_tile::index_t seq_stride_o,
ck_tile::index_t seq_stride_lse,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_lse,
const void* num_targets_ptr,
float p_drop,
uint64_t philox_seed,
@@ -489,6 +540,7 @@ struct HstuAttentionFwdKernel
reinterpret_cast<const float*>(group_attn_scale_ptr)}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for dropout
{}, // placeholder for LSE
};
if constexpr(kHasBias)
@@ -501,6 +553,12 @@ struct HstuAttentionFwdKernel
{
kargs.init_dropout(p_drop, philox_seed, philox_offset);
}
if constexpr(kStoreLSE)
{
kargs.lse_ptr = lse_ptr;
kargs.seq_stride_lse = seq_stride_lse;
kargs.nhead_stride_lse = nhead_stride_lse;
}
return kargs;
}

View File

@@ -43,6 +43,7 @@ struct HstuAttentionFwdSplitKVCombineKernel
static constexpr bool kIsJagged = HstuAttentionPipeline::Problem::kIsJagged;
static constexpr bool kUseSoftmax = HstuAttentionPipeline::Problem::kUseSoftmax;
static constexpr bool kStoreLSE = HstuAttentionPipeline::Problem::kStoreLSE;
static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ;
static constexpr bool kPadHeadDimO = HstuAttentionPipeline::kPadHeadDimO;
@@ -93,17 +94,40 @@ struct HstuAttentionFwdSplitKVCombineKernel
const void* lse_acc_ptr = nullptr;
};
struct HstuAttentionBatchedCombineKargs : HstuAttentionBatchedCombineBaseKargs,
std::conditional_t<kUseSoftmax,
HstuAttentionCombineSoftmaxKargs,
HstuAttentionCombineEmptyKargs<1>>
struct HstuAttentionBatchedCombineLSEKargs
{
void* lse_ptr;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t seq_stride_lse;
ck_tile::index_t nhead_stride_lse;
};
struct HstuAttentionJaggedCombineLSEKargs
{
void* lse_ptr;
ck_tile::index_t seq_stride_lse;
ck_tile::index_t nhead_stride_lse;
};
struct HstuAttentionBatchedCombineKargs
: HstuAttentionBatchedCombineBaseKargs,
std::conditional_t<kUseSoftmax,
HstuAttentionCombineSoftmaxKargs,
HstuAttentionCombineEmptyKargs<1>>,
std::conditional_t<kStoreLSE,
HstuAttentionBatchedCombineLSEKargs,
HstuAttentionCombineEmptyKargs<2>>
{
};
struct HstuAttentionJaggedCombineKargs : HstuAttentionJaggedCombineBaseKargs,
std::conditional_t<kUseSoftmax,
HstuAttentionCombineSoftmaxKargs,
HstuAttentionCombineEmptyKargs<1>>
HstuAttentionCombineEmptyKargs<1>>,
std::conditional_t<kStoreLSE,
HstuAttentionJaggedCombineLSEKargs,
HstuAttentionCombineEmptyKargs<2>>
{
};
@@ -115,29 +139,43 @@ struct HstuAttentionFwdSplitKVCombineKernel
MakeKargs(const void* o_acc_ptr, // workspace for accumulation of o
const void* lse_acc_ptr, // workspace for accummulation of lse
void* o_ptr,
void* lse_ptr,
ck_tile::index_t batch_stride_o,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t seq_stride_o,
ck_tile::index_t seq_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t seqlen_q,
ck_tile::index_t num_head,
ck_tile::index_t num_splits, // number of splitted seqlen_kv
ck_tile::index_t hdim_v)
{
Kargs kargs{{o_acc_ptr,
o_ptr,
batch_stride_o,
seq_stride_o,
nhead_stride_o,
seqlen_q,
num_head,
num_splits,
hdim_v},
{} /* place holder for softmax */};
Kargs kargs{
{o_acc_ptr,
o_ptr,
batch_stride_o,
seq_stride_o,
nhead_stride_o,
seqlen_q,
num_head,
num_splits,
hdim_v},
{}, // place holder for softmax
{}, // place holder for LSE
};
if constexpr(kUseSoftmax)
{
kargs.lse_acc_ptr = lse_acc_ptr;
}
if constexpr(kStoreLSE)
{
kargs.lse_ptr = lse_ptr;
kargs.batch_stride_lse = batch_stride_lse;
kargs.seq_stride_lse = seq_stride_lse;
kargs.nhead_stride_lse = nhead_stride_lse;
}
return kargs;
}
@@ -147,8 +185,11 @@ struct HstuAttentionFwdSplitKVCombineKernel
MakeKargs(const void* o_acc_ptr, // workspace for accumulation of o
const void* lse_acc_ptr, // workspace for accummulation of lse
void* o_ptr,
void* lse_ptr,
ck_tile::index_t seq_stride_o,
ck_tile::index_t seq_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_lse,
const void* seq_q_offsets_ptr,
ck_tile::index_t num_head,
ck_tile::index_t num_splits, // number of splitted seqlen_kv
@@ -164,13 +205,20 @@ struct HstuAttentionFwdSplitKVCombineKernel
num_splits,
hdim_v,
0 /* seqlen_q will be updated later*/},
{} /* place holder for softmax */
{}, // place holder for softmax
{}, // place holder for LSE
};
if constexpr(kUseSoftmax)
{
kargs.lse_acc_ptr = lse_acc_ptr;
}
if constexpr(kStoreLSE)
{
kargs.lse_ptr = lse_ptr;
kargs.seq_stride_lse = seq_stride_lse;
kargs.nhead_stride_lse = nhead_stride_lse;
}
return kargs;
}

View File

@@ -129,6 +129,7 @@ struct group_forward_causal_softmax_bias_dropout_dispatch
param.v_ptr,
param.bias_ptr,
param.o_ptr,
nullptr, // lse_ptr
param.num_batch / param.num_group,
param.seq_q_offsets_ptr,
param.is_cross_attention ? param.seq_kv_offsets_ptr
@@ -147,11 +148,13 @@ struct group_forward_causal_softmax_bias_dropout_dispatch
param.seq_stride_v,
param.seq_stride_bias,
param.seq_stride_o,
0, // seq_stride_lse
param.nhead_stride_q,
param.nhead_stride_k,
param.nhead_stride_v,
param.nhead_stride_bias,
param.nhead_stride_o,
0, // nhead_stride_lse
param.num_targets_ptr,
param.p_drop,
param.philox_seed,

View File

@@ -320,8 +320,11 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
return HstuKernel::MakeKargs(ws.o_acc_ptr,
ws.lse_acc_ptr,
param.o_ptr,
nullptr, // lse_ptr
param.seq_stride_o,
0, // seq_stride_lse
param.nhead_stride_o,
0, // nhead_stride_lse
param.seq_q_offsets_ptr,
param.num_head,
ws.num_splits,

View File

@@ -129,6 +129,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
param.v_ptr,
param.bias_ptr,
param.o_ptr,
nullptr, // lse_ptr
param.seq_q_offsets_ptr,
param.is_cross_attention ? param.seq_kv_offsets_ptr
: param.seq_q_offsets_ptr,
@@ -143,11 +144,13 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
param.seq_stride_v,
param.seq_stride_bias,
param.seq_stride_o,
0, // seq_stride_o
param.nhead_stride_q,
param.nhead_stride_k,
param.nhead_stride_v,
param.nhead_stride_bias,
param.nhead_stride_o,
0, // nhead_stride_o
param.num_targets_ptr,
param.contextual_seqlen,
param.window_size,

View File

@@ -323,8 +323,11 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
return HstuKernel::MakeKargs(ws.o_acc_ptr,
ws.lse_acc_ptr,
param.o_ptr,
nullptr, // lse_ptr
param.seq_stride_o,
0, // seq_stride_lse
param.nhead_stride_o,
0, // nhead_stride_lse
param.seq_q_offsets_ptr,
param.num_head,
ws.num_splits,

View File

@@ -14,6 +14,10 @@ struct HstuAttentionNoGroupFwdParams
bool is_jagged;
bool use_softmax;
bool is_training;
ck_tile::index_t num_batch;
ck_tile::index_t seqlen_q; // batched mode only
ck_tile::index_t seqlen_kv; // batched mode only
@@ -26,6 +30,7 @@ struct HstuAttentionNoGroupFwdParams
const void* v_ptr;
const void* bias_ptr;
void* o_ptr;
void* lse_ptr; // only used when both is_training and use_softmax be true
ck_tile::index_t hdim_qk;
ck_tile::index_t hdim_v;
@@ -38,12 +43,14 @@ struct HstuAttentionNoGroupFwdParams
ck_tile::index_t seq_stride_v;
ck_tile::index_t seq_stride_bias;
ck_tile::index_t seq_stride_o;
ck_tile::index_t seq_stride_lse;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_lse;
// batched mode only parameters
ck_tile::index_t batch_stride_q;
@@ -51,6 +58,7 @@ struct HstuAttentionNoGroupFwdParams
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_lse;
const void* num_targets_ptr;
@@ -60,8 +68,6 @@ struct HstuAttentionNoGroupFwdParams
ck_tile::index_t contextual_seqlen;
ck_tile::index_t min_full_attn_seqlen;
bool use_softmax;
float p_drop;
uint64_t philox_seed;
uint64_t philox_offset;
@@ -73,6 +79,10 @@ struct HstuAttentionGroupFwdParams
// 1) either seq_kv_offsets_ptr == nullptr, or seq_kv_offsets_ptr == seq_q_offsets_ptr
bool is_cross_attention;
bool use_softmax;
bool is_training;
ck_tile::index_t num_group;
ck_tile::index_t num_batch;
const void* seq_q_offsets_ptr;
@@ -84,6 +94,7 @@ struct HstuAttentionGroupFwdParams
const void* v_ptr;
const void* bias_ptr;
void* o_ptr;
void* lse_ptr; // only used when both is_training and use_softmax be true
ck_tile::index_t hdim_qk;
ck_tile::index_t hdim_v;
@@ -95,19 +106,14 @@ struct HstuAttentionGroupFwdParams
ck_tile::index_t seq_stride_v;
ck_tile::index_t seq_stride_bias;
ck_tile::index_t seq_stride_o;
ck_tile::index_t seq_stride_lse;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_o;
// batched mode only parameters
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_o;
ck_tile::index_t nhead_stride_lse;
const void* num_targets_ptr;
@@ -120,8 +126,6 @@ struct HstuAttentionGroupFwdParams
const void* group_contextual_seqlen_ptr;
const void* group_min_full_attn_seqlen_ptr;
bool use_softmax;
float p_drop;
uint64_t philox_seed;
uint64_t philox_offset;

View File

@@ -96,6 +96,8 @@ struct HstuAttentionFwdPipelineProblem
static_assert(!kUseGroup || (kUseGroup && kIsJagged),
"Group HSTU is only used with jagged mode!");
static_assert(!kStoreLSE || (kStoreLSE && kUseSoftmax),
"Storing Lse is only necessary when softmax is used!");
using HstuAttentionTileSetting = remove_cvref_t<AttentionTileSetting_>;