Add support for preparing lse_dram_window in hstu fwd kernel

This commit is contained in:
Qianfeng Zhang
2026-06-04 09:58:13 +00:00
parent 75a3b5aab0
commit cc184fc202
5 changed files with 92 additions and 51 deletions

View File

@@ -41,6 +41,8 @@ struct HstuAttentionFwdKernel
using BiasDataType =
ck_tile::remove_cvref_t<typename HstuAttentionPipeline::Problem::BiasDataType>;
using ODataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::Problem::ODataType>;
using CompDataType =
ck_tile::remove_cvref_t<typename HstuAttentionPipeline::Problem::CompDataType>;
static constexpr bool kIsCrossAttention = HstuAttentionPipeline::Problem::kIsCrossAttention;
static constexpr bool kUseGroup = HstuAttentionPipeline::Problem::kUseGroup;
@@ -701,6 +703,7 @@ struct HstuAttentionFwdKernel
long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0;
long_index_t batch_offset_o = 0;
long_index_t batch_offset_lse = 0;
if constexpr(kIsJagged)
{
@@ -717,6 +720,10 @@ struct HstuAttentionFwdKernel
batch_offset_bias = query_start * kargs.seq_stride_bias;
}
batch_offset_o = query_start * kargs.seq_stride_o;
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start * kargs.seq_stride_lse;
}
kargs.seqlen_q =
kargs.seq_q_offsets_ptr[i_batch + 1] - kargs.seq_q_offsets_ptr[i_batch];
@@ -747,6 +754,10 @@ struct HstuAttentionFwdKernel
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
}
int num_target = (kargs.num_targets_ptr == nullptr) ? 0 : kargs.num_targets_ptr[i_batch];
@@ -916,6 +927,35 @@ struct HstuAttentionFwdKernel
}
}();
auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto lse_dram_window_lengths =
make_tuple(number<HstuAttentionPipeline::kM0>{});
if constexpr(kStoreLSE)
{
CompDataType* lse_ptr =
reinterpret_cast<CompDataType*>(kargs.lse_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
const auto lse_dram = [&]() {
const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
lse_ptr,
make_tuple(seqlen_q_in_ctrl),
make_tuple(kargs.seq_stride_lse),
number<1>{},
number<1>{});
return pad_tensor_view(
lse_dram_naive, lse_dram_window_lengths, sequence<false>{});
}();
return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
}
else
{
return make_null_tile_window(lse_dram_window_lengths);
}
}();
auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
if constexpr(kHasDropout)
{
@@ -985,13 +1025,11 @@ struct HstuAttentionFwdKernel
}
else
{
auto null_tile_window = ck_tile::make_null_tile_window(ck_tile::make_tuple());
return HstuAttentionPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
null_tile_window,
lse_dram_window,
seqlen_k_start,
seqlen_k_end,
mask,
@@ -1040,12 +1078,11 @@ struct HstuAttentionFwdKernel
}
else
{
auto null_tile_window = ck_tile::make_null_tile_window(ck_tile::make_tuple());
return HstuAttentionPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
null_tile_window,
lse_dram_window,
seqlen_k_start,
seqlen_k_end,
mask,

View File

@@ -43,14 +43,14 @@ struct HstuAttentionNoGroupFwdParams
ck_tile::index_t seq_stride_v;
ck_tile::index_t seq_stride_bias;
ck_tile::index_t seq_stride_o;
ck_tile::index_t seq_stride_lse;
ck_tile::index_t seq_stride_lse; // not needed if lse layout is [nhead, seqlen_q]
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_lse; // not needed if lse layout is [seqlen_q, nhead]
// batched mode only parameters
ck_tile::index_t batch_stride_q;
@@ -106,14 +106,14 @@ struct HstuAttentionGroupFwdParams
ck_tile::index_t seq_stride_v;
ck_tile::index_t seq_stride_bias;
ck_tile::index_t seq_stride_o;
ck_tile::index_t seq_stride_lse;
ck_tile::index_t seq_stride_lse; // not needed if lse layout is [nhead, seqlen_q]
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_lse; // not needed if lse layout is [seqlen_q, nhead]
const void* num_targets_ptr;

View File

@@ -81,7 +81,6 @@ struct HstuAttentionFwdPipelineProblem
using CompDataType = remove_cvref_t<CompDataType_>;
using BiasDataType = remove_cvref_t<BiasDataType_>;
// to be compatible with ck_tile existing policy codes
using OaccDataType = GemmAccDataType;
using PDataType = QKVDataType;

View File

@@ -119,7 +119,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename LSEorLSEaccDramBlockWindowTmp,
typename QElementFunction,
typename BiasElementFunction,
typename LSEaccElementFunction,
@@ -134,8 +134,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0 tile
const LSEaccElementFunction& lse_acc_element_func,
LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile
const LSEaccElementFunction& lse_or_lse_acc_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
@@ -204,15 +204,15 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
clear_tile(o_acc);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
if constexpr(!is_null_tile_window_v<LSEaccDramBlockWindowTmp>)
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindowTmp>)
{
auto lse_acc =
auto lse_or_lse_acc =
make_static_distributed_tensor<CompDataType>(m.get_tile_distribution());
set_tile(lse_acc, -numeric<CompDataType>::infinity());
set_tile(lse_or_lse_acc, -numeric<CompDataType>::infinity());
store_tile(lse_acc_dram_block_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc));
store_tile(lse_or_lse_acc_dram_block_window,
tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc));
}
return o_acc;
@@ -598,19 +598,22 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
};
} while(seqlen_k_curr < seqlen_k_end);
if constexpr(!is_null_tile_window_v<LSEaccDramBlockWindowTmp>)
// if pipeline is called from splitkv_kernel, the window shall not be null;
// if pipeline is called from non-splitkv kernel, the window is null if kStoreLSE is false
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindowTmp>)
{
// store lse acc
auto lse_acc = make_static_distributed_tensor<CompDataType>(m.get_tile_distribution());
// store lse_or_lse_acc
auto lse_or_lse_acc =
make_static_distributed_tensor<CompDataType>(m.get_tile_distribution());
constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans();
sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
constexpr auto lse_or_lse_acc_spans = decltype(lse_or_lse_acc)::get_distributed_spans();
sweep_tile_span(lse_or_lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
lse_or_lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
});
store_tile(lse_acc_dram_block_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc));
store_tile(lse_or_lse_acc_dram_block_window,
tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc));
}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
@@ -638,14 +641,14 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename LSEorLSEaccDramBlockWindowTmp,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0 tile
LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask mask,
@@ -660,7 +663,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
v_dram_block_window_tmp,
bias_dram_block_window_tmp,
identity{},
lse_acc_dram_block_window_tmp,
lse_or_lse_acc_dram_block_window,
identity{},
identity{},
identity{},

