diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index f7c2b674aa..b36f78a06f 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -219,8 +219,6 @@ bool run(const ck_tile::ArgParser& arg_parser) int window_size = arg_parser.get_int("local_len"); - bool use_local = (window_size > 0); - int contextual_seqlen = arg_parser.get_int("context_len"); int min_full_attn_seqlen = arg_parser.get_int("minfull_len"); @@ -516,26 +514,25 @@ bool run(const ck_tile::ArgParser& arg_parser) using GemmAccDataType = typename HstuAttentionFwdTypeConfig::GemmAccDataType; using CompDataType = typename HstuAttentionFwdTypeConfig::CompDataType; - BOOL_SWITCH_3(is_jagged, kIsJagged, use_causal, kUseCausal, use_local, kUseLocal, [&] { + BOOL_SWITCH_2(is_jagged, kIsJagged, use_causal, kUseCausal, [&] { ck_tile::reference_hstu_attention::Run(q_host, - k_host, - v_host, - o_host_ref, - mask_host, - num_batch, - scale_s, - attn_scale, - max_seqlen, - seq_offsets, - num_targets, - window_size, - contextual_seqlen, - min_full_attn_seqlen); + kUseCausal>::Run(q_host, + k_host, + v_host, + o_host_ref, + mask_host, + num_batch, + scale_s, + attn_scale, + max_seqlen, + seq_offsets, + num_targets, + contextual_seqlen, + window_size, + min_full_attn_seqlen); }); ck_tile::HostTensor o_host( diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index 90ada7e7a7..08f561620a 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -29,13 +29,9 @@ template + bool kUseCausal> struct reference_hstu_attention { - using HstuMask = typename HstuBlockMasking::Type; - static constexpr bool kHasLocalMask = HstuMask::kUseLocal; - static void Run(const HostTensor& q_batch_seq_nhead_hdim, const HostTensor& k_batch_seq_nhead_hdim, const HostTensor& v_batch_seq_nhead_hdim, @@ -48,9 +44,9 @@ struct reference_hstu_attention std::vector seq_offsets, std::vector num_targets, // define masking length at the end of token // sequence to be excluded for attention - int max_attn_len, // define the diagonal local window size int contextual_seqlen, // define masking length at the begin of query token // sequence to be included for attention + int window_size, // define the diagonal local window size int min_full_attn_seqlen) // define masking length at the end of query token // sequence which is included for full attention { @@ -112,122 +108,130 @@ struct reference_hstu_attention int num_target = num_targets.empty() ? 0 : num_targets[i_batch]; - HstuMask mask = [&]() { - if constexpr(kHasLocalMask) - // need adjust the min_full_attn_seqlen passed to the HstuBlockMask() if the - // user passed min_full_attn_seqlen is bigger than max_uih_len - if(seqlen - num_target > min_full_attn_seqlen) - return ck_tile::make_hstu_block_mask_with_local( - true, - seqlen, - contextual_seqlen, - num_target, - max_attn_len, - min_full_attn_seqlen); + float scale_p = attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen); + + BOOL_SWITCH(window_size > 0, kHasLocal, [&] { + using HstuMaskType = typename HstuBlockMasking::Type; + + HstuMaskType mask = [&]() { + if constexpr(kHasLocal) + // need adjust the min_full_attn_seqlen passed to the HstuBlockMask() if the + // user passed min_full_attn_seqlen is bigger than max_uih_len + if(seqlen - num_target > min_full_attn_seqlen) + return ck_tile::make_hstu_block_mask_with_local( + true, + seqlen, + contextual_seqlen, + num_target, + window_size, + min_full_attn_seqlen); + else + return ck_tile::make_hstu_block_mask_with_local( + true, + seqlen, + contextual_seqlen, + num_target, + window_size, + seqlen - num_target); else - return ck_tile::make_hstu_block_mask_with_local(true, - seqlen, - contextual_seqlen, - num_target, - max_attn_len, - seqlen - - num_target); - else - return ck_tile::make_hstu_block_mask_without_local( - seqlen, contextual_seqlen, num_target); - }(); + return ck_tile::make_hstu_block_mask_without_local( + seqlen, contextual_seqlen, num_target); + }(); - if(save_mask) - { - // initialize the mask - for(int sq = 0; sq < max_seqlen; sq++) - for(int sk = 0; sk < max_seqlen; sk++) - mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) = - static_cast(mask.IsTokenPairInsideMask(sq, sk)); - } - - // for all rows in the batch - for(int sq = 0; sq < seqlen; sq++) - { - std::vector locals; - - // for all cols in the batch - for(int sk = 0; sk < seqlen; sk++) + if(save_mask) { - if(mask.IsTokenPairInsideMask(sq, sk)) + // initialize the mask + for(int sq = 0; sq < max_seqlen; sq++) + for(int sk = 0; sk < max_seqlen; sk++) + mask_batch_nhead_seq_seq(i_batch, i_head, sq, sk) = + static_cast(mask.IsTokenPairInsideMask(sq, sk)); + } + + // for all rows in the batch + for(int sq = 0; sq < seqlen; sq++) + { + std::vector locals; + + // for all cols in the batch + for(int sk = 0; sk < seqlen; sk++) + { + if(mask.IsTokenPairInsideMask(sq, sk)) + { + GemmAccDataType dot_prod = 0.f; + for(int k = 0; k < hdim_qk; k++) + { + if constexpr(kIsJagged) + { + InOutDataType qreg = q_batch_seq_nhead_hdim( + 0, seq_offsets[i_batch] + sq, i_head, k); + InOutDataType kreg = k_batch_seq_nhead_hdim( + 0, seq_offsets[i_batch] + sk, i_head, k); + + dot_prod += ck_tile::type_convert(qreg) * + ck_tile::type_convert(kreg); + } + else + { + InOutDataType qreg = + q_batch_seq_nhead_hdim(i_batch, sq, i_head, k); + InOutDataType kreg = + k_batch_seq_nhead_hdim(i_batch, sk, i_head, k); + + dot_prod += ck_tile::type_convert(qreg) * + ck_tile::type_convert(kreg); + }; + } + + locals.push_back(ck_tile::type_convert(dot_prod) * + ck_tile::type_convert(alpha)); + } + else + locals.push_back(ck_tile::type_convert(0.0f)); + }; + + // SiLu element-wise + for(CompDataType& elem : locals) + elem = silu(elem); + + // second Gemm + for(int k = 0; k < hdim_v; k++) { GemmAccDataType dot_prod = 0.f; - for(int k = 0; k < hdim_qk; k++) + + for(int sk = 0; sk < seqlen; sk++) { if constexpr(kIsJagged) { - InOutDataType qreg = - q_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k); - InOutDataType kreg = - k_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k); + InOutDataType preg = + ck_tile::type_convert(locals[sk]); + InOutDataType vreg = + v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k); - dot_prod += ck_tile::type_convert(qreg) * - ck_tile::type_convert(kreg); + dot_prod += ck_tile::type_convert(preg) * + ck_tile::type_convert(vreg); } else { - InOutDataType qreg = q_batch_seq_nhead_hdim(i_batch, sq, i_head, k); - InOutDataType kreg = k_batch_seq_nhead_hdim(i_batch, sk, i_head, k); + InOutDataType preg = + ck_tile::type_convert(locals[sk]); + InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k); - dot_prod += ck_tile::type_convert(qreg) * - ck_tile::type_convert(kreg); + dot_prod += ck_tile::type_convert(preg) * + ck_tile::type_convert(vreg); }; - } - - locals.push_back(ck_tile::type_convert(dot_prod) * - ck_tile::type_convert(alpha)); - } - else - locals.push_back(ck_tile::type_convert(0.0f)); - }; - - // SiLu element-wise - for(CompDataType& elem : locals) - elem = silu(elem); - - float scale_p = attn_scale ? attn_scale : 1.0f / static_cast(max_seqlen); - - // second Gemm - for(int k = 0; k < hdim_v; k++) - { - GemmAccDataType dot_prod = 0.f; - - for(int sk = 0; sk < seqlen; sk++) - { - if constexpr(kIsJagged) - { - InOutDataType preg = ck_tile::type_convert(locals[sk]); - InOutDataType vreg = - v_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sk, i_head, k); - - dot_prod += ck_tile::type_convert(preg) * - ck_tile::type_convert(vreg); - } - else - { - InOutDataType preg = ck_tile::type_convert(locals[sk]); - InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k); - - dot_prod += ck_tile::type_convert(preg) * - ck_tile::type_convert(vreg); }; + + dot_prod = dot_prod * ck_tile::type_convert(scale_p); + + if constexpr(kIsJagged) + o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) = + ck_tile::type_convert(dot_prod); + else + o_batch_seq_nhead_hdim(i_batch, sq, i_head, k) = + ck_tile::type_convert(dot_prod); }; - - dot_prod = dot_prod * ck_tile::type_convert(scale_p); - - if constexpr(kIsJagged) - o_batch_seq_nhead_hdim(0, seq_offsets[i_batch] + sq, i_head, k) = - ck_tile::type_convert(dot_prod); - else - o_batch_seq_nhead_hdim(i_batch, sq, i_head, k) = - ck_tile::type_convert(dot_prod); }; - }; + }); }; make_ParallelTensorFunctor(f, num_batch, num_head)(std::thread::hardware_concurrency());