From 3cd1b13e46b4a714e701f825df85b89524b75e4e Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Tue, 15 Apr 2025 14:40:55 +0000 Subject: [PATCH] Split HstuBlockMasking into HstuBlockMaskWithLocal and HstuBlockMaskNoLocal to save vgprs for non-local situations --- .../example_hstu_attention.cpp | 4 +- ...stu_attention_batched_forward_dispatch.hpp | 2 +- .../hstu_attention_fwd_kernel.hpp | 61 +++--- ...hstu_attention_jagged_forward_dispatch.hpp | 2 +- .../18_hstu_attention/hstu_block_masking.hpp | 185 +++++++++++------- .../reference_hstu_attention.hpp | 19 +- 6 files changed, 163 insertions(+), 110 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index b33a7f8dee..3d3da4f604 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -162,8 +162,8 @@ static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdPara template auto get_elimit() { - double rtol = 1e-2; - double atol = 1e-2; + double rtol = 2e-2; + double atol = 2e-2; return ck_tile::make_tuple(rtol, atol); } diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp index 01523bacea..25e15eb458 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp @@ -32,7 +32,7 @@ template ::Type; - using HstuMask = ck_tile::HstuBlockMasking; + using HstuMask = typename ck_tile::HstuBlockMasking::Type; template using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem< diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 56734e88cd..242abfc73c 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -11,6 +11,8 @@ #include #include +#include "hstu_block_masking.hpp" + // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] @@ -44,7 +46,7 @@ struct HstuAttentionFwdKernel static constexpr auto kHasBias = HstuAttentionPipeline::kHasBias; static constexpr bool kHasDropout = HstuAttentionPipeline::kHasDropout; using HstuMask = ck_tile::remove_cvref_t; - static constexpr bool kHasMask = HstuMask::IsMasking; + static constexpr bool kHasLocalMask = HstuMask::kUseLocal; template // to avoid duplicated base class problem, introduce an template // arg @@ -124,15 +126,16 @@ struct HstuAttentionFwdKernel uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); }; - struct HstuAttentionFwdBatchModeKargs - : HstuAttentionFwdCommonKargs, - std::conditional_t>, - std::conditional_t>, - std::conditional_t> + struct HstuAttentionFwdBatchModeKargs : HstuAttentionFwdCommonKargs, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -140,15 +143,16 @@ struct HstuAttentionFwdKernel ck_tile::index_t batch_stride_o; }; - struct HstuAttentionFwdJaggModeKargs - : HstuAttentionFwdCommonKargs, - std::conditional_t>, - std::conditional_t>, - std::conditional_t> + struct HstuAttentionFwdJaggModeKargs : HstuAttentionFwdCommonKargs, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> { const int32_t* seq_offsets_ptr; }; @@ -224,7 +228,7 @@ struct HstuAttentionFwdKernel kargs.nhead_stride_bias = nhead_stride_bias; kargs.batch_stride_bias = batch_stride_bias; } - if constexpr(kHasMask) + if constexpr(kHasLocalMask) { kargs.window_size = window_size; kargs.min_full_attn_seqlen = min_full_attn_seqlen; @@ -366,7 +370,7 @@ struct HstuAttentionFwdKernel kargs.seq_stride_bias = seq_stride_bias; kargs.nhead_stride_bias = nhead_stride_bias; } - if constexpr(kHasMask) + if constexpr(kHasLocalMask) { kargs.window_size = window_size; kargs.min_full_attn_seqlen = min_full_attn_seqlen; @@ -542,14 +546,15 @@ struct HstuAttentionFwdKernel int num_target = (kargs.num_targets_ptr == nullptr) ? 0 : kargs.num_targets_ptr[i_batch]; HstuMask mask = [&]() { - if constexpr(kHasMask) - return HstuMask{kargs.seqlen, - kargs.contextual_seqlen, - num_target, - kargs.window_size, - kargs.min_full_attn_seqlen}; + if constexpr(kHasLocalMask) + return make_hstu_block_mask_with_local(kargs.seqlen, + kargs.contextual_seqlen, + num_target, + kargs.window_size, + kargs.min_full_attn_seqlen); else - return HstuMask{kargs.seqlen, kargs.contextual_seqlen, num_target}; + return make_hstu_block_mask_without_local( + kargs.seqlen, kargs.contextual_seqlen, num_target); }(); // for simplicity, batch stride we just modify the pointer diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp index df6e3dddae..b90778e939 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -32,7 +32,7 @@ template ::Type; - using HstuMask = ck_tile::HstuBlockMasking; + using HstuMask = typename ck_tile::HstuBlockMasking::Type; template using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem< diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index ef01876fcd..580493d366 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -3,14 +3,16 @@ #pragma once +#include #include "ck_tile/core.hpp" namespace ck_tile { -template -struct HstuBlockMasking +template +struct HstuBlockMaskWithLocal { - static constexpr bool IsMasking = (kUseCausal || kUseLocal); + static constexpr bool kUseLocal = true; + static constexpr bool IsMasking = true; int contextual_seqlen; int max_uih_len; @@ -18,30 +20,86 @@ struct HstuBlockMasking int max_attn_len; int min_full_attn_seqlen; - CK_TILE_HOST_DEVICE HstuBlockMasking(int seqlen_, - int contextual_seqlen_, - int num_target, - int max_attn_len_, - int min_full_attn_seqlen_) + CK_TILE_HOST_DEVICE HstuBlockMaskWithLocal(int contextual_seqlen_, + int max_uih_len_, + int max_attn_len_, + int min_full_attn_seqlen_) + : contextual_seqlen(contextual_seqlen_), + max_uih_len(max_uih_len_), + max_attn_len(max_attn_len_), + min_full_attn_seqlen(min_full_attn_seqlen_){}; + + // to get the loop length along X axis, return index:[start, end), end-start=length + // use this if need loop over X axis tile by tile (eg. seqlen_k loop-over) + // i_y is the start offset of the current tile along the seqlen_q dimension + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongX(index_t i_y, number, number) const { - max_uih_len = seqlen_; - contextual_seqlen = contextual_seqlen_; + if(i_y < contextual_seqlen) + return ck_tile::make_tuple(0, max_uih_len); - max_attn_len = max_attn_len_; - min_full_attn_seqlen = min_full_attn_seqlen_; + if constexpr(!kUseCausal) + { + if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen) + { + return ck_tile::make_tuple(0, max_uih_len); + } + else + { + index_t x_start = max(0, i_y - max_attn_len); + index_t x_end = i_y + YTile + max_attn_len; - max_uih_len -= contextual_seqlen > 0 ? contextual_seqlen - 1 : 0; - max_uih_len -= num_target; - }; + return ck_tile::make_tuple(x_start - x_start % XTile, x_end); + }; + } + else // kUseCausal && kUseLocal + { + if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen) + { + return ck_tile::make_tuple(0, max_uih_len); + } + else + { + index_t x_end = i_y + YTile + max_attn_len; - CK_TILE_HOST_DEVICE HstuBlockMasking(int seqlen_, int contextual_seqlen_, int num_target) + return ck_tile::make_tuple(0, x_end); + }; + }; + } + + CK_TILE_HOST_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col) { - max_uih_len = seqlen_; - contextual_seqlen = contextual_seqlen_; + if(row >= max_uih_len || col >= max_uih_len) + return false; - max_uih_len -= contextual_seqlen > 0 ? contextual_seqlen - 1 : 0; - max_uih_len -= num_target; + if(row < contextual_seqlen) + return true; + + bool result = false; + if constexpr(kUseCausal) + result = (row >= col) && (row - col <= max_attn_len); + else + result = abs(row - col) <= max_attn_len; + + if(min_full_attn_seqlen > 0) + result = result || (row >= max_uih_len - min_full_attn_seqlen); + + return result; }; +}; + +template +struct HstuBlockMaskNoLocal +{ + static constexpr bool kUseLocal = false; + static constexpr bool IsMasking = kUseCausal; + + int contextual_seqlen; + int max_uih_len; + + CK_TILE_HOST_DEVICE HstuBlockMaskNoLocal(int contextual_seqlen_, int max_uih_len_) + : contextual_seqlen(contextual_seqlen_), max_uih_len(max_uih_len_){}; // to get the loop length along X axis, return index:[start, end), end-start=length // use this if need loop over X axis tile by tile (eg. seqlen_k loop-over) @@ -59,40 +117,9 @@ struct HstuBlockMasking if(i_y < contextual_seqlen) return ck_tile::make_tuple(0, max_uih_len); - if constexpr(kUseCausal && !kUseLocal) - { - index_t x_end = - min(i_y + YTile, max_uih_len); // for lower-triangular masking, x <= y + index_t x_end = min(i_y + YTile, max_uih_len); // for lower-triangular masking, x <= y - return ck_tile::make_tuple(0, x_end); - } - else if constexpr(!kUseCausal && kUseLocal) - { - if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen) - { - return ck_tile::make_tuple(0, max_uih_len); - } - else - { - index_t x_start = max(0, i_y - max_attn_len); - index_t x_end = i_y + YTile + max_attn_len; - - return ck_tile::make_tuple(x_start - x_start % XTile, x_end); - }; - } - else // kUseCausal && kUseLocal - { - if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen) - { - return ck_tile::make_tuple(0, max_uih_len); - } - else - { - index_t x_end = i_y + YTile + max_attn_len; - - return ck_tile::make_tuple(0, x_end); - }; - }; + return ck_tile::make_tuple(0, x_end); }; } @@ -106,21 +133,7 @@ struct HstuBlockMasking if constexpr(IsMasking) { - bool result = false; - if constexpr(kUseLocal) - { - if constexpr(kUseCausal) - result = (row >= col) && (row - col <= max_attn_len); - else - result = abs(row - col) <= max_attn_len; - - if(min_full_attn_seqlen > 0) - result = (row >= max_uih_len - min_full_attn_seqlen); - } - else - { - result = (row >= col); - }; + bool result = (row >= col); return result; } @@ -129,4 +142,40 @@ struct HstuBlockMasking }; }; +template +struct HstuBlockMasking +{ + using Type = std::conditional_t, + HstuBlockMaskNoLocal>; +}; + +template +CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(int seqlen_, + int contextual_seqlen_, + int num_target, + int max_attn_len_, + int min_full_attn_seqlen_) +{ + auto max_uih_len_ = seqlen_; + + max_uih_len_ -= contextual_seqlen_ > 0 ? contextual_seqlen_ - 1 : 0; + max_uih_len_ -= num_target; + + return HstuBlockMaskType{ + contextual_seqlen_, max_uih_len_, max_attn_len_, min_full_attn_seqlen_}; +}; + +template +CK_TILE_HOST_DEVICE constexpr auto +make_hstu_block_mask_without_local(int seqlen_, int contextual_seqlen_, int num_target) +{ + auto max_uih_len_ = seqlen_; + + max_uih_len_ -= contextual_seqlen_ > 0 ? contextual_seqlen_ - 1 : 0; + max_uih_len_ -= num_target; + + return HstuBlockMaskType{contextual_seqlen_, max_uih_len_}; +}; + } // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index e8ea6fa1f0..b5c9769fab 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -33,8 +33,8 @@ template struct reference_hstu_attention { - using HstuMask = HstuBlockMasking; - static constexpr bool kHasMask = kUseCausal || kUseLocal; + using HstuMask = typename HstuBlockMasking::Type; + static constexpr bool kHasLocalMask = HstuMask::kUseLocal; static void Run(const HostTensor& q_batch_seq_nhead_hdim, const HostTensor& k_batch_seq_nhead_hdim, @@ -90,11 +90,9 @@ struct reference_hstu_attention assert(num_tagets.empty() || num_targets.size() == num_batch); auto silu = [](CompDataType x) { - auto one = ck_tile::type_convert(1.0f); + const auto one = ck_tile::type_convert(1.0f); - auto sigmod_val = one / (one + std::exp(-x)); - - return sigmod_val * x; + return x / (one + std::exp(-x)); }; auto f = [&](auto i_batch, auto i_head) { @@ -104,11 +102,12 @@ struct reference_hstu_attention int num_target = num_targets.empty() ? 0 : num_targets[i_batch]; HstuMask mask = [&]() { - if constexpr(kHasMask) - return HstuMask{ - seqlen, contextual_seqlen, num_target, max_attn_len, min_full_attn_seqlen}; + if constexpr(kHasLocalMask) + return ck_tile::make_hstu_block_mask_with_local( + seqlen, contextual_seqlen, num_target, max_attn_len, min_full_attn_seqlen); else - return HstuMask{seqlen, contextual_seqlen, num_target}; + return ck_tile::make_hstu_block_mask_without_local( + seqlen, contextual_seqlen, num_target); }(); // for all rows in the batch