From fff13b6c76482dbd9ae35d5835a511d16a15733f Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Tue, 15 Apr 2025 07:22:09 +0000 Subject: [PATCH] Update to partially reduce the register spilling --- .../hstu_attention_fwd_kernel.hpp | 10 +++---- .../hstu_attention_fwd_pipeline.hpp | 8 +++--- .../18_hstu_attention/hstu_block_masking.hpp | 26 +++++++++++++------ .../reference_hstu_attention.hpp | 4 +-- 4 files changed, 28 insertions(+), 20 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 881ec297ab..56734e88cd 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -543,13 +543,13 @@ struct HstuAttentionFwdKernel HstuMask mask = [&]() { if constexpr(kHasMask) - return HstuMask{kargs.window_size, + return HstuMask{kargs.seqlen, kargs.contextual_seqlen, - kargs.min_full_attn_seqlen, - kargs.seqlen, - num_target}; + num_target, + kargs.window_size, + kargs.min_full_attn_seqlen}; else - return HstuMask{0, kargs.contextual_seqlen, 0, kargs.seqlen, num_target}; + return HstuMask{kargs.seqlen, kargs.contextual_seqlen, num_target}; }(); // for simplicity, batch stride we just modify the pointer diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index f97bde7126..341fbf8f8c 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -248,12 +248,10 @@ struct HstuAttentionFwdPipelineQRKSVS auto s_acc = SaccBlockTileType{}; // reduction function for softmax - const auto f_silu = [](CompDataType x) { + const auto f_silu = [](CompDataType& x) { auto one = ck_tile::type_convert(1.0f); - auto sigmod_val = one / (one + exp(-x)); - - return sigmod_val * x; + return x = x / (one + exp(-x)); }; using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); @@ -405,7 +403,7 @@ struct HstuAttentionFwdPipelineQRKSVS auto s = cast_tile(s_acc); // S{j} - s = tile_elementwise_in(f_silu, s); + tile_elementwise_inout(f_silu, s); if constexpr(kHasDropout) { diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index d6307128e0..ef01876fcd 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -12,22 +12,32 @@ struct HstuBlockMasking { static constexpr bool IsMasking = (kUseCausal || kUseLocal); - int max_attn_len; int contextual_seqlen; - int min_full_attn_seqlen; int max_uih_len; - CK_TILE_HOST_DEVICE HstuBlockMasking(int max_attn_len_, + int max_attn_len; + int min_full_attn_seqlen; + + CK_TILE_HOST_DEVICE HstuBlockMasking(int seqlen_, int contextual_seqlen_, - int min_full_attn_seqlen_, - int seqlen_, - int num_target) + int num_target, + int max_attn_len_, + int min_full_attn_seqlen_) { + max_uih_len = seqlen_; + contextual_seqlen = contextual_seqlen_; + max_attn_len = max_attn_len_; - contextual_seqlen = contextual_seqlen_; min_full_attn_seqlen = min_full_attn_seqlen_; - max_uih_len = seqlen_; + max_uih_len -= contextual_seqlen > 0 ? contextual_seqlen - 1 : 0; + max_uih_len -= num_target; + }; + + CK_TILE_HOST_DEVICE HstuBlockMasking(int seqlen_, int contextual_seqlen_, int num_target) + { + max_uih_len = seqlen_; + contextual_seqlen = contextual_seqlen_; max_uih_len -= contextual_seqlen > 0 ? contextual_seqlen - 1 : 0; max_uih_len -= num_target; diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index f87b7c2ff8..e8ea6fa1f0 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -106,9 +106,9 @@ struct reference_hstu_attention HstuMask mask = [&]() { if constexpr(kHasMask) return HstuMask{ - max_attn_len, contextual_seqlen, min_full_attn_seqlen, seqlen, num_target}; + seqlen, contextual_seqlen, num_target, max_attn_len, min_full_attn_seqlen}; else - return HstuMask{0, contextual_seqlen, 0, seqlen, num_target}; + return HstuMask{seqlen, contextual_seqlen, num_target}; }(); // for all rows in the batch