Update and fix for leeked changes and make the scripts be able to test/benchmark kStoreLSE cases

This commit is contained in:
Qianfeng Zhang
2026-06-05 10:33:32 +00:00
parent 798fd3cd8b
commit 1304e807fb
16 changed files with 120 additions and 37 deletions

View File

@@ -415,6 +415,8 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged)
int batches_for_alloc = is_jagged ? 1 : num_batch;
bool store_lse = (is_training & use_softmax);
ck_tile::HostTensor<InOutDataType> q_host(
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_qk});
ck_tile::HostTensor<InOutDataType> k_host(
@@ -424,9 +426,8 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged)
ck_tile::HostTensor<InOutDataType> o_host_ref(
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_v});
ck_tile::HostTensor<CompDataType> lse_host_ref(
(is_training && use_softmax)
? std::array<ck_tile::index_t, 3>{batches_for_alloc, phy_seqlen_q, num_head}
: std::array<ck_tile::index_t, 3>{1, 1, 1});
store_lse ? std::array<ck_tile::index_t, 3>{batches_for_alloc, phy_seqlen_q, num_head}
: std::array<ck_tile::index_t, 3>{1, 1, 1});
ck_tile::HostTensor<int8_t> mask_host(
save_mask
@@ -593,7 +594,6 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged)
using GemmAccDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType;
BOOL_SWITCH_2(is_jagged, kIsJagged, use_causal, kUseCausal, [&] {
bool store_lse = (is_training && use_softmax);
ck_tile::reference_no_group_hstu_attention_fwd<InOutDataType,
GemmAccDataType,
CompDataType,
@@ -640,7 +640,7 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged)
res = ck_tile::check_err(
o_host, o_host_ref, std::string("hstu_attention output error"), rtol, atol);
if(is_training && use_softmax)
if(store_lse)
{
ck_tile::HostTensor<CompDataType> lse_host(
std::array<ck_tile::index_t, 3>{batches_for_alloc, phy_seqlen_q, num_head});
@@ -913,6 +913,8 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
int batches_for_alloc = 1;
bool store_lse = (is_training & use_softmax);
ck_tile::HostTensor<InOutDataType> q_host(
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_qk});
ck_tile::HostTensor<InOutDataType> k_host(
@@ -922,9 +924,8 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
ck_tile::HostTensor<InOutDataType> o_host_ref(
std::array<ck_tile::index_t, 4>{batches_for_alloc, phy_seqlen_q, num_head, hdim_v});
ck_tile::HostTensor<CompDataType> lse_host_ref(
(is_training && use_softmax)
? std::array<ck_tile::index_t, 3>{batches_for_alloc, phy_seqlen_q, num_head}
: std::array<ck_tile::index_t, 3>{1, 1, 1});
store_lse ? std::array<ck_tile::index_t, 3>{batches_for_alloc, phy_seqlen_q, num_head}
: std::array<ck_tile::index_t, 3>{1, 1, 1});
ck_tile::HostTensor<int8_t> mask_host(save_mask
? std::array<ck_tile::index_t, 4>{num_batch,
@@ -1054,7 +1055,6 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
using GemmAccDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType;
BOOL_SWITCH(use_causal, kUseCausal, [&] {
bool store_lse = (is_training && use_softmax);
ck_tile::reference_group_hstu_attention_fwd<
InOutDataType,
GemmAccDataType,
@@ -1103,7 +1103,7 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group)
res = ck_tile::check_err(
o_host, o_host_ref, std::string("hstu_attention output error"), rtol, atol);
if(is_training && use_softmax)
if(store_lse)
{
ck_tile::HostTensor<CompDataType> lse_host(
std::array<ck_tile::index_t, 3>{batches_for_alloc, phy_seqlen_q, num_head});

View File

@@ -79,7 +79,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
ODataType,
false /* kIsJagged */,
kUseSoftmax,
false, // kStoreLSE
kStoreLSE,
HstuAttentionCombineTileSetting,
kMaxSplits>;

View File

@@ -281,6 +281,7 @@ struct HstuAttentionFwdSplitKVCombineKernel
long_index_t batch_offset_o_acc = 0;
long_index_t batch_offset_o = 0;
long_index_t batch_offset_lse_acc = 0;
long_index_t batch_offset_lse = 0;
if constexpr(kIsJagged)
{
@@ -299,6 +300,11 @@ struct HstuAttentionFwdSplitKVCombineKernel
{
batch_offset_lse_acc = query_start * kargs.num_head * kargs.num_splits;
}
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start * kargs.seq_stride_lse;
}
kargs.seqlen_q =
kargs.seq_q_offsets_ptr[i_batch + 1] - kargs.seq_q_offsets_ptr[i_batch];
}
@@ -317,6 +323,10 @@ struct HstuAttentionFwdSplitKVCombineKernel
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.seqlen_q *
kargs.num_head * kargs.num_splits;
}
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
}
index_t i_m0;
@@ -394,8 +404,40 @@ struct HstuAttentionFwdSplitKVCombineKernel
number<HstuAttentionPipeline::kMaxSplits>{}),
{i_m0, 0});
auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto lse_dram_window_lengths =
make_tuple(number<HstuAttentionPipeline::kM>{});
if constexpr(kStoreLSE)
{
LSEDataType* lse_ptr =
reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse +
batch_offset_lse;
const auto lse_dram = [&]() {
const auto lse_dram_naive =
make_naive_tensor_view<address_space_enum::global>(
lse_ptr,
make_tuple(kargs.seqlen_q),
make_tuple(kargs.seq_stride_lse),
number<1>{},
number<1>{});
return pad_tensor_view(
lse_dram_naive, lse_dram_window_lengths, sequence<false>{});
}();
return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
}
else
{
return make_null_tile_window(lse_dram_window_lengths);
}
}();
return HstuAttentionPipeline{}(lse_acc_dram_window,
o_acc_dram_window,
lse_dram_window,
kargs.hdim_v,
kargs.num_splits,
smem_ptr);

