diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 7896fac39e..dab1620fb7 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -442,8 +442,8 @@ struct HstuAttentionFwdPipelineQRKSVS const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); constexpr auto i_j_idx = make_tuple(idx0, idx1); - pcomp_tile(i_j_idx) *= - static_cast(mask.IsTokenPairInsideMask(row, col)); + if(!mask.IsTokenPairInsideMask(row, col)) + pcomp_tile(i_j_idx) = static_cast(0.0f); }); }); } diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index f1bb1308dc..24c8de1d50 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -210,7 +210,7 @@ struct HstuBlockMaskWithLocal } }; - CK_TILE_DEVICE int IsTokenPairInsideMask(int row, int col) + CK_TILE_DEVICE bool IsTokenPairInsideMask(int row, int col) { int row_id; int col_id; @@ -226,7 +226,7 @@ struct HstuBlockMaskWithLocal col_id = min(col_id, max_id); if(row_id == 0 && col_id < max_id) - return 1; + return true; } else { @@ -245,7 +245,7 @@ struct HstuBlockMaskWithLocal bool res = (((row_id > col_id) || (row == col)) && ((row_id - col_id <= max_attn_len) || in_min_full_scope)); - return static_cast(res); + return res; } else { @@ -254,7 +254,7 @@ struct HstuBlockMaskWithLocal bool res = (((row_id != col_id) || (row == col)) && ((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope)); - return static_cast(res); + return res; } }; @@ -383,7 +383,7 @@ struct HstuBlockMaskNoLocal }; }; - CK_TILE_DEVICE int IsTokenPairInsideMask(int row, int col) + CK_TILE_DEVICE bool IsTokenPairInsideMask(int row, int col) { int row_id; int col_id; @@ -399,7 +399,7 @@ struct HstuBlockMaskNoLocal col_id = min(col_id, max_id); if(row_id == 0 && col_id < max_id) - return 1; + return true; } else { @@ -415,13 +415,13 @@ struct HstuBlockMaskNoLocal { bool res = ((row_id > col_id) || (row == col)); - return static_cast(res); + return res; } else { bool res = ((row_id != col_id) || (row == col)); - return static_cast(res); + return res; }; };