diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index dcaeca2665..7717f64fcd 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -80,6 +80,7 @@ struct HstuAttentionFwdKernel ck_tile::index_t nhead_stride_o; const int32_t* num_targets_ptr; + ck_tile::index_t contextual_seqlen; }; struct HstuAttentionFwdCommonBiasKargs @@ -97,7 +98,6 @@ struct HstuAttentionFwdKernel struct HstuAttentionFwdMaskKargs { ck_tile::index_t window_size; - ck_tile::index_t contextual_seqlen; ck_tile::index_t min_full_attn_seqlen; }; @@ -184,8 +184,8 @@ struct HstuAttentionFwdKernel ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_o, const void* num_targets_ptr, - ck_tile::index_t window_size, ck_tile::index_t contextual_seqlen, + ck_tile::index_t window_size, ck_tile::index_t min_full_attn_seqlen, float p_drop, const std::pair& drop_seed_offset) @@ -207,10 +207,11 @@ struct HstuAttentionFwdKernel nhead_stride_k, nhead_stride_v, nhead_stride_o, - reinterpret_cast(num_targets_ptr)}, // args for common karg - {}, // placeholder for bias - {}, // placeholder for mask - {}, // placeholder for dropout + reinterpret_cast(num_targets_ptr), + contextual_seqlen}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for dropout batch_stride_q, batch_stride_k, batch_stride_v, @@ -226,7 +227,6 @@ struct HstuAttentionFwdKernel if constexpr(kHasMask) { kargs.window_size = window_size; - kargs.contextual_seqlen = contextual_seqlen; kargs.min_full_attn_seqlen = min_full_attn_seqlen; } if constexpr(kHasDropout) @@ -267,8 +267,8 @@ struct HstuAttentionFwdKernel ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_o, const void* num_targets_ptr, - ck_tile::index_t window_size, ck_tile::index_t contextual_seqlen, + ck_tile::index_t window_size, ck_tile::index_t min_full_attn_seqlen, float p_drop, uint64_t philox_seed, @@ -300,8 +300,8 @@ struct HstuAttentionFwdKernel batch_stride_bias, batch_stride_o, num_targets_ptr, - window_size, contextual_seqlen, + window_size, min_full_attn_seqlen, p_drop, std::make_pair(philox_seed, philox_offset)); @@ -330,8 +330,8 @@ struct HstuAttentionFwdKernel ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_o, const void* num_targets_ptr, - ck_tile::index_t window_size, ck_tile::index_t contextual_seqlen, + ck_tile::index_t window_size, ck_tile::index_t min_full_attn_seqlen, float p_drop, const std::pair& drop_seed_offset) @@ -353,10 +353,11 @@ struct HstuAttentionFwdKernel nhead_stride_k, nhead_stride_v, nhead_stride_o, - reinterpret_cast(num_targets_ptr)}, // args for common karg - {}, // placeholder for bias - {}, // placeholder for mask - {}, // placeholder for dropout + reinterpret_cast(num_targets_ptr), + contextual_seqlen}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for dropout reinterpret_cast(seq_offsets_ptr)}; if constexpr(kHasBias) @@ -368,7 +369,6 @@ struct HstuAttentionFwdKernel if constexpr(kHasMask) { kargs.window_size = window_size; - kargs.contextual_seqlen = contextual_seqlen; kargs.min_full_attn_seqlen = min_full_attn_seqlen; } if constexpr(kHasDropout) @@ -404,8 +404,8 @@ struct HstuAttentionFwdKernel ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_o, const void* num_targets_ptr, - ck_tile::index_t window_size, ck_tile::index_t contextual_seqlen, + ck_tile::index_t window_size, ck_tile::index_t min_full_attn_seqlen, float p_drop, uint64_t philox_seed, @@ -432,8 +432,8 @@ struct HstuAttentionFwdKernel nhead_stride_bias, nhead_stride_o, num_targets_ptr, - window_size, contextual_seqlen, + window_size, min_full_attn_seqlen, p_drop, std::make_pair(philox_seed, philox_offset)); @@ -539,30 +539,17 @@ struct HstuAttentionFwdKernel batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; } - int max_uih_len = kargs.seqlen; - - if constexpr(kHasMask) - { - if(kargs.contextual_seqlen > 0) - max_uih_len -= kargs.contextual_seqlen - 1; - }; - - if(kargs.num_targets_ptr != nullptr) - { - if constexpr(kIsJagged) - max_uih_len -= kargs.num_targets_ptr[i_batch]; - else - max_uih_len -= kargs.num_targets_ptr[0]; - }; + int num_target = (kargs.num_targets_ptr == nullptr) ? 0 : kargs.num_targets_ptr[i_batch]; HstuMask mask = [&]() { if constexpr(kHasMask) return HstuMask{kargs.window_size, kargs.contextual_seqlen, kargs.min_full_attn_seqlen, - max_uih_len}; + kargs.seqlen, + num_target}; else - return HstuMask{0, 0, 0, 0}; + return HstuMask{0, kargs.contextual_seqlen, 0, kargs.seqlen, num_target}; }(); // for simplicity, batch stride we just modify the pointer diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index f43f45b4e1..87ebbb941b 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -373,21 +373,21 @@ struct HstuAttentionFwdPipelineQRKSVS if constexpr(HstuMask::IsMasking) { const auto k_origin = k_dram_block_window.get_window_origin(); - set_tile_if(s_acc, -numeric::infinity(), [&](auto tile_idx) { + set_tile_if(s_acc, type_convert(0), [&](auto tile_idx) { const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsTokenPairInsideMask(row, col); + return !mask.IsTokenPairInsideMask(row, col); }); } else if constexpr(kPadSeqLenK) { const auto k_origin = k_dram_block_window.get_window_origin(); - set_tile_if(s_acc, -numeric::infinity(), [&](auto tile_idx) { - if(i_loop < num_loops) + set_tile_if(s_acc, type_convert(0), [&](auto tile_idx) { + if(i_loop < num_loops - 1) return false; const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsTokenPairInsideMask(row, col); + return !mask.IsTokenPairInsideMask(row, col); }); }; diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index 005b9b29b1..85bffdcf9e 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -20,12 +20,17 @@ struct HstuBlockMasking CK_TILE_HOST_DEVICE HstuBlockMasking(int max_attn_len_, int contextual_seqlen_, int min_full_attn_seqlen_, - int max_uih_len_) + int seqlen_, + int num_target) { max_attn_len = max_attn_len_; contextual_seqlen = contextual_seqlen_; min_full_attn_seqlen = min_full_attn_seqlen_; - max_uih_len = max_uih_len_; + + max_uih_len = seqlen_; + + max_uih_len -= contextual_seqlen - 1; + max_uih_len -= num_target; }; // to get the loop length along X axis, return index:[start, end), end-start=length @@ -82,27 +87,34 @@ struct HstuBlockMasking CK_TILE_HOST_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col) { + if(row >= max_uih_len || col >= max_uih_len) + return false; + if(row < contextual_seqlen) return true; - bool result = false; - if constexpr(kUseLocal) + if constexpr(IsMasking) { - if constexpr(kUseCausal) - result = (row >= col) && (row - col <= max_attn_len); + bool result = false; + if constexpr(kUseLocal) + { + if constexpr(kUseCausal) + result = (row >= col) && (row - col <= max_attn_len); + else + result = std::abs(row - col) <= max_attn_len; + + if(min_full_attn_seqlen > 0) + result = result || (row >= max_uih_len - min_full_attn_seqlen); + } else - result = std::abs(row - col) <= max_attn_len; - - if(min_full_attn_seqlen > 0) - result = result || (row >= max_uih_len - min_full_attn_seqlen); - } - else - { - if constexpr(kUseCausal) + { result = (row >= col); - }; + }; - return result; + return result; + } + + return true; }; }; diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index 1651e546d5..96e7922eff 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -83,6 +83,9 @@ struct reference_hstu_attention assert(hdim_qk == k_batch_seq_nhead_hdim.get_lengths()[3]); assert(hdim_v == o_batch_seq_nhead_hdim.get_lengths()[3]); + // check num_tagets + assert(num_tagets.empty() || num_targets.size() == num_batch); + auto silu = [](CompDataType x) { auto one = ck_tile::type_convert(1.0f); @@ -91,33 +94,22 @@ struct reference_hstu_attention return sigmod_val * x; }; - bool has_target = !num_targets.empty(); - - if(has_target) - assert(num_targets.size() == num_batch); - auto f = [&](auto i_batch, auto i_head) { int seqlen = kIsJagged ? (seq_offsets[i_batch + 1] - seq_offsets[i_batch]) : q_batch_seq_nhead_hdim.get_lengths()[1]; - int max_uih_len = seqlen; - - if(contextual_seqlen > 0) - max_uih_len -= contextual_seqlen - 1; - - if(has_target) - max_uih_len -= num_targets[i_batch]; + int num_target = num_targets.empty() ? 0 : num_targets[i_batch]; HstuBlockMasking mask{ - max_attn_len, contextual_seqlen, min_full_attn_seqlen, max_uih_len}; + max_attn_len, contextual_seqlen, min_full_attn_seqlen, seqlen, num_target}; // for all rows in the batch - for(int sq = 0; sq < max_uih_len; sq++) + for(int sq = 0; sq < seqlen; sq++) { std::vector locals; // for all cols in the batch - for(int sk = 0; sk < max_uih_len; sk++) + for(int sk = 0; sk < seqlen; sk++) { if(mask.IsTokenPairInsideMask(sq, sk)) { @@ -153,14 +145,14 @@ struct reference_hstu_attention // SiLu element-wise for(CompDataType& elem : locals) - elem = silu(elem) / ck_tile::type_convert(seqlen); + elem = silu(elem); // second Gemm for(int k = 0; k < hdim_v; k++) { GemmAccDataType dot_prod = 0.f; - for(int sk = 0; sk < max_uih_len; sk++) + for(int sk = 0; sk < seqlen; sk++) { if constexpr(kIsJagged) { diff --git a/example/ck_tile/18_hstu_attention/test_hstu_attention.sh b/example/ck_tile/18_hstu_attention/test_hstu_attention.sh new file mode 100644 index 0000000000..c2b8076533 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/test_hstu_attention.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +bin/tile_example_hstu_attention -v=1 -prec=bf16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=750,730,733 -causal=1 -local_len=5 -context_len=6 -minfull_len=6