From 30b5d7bd0193faed28d34ffffab0330281f66693 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sat, 23 May 2026 04:37:52 +0000 Subject: [PATCH] Use buffer_view to create lse_acc_dram_naive so that out_of_boundary loading value can be specified (be -inf) --- ...u_attention_fwd_splitkv_combine_kernel.hpp | 21 +++++++++++++------ ...h_softmax_fwd_splitkv_combine_pipeline.hpp | 14 ------------- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_kernel.hpp index c2b8492012..332d51268b 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_kernel.hpp @@ -317,13 +317,22 @@ struct HstuAttentionFwdSplitKVCombineKernel static_cast(i_nhead) * kargs.num_splits + batch_offset_lse_acc; // LSEacc DRAM and LSEacc DRAM window - auto seq_stride_lse_acc = kargs.num_head * kargs.num_splits; - const auto lse_acc_dram_naive = make_naive_tensor_view( + auto seq_stride_lse_acc = kargs.num_head * kargs.num_splits; + + auto lse_acc_desc = + make_naive_tensor_descriptor(make_tuple(kargs.seqlen_q, kargs.num_splits), + make_tuple(seq_stride_lse_acc, 1), + number{}, + number<1>{}); + + auto lse_acc_buf_view = make_buffer_view( lse_acc_ptr, - make_tuple(kargs.seqlen_q, kargs.num_splits), - make_tuple(seq_stride_lse_acc, 1), - number{}, - number<1>{}); + lse_acc_desc.get_element_space_size(), + -numeric::infinity()); + + auto lse_acc_dram_naive = + tensor_view{ + lse_acc_buf_view, lse_acc_desc}; const auto lse_acc_dram = pad_tensor_view(lse_acc_dram_naive, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp index 546e8926a6..a22f91b084 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp @@ -117,20 +117,6 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline // provide partition_index for LDS tile window so that warp_id is in vgpr array partition_index{get_warp_id(), get_lane_id()}; - // ToDo: use buffer_view interface to enable the tile loading to set -inf for oob elements - sweep_tile_span(lse_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(lse_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - const auto x_indices = get_x_indices_from_distributed_indices( - lse_acc.get_tile_distribution(), i_j_idx, partition_index); - - const auto col = x_indices.at(number<1>{}); - if(col >= num_splits) - lse_acc(i_j_idx) = -numeric::infinity(); - }); - }); - // calculate max of lse_acc[] across all splits for all rows in the tile, lse_max is // only used for stablizing the exp() auto lse_max = block_tile_reduce(