From 1304e807fb33870b1be23edcf780e0a4aabf8557 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 5 Jun 2026 10:33:32 +0000 Subject: [PATCH] Update and fix for leeked changes and make the scripts be able to test/benchmark kStoreLSE cases --- .../example_hstu_attention_fwd.cpp | 20 ++++----- ...ntion_batched_forward_splitkv_dispatch.hpp | 2 +- ...u_attention_fwd_splitkv_combine_kernel.hpp | 42 +++++++++++++++++++ ...tu_attention_with_softmax_fwd_pipeline.hpp | 12 +++--- ...h_softmax_fwd_splitkv_combine_pipeline.hpp | 19 ++++++++- .../reference_hstu_attention_fwd.hpp | 9 +++- .../bench_jagged_causal_mattn0_full0.sh | 6 ++- .../bench_jagged_causal_mattn256_full256.sh | 6 ++- .../test_cross_attention_with_sparsity.sh | 6 ++- .../test_group_hstu_softmax_attention.sh | 4 +- .../scripts/test_hstu_cross_attention.sh | 6 ++- .../scripts/test_hstu_softmax_attention.sh | 4 +- ...st_hstu_softmax_attention_hdim96_hdim64.sh | 4 +- .../test_jagged_causal_mattn0_full0.sh | 6 ++- .../test_jagged_causal_mattn256_full0.sh | 5 ++- .../test_jagged_causal_mattn256_full256.sh | 6 ++- 16 files changed, 120 insertions(+), 37 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp index b346f423b9..901993f805 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention_fwd.cpp @@ -415,6 +415,8 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged) int batches_for_alloc = is_jagged ? 1 : num_batch; + bool store_lse = (is_training & use_softmax); + ck_tile::HostTensor q_host( std::array{batches_for_alloc, phy_seqlen_q, num_head, hdim_qk}); ck_tile::HostTensor k_host( @@ -424,9 +426,8 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged) ck_tile::HostTensor o_host_ref( std::array{batches_for_alloc, phy_seqlen_q, num_head, hdim_v}); ck_tile::HostTensor lse_host_ref( - (is_training && use_softmax) - ? std::array{batches_for_alloc, phy_seqlen_q, num_head} - : std::array{1, 1, 1}); + store_lse ? std::array{batches_for_alloc, phy_seqlen_q, num_head} + : std::array{1, 1, 1}); ck_tile::HostTensor mask_host( save_mask @@ -593,7 +594,6 @@ bool run_no_group_hstu(const ck_tile::ArgParser& arg_parser, bool is_jagged) using GemmAccDataType = typename HstuAttentionFwdTypeConfig::GemmAccDataType; BOOL_SWITCH_2(is_jagged, kIsJagged, use_causal, kUseCausal, [&] { - bool store_lse = (is_training && use_softmax); ck_tile::reference_no_group_hstu_attention_fwd lse_host( std::array{batches_for_alloc, phy_seqlen_q, num_head}); @@ -913,6 +913,8 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group) int batches_for_alloc = 1; + bool store_lse = (is_training & use_softmax); + ck_tile::HostTensor q_host( std::array{batches_for_alloc, phy_seqlen_q, num_head, hdim_qk}); ck_tile::HostTensor k_host( @@ -922,9 +924,8 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group) ck_tile::HostTensor o_host_ref( std::array{batches_for_alloc, phy_seqlen_q, num_head, hdim_v}); ck_tile::HostTensor lse_host_ref( - (is_training && use_softmax) - ? std::array{batches_for_alloc, phy_seqlen_q, num_head} - : std::array{1, 1, 1}); + store_lse ? std::array{batches_for_alloc, phy_seqlen_q, num_head} + : std::array{1, 1, 1}); ck_tile::HostTensor mask_host(save_mask ? std::array{num_batch, @@ -1054,7 +1055,6 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group) using GemmAccDataType = typename HstuAttentionFwdTypeConfig::GemmAccDataType; BOOL_SWITCH(use_causal, kUseCausal, [&] { - bool store_lse = (is_training && use_softmax); ck_tile::reference_group_hstu_attention_fwd< InOutDataType, GemmAccDataType, @@ -1103,7 +1103,7 @@ bool run_group_hstu(const ck_tile::ArgParser& arg_parser, int num_group) res = ck_tile::check_err( o_host, o_host_ref, std::string("hstu_attention output error"), rtol, atol); - if(is_training && use_softmax) + if(store_lse) { ck_tile::HostTensor lse_host( std::array{batches_for_alloc, phy_seqlen_q, num_head}); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp index 2a0e55d5ad..120517d2e0 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_splitkv_dispatch.hpp @@ -79,7 +79,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch ODataType, false /* kIsJagged */, kUseSoftmax, - false, // kStoreLSE + kStoreLSE, HstuAttentionCombineTileSetting, kMaxSplits>; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_kernel.hpp index fdcb1ddbc9..e74c84f75e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_splitkv_combine_kernel.hpp @@ -281,6 +281,7 @@ struct HstuAttentionFwdSplitKVCombineKernel long_index_t batch_offset_o_acc = 0; long_index_t batch_offset_o = 0; long_index_t batch_offset_lse_acc = 0; + long_index_t batch_offset_lse = 0; if constexpr(kIsJagged) { @@ -299,6 +300,11 @@ struct HstuAttentionFwdSplitKVCombineKernel { batch_offset_lse_acc = query_start * kargs.num_head * kargs.num_splits; } + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start * kargs.seq_stride_lse; + } + kargs.seqlen_q = kargs.seq_q_offsets_ptr[i_batch + 1] - kargs.seq_q_offsets_ptr[i_batch]; } @@ -317,6 +323,10 @@ struct HstuAttentionFwdSplitKVCombineKernel batch_offset_lse_acc = static_cast(i_batch) * kargs.seqlen_q * kargs.num_head * kargs.num_splits; } + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } } index_t i_m0; @@ -394,8 +404,40 @@ struct HstuAttentionFwdSplitKVCombineKernel number{}), {i_m0, 0}); + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = + make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = + make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(kargs.seq_stride_lse), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + return HstuAttentionPipeline{}(lse_acc_dram_window, o_acc_dram_window, + lse_dram_window, kargs.hdim_v, kargs.num_splits, smem_ptr); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index 62433bd678..5df4d905ce 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -119,7 +119,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename LSEorLSEaccDramBlockWindowTmp, + typename LSEorLSEaccDramBlockWindow, typename QElementFunction, typename BiasElementFunction, typename LSEaccElementFunction, @@ -134,7 +134,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, - LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile + LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile const LSEaccElementFunction& lse_or_lse_acc_element_func, const SAccElementFunction& s_acc_element_func, const PComputeElementFunction& p_compute_element_func, @@ -204,7 +204,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS clear_tile(o_acc); o_acc = tile_elementwise_in(o_acc_element_func, o_acc); - if constexpr(!is_null_tile_window_v) + if constexpr(!is_null_tile_window_v) { auto lse_or_lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); @@ -600,7 +600,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS // if pipeline is called from splitkv_kernel, the window shall not be null; // if pipeline is called from non-splitkv kernel, the window is null if kStoreLSE is false - if constexpr(!is_null_tile_window_v) + if constexpr(!is_null_tile_window_v) { // store lse_or_lse_acc auto lse_or_lse_acc = @@ -641,14 +641,14 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename LSEorLSEaccDramBlockWindowTmp, + typename LSEorLSEaccDramBlockWindow, typename HstuMask> CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEorLSEaccDramBlockWindowTmp& lse_or_lse_acc_dram_block_window, // M0 tile + LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile index_t seqlen_k_start, index_t seqlen_k_end, HstuMask mask, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp index a22f91b084..fb5d14ef2f 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_splitkv_combine_pipeline.hpp @@ -63,13 +63,17 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline template + typename LSEaccElementFunction, + typename LSEElementFunction> CK_TILE_DEVICE auto operator()(const LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // kM tile const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // kM*kOHeaddim tile + LSEDramBlockWindow& lse_dram_block_window, // kM tile const OAccElementFunction& o_acc_element_func, const LSEaccElementFunction& lse_acc_element_func, + const LSEElementFunction& lse_element_func, index_t o_acc_split_stride, index_t num_splits, void* smem_ptr) const @@ -166,6 +170,12 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline }); } + // in case kStoreLSE is false, LSEDramBlockWindow is null + if constexpr(!is_null_tile_window_v) + { + store_tile(lse_dram_block_window, tile_elementwise_in(lse_element_func, lse_logsum)); + } + // calculate scale value (used for adjusting the o_acc) for all splits for all rows in // the tile lse_acc_type& lse_scale = lse_acc; @@ -235,16 +245,21 @@ struct HstuAttentionWithSoftmaxFwdSplitKVCombinePipeline return o_acc; } - template + template CK_TILE_DEVICE auto operator()(const LSEaccDramBlockWindow& lse_acc_dram_block_window_tmp, const OAccDramBlockWindowTmp& o_acc_dram_block_window_tmp, // kM*kOHeaddim tile + LSEDramBlockWindow& lse_dram_block_window, // kM tile ck_tile::index_t o_acc_split_stride, index_t num_splits, void* smem_ptr) const { return operator()(lse_acc_dram_block_window_tmp, o_acc_dram_block_window_tmp, + lse_dram_block_window, + identity{}, identity{}, identity{}, o_acc_split_stride, diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention_fwd.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention_fwd.hpp index 2eb885ecf0..cba68e43c6 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention_fwd.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention_fwd.hpp @@ -294,7 +294,11 @@ struct reference_no_group_hstu_attention_fwd if(store_lse) { - lse_batch_seq_nhead(i_batch, sq, i_head) = std::log(l) + m; + if constexpr(kIsJagged) + lse_batch_seq_nhead(0, seq_q_offsets[i_batch] + sq, i_head) = + std::log(l) + m; + else + lse_batch_seq_nhead(i_batch, sq, i_head) = std::log(l) + m; } }; @@ -583,7 +587,8 @@ struct reference_group_hstu_attention_fwd if(store_lse) { - lse_batch_seq_nhead(i_batch, sq, i_head) = std::log(l) + m; + lse_batch_seq_nhead(0, seq_q_offsets[i_batch] + sq, i_head) = + std::log(l) + m; } }; diff --git a/example/ck_tile/18_hstu_attention/scripts/bench_jagged_causal_mattn0_full0.sh b/example/ck_tile/18_hstu_attention/scripts/bench_jagged_causal_mattn0_full0.sh index 9cd371aaa6..359f1d384d 100644 --- a/example/ck_tile/18_hstu_attention/scripts/bench_jagged_causal_mattn0_full0.sh +++ b/example/ck_tile/18_hstu_attention/scripts/bench_jagged_causal_mattn0_full0.sh @@ -9,10 +9,12 @@ if [ $# -ge 1 ]; then USE_SOFTMAX=$1 fi +Training=${TEST_HSTU_FWD_TRAINING:-0} + if [ $USE_SOFTMAX -eq 1 ]; then - EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1" + EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training" else - EXE="$BUILD/bin/tile_example_hstu_attention" + EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training" fi dtype="bf16" diff --git a/example/ck_tile/18_hstu_attention/scripts/bench_jagged_causal_mattn256_full256.sh b/example/ck_tile/18_hstu_attention/scripts/bench_jagged_causal_mattn256_full256.sh index e992d61849..17b7e0c9ea 100644 --- a/example/ck_tile/18_hstu_attention/scripts/bench_jagged_causal_mattn256_full256.sh +++ b/example/ck_tile/18_hstu_attention/scripts/bench_jagged_causal_mattn256_full256.sh @@ -9,10 +9,12 @@ if [ $# -ge 1 ]; then USE_SOFTMAX=$1 fi +Training=${TEST_HSTU_FWD_TRAINING:-0} + if [ $USE_SOFTMAX -eq 1 ]; then - EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1" + EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training" else - EXE="$BUILD/bin/tile_example_hstu_attention" + EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training" fi dtype="bf16" diff --git a/example/ck_tile/18_hstu_attention/scripts/test_cross_attention_with_sparsity.sh b/example/ck_tile/18_hstu_attention/scripts/test_cross_attention_with_sparsity.sh index ddbd325896..ddde9afd8f 100755 --- a/example/ck_tile/18_hstu_attention/scripts/test_cross_attention_with_sparsity.sh +++ b/example/ck_tile/18_hstu_attention/scripts/test_cross_attention_with_sparsity.sh @@ -7,10 +7,12 @@ if [ $# -ge 1 ]; then USE_SOFTMAX=$1 fi +Training=${TEST_HSTU_FWD_TRAINING:-0} + if [ $USE_SOFTMAX -eq 1 ]; then - EXE="$BUILD/bin/tile_example_hstu_attention -v=1 -softmax=1" + EXE="$BUILD/bin/tile_example_hstu_attention -v=1 -softmax=1 -training=$Training" else - EXE="$BUILD/bin/tile_example_hstu_attention -v=1" + EXE="$BUILD/bin/tile_example_hstu_attention -v=1 -training=$Training" fi set -x diff --git a/example/ck_tile/18_hstu_attention/scripts/test_group_hstu_softmax_attention.sh b/example/ck_tile/18_hstu_attention/scripts/test_group_hstu_softmax_attention.sh index 84546d178e..2b106073a3 100644 --- a/example/ck_tile/18_hstu_attention/scripts/test_group_hstu_softmax_attention.sh +++ b/example/ck_tile/18_hstu_attention/scripts/test_group_hstu_softmax_attention.sh @@ -1,7 +1,9 @@ #!/bin/bash BUILD=build -EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1" +Training=${TEST_HSTU_FWD_TRAINING:-0} + +EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training" ndist=0 diff --git a/example/ck_tile/18_hstu_attention/scripts/test_hstu_cross_attention.sh b/example/ck_tile/18_hstu_attention/scripts/test_hstu_cross_attention.sh index 647b72c074..9f17cc2d36 100644 --- a/example/ck_tile/18_hstu_attention/scripts/test_hstu_cross_attention.sh +++ b/example/ck_tile/18_hstu_attention/scripts/test_hstu_cross_attention.sh @@ -7,10 +7,12 @@ if [ $# -ge 1 ]; then USE_SOFTMAX=$1 fi +Training=${TEST_HSTU_FWD_TRAINING:-0} + if [ $USE_SOFTMAX -eq 1 ]; then - EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1" + EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training" else - EXE="$BUILD/bin/tile_example_hstu_attention" + EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training" fi for T in "fp16" "bf16"; do diff --git a/example/ck_tile/18_hstu_attention/scripts/test_hstu_softmax_attention.sh b/example/ck_tile/18_hstu_attention/scripts/test_hstu_softmax_attention.sh index bee42b45ea..72e58240b1 100644 --- a/example/ck_tile/18_hstu_attention/scripts/test_hstu_softmax_attention.sh +++ b/example/ck_tile/18_hstu_attention/scripts/test_hstu_softmax_attention.sh @@ -1,7 +1,9 @@ #!/bin/bash BUILD=build -EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1" +Training=${TEST_HSTU_FWD_TRAINING:-0} + +EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training" attn_scale=0 if [ $# -ge 1 ]; then diff --git a/example/ck_tile/18_hstu_attention/scripts/test_hstu_softmax_attention_hdim96_hdim64.sh b/example/ck_tile/18_hstu_attention/scripts/test_hstu_softmax_attention_hdim96_hdim64.sh index 7aacfb2e12..77b66cf229 100644 --- a/example/ck_tile/18_hstu_attention/scripts/test_hstu_softmax_attention_hdim96_hdim64.sh +++ b/example/ck_tile/18_hstu_attention/scripts/test_hstu_softmax_attention_hdim96_hdim64.sh @@ -2,7 +2,9 @@ ## This script can be used the verifying the using of WarpGemm 32x32x16 which is used by hdim64 + softmax BUILD=build -EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1" +Training=${TEST_HSTU_FWD_TRAINING:-0} + +EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training" attn_scale=1.0 ndist=1 diff --git a/example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn0_full0.sh b/example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn0_full0.sh index 85f636a828..eeab30bd1f 100644 --- a/example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn0_full0.sh +++ b/example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn0_full0.sh @@ -10,10 +10,12 @@ if [ $# -ge 1 ]; then USE_SOFTMAX=$1 fi +Training=${TEST_HSTU_FWD_TRAINING:-0} + if [ $USE_SOFTMAX -eq 1 ]; then - EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1" + EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training" else - EXE="$BUILD/bin/tile_example_hstu_attention" + EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training" fi dtype="bf16" diff --git a/example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn256_full0.sh b/example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn256_full0.sh index 339f7ca2cf..cb06e4f4f8 100644 --- a/example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn256_full0.sh +++ b/example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn256_full0.sh @@ -8,7 +8,10 @@ fi set +x BUILD=build -EXE=$BUILD/bin/tile_example_hstu_attention + +Training=${TEST_HSTU_FWD_TRAINING:-0} + +EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training" dtype="bf16" diff --git a/example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn256_full256.sh b/example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn256_full256.sh index 2667ef8c96..5eb49b21d1 100644 --- a/example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn256_full256.sh +++ b/example/ck_tile/18_hstu_attention/scripts/test_jagged_causal_mattn256_full256.sh @@ -10,10 +10,12 @@ if [ $# -ge 1 ]; then USE_SOFTMAX=$1 fi +Training=${TEST_HSTU_FWD_TRAINING:-0} + if [ $USE_SOFTMAX -eq 1 ]; then - EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1" + EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1 -training=$Training" else - EXE="$BUILD/bin/tile_example_hstu_attention" + EXE="$BUILD/bin/tile_example_hstu_attention -training=$Training" fi dtype="bf16"