From 36a0f2020c666f158ace12adaa2da1ed83ded1b2 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Thu, 29 May 2025 01:02:20 +0000 Subject: [PATCH] not-critical updates in example and block_masking codes --- .../example_hstu_attention.cpp | 8 +- .../18_hstu_attention/hstu_block_masking.hpp | 128 ++++++++++++------ 2 files changed, 94 insertions(+), 42 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index bcf44508e3..48fcba02ff 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -391,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser) HstuAttentionFwdParams params; + float scale_s = 1.0f / std::sqrt(hdim_qk); + if(is_jagged) { params.is_jagged = true; @@ -405,7 +407,7 @@ bool run(const ck_tile::ArgParser& arg_parser) params.hdim_qk = hdim_qk; params.hdim_v = hdim_v; params.num_head = num_head; - params.scale_s = 1.0f / std::sqrt(params.hdim_qk); + params.scale_s = scale_s; params.seq_stride_q = q_host.get_strides()[1]; params.seq_stride_k = k_host.get_strides()[1]; params.seq_stride_v = v_host.get_strides()[1]; @@ -438,7 +440,7 @@ bool run(const ck_tile::ArgParser& arg_parser) params.hdim_qk = hdim_qk; params.hdim_v = hdim_v; params.num_head = num_head; - params.scale_s = 1.0f / std::sqrt(params.hdim_qk); + params.scale_s = scale_s; params.seq_stride_q = q_host.get_strides()[1]; params.seq_stride_k = k_host.get_strides()[1]; params.seq_stride_v = v_host.get_strides()[1]; @@ -507,7 +509,7 @@ bool run(const ck_tile::ArgParser& arg_parser) o_host_ref, mask_host, num_batch, - 1.0f / std::sqrt(params.hdim_qk), + scale_s, max_seqlen, seq_offsets, num_targets, diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index 22e3349105..16b080227d 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -132,37 +132,66 @@ struct HstuBlockMaskWithLocal CK_TILE_HOST constexpr bool IsTokenPairInsideMask(int row, int col) { - // row_id/col_id is clamped from physical row/col according to contextual_seqlen and - // max_uih_len - int row_id = contextual_seqlen > 0 ? max(row - contextual_seqlen + 1, 0) : row; - int col_id = contextual_seqlen > 0 ? max(col - contextual_seqlen + 1, 0) : col; - - row_id = min(row_id, max_id); - col_id = min(col_id, max_id); - - if(contextual_seqlen > 0 && row_id == 0 && col_id < max_id) - return true; - - // use row_id/col_id to check the dist between two q/k token pair, token pairs on the - // diagonal line are always considerred - if constexpr(kUseCausal) + if(contextual_seqlen > 0) { - if(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col)) + // row_id/col_id is clamped from physical row/col according to contextual_seqlen and + // max_uih_len + int row_id = max(row - contextual_seqlen + 1, 0); + int col_id = max(col - contextual_seqlen + 1, 0); + + row_id = min(row_id, max_id); + col_id = min(col_id, max_id); + + if(row_id == 0 && col_id < max_id) return true; - if((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen)) - return true; + // use row_id/col_id to check the dist between two q/k token pair, token pairs on the + // diagonal line are always considerred + if constexpr(kUseCausal) + { + if(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col)) + return true; + + if((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen)) + return true; + } + else + { + if(((row_id != col_id && abs(row_id - col_id) <= max_attn_len)) || (row == col)) + return true; + + if((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen)) + return true; + } + + return false; } else { - if(((row_id != col_id && abs(row_id - col_id) <= max_attn_len)) || (row == col)) - return true; + int row_id = min(row, max_id); + int col_id = min(col, max_id); - if((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen)) - return true; + // use row_id/col_id to check the dist between two q/k token pair, token pairs on the + // diagonal line are always considerred + if constexpr(kUseCausal) + { + if(((row_id > col_id) && (row_id - col_id <= max_attn_len)) || (row == col)) + return true; + + if((min_full_attn_seqlen > 0) && (row_id >= max_id - min_full_attn_seqlen)) + return true; + } + else + { + if(((row_id != col_id && abs(row_id - col_id) <= max_attn_len)) || (row == col)) + return true; + + if((min_full_attn_seqlen > 0) && (row >= max_id - min_full_attn_seqlen)) + return true; + } + + return false; } - - return false; }; CK_TILE_DEVICE constexpr int IsTokenPairInsideMask(int row, int col) @@ -303,26 +332,47 @@ struct HstuBlockMaskNoLocal CK_TILE_HOST constexpr bool IsTokenPairInsideMask(int row, int col) { - // row_id/col_id is clamped from physical row/col according to contextual_seqlen and - // max_uih_len - int row_id = contextual_seqlen > 0 ? max(row - contextual_seqlen + 1, 0) : row; - int col_id = contextual_seqlen > 0 ? max(col - contextual_seqlen + 1, 0) : col; - - row_id = min(row_id, max_id); - col_id = min(col_id, max_id); - - if(contextual_seqlen > 0 && row_id == 0 && col_id < max_id) - return true; - - // use row_id/col_id to check the dist between two q/k token pair, token pairs on the - // diagonal line are always considerred - if constexpr(IsMasking) + if(contextual_seqlen > 0) { - return (row_id > col_id) || (row == col); + // row_id/col_id is clamped from physical row/col according to contextual_seqlen and + // max_uih_len + int row_id = max(row - contextual_seqlen + 1, 0); + int col_id = max(col - contextual_seqlen + 1, 0); + + row_id = min(row_id, max_id); + col_id = min(col_id, max_id); + + if(row_id == 0 && col_id < max_id) + return true; + + // use row_id/col_id to check the dist between two q/k token pair, token pairs on the + // diagonal line are always considerred + if constexpr(IsMasking) + { + return (row_id > col_id) || (row == col); + } + else + { + return (row_id != col_id) || (row == col); + }; } else { - return (row_id != col_id) || (row == col); + // row_id/col_id is clamped from physical row/col according to contextual_seqlen and + // max_uih_len + int row_id = min(row, max_id); + int col_id = min(col, max_id); + + // use row_id/col_id to check the dist between two q/k token pair, token pairs on the + // diagonal line are always considerred + if constexpr(IsMasking) + { + return (row_id > col_id) || (row == col); + } + else + { + return (row_id != col_id) || (row == col); + }; }; };