diff --git a/example/ck_tile/42_unified_attention/example_unified_attention.cpp b/example/ck_tile/42_unified_attention/example_unified_attention.cpp index 03e5697ba0..95508f8a42 100644 --- a/example/ck_tile/42_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/example_unified_attention.cpp @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +#include #include #include #include @@ -52,7 +53,16 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair0 means sliding-window attention\n" + " 't:l,r'/'b:l,r' : FA-style left/right window from top-left/bottom-right\n" + " 'g:y,x' : generic mask coordinate") .insert("verify", "1", "0:no verify, 1:verify") .insert("varlen", "1", "0: fixed length, 1: variable length") .insert("seed", @@ -169,6 +179,20 @@ struct Problem { num_tokens += len; } + + mask_str = args.get_str("mask"); + // Decode once with the maximum batch shape for top-level reporting and + // for the kernel-side mask_type. The host reference re-decodes per-batch + // with each batch's effective seqlens (varlen-aware) inside run_impl. + const ck_tile::index_t report_seqlen_q = + query_lens.empty() + ? max_seqlen_q + : *std::max_element(query_lens.begin(), query_lens.end()); + const ck_tile::index_t report_seqlen_kv = + kv_lens.empty() + ? max_seqlen_kv + : *std::max_element(kv_lens.begin(), kv_lens.end()); + mask = mask_info::decode(mask_str, report_seqlen_q, report_seqlen_kv); } std::vector get_query_shape() const { return {num_tokens, nhead_q, hdim}; } @@ -198,6 +222,7 @@ struct Problem float scale; float scale_k; float scale_v; + std::string mask_str; mask_info mask; std::vector query_lens; std::vector kv_lens; @@ -256,7 +281,7 @@ template & q_bshd, const ck_tile::HostTensor& k_bshd, const ck_tile::HostTensor& v_bshd, - // const mask_info& mask, + const mask_info& mask, ck_tile::HostTensor& o_bshd, const QElementOp& q_element_op = {}, const KElementOp& k_element_op = {}, @@ -295,10 +320,19 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, ck_tile::reference_batched_gemm( q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - -1, 0, seqlen_q, seqlen_kv, 1, false)); + if(mask.type != mask_enum::no_mask) + { + // Always use the GenericMask (IsLocal=true) path so both classical causal + // (left=-1, right=0) and sliding-window (left>=0) flow through the same + // codepath. The helper translates left/right into y/x mask coordinates + // and is_top_left selects the corner. + const bool is_top_left = (mask.type == mask_enum::mask_top_left); + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window< + UnifiedAttentionMasks::GenericMask>( + mask.left, mask.right, seqlen_q, seqlen_kv, /*repeat_idx=*/1, is_top_left)); + } ck_tile::reference_batched_softmax( s_host_ref, p_host_ref, ck_tile::identity{}); ck_tile::reference_batched_gemm( @@ -336,7 +370,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) args.num_head_q = problem.nhead_q; args.num_queries_per_kv = problem.num_queries_per_kv; args.page_blk_size = problem.page_blk_size; - args.mask_type = 2; + args.mask_type = static_cast(problem.mask.type); args.hdim = problem.hdim; args.num_blks = problem.num_blks; @@ -514,7 +548,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) if(i < problem.kv_lens.size() - 1) std::cout << ","; } - std::cout << "], mask:" << "causal mask" << std::fixed << ", " << std::setprecision(8) << time + std::cout << "], mask:" << problem.mask << std::fixed << ", " << std::setprecision(8) << time << " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << (static_cast(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl; @@ -566,11 +600,15 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) }); // v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); + // Decode the mask freshly with this batch's effective seqlens so the host + // reference matches the per-batch attention shape (varlen-aware). + const auto batch_mask = mask_info::decode(problem.mask_str, seqlen_q_eff, seqlen_kv_eff); + // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) host::fmha_fwd(q_b, k_b, v_b, - // problem.mask, + batch_mask, o_b, ck_tile::identity{}, ck_tile::identity{}, diff --git a/include/ck_tile/ops/unified_attention/block/block_masking.hpp b/include/ck_tile/ops/unified_attention/block/block_masking.hpp index dd735c7689..8e02fcbc30 100644 --- a/include/ck_tile/ops/unified_attention/block/block_masking.hpp +++ b/include/ck_tile/ops/unified_attention/block/block_masking.hpp @@ -207,6 +207,15 @@ struct GenericAttentionMask } } + // Attention-sink aware variant. The unified-attention kernel does not yet + // implement sink tokens, so for now this is an alias for IsOutOfBound. + // Host reference code (reference_batched_masking) calls this; keeping the + // alias decouples that call site from future sink support. + CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const + { + return IsOutOfBound(i_y, i_x); + } + // if current tile is at the edge, means need per-pixel mask check. // otherwise no need to check per-pixel // Attention! assume the index passed in this function is within range of GetTileRangeAlongX/Y()