mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-28 18:56:59 +00:00
Use buffer_view to create lse_acc_dram_naive so that out_of_boundary loading value can be specified (be -inf)
This commit is contained in:
@@ -317,13 +317,22 @@ struct HstuAttentionFwdSplitKVCombineKernel
|
||||
static_cast<long_index_t>(i_nhead) * kargs.num_splits + batch_offset_lse_acc;
|
||||
|
||||
// LSEacc DRAM and LSEacc DRAM window
|
||||
auto seq_stride_lse_acc = kargs.num_head * kargs.num_splits;
|
||||
const auto lse_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
auto seq_stride_lse_acc = kargs.num_head * kargs.num_splits;
|
||||
|
||||
auto lse_acc_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(kargs.seqlen_q, kargs.num_splits),
|
||||
make_tuple(seq_stride_lse_acc, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentLSEacc>{},
|
||||
number<1>{});
|
||||
|
||||
auto lse_acc_buf_view = make_buffer_view<address_space_enum::global>(
|
||||
lse_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.num_splits),
|
||||
make_tuple(seq_stride_lse_acc, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentLSEacc>{},
|
||||
number<1>{});
|
||||
lse_acc_desc.get_element_space_size(),
|
||||
-numeric<LSEDataType>::infinity());
|
||||
|
||||
auto lse_acc_dram_naive =
|
||||
tensor_view<decltype(lse_acc_buf_view), decltype(lse_acc_desc)>{
|
||||
lse_acc_buf_view, lse_acc_desc};
|
||||
|
||||
const auto lse_acc_dram =
|
||||
pad_tensor_view(lse_acc_dram_naive,
|
||||
|
||||
@@ -117,20 +117,6 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline
|
||||
// provide partition_index for LDS tile window so that warp_id is in vgpr
|
||||
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
|
||||
|
||||
// ToDo: use buffer_view interface to enable the tile loading to set -inf for oob elements
|
||||
sweep_tile_span(lse_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(lse_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
lse_acc.get_tile_distribution(), i_j_idx, partition_index);
|
||||
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
if(col >= num_splits)
|
||||
lse_acc(i_j_idx) = -numeric<LSEDataType>::infinity();
|
||||
});
|
||||
});
|
||||
|
||||
// calculate max of lse_acc[] across all splits for all rows in the tile, lse_max is
|
||||
// only used for stablizing the exp()
|
||||
auto lse_max = block_tile_reduce<LSEDataType>(
|
||||
|
||||
Reference in New Issue
Block a user