View File

@@ -119,7 +119,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename LSEorLSEaccDramBlockWindowTmp,
typename QElementFunction,
typename BiasElementFunction,
typename LSEaccElementFunction,
@@ -134,8 +134,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0 tile
const LSEaccElementFunction& lse_acc_element_func,
LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile
const LSEaccElementFunction& lse_or_lse_acc_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
@@ -205,15 +205,15 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
clear_tile(o_acc);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
if constexpr(!is_null_tile_window_v<LSEaccDramBlockWindowTmp>)
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindowTmp>)
{
auto lse_acc =
auto lse_or_lse_acc =
make_static_distributed_tensor<CompDataType>(m.get_tile_distribution());
set_tile(lse_acc, -numeric<CompDataType>::infinity());
set_tile(lse_or_lse_acc, -numeric<CompDataType>::infinity());
store_tile(lse_acc_dram_block_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc));
store_tile(lse_or_lse_acc_dram_block_window,
tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc));
}
return o_acc;
@@ -604,19 +604,22 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
};
} while(seqlen_k_curr < seqlen_k_end);
if constexpr(!is_null_tile_window_v<LSEaccDramBlockWindowTmp>)
// if pipeline is called from splitkv_kernel, the window shall not be null;
// if pipeline is called from non-splitkv kernel, the window is null if kStoreLSE is false
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindowTmp>)
{
// store lse acc
auto lse_acc = make_static_distributed_tensor<CompDataType>(m.get_tile_distribution());
// store lse or lse_acc
auto lse_or_lse_acc =
make_static_distributed_tensor<CompDataType>(m.get_tile_distribution());
constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans();
sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
constexpr auto lse_or_lse_acc_spans = decltype(lse_or_lse_acc)::get_distributed_spans();
sweep_tile_span(lse_or_lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
lse_or_lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
});
store_tile(lse_acc_dram_block_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc));
store_tile(lse_or_lse_acc_dram_block_window,
tile_elementwise_in(lse_or_lse_acc_element_func, lse_or_lse_acc));
}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
@@ -644,15 +647,14 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename LSEorLSEaccDramBlockWindowTmp,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp&
lse_acc_dram_block_window_tmp, // M0 tile //
LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask mask,
@@ -667,7 +669,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
v_dram_block_window_tmp,
bias_dram_block_window_tmp,
identity{},
lse_acc_dram_block_window_tmp,
lse_or_lse_acc_dram_block_window,
identity{},
identity{},
identity{},