diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index 7457a6a335..e13a0ed534 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -130,193 +130,99 @@ struct HstuBlockMaskWithLocal }; } - CK_TILE_HOST constexpr bool IsTokenPairInsideMask(int row, int col) + CK_TILE_HOST bool IsTokenPairInsideMask(int row, int col) { + int row_id; + int col_id; + if(contextual_seqlen > 0) { // row_id/col_id is clamped from physical row/col according to contextual_seqlen and // max_uih_len - int row_id = max(row - contextual_seqlen + 1, 0); - int col_id = max(col - contextual_seqlen + 1, 0); + row_id = max(row - contextual_seqlen + 1, 0); + col_id = max(col - contextual_seqlen + 1, 0); row_id = min(row_id, max_id); col_id = min(col_id, max_id); if(row_id == 0 && col_id < max_id) return true; - - // use row_id/col_id to check the dist between two q/k token pair, token pairs on the - // diagonal line are always considerred - if constexpr(kUseCausal) - { - if(min_full_attn_seqlen > 0) - { - return (((row_id > col_id) || (row == col)) && - ((row_id - col_id <= max_attn_len) || - (row_id >= max_id - min_full_attn_seqlen))); - } - else - { - return (((row_id > col_id) || (row == col)) && - (row_id - col_id <= max_attn_len)); - }; - } - else - { - if(min_full_attn_seqlen > 0) - { - return (((row_id != col_id) || (row == col)) && - ((abs(row_id - col_id) <= max_attn_len) || - (row_id >= max_id - min_full_attn_seqlen))); - } - else - { - return (((row_id != col_id) || (row == col)) && - (abs(row_id - col_id) <= max_attn_len)); - }; - } - - return false; } else { - int row_id = min(row, max_id); - int col_id = min(col, max_id); + // row_id/col_id is clamped from physical row/col according to contextual_seqlen and + // max_uih_len + row_id = min(row, max_id); + col_id = min(col, max_id); + }; - // use row_id/col_id to check the dist between two q/k token pair, token pairs on the - // diagonal line are always considerred - if constexpr(kUseCausal) - { - if(min_full_attn_seqlen > 0) - { - return (((row_id > col_id) || (row == col)) && - ((row_id - col_id <= max_attn_len) || - (row_id >= max_id - min_full_attn_seqlen))); - } - else - { - return (((row_id > col_id) || (row == col)) && - (row_id - col_id <= max_attn_len)); - }; - } - else - { - if(min_full_attn_seqlen > 0) - { - return (((row_id != col_id) || (row == col)) && - ((abs(row_id - col_id) <= max_attn_len) || - (row_id >= max_id - min_full_attn_seqlen))); - } - else - { - return (((row_id != col_id) || (row == col)) && - (abs(row_id - col_id) <= max_attn_len)); - }; - } + // use row_id/col_id to check the dist between two q/k token pair, token pairs on the + // diagonal line are always considerred + if constexpr(kUseCausal) + { + bool in_min_full_scope = + (min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false; - return false; + return (((row_id > col_id) || (row == col)) && + ((row_id - col_id <= max_attn_len) || in_min_full_scope)); + } + else + { + bool in_min_full_scope = + (min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false; + + return (((row_id != col_id) || (row == col)) && + ((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope)); } }; - CK_TILE_DEVICE constexpr int IsTokenPairInsideMask(int row, int col) + CK_TILE_DEVICE int IsTokenPairInsideMask(int row, int col) { + int row_id; + int col_id; + if(contextual_seqlen > 0) { // row_id/col_id is clamped from physical row/col according to contextual_seqlen and // max_uih_len - int row_id = max(row - contextual_seqlen + 1, 0); - int col_id = max(col - contextual_seqlen + 1, 0); + row_id = max(row - contextual_seqlen + 1, 0); + col_id = max(col - contextual_seqlen + 1, 0); row_id = min(row_id, max_id); col_id = min(col_id, max_id); if(row_id == 0 && col_id < max_id) return 1; - - // use row_id/col_id to check the dist between two q/k token pair, token pairs on the - // diagonal line are always considerred - if constexpr(kUseCausal) - { - if(min_full_attn_seqlen > 0) - { - bool res = (((row_id > col_id) || (row == col)) && - ((row_id - col_id <= max_attn_len) || - (row_id >= max_id - min_full_attn_seqlen))); - - return static_cast(res); - } - else - { - bool res = - (((row_id > col_id) || (row == col)) && (row_id - col_id <= max_attn_len)); - - return static_cast(res); - }; - } - else - { - if(min_full_attn_seqlen > 0) - { - bool res = (((row_id != col_id) || (row == col)) && - ((abs(row_id - col_id) <= max_attn_len) || - (row_id >= max_id - min_full_attn_seqlen))); - - return static_cast(res); - } - else - { - bool res = (((row_id != col_id) || (row == col)) && - (abs(row_id - col_id) <= max_attn_len)); - - return static_cast(res); - }; - } } else { // row_id/col_id is clamped from physical row/col according to contextual_seqlen and // max_uih_len - int row_id = min(row, max_id); - int col_id = min(col, max_id); + row_id = min(row, max_id); + col_id = min(col, max_id); + }; - // use row_id/col_id to check the dist between two q/k token pair, token pairs on the - // diagonal line are always considerred - if constexpr(kUseCausal) - { - if(min_full_attn_seqlen > 0) - { - bool res = (((row_id > col_id) || (row == col)) && - ((row_id - col_id <= max_attn_len) || - (row_id >= max_id - min_full_attn_seqlen))); + // use row_id/col_id to check the dist between two q/k token pair, token pairs on the + // diagonal line are always considerred + if constexpr(kUseCausal) + { + bool in_min_full_scope = + (min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false; - return static_cast(res); - } - else - { - bool res = - (((row_id > col_id) || (row == col)) && (row_id - col_id <= max_attn_len)); + bool res = (((row_id > col_id) || (row == col)) && + ((row_id - col_id <= max_attn_len) || in_min_full_scope)); - return static_cast(res); - }; - } - else - { - if(min_full_attn_seqlen > 0) - { - bool res = (((row_id != col_id) || (row == col)) && - ((abs(row_id - col_id) <= max_attn_len) || - (row_id >= max_id - min_full_attn_seqlen))); + return static_cast(res); + } + else + { + bool in_min_full_scope = + (min_full_attn_seqlen > 0) ? (row_id >= max_id - min_full_attn_seqlen) : false; - return static_cast(res); - } - else - { - bool res = (((row_id != col_id) || (row == col)) && - (abs(row_id - col_id) <= max_attn_len)); + bool res = (((row_id != col_id) || (row == col)) && + ((abs(row_id - col_id) <= max_attn_len) || in_min_full_scope)); - return static_cast(res); - }; - } + return static_cast(res); } }; @@ -397,104 +303,84 @@ struct HstuBlockMaskNoLocal }; } - CK_TILE_HOST constexpr bool IsTokenPairInsideMask(int row, int col) + CK_TILE_HOST bool IsTokenPairInsideMask(int row, int col) { + int row_id; + int col_id; + if(contextual_seqlen > 0) { // row_id/col_id is clamped from physical row/col according to contextual_seqlen and // max_uih_len - int row_id = max(row - contextual_seqlen + 1, 0); - int col_id = max(col - contextual_seqlen + 1, 0); + row_id = max(row - contextual_seqlen + 1, 0); + col_id = max(col - contextual_seqlen + 1, 0); row_id = min(row_id, max_id); col_id = min(col_id, max_id); if(row_id == 0 && col_id < max_id) return true; - - // use row_id/col_id to check the dist between two q/k token pair, token pairs on the - // diagonal line are always considerred - if constexpr(IsMasking) - { - return (row_id > col_id) || (row == col); - } - else - { - return (row_id != col_id) || (row == col); - }; } else { // row_id/col_id is clamped from physical row/col according to contextual_seqlen and // max_uih_len - int row_id = min(row, max_id); - int col_id = min(col, max_id); + row_id = min(row, max_id); + col_id = min(col, max_id); + }; - // use row_id/col_id to check the dist between two q/k token pair, token pairs on the - // diagonal line are always considerred - if constexpr(IsMasking) - { - return (row_id > col_id) || (row == col); - } - else - { - return (row_id != col_id) || (row == col); - }; + // use row_id/col_id to check the dist between two q/k token pair, token pairs on the + // diagonal line are always considerred + if constexpr(IsMasking) + { + return (row_id > col_id) || (row == col); + } + else + { + return (row_id != col_id) || (row == col); }; }; - CK_TILE_DEVICE constexpr int IsTokenPairInsideMask(int row, int col) + CK_TILE_DEVICE int IsTokenPairInsideMask(int row, int col) { + int row_id; + int col_id; + if(contextual_seqlen > 0) { // row_id/col_id is clamped from physical row/col according to contextual_seqlen and // max_uih_len - int row_id = max(row - contextual_seqlen + 1, 0); - int col_id = max(col - contextual_seqlen + 1, 0); + row_id = max(row - contextual_seqlen + 1, 0); + col_id = max(col - contextual_seqlen + 1, 0); row_id = min(row_id, max_id); col_id = min(col_id, max_id); if(row_id == 0 && col_id < max_id) return 1; - - // use row_id/col_id to check the dist between two q/k token pair, token pairs on the - // diagonal line are always considerred - if constexpr(IsMasking) - { - bool res = ((row_id > col_id) || (row == col)); - - return static_cast(res); - } - else - { - bool res = ((row_id != col_id) || (row == col)); - - return static_cast(res); - }; } else { // row_id/col_id is clamped from physical row/col according to contextual_seqlen and // max_uih_len - int row_id = min(row, max_id); - int col_id = min(col, max_id); + row_id = min(row, max_id); + col_id = min(col, max_id); + }; - // use row_id/col_id to check the dist between two q/k token pair, token pairs on the - // diagonal line are always considerred - if constexpr(IsMasking) - { - bool res = ((row_id > col_id) || (row == col)); + // use row_id/col_id to check the dist between two q/k token pair, token pairs on the + // diagonal line are always considerred + if constexpr(IsMasking) + { + bool res = ((row_id > col_id) || (row == col)); - return static_cast(res); - } - else - { - bool res = ((row_id != col_id) || (row == col)); - - return static_cast(res); - }; + return static_cast(res); } + else + { + bool res = ((row_id != col_id) || (row == col)); + + return static_cast(res); + }; }; // if the whole tile inside the masking area, no need for pixel-by-pixel checking