diff --git a/example/ck_tile/42_unified_attention/example_unified_attention.cpp b/example/ck_tile/42_unified_attention/example_unified_attention.cpp index 03e5697ba0..abdeb461ab 100644 --- a/example/ck_tile/42_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/example_unified_attention.cpp @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +#include #include #include #include @@ -52,7 +53,19 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair0 means sliding-window attention\n" + " 't:l,r'/'b:l,r' : FA-style left/right window from top-left/bottom-right\n" + " 'g:y,x' : generic mask coordinate\n" + "NOTE: today the CK unified attention kernel only supports no-mask and\n" + " bottom-right causal. Other values are accepted by the harness so\n" + " that future SWA work can drive failing/passing tests through it.") .insert("verify", "1", "0:no verify, 1:verify") .insert("varlen", "1", "0: fixed length, 1: variable length") .insert("seed", @@ -169,6 +182,20 @@ struct Problem { num_tokens += len; } + + mask_str = args.get_str("mask"); + // Decode once with the maximum batch shape just for top-level reporting and + // for the kernel-side mask_type. The host reference re-decodes per-batch + // with each batch's effective seqlens (varlen-aware). + const ck_tile::index_t report_seqlen_q = + query_lens.empty() + ? max_seqlen_q + : *std::max_element(query_lens.begin(), query_lens.end()); + const ck_tile::index_t report_seqlen_kv = + kv_lens.empty() + ? max_seqlen_kv + : *std::max_element(kv_lens.begin(), kv_lens.end()); + mask = mask_info::decode(mask_str, report_seqlen_q, report_seqlen_kv); } std::vector get_query_shape() const { return {num_tokens, nhead_q, hdim}; } @@ -198,6 +225,7 @@ struct Problem float scale; float scale_k; float scale_v; + std::string mask_str; mask_info mask; std::vector query_lens; std::vector kv_lens; @@ -256,7 +284,7 @@ template & q_bshd, const ck_tile::HostTensor& k_bshd, const ck_tile::HostTensor& v_bshd, - // const mask_info& mask, + const mask_info& mask, ck_tile::HostTensor& o_bshd, const QElementOp& q_element_op = {}, const KElementOp& k_element_op = {}, @@ -295,10 +323,18 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, ck_tile::reference_batched_gemm( q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - -1, 0, seqlen_q, seqlen_kv, 1, false)); + if(mask.type != mask_enum::no_mask) + { + // We always use the local-mask path: with left=-1 (causal) it + // degenerates to standard causal; with left>=0 it produces SWA. This keeps + // a single masking codepath for both causal and (future) sliding-window cases. + const bool is_top_left = (mask.type == mask_enum::mask_top_left); + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window< + UnifiedAttentionMasks::GenericMask>( + mask.left, mask.right, seqlen_q, seqlen_kv, /*repeat_idx=*/1, is_top_left)); + } ck_tile::reference_batched_softmax( s_host_ref, p_host_ref, ck_tile::identity{}); ck_tile::reference_batched_gemm( @@ -336,7 +372,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) args.num_head_q = problem.nhead_q; args.num_queries_per_kv = problem.num_queries_per_kv; args.page_blk_size = problem.page_blk_size; - args.mask_type = 2; + args.mask_type = static_cast(problem.mask.type); args.hdim = problem.hdim; args.num_blks = problem.num_blks; @@ -514,7 +550,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) if(i < problem.kv_lens.size() - 1) std::cout << ","; } - std::cout << "], mask:" << "causal mask" << std::fixed << ", " << std::setprecision(8) << time + std::cout << "], mask:" << problem.mask << std::fixed << ", " << std::setprecision(8) << time << " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << (static_cast(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl; @@ -566,11 +602,15 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) }); // v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); + // Decode the mask freshly with this batch's effective seqlens so the host + // reference matches the per-batch attention shape (varlen-aware). + const auto batch_mask = mask_info::decode(problem.mask_str, seqlen_q_eff, seqlen_kv_eff); + // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) host::fmha_fwd(q_b, k_b, v_b, - // problem.mask, + batch_mask, o_b, ck_tile::identity{}, ck_tile::identity{}, diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp index bdeb56aed9..d95a3889b8 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -71,7 +71,7 @@ static tile_tier select_tile_tier(const unified_attention_args& args) const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs : args.num_tokens; const index_t kBlockQ_tiny = 16 / args.num_queries_per_kv; const index_t kBlockQ_small = 64 / args.num_queries_per_kv; - const index_t kBlockQ_medium = 128 / args.num_queries_per_kv; + [[maybe_unused]] const index_t kBlockQ_medium = 128 / args.num_queries_per_kv; // Decode tiers use a 2D grid (num_kv_heads, num_seqs) that assumes each // seq has at most kBlockQ tokens. For mixed batches where some seqs have diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 16a617fb26..f7699861a0 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -31,6 +31,7 @@ add_subdirectory(38_block_scale_gemm) add_subdirectory(40_streamk_gemm) add_subdirectory(41_batched_contraction) add_subdirectory(42_mx_gemm) +add_subdirectory(42_unified_attention) add_subdirectory(50_sparse_attn) add_subdirectory(51_tile_distr_enc_reg_map)