View File

@@ -119,7 +119,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEorLSEaccDramBlockWindowTmp,
typename LSEorLSEaccDramBlockWindow,
typename QElementFunction,
typename BiasElementFunction,
typename LSEaccElementFunction,
@@ -134,7 +134,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
const LSEaccElementFunction& lse_or_lse_acc_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
@@ -204,7 +204,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
clear_tile(o_acc);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindowTmp>)
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindow>)
{
auto lse_or_lse_acc =
make_static_distributed_tensor<CompDataType>(m.get_tile_distribution());
@@ -600,7 +600,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
// if pipeline is called from splitkv_kernel, the window shall not be null;
// if pipeline is called from non-splitkv kernel, the window is null if kStoreLSE is false
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindowTmp>)
if constexpr(!is_null_tile_window_v<LSEorLSEaccDramBlockWindow>)
{
// store lse_or_lse_acc
auto lse_or_lse_acc =
@@ -641,14 +641,14 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEorLSEaccDramBlockWindowTmp,
typename LSEorLSEaccDramBlockWindow,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask mask,

View File

@@ -63,13 +63,17 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline
template <typename LSEaccDramBlockWindowTmp,
typename OAccDramBlockWindowTmp,
typename LSEDramBlockWindow,
typename OAccElementFunction,
typename LSEaccElementFunction>
typename LSEaccElementFunction,
typename LSEElementFunction>
CK_TILE_DEVICE auto
operator()(const LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // kM tile
const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // kM*kOHeaddim tile
LSEDramBlockWindow& lse_dram_block_window, // kM tile
const OAccElementFunction& o_acc_element_func,
const LSEaccElementFunction& lse_acc_element_func,
const LSEElementFunction& lse_element_func,
index_t o_acc_split_stride,
index_t num_splits,
void* smem_ptr) const
@@ -166,6 +170,12 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline
});
}
// in case kStoreLSE is false, LSEDramBlockWindow is null
if constexpr(!is_null_tile_window_v<LSEDramBlockWindow>)
{
store_tile(lse_dram_block_window, tile_elementwise_in(lse_element_func, lse_logsum));
}
// calculate scale value (used for adjusting the o_acc) for all splits for all rows in
// the tile
lse_acc_type& lse_scale = lse_acc;
@@ -235,16 +245,21 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline
return o_acc;
}
template <typename LSEaccDramBlockWindow, typename OAccDramBlockWindowTmp>
template <typename LSEaccDramBlockWindow,
typename OAccDramBlockWindowTmp,
typename LSEDramBlockWindow>
CK_TILE_DEVICE auto
operator()(const LSEaccDramBlockWindow& lse_acc_dram_block_window_tmp,
const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // kM*kOHeaddim tile
LSEDramBlockWindow& lse_dram_block_window, // kM tile
ck_tile::index_t o_acc_split_stride,
index_t num_splits,
void* smem_ptr) const
{
return operator()(lse_acc_dram_block_window_tmp,
o_acc_dram_block_window_tmp,
lse_dram_block_window,
identity{},
identity{},
identity{},
o_acc_split_stride,

View File

@@ -294,7 +294,11 @@ struct reference_no_group_hstu_attention_fwd
if(store_lse)
{
lse_batch_seq_nhead(i_batch, sq, i_head) = std::log(l) + m;
if constexpr(kIsJagged)
lse_batch_seq_nhead(0, seq_q_offsets[i_batch] + sq, i_head) =
std::log(l) + m;
else
lse_batch_seq_nhead(i_batch, sq, i_head) = std::log(l) + m;
}
};
@@ -583,7 +587,8 @@ struct reference_group_hstu_attention_fwd
if(store_lse)
{
lse_batch_seq_nhead(i_batch, sq, i_head) = std::log(l) + m;
lse_batch_seq_nhead(0, seq_q_offsets[i_batch] + sq, i_head) =
std::log(l) + m;
}
};

