mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Add the IsOutOfSinkBound alias + update mask cmd line argument in the example
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <algorithm>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
@@ -52,7 +53,16 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
|
||||
"permute input\n"
|
||||
"if true, will be b*h*s*d, else b*s*h*d")
|
||||
.insert("operm", "0", "permute output")
|
||||
.insert("causal", "0", "0: no mask, 1: causal mask")
|
||||
.insert("mask",
|
||||
"b",
|
||||
"attention mask. accepts the same syntax as 01_fmha:\n"
|
||||
" '0' : no mask\n"
|
||||
" '1' or 't' : causal mask from top-left\n"
|
||||
" '2' or 'b' : causal mask from bottom-right (default)\n"
|
||||
" 'xt:N'/'xb:N' : xformer-style window_size N from top-left/bottom-right\n"
|
||||
" N<0 means causal, N>0 means sliding-window attention\n"
|
||||
" 't:l,r'/'b:l,r' : FA-style left/right window from top-left/bottom-right\n"
|
||||
" 'g:y,x' : generic mask coordinate")
|
||||
.insert("verify", "1", "0:no verify, 1:verify")
|
||||
.insert("varlen", "1", "0: fixed length, 1: variable length")
|
||||
.insert("seed",
|
||||
@@ -169,6 +179,20 @@ struct Problem
|
||||
{
|
||||
num_tokens += len;
|
||||
}
|
||||
|
||||
mask_str = args.get_str("mask");
|
||||
// Decode once with the maximum batch shape for top-level reporting and
|
||||
// for the kernel-side mask_type. The host reference re-decodes per-batch
|
||||
// with each batch's effective seqlens (varlen-aware) inside run_impl.
|
||||
const ck_tile::index_t report_seqlen_q =
|
||||
query_lens.empty()
|
||||
? max_seqlen_q
|
||||
: *std::max_element(query_lens.begin(), query_lens.end());
|
||||
const ck_tile::index_t report_seqlen_kv =
|
||||
kv_lens.empty()
|
||||
? max_seqlen_kv
|
||||
: *std::max_element(kv_lens.begin(), kv_lens.end());
|
||||
mask = mask_info::decode(mask_str, report_seqlen_q, report_seqlen_kv);
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> get_query_shape() const { return {num_tokens, nhead_q, hdim}; }
|
||||
@@ -198,6 +222,7 @@ struct Problem
|
||||
float scale;
|
||||
float scale_k;
|
||||
float scale_v;
|
||||
std::string mask_str;
|
||||
mask_info mask;
|
||||
std::vector<int> query_lens;
|
||||
std::vector<int> kv_lens;
|
||||
@@ -256,7 +281,7 @@ template <typename AccDataType,
|
||||
CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
|
||||
const ck_tile::HostTensor<KDataType>& k_bshd,
|
||||
const ck_tile::HostTensor<VDataType>& v_bshd,
|
||||
// const mask_info& mask,
|
||||
const mask_info& mask,
|
||||
ck_tile::HostTensor<ODataType>& o_bshd,
|
||||
const QElementOp& q_element_op = {},
|
||||
const KElementOp& k_element_op = {},
|
||||
@@ -295,10 +320,19 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
|
||||
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
|
||||
q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
|
||||
|
||||
ck_tile::reference_batched_masking(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<UnifiedAttentionMasks::CausalMask>(
|
||||
-1, 0, seqlen_q, seqlen_kv, 1, false));
|
||||
if(mask.type != mask_enum::no_mask)
|
||||
{
|
||||
// Always use the GenericMask (IsLocal=true) path so both classical causal
|
||||
// (left=-1, right=0) and sliding-window (left>=0) flow through the same
|
||||
// codepath. The helper translates left/right into y/x mask coordinates
|
||||
// and is_top_left selects the corner.
|
||||
const bool is_top_left = (mask.type == mask_enum::mask_top_left);
|
||||
ck_tile::reference_batched_masking(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<
|
||||
UnifiedAttentionMasks::GenericMask>(
|
||||
mask.left, mask.right, seqlen_q, seqlen_kv, /*repeat_idx=*/1, is_top_left));
|
||||
}
|
||||
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
|
||||
s_host_ref, p_host_ref, ck_tile::identity{});
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
|
||||
@@ -336,7 +370,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
args.num_head_q = problem.nhead_q;
|
||||
args.num_queries_per_kv = problem.num_queries_per_kv;
|
||||
args.page_blk_size = problem.page_blk_size;
|
||||
args.mask_type = 2;
|
||||
args.mask_type = static_cast<int>(problem.mask.type);
|
||||
args.hdim = problem.hdim;
|
||||
|
||||
args.num_blks = problem.num_blks;
|
||||
@@ -514,7 +548,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
if(i < problem.kv_lens.size() - 1)
|
||||
std::cout << ",";
|
||||
}
|
||||
std::cout << "], mask:" << "causal mask" << std::fixed << ", " << std::setprecision(8) << time
|
||||
std::cout << "], mask:" << problem.mask << std::fixed << ", " << std::setprecision(8) << time
|
||||
<< " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2)
|
||||
<< (static_cast<double>(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl;
|
||||
|
||||
@@ -566,11 +600,15 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
});
|
||||
// v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
|
||||
|
||||
// Decode the mask freshly with this batch's effective seqlens so the host
|
||||
// reference matches the per-batch attention shape (varlen-aware).
|
||||
const auto batch_mask = mask_info::decode(problem.mask_str, seqlen_q_eff, seqlen_kv_eff);
|
||||
|
||||
// Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
|
||||
host::fmha_fwd<float, DataType>(q_b,
|
||||
k_b,
|
||||
v_b,
|
||||
// problem.mask,
|
||||
batch_mask,
|
||||
o_b,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
|
||||
@@ -207,6 +207,15 @@ struct GenericAttentionMask
|
||||
}
|
||||
}
|
||||
|
||||
// Attention-sink aware variant. The unified-attention kernel does not yet
|
||||
// implement sink tokens, so for now this is an alias for IsOutOfBound.
|
||||
// Host reference code (reference_batched_masking) calls this; keeping the
|
||||
// alias decouples that call site from future sink support.
|
||||
CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
|
||||
{
|
||||
return IsOutOfBound(i_y, i_x);
|
||||
}
|
||||
|
||||
// if current tile is at the edge, means need per-pixel mask check.
|
||||
// otherwise no need to check per-pixel
|
||||
// Attention! assume the index passed in this function is within range of GetTileRangeAlongX/Y()
|
||||
|
||||
Reference in New Issue
Block a user