From cc184fc2020939cabcb534d65cc2af0792941106 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Thu, 4 Jun 2026 09:58:13 +0000 Subject: [PATCH] Add support for preparing lse_dram_window in hstu fwd kernel --- .../hstu_attention_fwd_kernel.hpp | 47 +++++++++++++++++-- .../hstu_attention_params.hpp | 8 ++-- .../hstu_attention_pipeline_problem.hpp | 1 - ...tu_attention_with_softmax_fwd_pipeline.hpp | 43 +++++++++-------- ...ntion_with_softmax_fwd_trload_pipeline.hpp | 44 ++++++++--------- 5 files changed, 92 insertions(+), 51 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 2b9f7a3d2f..d0ecab36bd 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -41,6 +41,8 @@ struct HstuAttentionFwdKernel using BiasDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; + using CompDataType = + ck_tile::remove_cvref_t; static constexpr bool kIsCrossAttention = HstuAttentionPipeline::Problem::kIsCrossAttention; static constexpr bool kUseGroup = HstuAttentionPipeline::Problem::kUseGroup; @@ -701,6 +703,7 @@ struct HstuAttentionFwdKernel long_index_t batch_offset_v = 0; long_index_t batch_offset_bias = 0; long_index_t batch_offset_o = 0; + long_index_t batch_offset_lse = 0; if constexpr(kIsJagged) { @@ -717,6 +720,10 @@ struct HstuAttentionFwdKernel batch_offset_bias = query_start * kargs.seq_stride_bias; } batch_offset_o = query_start * kargs.seq_stride_o; + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start * kargs.seq_stride_lse; + } kargs.seqlen_q = kargs.seq_q_offsets_ptr[i_batch + 1] - kargs.seq_q_offsets_ptr[i_batch]; @@ -747,6 +754,10 @@ struct HstuAttentionFwdKernel batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } } int num_target = (kargs.num_targets_ptr == nullptr) ? 0 : kargs.num_targets_ptr[i_batch]; @@ -916,6 +927,35 @@ struct HstuAttentionFwdKernel } }(); + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = + make_tuple(number{}); + if constexpr(kStoreLSE) + { + CompDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view( + lse_ptr, + make_tuple(seqlen_q_in_ctrl), + make_tuple(kargs.seq_stride_lse), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { if constexpr(kHasDropout) { @@ -985,13 +1025,11 @@ struct HstuAttentionFwdKernel } else { - auto null_tile_window = ck_tile::make_null_tile_window(ck_tile::make_tuple()); - return HstuAttentionPipeline{}(q_dram_window, k_dram_window, v_dram_window, bias_dram_window, - null_tile_window, + lse_dram_window, seqlen_k_start, seqlen_k_end, mask, @@ -1040,12 +1078,11 @@ struct HstuAttentionFwdKernel } else { - auto null_tile_window = ck_tile::make_null_tile_window(ck_tile::make_tuple()); return HstuAttentionPipeline{}(q_dram_window, k_dram_window, v_dram_window, bias_dram_window, - null_tile_window, + lse_dram_window, seqlen_k_start, seqlen_k_end, mask, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp index 91dcacd527..2c20c4b4ab 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp @@ -43,14 +43,14 @@ struct HstuAttentionNoGroupFwdParams ck_tile::index_t seq_stride_v; ck_tile::index_t seq_stride_bias; ck_tile::index_t seq_stride_o; - ck_tile::index_t seq_stride_lse; + ck_tile::index_t seq_stride_lse; // not needed if lse layout is [nhead, seqlen_q] ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_bias; ck_tile::index_t nhead_stride_o; - ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_lse; // not needed if lse layout is [seqlen_q, nhead] // batched mode only parameters ck_tile::index_t batch_stride_q; @@ -106,14 +106,14 @@ struct HstuAttentionGroupFwdParams ck_tile::index_t seq_stride_v; ck_tile::index_t seq_stride_bias; ck_tile::index_t seq_stride_o; - ck_tile::index_t seq_stride_lse; + ck_tile::index_t seq_stride_lse; // not needed if lse layout is [nhead, seqlen_q] ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_bias; ck_tile::index_t nhead_stride_o; - ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_lse; // not needed if lse layout is [seqlen_q, nhead] const void* num_targets_ptr; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp index 8b982c6f04..8252f77fbe 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp @@ -81,7 +81,6 @@ struct HstuAttentionFwdPipelineProblem using CompDataType = remove_cvref_t; using BiasDataType = remove_cvref_t; - // to be compatible with ck_tile existing policy codes using OaccDataType = GemmAccDataType; using PDataType = QKVDataType; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index c282f4946e..62433bd678 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -119,7 +119,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename LSEaccDramBlockWindowTmp, + typename LSEorLSEaccDramBlockWindowTmp, typename QElementFunction, typename BiasElementFunction, typename LSEaccElementFunction, @@ -134,8 +134,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, - LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0 tile - const LSEaccElementFunction& lse_acc_element_func, + LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile + const LSEaccElementFunction& lse_or_lse_acc_element_func, const SAccElementFunction& s_acc_element_func, const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, @@ -204,15 +204,15 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS clear_tile(o_acc); o_acc = tile_elementwise_in(o_acc_element_func, o_acc); - if constexpr(!is_null_tile_window_v) + if constexpr(!is_null_tile_window_v) { - auto lse_acc = + auto lse_or_lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse_acc, -numeric::infinity()); + set_tile(lse_or_lse_acc, -numeric::infinity()); - store_tile(lse_acc_dram_block_window_tmp, - tile_elementwise_in(lse_acc_element_func, lse_acc)); + store_tile(lse_or_lse_acc_dram_block_window, + tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc)); } return o_acc; @@ -598,19 +598,22 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS }; } while(seqlen_k_curr < seqlen_k_end); - if constexpr(!is_null_tile_window_v) + // if pipeline is called from splitkv_kernel, the window shall not be null; + // if pipeline is called from non-splitkv kernel, the window is null if kStoreLSE is false + if constexpr(!is_null_tile_window_v) { - // store lse acc - auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); + // store lse_or_lse_acc + auto lse_or_lse_acc = + make_static_distributed_tensor(m.get_tile_distribution()); - constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); - sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); + constexpr auto lse_or_lse_acc_spans = decltype(lse_or_lse_acc)::get_distributed_spans(); + sweep_tile_span(lse_or_lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + lse_or_lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); }); - store_tile(lse_acc_dram_block_window_tmp, - tile_elementwise_in(lse_acc_element_func, lse_acc)); + store_tile(lse_or_lse_acc_dram_block_window, + tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc)); } constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); @@ -638,14 +641,14 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename LSEaccDramBlockWindowTmp, + typename LSEorLSEaccDramBlockWindowTmp, typename HstuMask> CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0 tile + LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile index_t seqlen_k_start, index_t seqlen_k_end, HstuMask mask, @@ -660,7 +663,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS v_dram_block_window_tmp, bias_dram_block_window_tmp, identity{}, - lse_acc_dram_block_window_tmp, + lse_or_lse_acc_dram_block_window, identity{}, identity{}, identity{}, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index 52201a74df..13944f04fb 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -119,7 +119,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename LSEaccDramBlockWindowTmp, + typename LSEorLSEaccDramBlockWindowTmp, typename QElementFunction, typename BiasElementFunction, typename LSEaccElementFunction, @@ -134,8 +134,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, - LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0 tile - const LSEaccElementFunction& lse_acc_element_func, + LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile + const LSEaccElementFunction& lse_or_lse_acc_element_func, const SAccElementFunction& s_acc_element_func, const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, @@ -205,15 +205,15 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad clear_tile(o_acc); o_acc = tile_elementwise_in(o_acc_element_func, o_acc); - if constexpr(!is_null_tile_window_v) + if constexpr(!is_null_tile_window_v) { - auto lse_acc = + auto lse_or_lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse_acc, -numeric::infinity()); + set_tile(lse_or_lse_acc, -numeric::infinity()); - store_tile(lse_acc_dram_block_window_tmp, - tile_elementwise_in(lse_acc_element_func, lse_acc)); + store_tile(lse_or_lse_acc_dram_block_window, + tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc)); } return o_acc; @@ -604,19 +604,22 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad }; } while(seqlen_k_curr < seqlen_k_end); - if constexpr(!is_null_tile_window_v) + // if pipeline is called from splitkv_kernel, the window shall not be null; + // if pipeline is called from non-splitkv kernel, the window is null if kStoreLSE is false + if constexpr(!is_null_tile_window_v) { - // store lse acc - auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); + // store lse or lse_acc + auto lse_or_lse_acc = + make_static_distributed_tensor(m.get_tile_distribution()); - constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); - sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); + constexpr auto lse_or_lse_acc_spans = decltype(lse_or_lse_acc)::get_distributed_spans(); + sweep_tile_span(lse_or_lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + lse_or_lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); }); - store_tile(lse_acc_dram_block_window_tmp, - tile_elementwise_in(lse_acc_element_func, lse_acc)); + store_tile(lse_or_lse_acc_dram_block_window, + tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc)); } constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); @@ -644,15 +647,14 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename LSEaccDramBlockWindowTmp, + typename LSEorLSEaccDramBlockWindowTmp, typename HstuMask> CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEaccDramBlockWindowTmp& - lse_acc_dram_block_window_tmp, // M0 tile // + LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile index_t seqlen_k_start, index_t seqlen_k_end, HstuMask mask, @@ -667,7 +669,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad v_dram_block_window_tmp, bias_dram_block_window_tmp, identity{}, - lse_acc_dram_block_window_tmp, + lse_or_lse_acc_dram_block_window, identity{}, identity{}, identity{},