diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index 6311a939d5..7457a6a335 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -149,19 +149,31 @@ struct HstuBlockMaskWithLocal // diagonal line are always considerred if constexpr(kUseCausal) { - if(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col)) - return true; - - if((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen)) - return true; + if(min_full_attn_seqlen > 0) + { + return (((row_id > col_id) || (row == col)) && + ((row_id - col_id <= max_attn_len) || + (row_id >= max_id - min_full_attn_seqlen))); + } + else + { + return (((row_id > col_id) || (row == col)) && + (row_id - col_id <= max_attn_len)); + }; } else { - if(((row_id != col_id && abs(row_id - col_id) <= max_attn_len)) || (row == col)) - return true; - - if((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen)) - return true; + if(min_full_attn_seqlen > 0) + { + return (((row_id != col_id) || (row == col)) && + ((abs(row_id - col_id) <= max_attn_len) || + (row_id >= max_id - min_full_attn_seqlen))); + } + else + { + return (((row_id != col_id) || (row == col)) && + (abs(row_id - col_id) <= max_attn_len)); + }; } return false; @@ -175,19 +187,31 @@ struct HstuBlockMaskWithLocal // diagonal line are always considerred if constexpr(kUseCausal) { - if(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col)) - return true; - - if((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen)) - return true; + if(min_full_attn_seqlen > 0) + { + return (((row_id > col_id) || (row == col)) && + ((row_id - col_id <= max_attn_len) || + (row_id >= max_id - min_full_attn_seqlen))); + } + else + { + return (((row_id > col_id) || (row == col)) && + (row_id - col_id <= max_attn_len)); + }; } else { - if(((row_id != col_id && abs(row_id - col_id) <= max_attn_len)) || (row == col)) - return true; - - if((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen)) - return true; + if(min_full_attn_seqlen > 0) + { + return (((row_id != col_id) || (row == col)) && + ((abs(row_id - col_id) <= max_attn_len) || + (row_id >= max_id - min_full_attn_seqlen))); + } + else + { + return (((row_id != col_id) || (row == col)) && + (abs(row_id - col_id) <= max_attn_len)); + }; } return false; @@ -213,20 +237,39 @@ struct HstuBlockMaskWithLocal // diagonal line are always considerred if constexpr(kUseCausal) { - bool res1 = - (((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col)); - bool res2 = - ((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen)); + if(min_full_attn_seqlen > 0) + { + bool res = (((row_id > col_id) || (row == col)) && + ((row_id - col_id <= max_attn_len) || + (row_id >= max_id - min_full_attn_seqlen))); - return static_cast(res1 || res2); + return static_cast(res); + } + else + { + bool res = + (((row_id > col_id) || (row == col)) && (row_id - col_id <= max_attn_len)); + + return static_cast(res); + }; } else { - bool res1 = (((row_id != col_id) && (abs(row_id - col_id) <= max_attn_len)) || - (row == col)); - bool res2 = ((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen)); + if(min_full_attn_seqlen > 0) + { + bool res = (((row_id != col_id) || (row == col)) && + ((abs(row_id - col_id) <= max_attn_len) || + (row_id >= max_id - min_full_attn_seqlen))); - return static_cast(res1 || res2); + return static_cast(res); + } + else + { + bool res = (((row_id != col_id) || (row == col)) && + (abs(row_id - col_id) <= max_attn_len)); + + return static_cast(res); + }; } } else @@ -240,20 +283,39 @@ struct HstuBlockMaskWithLocal // diagonal line are always considerred if constexpr(kUseCausal) { - bool res1 = - (((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col)); - bool res2 = - ((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen)); + if(min_full_attn_seqlen > 0) + { + bool res = (((row_id > col_id) || (row == col)) && + ((row_id - col_id <= max_attn_len) || + (row_id >= max_id - min_full_attn_seqlen))); - return static_cast(res1 || res2); + return static_cast(res); + } + else + { + bool res = + (((row_id > col_id) || (row == col)) && (row_id - col_id <= max_attn_len)); + + return static_cast(res); + }; } else { - bool res1 = (((row_id != col_id) && (abs(row_id - col_id) <= max_attn_len)) || - (row == col)); - bool res2 = ((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen)); + if(min_full_attn_seqlen > 0) + { + bool res = (((row_id != col_id) || (row == col)) && + ((abs(row_id - col_id) <= max_attn_len) || + (row_id >= max_id - min_full_attn_seqlen))); - return static_cast(res1 || res2); + return static_cast(res); + } + else + { + bool res = (((row_id != col_id) || (row == col)) && + (abs(row_id - col_id) <= max_attn_len)); + + return static_cast(res); + }; } } };