diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index cdc8fb6c3a..293b191eaa 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -362,15 +362,16 @@ struct HstuAttentionFwdPipelineQRKSVS } else if constexpr(kPadSeqLenK) { - set_tile_if( - sacc_tiles[i_k1], type_convert(0), [&](auto tile_idx) { - if(q_origin.at(number<0>{}) + kM0 <= mask.max_uih_len && - i_loop < num_loops - 1) - return false; - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); - return !mask.IsTokenPairInsideMask(row, col); - }); + if(i_loop >= num_loops - 1) + { + set_tile_if( + sacc_tiles[i_k1], type_convert(0), [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + return !mask.IsTokenPairInsideMask(row, col); + }); + } } pcomp_tiles[i_k1] = cast_tile(sacc_tiles[i_k1]);