diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 242abfc73c..792d9ed44e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -575,7 +575,7 @@ struct HstuAttentionFwdKernel const auto q_dram = [&]() { const auto q_dram_naive = make_naive_tensor_view( q_ptr, - make_tuple(kargs.seqlen, kargs.hdim_qk), + make_tuple(mask.max_uih_len, kargs.hdim_qk), make_tuple(kargs.seq_stride_q, 1), number{}, number<1>{}); @@ -584,20 +584,20 @@ struct HstuAttentionFwdKernel return pad_tensor_view(q_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } else { return pad_tensor_view(q_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } }(); const auto k_dram = [&]() { const auto k_dram_naive = make_naive_tensor_view( k_ptr, - make_tuple(kargs.seqlen, kargs.hdim_qk), + make_tuple(mask.max_uih_len, kargs.hdim_qk), make_tuple(kargs.seq_stride_k, 1), number{}, number<1>{}); @@ -612,7 +612,7 @@ struct HstuAttentionFwdKernel { const auto v_dram_naive = make_naive_tensor_view( v_ptr, - make_tuple(kargs.seqlen, kargs.hdim_v), + make_tuple(mask.max_uih_len, kargs.hdim_v), make_tuple(kargs.seq_stride_v, 1), number{}, number<1>{}); @@ -732,7 +732,7 @@ struct HstuAttentionFwdKernel auto o_dram = [&]() { const auto o_dram_naive = make_naive_tensor_view( o_ptr, - make_tuple(kargs.seqlen, kargs.hdim_v), + make_tuple(mask.max_uih_len, kargs.hdim_v), make_tuple(kargs.seq_stride_o, 1), number{}, number<1>{}); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 6560df75db..4942b67606 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -388,20 +388,8 @@ struct HstuAttentionFwdPipelineQRKSVS return !mask.IsTokenPairInsideMask(row, col); }); } - else - { - if(q_origin.at(number<0>{}) + kM0 > mask.max_uih_len) - { - const auto k_origin = k_dram_block_window.get_window_origin(); - set_tile_if(s_acc, type_convert(0), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return !mask.IsTokenPairInsideMask(row, col); - }); - }; - }; - auto s = cast_tile(s_acc); // S{j} + auto s = cast_tile(s_acc); tile_elementwise_inout(f_silu, s); diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index 580493d366..99ce3e0447 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -68,7 +68,7 @@ struct HstuBlockMaskWithLocal }; } - CK_TILE_HOST_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col) + CK_TILE_HOST constexpr bool IsTokenPairInsideMask(int row, int col) { if(row >= max_uih_len || col >= max_uih_len) return false; @@ -87,6 +87,25 @@ struct HstuBlockMaskWithLocal return result; }; + + // masking codes in device don't have to compare row/col with max_uih_len, since + // buffer_load_xxx instruction is able to return zero for out-of-boundary access + CK_TILE_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col) + { + if(row < contextual_seqlen) + return true; + + bool result = false; + if constexpr(kUseCausal) + result = (row >= col) && (row - col <= max_attn_len); + else + result = abs(row - col) <= max_attn_len; + + if(min_full_attn_seqlen > 0) + result = result || (row >= max_uih_len - min_full_attn_seqlen); + + return result; + }; }; template @@ -123,7 +142,7 @@ struct HstuBlockMaskNoLocal }; } - CK_TILE_HOST_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col) + CK_TILE_HOST constexpr bool IsTokenPairInsideMask(int row, int col) { if(row >= max_uih_len || col >= max_uih_len) return false; @@ -140,6 +159,23 @@ struct HstuBlockMaskNoLocal return true; }; + + // masking codes in device don't have to compare row/col with max_uih_len, since + // buffer_load_xxx instruction is able to return zero for out-of-boundary access + CK_TILE_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col) + { + if(row < contextual_seqlen) + return true; + + if constexpr(IsMasking) + { + bool result = (row >= col); + + return result; + } + + return true; + }; }; template @@ -157,10 +193,12 @@ CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(int seqlen_, int max_attn_len_, int min_full_attn_seqlen_) { - auto max_uih_len_ = seqlen_; - - max_uih_len_ -= contextual_seqlen_ > 0 ? contextual_seqlen_ - 1 : 0; - max_uih_len_ -= num_target; + auto max_uih_len_ = [&]() { + if(contextual_seqlen_ > 0) + return seqlen_ - (contextual_seqlen_ - 1) - num_target; + else + return seqlen_ - num_target; + }(); return HstuBlockMaskType{ contextual_seqlen_, max_uih_len_, max_attn_len_, min_full_attn_seqlen_}; @@ -170,10 +208,12 @@ template CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_without_local(int seqlen_, int contextual_seqlen_, int num_target) { - auto max_uih_len_ = seqlen_; - - max_uih_len_ -= contextual_seqlen_ > 0 ? contextual_seqlen_ - 1 : 0; - max_uih_len_ -= num_target; + auto max_uih_len_ = [&]() { + if(contextual_seqlen_ > 0) + return seqlen_ - (contextual_seqlen_ - 1) - num_target; + else + return seqlen_ - num_target; + }(); return HstuBlockMaskType{contextual_seqlen_, max_uih_len_}; };