From 1af27022efa6a6f57a504cdfdc9b5ef699c8063e Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sun, 27 Apr 2025 09:31:38 +0000 Subject: [PATCH] Add IsFullTileInsideMask() to avoid pixel-by-pixel checking when kUseCausl=true but kUseLocal=false --- .../hstu_attention_fwd_pipeline.hpp | 50 +++++++++++++++---- .../18_hstu_attention/hstu_block_masking.hpp | 33 ++++++++++++ 2 files changed, 72 insertions(+), 11 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 5949442163..1d6504475a 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -358,20 +358,48 @@ struct HstuAttentionFwdPipelineQRKSVS if constexpr(HstuMask::IsMasking) { - constexpr auto s_spans = SaccBlockTileType::get_distributed_spans(); - sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - sacc_tiles[i_k1].get_tile_distribution(), make_tuple(idx0, idx1)); + if constexpr(HstuMask::kUseLocal) + { + constexpr auto s_spans = SaccBlockTileType::get_distributed_spans(); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + sacc_tiles[i_k1].get_tile_distribution(), + make_tuple(idx0, idx1)); - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); - sacc_tiles[i_k1](i_j_idx) *= - static_cast(mask.IsTokenPairInsideMask(row, col)); + sacc_tiles[i_k1](i_j_idx) *= static_cast( + mask.IsTokenPairInsideMask(row, col)); + }); }); - }); + } + else // kUseCausal=true, kUseLocal=false + { + if(!mask.IsFullTileInsideMask( + q_origin.at(number<0>{}), seqlen_k_curr, number{})) + { + constexpr auto s_spans = SaccBlockTileType::get_distributed_spans(); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + sacc_tiles[i_k1].get_tile_distribution(), + make_tuple(idx0, idx1)); + + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + sacc_tiles[i_k1](i_j_idx) *= static_cast( + mask.IsTokenPairInsideMask(row, col)); + }); + }); + } + }; } else if constexpr(kPadSeqLenK) { diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index 8d940adfe4..3eb2199502 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -115,6 +115,18 @@ struct HstuBlockMaskWithLocal return static_cast(result); } }; + + // if the whole tile inside the masking area, no need for pixel-by-pixel checking + template + CK_TILE_DEVICE constexpr bool + IsFullTileInsideMask(index_t i_tile_top, index_t i_tile_left, number) const + { + // when local masking used, we assume all tiles need pixel-by-pixel checking + std::ignore = i_tile_top; + std::ignore = i_tile_left; + + return false; + } }; template @@ -185,6 +197,27 @@ struct HstuBlockMaskNoLocal return 1; }; + + // if the whole tile inside the masking area, no need for pixel-by-pixel checking + template + CK_TILE_DEVICE constexpr bool + IsFullTileInsideMask(index_t i_tile_top, index_t i_tile_left, number) const + { + if constexpr(kUseCausal) + { + index_t i_tile_right = i_tile_left + TileWidth; + + if(i_tile_right > i_tile_top) + return false; + + return true; + } + else + { + // need further check kPadSeqLenK in the masking context + return true; + }; + } }; template