From eba3c2f635f2aa16359cec307e8666e62e09f6bc Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Wed, 3 Jun 2026 09:19:32 +0000 Subject: [PATCH] Add parameters used by storing lse in the fwd and fwd_splitkv_combine kernel to prepare for supporting training --- ...stu_attention_batched_forward_dispatch.hpp | 4 + ...ntion_batched_forward_splitkv_dispatch.hpp | 4 + .../hstu_attention_fwd_kernel.hpp | 64 ++++++++++++++- ...u_attention_fwd_splitkv_combine_kernel.hpp | 80 +++++++++++++++---- .../hstu_attention_group_forward_dispatch.hpp | 3 + ...tention_group_forward_splitkv_dispatch.hpp | 3 + ...hstu_attention_jagged_forward_dispatch.hpp | 3 + ...ention_jagged_forward_splitkv_dispatch.hpp | 3 + .../hstu_attention_params.hpp | 26 +++--- .../hstu_attention_pipeline_problem.hpp | 2 + 10 files changed, 162 insertions(+), 30 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp index 3426ddb328..dcb7e86689 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp @@ -139,6 +139,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch param.v_ptr, param.bias_ptr, param.o_ptr, + nullptr, // lse_ptr param.seqlen_q, param.is_cross_attention ? param.seqlen_kv : param.seqlen_q, @@ -152,16 +153,19 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch param.seq_stride_v, param.seq_stride_bias, param.seq_stride_o, + 0, // seq_stride_lse param.nhead_stride_q, param.nhead_stride_k, param.nhead_stride_v, param.nhead_stride_bias, param.nhead_stride_o, + 0, // nhead_stride_lse param.batch_stride_q, param.batch_stride_k, param.batch_stride_v, param.batch_stride_bias, param.batch_stride_o, + 0, // batch_stride_lse param.num_targets_ptr, param.contextual_seqlen, param.window_size, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp index 4d493bf028..4e75d43a05 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp @@ -334,9 +334,13 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch return HstuKernel::MakeKargs(ws.o_acc_ptr, ws.lse_acc_ptr, param.o_ptr, + nullptr, // lse_ptr param.batch_stride_o, + 0, // batch_stride_lse param.seq_stride_o, + 0, // seq_stride_o param.nhead_stride_o, + 0, // nhead_stride_o param.seqlen_q, param.num_head, ws.num_splits, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 5393af38c7..76013bf2e3 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -47,6 +47,7 @@ struct HstuAttentionFwdKernel static constexpr bool kHasDropout = HstuAttentionPipeline::Problem::kHasDropout; static constexpr bool kHasCausalMask = HstuAttentionPipeline::Problem::kHasCausal; static constexpr bool kUseSoftmax = HstuAttentionPipeline::Problem::kUseSoftmax; + static constexpr bool kStoreLSE = HstuAttentionPipeline::Problem::kStoreLSE; static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ; static constexpr bool kPadSeqLenK = HstuAttentionPipeline::kPadSeqLenK; @@ -202,6 +203,21 @@ struct HstuAttentionFwdKernel uint64_t drop_offset; }; + struct HstuAttentionFwdBatchedLSEKargs + { + void* lse_ptr; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t seq_stride_lse; + ck_tile::index_t nhead_stride_lse; + }; + + struct HstuAttentionFwdJaggedLSEKargs + { + void* lse_ptr; + ck_tile::index_t seq_stride_lse; + ck_tile::index_t nhead_stride_lse; + }; + struct HstuAttentionFwdCommonDropoutKargs : HstuAttentionFwdDropoutSeedOffset { void init_dropout(float p_drop, uint64_t seed, uint64_t offset) @@ -226,7 +242,11 @@ struct HstuAttentionFwdKernel HstuAttentionFwdEmptyKargs<1>>, std::conditional_t> + HstuAttentionFwdEmptyKargs<2>>, + std::conditional_t> + { }; @@ -237,7 +257,10 @@ struct HstuAttentionFwdKernel HstuAttentionFwdEmptyKargs<1>>, std::conditional_t> + HstuAttentionFwdEmptyKargs<2>>, + std::conditional_t> { }; @@ -247,7 +270,10 @@ struct HstuAttentionFwdKernel HstuAttentionFwdEmptyKargs<1>>, std::conditional_t> + HstuAttentionFwdEmptyKargs<2>>, + std::conditional_t> { }; @@ -267,6 +293,7 @@ struct HstuAttentionFwdKernel const void* v_ptr, const void* bias_ptr, void* o_ptr, + void* lse_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_kv, ck_tile::index_t hdim_qk, @@ -279,16 +306,19 @@ struct HstuAttentionFwdKernel ck_tile::index_t seq_stride_v, ck_tile::index_t seq_stride_bias, ck_tile::index_t seq_stride_o, + ck_tile::index_t seq_stride_lse, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_lse, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_lse, const void* num_targets_ptr, ck_tile::index_t contextual_seqlen, ck_tile::index_t window_size, @@ -327,6 +357,7 @@ struct HstuAttentionFwdKernel min_full_attn_seqlen}, // args for common karg {}, // placeholder for bias {}, // placeholder for dropout + {}, // placeholder for LSE }; if constexpr(kHasBias) @@ -340,6 +371,13 @@ struct HstuAttentionFwdKernel { kargs.init_dropout(p_drop, philox_seed, philox_offset); } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.batch_stride_lse = batch_stride_lse; + kargs.seq_stride_lse = seq_stride_lse; + kargs.nhead_stride_lse = nhead_stride_lse; + } return kargs; } @@ -351,6 +389,7 @@ struct HstuAttentionFwdKernel const void* v_ptr, const void* bias_ptr, void* o_ptr, + void* lse_ptr, const void* seq_q_offsets_ptr, const void* seq_kv_offsets_ptr, ck_tile::index_t max_seqlen_q, @@ -364,11 +403,13 @@ struct HstuAttentionFwdKernel ck_tile::index_t seq_stride_v, ck_tile::index_t seq_stride_bias, ck_tile::index_t seq_stride_o, + ck_tile::index_t seq_stride_lse, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_lse, const void* num_targets_ptr, ck_tile::index_t contextual_seqlen, ck_tile::index_t window_size, @@ -405,6 +446,7 @@ struct HstuAttentionFwdKernel min_full_attn_seqlen}, // args for common karg {}, // placeholder for bias {}, // placeholder for dropout + {}, // placeholder for LSE }; if constexpr(kHasBias) @@ -417,6 +459,12 @@ struct HstuAttentionFwdKernel { kargs.init_dropout(p_drop, philox_seed, philox_offset); } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.seq_stride_lse = seq_stride_lse; + kargs.nhead_stride_lse = nhead_stride_lse; + } return kargs; } @@ -428,6 +476,7 @@ struct HstuAttentionFwdKernel const void* v_ptr, const void* bias_ptr, void* o_ptr, + void* lse_ptr, ck_tile::index_t num_batch_per_group, const void* seq_q_offsets_ptr, const void* seq_kv_offsets_ptr, @@ -445,11 +494,13 @@ struct HstuAttentionFwdKernel ck_tile::index_t seq_stride_v, ck_tile::index_t seq_stride_bias, ck_tile::index_t seq_stride_o, + ck_tile::index_t seq_stride_lse, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_lse, const void* num_targets_ptr, float p_drop, uint64_t philox_seed, @@ -489,6 +540,7 @@ struct HstuAttentionFwdKernel reinterpret_cast(group_attn_scale_ptr)}, // args for common karg {}, // placeholder for bias {}, // placeholder for dropout + {}, // placeholder for LSE }; if constexpr(kHasBias) @@ -501,6 +553,12 @@ struct HstuAttentionFwdKernel { kargs.init_dropout(p_drop, philox_seed, philox_offset); } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.seq_stride_lse = seq_stride_lse; + kargs.nhead_stride_lse = nhead_stride_lse; + } return kargs; } diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_kernel.hpp index 332d51268b..fdcb1ddbc9 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_kernel.hpp @@ -43,6 +43,7 @@ struct HstuAttentionFwdSplitKVCombineKernel static constexpr bool kIsJagged = HstuAttentionPipeline::Problem::kIsJagged; static constexpr bool kUseSoftmax = HstuAttentionPipeline::Problem::kUseSoftmax; + static constexpr bool kStoreLSE = HstuAttentionPipeline::Problem::kStoreLSE; static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ; static constexpr bool kPadHeadDimO = HstuAttentionPipeline::kPadHeadDimO; @@ -93,17 +94,40 @@ struct HstuAttentionFwdSplitKVCombineKernel const void* lse_acc_ptr = nullptr; }; - struct HstuAttentionBatchedCombineKargs : HstuAttentionBatchedCombineBaseKargs, - std::conditional_t> + struct HstuAttentionBatchedCombineLSEKargs + { + void* lse_ptr; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t seq_stride_lse; + ck_tile::index_t nhead_stride_lse; + }; + + struct HstuAttentionJaggedCombineLSEKargs + { + void* lse_ptr; + ck_tile::index_t seq_stride_lse; + ck_tile::index_t nhead_stride_lse; + }; + + struct HstuAttentionBatchedCombineKargs + : HstuAttentionBatchedCombineBaseKargs, + std::conditional_t>, + std::conditional_t> + { }; struct HstuAttentionJaggedCombineKargs : HstuAttentionJaggedCombineBaseKargs, std::conditional_t> + HstuAttentionCombineEmptyKargs<1>>, + std::conditional_t> { }; @@ -115,29 +139,43 @@ struct HstuAttentionFwdSplitKVCombineKernel MakeKargs(const void* o_acc_ptr, // workspace for accumulation of o const void* lse_acc_ptr, // workspace for accummulation of lse void* o_ptr, + void* lse_ptr, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_lse, ck_tile::index_t seq_stride_o, + ck_tile::index_t seq_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_lse, ck_tile::index_t seqlen_q, ck_tile::index_t num_head, ck_tile::index_t num_splits, // number of splitted seqlen_kv ck_tile::index_t hdim_v) { - Kargs kargs{{o_acc_ptr, - o_ptr, - batch_stride_o, - seq_stride_o, - nhead_stride_o, - seqlen_q, - num_head, - num_splits, - hdim_v}, - {} /* place holder for softmax */}; + Kargs kargs{ + {o_acc_ptr, + o_ptr, + batch_stride_o, + seq_stride_o, + nhead_stride_o, + seqlen_q, + num_head, + num_splits, + hdim_v}, + {}, // place holder for softmax + {}, // place holder for LSE + }; if constexpr(kUseSoftmax) { kargs.lse_acc_ptr = lse_acc_ptr; } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.batch_stride_lse = batch_stride_lse; + kargs.seq_stride_lse = seq_stride_lse; + kargs.nhead_stride_lse = nhead_stride_lse; + } return kargs; } @@ -147,8 +185,11 @@ struct HstuAttentionFwdSplitKVCombineKernel MakeKargs(const void* o_acc_ptr, // workspace for accumulation of o const void* lse_acc_ptr, // workspace for accummulation of lse void* o_ptr, + void* lse_ptr, ck_tile::index_t seq_stride_o, + ck_tile::index_t seq_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_lse, const void* seq_q_offsets_ptr, ck_tile::index_t num_head, ck_tile::index_t num_splits, // number of splitted seqlen_kv @@ -164,13 +205,20 @@ struct HstuAttentionFwdSplitKVCombineKernel num_splits, hdim_v, 0 /* seqlen_q will be updated later*/}, - {} /* place holder for softmax */ + {}, // place holder for softmax + {}, // place holder for LSE }; if constexpr(kUseSoftmax) { kargs.lse_acc_ptr = lse_acc_ptr; } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.seq_stride_lse = seq_stride_lse; + kargs.nhead_stride_lse = nhead_stride_lse; + } return kargs; } diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp index b649ffac79..d5858c36b9 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_dispatch.hpp @@ -129,6 +129,7 @@ struct group_forward_causal_softmax_bias_dropout_dispatch param.v_ptr, param.bias_ptr, param.o_ptr, + nullptr, // lse_ptr param.num_batch / param.num_group, param.seq_q_offsets_ptr, param.is_cross_attention ? param.seq_kv_offsets_ptr @@ -147,11 +148,13 @@ struct group_forward_causal_softmax_bias_dropout_dispatch param.seq_stride_v, param.seq_stride_bias, param.seq_stride_o, + 0, // seq_stride_lse param.nhead_stride_q, param.nhead_stride_k, param.nhead_stride_v, param.nhead_stride_bias, param.nhead_stride_o, + 0, // nhead_stride_lse param.num_targets_ptr, param.p_drop, param.philox_seed, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp index 35eda44a0d..8aac606be2 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_group_forward_splitkv_dispatch.hpp @@ -320,8 +320,11 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch return HstuKernel::MakeKargs(ws.o_acc_ptr, ws.lse_acc_ptr, param.o_ptr, + nullptr, // lse_ptr param.seq_stride_o, + 0, // seq_stride_lse param.nhead_stride_o, + 0, // nhead_stride_lse param.seq_q_offsets_ptr, param.num_head, ws.num_splits, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index 90d833d585..017f99358b 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -129,6 +129,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch param.v_ptr, param.bias_ptr, param.o_ptr, + nullptr, // lse_ptr param.seq_q_offsets_ptr, param.is_cross_attention ? param.seq_kv_offsets_ptr : param.seq_q_offsets_ptr, @@ -143,11 +144,13 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch param.seq_stride_v, param.seq_stride_bias, param.seq_stride_o, + 0, // seq_stride_o param.nhead_stride_q, param.nhead_stride_k, param.nhead_stride_v, param.nhead_stride_bias, param.nhead_stride_o, + 0, // nhead_stride_o param.num_targets_ptr, param.contextual_seqlen, param.window_size, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp index 57e5338a55..2de5ff6bc8 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_splitkv_dispatch.hpp @@ -323,8 +323,11 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch return HstuKernel::MakeKargs(ws.o_acc_ptr, ws.lse_acc_ptr, param.o_ptr, + nullptr, // lse_ptr param.seq_stride_o, + 0, // seq_stride_lse param.nhead_stride_o, + 0, // nhead_stride_lse param.seq_q_offsets_ptr, param.num_head, ws.num_splits, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp index cd43e74ce5..91dcacd527 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp @@ -14,6 +14,10 @@ struct HstuAttentionNoGroupFwdParams bool is_jagged; + bool use_softmax; + + bool is_training; + ck_tile::index_t num_batch; ck_tile::index_t seqlen_q; // batched mode only ck_tile::index_t seqlen_kv; // batched mode only @@ -26,6 +30,7 @@ struct HstuAttentionNoGroupFwdParams const void* v_ptr; const void* bias_ptr; void* o_ptr; + void* lse_ptr; // only used when both is_training and use_softmax be true ck_tile::index_t hdim_qk; ck_tile::index_t hdim_v; @@ -38,12 +43,14 @@ struct HstuAttentionNoGroupFwdParams ck_tile::index_t seq_stride_v; ck_tile::index_t seq_stride_bias; ck_tile::index_t seq_stride_o; + ck_tile::index_t seq_stride_lse; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_bias; ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_lse; // batched mode only parameters ck_tile::index_t batch_stride_q; @@ -51,6 +58,7 @@ struct HstuAttentionNoGroupFwdParams ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_bias; ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_lse; const void* num_targets_ptr; @@ -60,8 +68,6 @@ struct HstuAttentionNoGroupFwdParams ck_tile::index_t contextual_seqlen; ck_tile::index_t min_full_attn_seqlen; - bool use_softmax; - float p_drop; uint64_t philox_seed; uint64_t philox_offset; @@ -73,6 +79,10 @@ struct HstuAttentionGroupFwdParams // 1) either seq_kv_offsets_ptr == nullptr, or seq_kv_offsets_ptr == seq_q_offsets_ptr bool is_cross_attention; + bool use_softmax; + + bool is_training; + ck_tile::index_t num_group; ck_tile::index_t num_batch; const void* seq_q_offsets_ptr; @@ -84,6 +94,7 @@ struct HstuAttentionGroupFwdParams const void* v_ptr; const void* bias_ptr; void* o_ptr; + void* lse_ptr; // only used when both is_training and use_softmax be true ck_tile::index_t hdim_qk; ck_tile::index_t hdim_v; @@ -95,19 +106,14 @@ struct HstuAttentionGroupFwdParams ck_tile::index_t seq_stride_v; ck_tile::index_t seq_stride_bias; ck_tile::index_t seq_stride_o; + ck_tile::index_t seq_stride_lse; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_bias; ck_tile::index_t nhead_stride_o; - - // batched mode only parameters - ck_tile::index_t batch_stride_q; - ck_tile::index_t batch_stride_k; - ck_tile::index_t batch_stride_v; - ck_tile::index_t batch_stride_bias; - ck_tile::index_t batch_stride_o; + ck_tile::index_t nhead_stride_lse; const void* num_targets_ptr; @@ -120,8 +126,6 @@ struct HstuAttentionGroupFwdParams const void* group_contextual_seqlen_ptr; const void* group_min_full_attn_seqlen_ptr; - bool use_softmax; - float p_drop; uint64_t philox_seed; uint64_t philox_offset; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp index 6547986384..8b982c6f04 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp @@ -96,6 +96,8 @@ struct HstuAttentionFwdPipelineProblem static_assert(!kUseGroup || (kUseGroup && kIsJagged), "Group HSTU is only used with jagged mode!"); + static_assert(!kStoreLSE || (kStoreLSE && kUseSoftmax), + "Storing Lse is only necessary when softmax is used!"); using HstuAttentionTileSetting = remove_cvref_t;