Prepare the interface to support SWA

This commit is contained in:
Damien Lejeune
2026-05-07 13:52:56 +00:00
parent 3f076a6fc1
commit c132e6fc18
3 changed files with 51 additions and 10 deletions

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <iomanip>
#include <iostream>
#include <optional>
@@ -52,7 +53,19 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
"permute input\n"
"if true, will be b*h*s*d, else b*s*h*d")
.insert("operm", "0", "permute output")
.insert("causal", "0", "0: no mask, 1: causal mask")
.insert("mask",
"b",
"attention mask. accepts the same syntax as 01_fmha:\n"
" '0' : no mask\n"
" '1' or 't' : causal mask from top-left\n"
" '2' or 'b' : causal mask from bottom-right (default)\n"
" 'xt:N'/'xb:N' : xformer-style window_size N from top-left/bottom-right\n"
" N<0 means causal, N>0 means sliding-window attention\n"
" 't:l,r'/'b:l,r' : FA-style left/right window from top-left/bottom-right\n"
" 'g:y,x' : generic mask coordinate\n"
"NOTE: today the CK unified attention kernel only supports no-mask and\n"
" bottom-right causal. Other values are accepted by the harness so\n"
" that future SWA work can drive failing/passing tests through it.")
.insert("verify", "1", "0:no verify, 1:verify")
.insert("varlen", "1", "0: fixed length, 1: variable length")
.insert("seed",
@@ -169,6 +182,20 @@ struct Problem
{
num_tokens += len;
}
mask_str = args.get_str("mask");
// Decode once with the maximum batch shape just for top-level reporting and
// for the kernel-side mask_type. The host reference re-decodes per-batch
// with each batch's effective seqlens (varlen-aware).
const ck_tile::index_t report_seqlen_q =
query_lens.empty()
? max_seqlen_q
: *std::max_element(query_lens.begin(), query_lens.end());
const ck_tile::index_t report_seqlen_kv =
kv_lens.empty()
? max_seqlen_kv
: *std::max_element(kv_lens.begin(), kv_lens.end());
mask = mask_info::decode(mask_str, report_seqlen_q, report_seqlen_kv);
}
std::vector<ck_tile::index_t> get_query_shape() const { return {num_tokens, nhead_q, hdim}; }
@@ -198,6 +225,7 @@ struct Problem
float scale;
float scale_k;
float scale_v;
std::string mask_str;
mask_info mask;
std::vector<int> query_lens;
std::vector<int> kv_lens;
@@ -256,7 +284,7 @@ template <typename AccDataType,
CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
const ck_tile::HostTensor<KDataType>& k_bshd,
const ck_tile::HostTensor<VDataType>& v_bshd,
// const mask_info& mask,
const mask_info& mask,
ck_tile::HostTensor<ODataType>& o_bshd,
const QElementOp& q_element_op = {},
const KElementOp& k_element_op = {},
@@ -295,10 +323,18 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<UnifiedAttentionMasks::CausalMask>(
-1, 0, seqlen_q, seqlen_kv, 1, false));
if(mask.type != mask_enum::no_mask)
{
// We always use the local-mask path: with left=-1 (causal) it
// degenerates to standard causal; with left>=0 it produces SWA. This keeps
// a single masking codepath for both causal and (future) sliding-window cases.
const bool is_top_left = (mask.type == mask_enum::mask_top_left);
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<
UnifiedAttentionMasks::GenericMask>(
mask.left, mask.right, seqlen_q, seqlen_kv, /*repeat_idx=*/1, is_top_left));
}
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
s_host_ref, p_host_ref, ck_tile::identity{});
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
@@ -336,7 +372,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
args.num_head_q = problem.nhead_q;
args.num_queries_per_kv = problem.num_queries_per_kv;
args.page_blk_size = problem.page_blk_size;
args.mask_type = 2;
args.mask_type = static_cast<int>(problem.mask.type);
args.hdim = problem.hdim;
args.num_blks = problem.num_blks;
@@ -514,7 +550,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
if(i < problem.kv_lens.size() - 1)
std::cout << ",";
}
std::cout << "], mask:" << "causal mask" << std::fixed << ", " << std::setprecision(8) << time
std::cout << "], mask:" << problem.mask << std::fixed << ", " << std::setprecision(8) << time
<< " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2)
<< (static_cast<double>(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl;
@@ -566,11 +602,15 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
});
// v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
// Decode the mask freshly with this batch's effective seqlens so the host
// reference matches the per-batch attention shape (varlen-aware).
const auto batch_mask = mask_info::decode(problem.mask_str, seqlen_q_eff, seqlen_kv_eff);
// Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
host::fmha_fwd<float, DataType>(q_b,
k_b,
v_b,
// problem.mask,
batch_mask,
o_b,
ck_tile::identity{},
ck_tile::identity{},

View File

@@ -71,7 +71,7 @@ static tile_tier select_tile_tier(const unified_attention_args& args)
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs : args.num_tokens;
const index_t kBlockQ_tiny = 16 / args.num_queries_per_kv;
const index_t kBlockQ_small = 64 / args.num_queries_per_kv;
const index_t kBlockQ_medium = 128 / args.num_queries_per_kv;
[[maybe_unused]] const index_t kBlockQ_medium = 128 / args.num_queries_per_kv;
// Decode tiers use a 2D grid (num_kv_heads, num_seqs) that assumes each
// seq has at most kBlockQ tokens. For mixed batches where some seqs have

View File

@@ -31,6 +31,7 @@ add_subdirectory(38_block_scale_gemm)
add_subdirectory(40_streamk_gemm)
add_subdirectory(41_batched_contraction)
add_subdirectory(42_mx_gemm)
add_subdirectory(42_unified_attention)
add_subdirectory(50_sparse_attn)
add_subdirectory(51_tile_distr_enc_reg_map)