Use buffer_view to create lse_acc_dram_naive so that out_of_boundary loading value can be specified (be -inf)

This commit is contained in:
Qianfeng Zhang
2026-05-23 04:37:52 +00:00
parent 9a7cc5b4a3
commit 30b5d7bd01
2 changed files with 15 additions and 20 deletions

View File

@@ -317,13 +317,22 @@ struct HstuAttentionFwdSplitKVCombineKernel
static_cast<long_index_t>(i_nhead) * kargs.num_splits + batch_offset_lse_acc;
// LSEacc DRAM and LSEacc DRAM window
auto seq_stride_lse_acc = kargs.num_head * kargs.num_splits;
const auto lse_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
auto seq_stride_lse_acc = kargs.num_head * kargs.num_splits;
auto lse_acc_desc =
make_naive_tensor_descriptor(make_tuple(kargs.seqlen_q, kargs.num_splits),
make_tuple(seq_stride_lse_acc, 1),
number<HstuAttentionPipeline::kAlignmentLSEacc>{},
number<1>{});
auto lse_acc_buf_view = make_buffer_view<address_space_enum::global>(
lse_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.num_splits),
make_tuple(seq_stride_lse_acc, 1),
number<HstuAttentionPipeline::kAlignmentLSEacc>{},
number<1>{});
lse_acc_desc.get_element_space_size(),
-numeric<LSEDataType>::infinity());
auto lse_acc_dram_naive =
tensor_view<decltype(lse_acc_buf_view), decltype(lse_acc_desc)>{
lse_acc_buf_view, lse_acc_desc};
const auto lse_acc_dram =
pad_tensor_view(lse_acc_dram_naive,

View File

@@ -117,20 +117,6 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline
// provide partition_index for LDS tile window so that warp_id is in vgpr
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
// ToDo: use buffer_view interface to enable the tile loading to set -inf for oob elements
sweep_tile_span(lse_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(lse_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
lse_acc.get_tile_distribution(), i_j_idx, partition_index);
const auto col = x_indices.at(number<1>{});
if(col >= num_splits)
lse_acc(i_j_idx) = -numeric<LSEDataType>::infinity();
});
});
// calculate max of lse_acc[] across all splits for all rows in the tile, lse_max is
// only used for stablizing the exp()
auto lse_max = block_tile_reduce<LSEDataType>(