diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 5d819aef4a..73c9b3fc30 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -397,77 +397,23 @@ struct HstuAttentionFwdPipelineQRKSVS tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, sacc_tile); } - if constexpr(HstuMask::IsMasking) + if(!mask.IsFullTileInsideMask( + q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{})) { - if constexpr(HstuMask::kUseLocal) - { - if(!mask.IsFullTileInsideMask(q_origin.at(number<0>{}), - seqlen_k_curr, - number{}, - number{})) - { - constexpr auto s_spans = SaccBlockTileType::get_distributed_spans(); - sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + constexpr auto s_spans = SaccBlockTileType::get_distributed_spans(); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1)); - const auto row = - q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); - sacc_tile(i_j_idx) *= static_cast( - mask.IsTokenPairInsideMask(row, col)); - }); - }); - } - } - else // kUseCausal=true, kUseLocal=false - { - if(!mask.IsFullTileInsideMask(q_origin.at(number<0>{}), - seqlen_k_curr, - number{}, - number{})) - { - constexpr auto s_spans = SaccBlockTileType::get_distributed_spans(); - sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1)); - - const auto row = - q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - sacc_tile(i_j_idx) *= static_cast( - mask.IsTokenPairInsideMask(row, col)); - }); - }); - } - }; - } - else if constexpr(kPadSeqLenK) - { - if(i_loop >= num_loops - 1) - { - constexpr auto s_spans = SaccBlockTileType::get_distributed_spans(); - sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - sacc_tile.get_tile_distribution(), make_tuple(idx0, idx1)); - - const auto row = - q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - sacc_tile(i_j_idx) *= static_cast( - mask.IsTokenPairInsideMask(row, col)); - }); + sacc_tile(i_j_idx) *= + static_cast(mask.IsTokenPairInsideMask(row, col)); }); - } + }); } pcomp_tile = cast_tile(sacc_tile);