View File

@@ -9,10 +9,12 @@ if [ $# -ge 1 ]; then
USE_SOFTMAX=$1
fi
Training=${TEST_HSTU_FWD_TRAINING:-0}
if [ $USE_SOFTMAX -eq 1 ]; then
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training"
else
EXE="$BUILD/bin/tile_example_hstu_attention"
EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training"
fi
dtype="bf16"

View File

@@ -9,10 +9,12 @@ if [ $# -ge 1 ]; then
USE_SOFTMAX=$1
fi
Training=${TEST_HSTU_FWD_TRAINING:-0}
if [ $USE_SOFTMAX -eq 1 ]; then
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training"
else
EXE="$BUILD/bin/tile_example_hstu_attention"
EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training"
fi
dtype="bf16"

View File

@@ -7,10 +7,12 @@ if [ $# -ge 1 ]; then
USE_SOFTMAX=$1
fi
Training=${TEST_HSTU_FWD_TRAINING:-0}
if [ $USE_SOFTMAX -eq 1 ]; then
EXE="$BUILD/bin/tile_example_hstu_attention -v=1 -softmax=1"
EXE="$BUILD/bin/tile_example_hstu_attention -v=1 -softmax=1 -training=$Training"
else
EXE="$BUILD/bin/tile_example_hstu_attention -v=1"
EXE="$BUILD/bin/tile_example_hstu_attention -v=1 -training=$Training"
fi
set -x

View File

@@ -1,7 +1,9 @@
#!/bin/bash
BUILD=build
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
Training=${TEST_HSTU_FWD_TRAINING:-0}
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training"
ndist=0

View File

@@ -7,10 +7,12 @@ if [ $# -ge 1 ]; then
USE_SOFTMAX=$1
fi
Training=${TEST_HSTU_FWD_TRAINING:-0}
if [ $USE_SOFTMAX -eq 1 ]; then
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training"
else
EXE="$BUILD/bin/tile_example_hstu_attention"
EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training"
fi
for T in "fp16" "bf16"; do

View File

@@ -1,7 +1,9 @@
#!/bin/bash
BUILD=build
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
Training=${TEST_HSTU_FWD_TRAINING:-0}
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training"
attn_scale=0
if [ $# -ge 1 ]; then

View File

@@ -2,7 +2,9 @@
## This script can be used the verifying the using of WarpGemm 32x32x16 which is used by hdim64 + softmax
BUILD=build
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
Training=${TEST_HSTU_FWD_TRAINING:-0}
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training"
attn_scale=1.0
ndist=1

View File

@@ -10,10 +10,12 @@ if [ $# -ge 1 ]; then
USE_SOFTMAX=$1
fi
Training=${TEST_HSTU_FWD_TRAINING:-0}
if [ $USE_SOFTMAX -eq 1 ]; then
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training"
else
EXE="$BUILD/bin/tile_example_hstu_attention"
EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training"
fi
dtype="bf16"

View File

@@ -8,7 +8,10 @@ fi
set +x
BUILD=build
EXE=$BUILD/bin/tile_example_hstu_attention
Training=${TEST_HSTU_FWD_TRAINING:-0}
EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training"
dtype="bf16"

View File

@@ -10,10 +10,12 @@ if [ $# -ge 1 ]; then
USE_SOFTMAX=$1
fi
Training=${TEST_HSTU_FWD_TRAINING:-0}
if [ $USE_SOFTMAX -eq 1 ]; then
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training"
else
EXE="$BUILD/bin/tile_example_hstu_attention"
EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training"
fi
dtype="bf16"