Collapse CK-UA traits into single kernel_traits<V, DType, IsMask> template

Replace 4 near-identical *_kernel_traits classes (~400 lines of repeated
shape/policy plumbing) with one templated `unified_attention_kernel_traits`
parameterized by `KernelVariant V`. The 6 dispatch_<variant> helpers in
unified_attention.cpp collapse into a single `dispatch_variant<V>` function
template that fans out over (dtype, mask).

Per-variant compile-time knobs (BlockM, BlockSize, warp count, MFMA shape,
pipeline policy, decode-grid flag) now live in one variant_config<V>
specialization each. "What's different between variants" is readable
top-to-bottom in a single block of code, and each instance .cpp shrinks to
a one-line `INST_UNIFIED_ATTENTION_DISPATCH(V, DTYPE, IS_MASK)` macro.

No behavior change. Correctness suite: 236/240 (same 4 known
num_blocks=32768 + d=128 MHA int32-overflow failures as baseline).

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-05-12 10:35:15 +00:00
parent 5bd8f73a28
commit fb0d729fbb
26 changed files with 299 additions and 520 deletions

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, true>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, bf16, true)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 128, 128, 1>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, bf16, true)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, false>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, bf16, false)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 128, 128, 1>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, bf16, false)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, true>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, fp16, true)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 128, 128, 1>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, fp16, true)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, false>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, fp16, false)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 128, 128, 1>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, fp16, false)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 256, 8>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, bf16, true)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 128, 8>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, bf16, true)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 64, 8>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, bf16, true)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 16, 8>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, bf16, true)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 256, 8>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, bf16, false)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 128, 8>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, bf16, false)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 64, 8>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, bf16, false)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 16, 8>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, bf16, false)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 256, 8>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, fp16, true)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 128, 8>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, fp16, true)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 64, 8>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, fp16, true)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 16, 8>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, fp16, true)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 256, 8>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, fp16, false)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 128, 8>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, fp16, false)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 64, 8>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, fp16, false)
} // namespace ck_tile

View File

@@ -6,9 +6,6 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 16, 8>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, fp16, false)
} // namespace ck_tile

View File

@@ -23,50 +23,40 @@ std::ostream& operator<<(std::ostream& stream,
//
// The job is split in two halves so each is small enough to read in one sitting:
//
// 1. KernelVariant + select_config(args)
// - KernelVariant is a flat enum of every compiled kernel instance the
// module knows about. Each entry fixes the static knobs (kBlockM,
// warp count, MFMA shape, pipeline policy).
// - select_config() is the ONLY place where shape-based runtime
// decisions live. It reads (hdim, num_queries_per_kv, avg_q,
// max_seqlen_q) and emits a KernelConfig.
// 1. select_config(args)
// - Reads shape (hdim, num_queries_per_kv, avg_q, max_seqlen_q) and
// picks one of the KernelVariants defined in unified_attention_impl.hpp.
// KernelVariant is the only place where compile-time knobs live —
// changing a knob means adding/editing a variant_config<V>.
//
// 2. dispatch_<variant>() helpers + the final switch
// - Each KernelVariant has a tiny helper that fans out over the
// (dtype, mask) cross-product and calls into the existing
// DISPATCH_UNIFIED_ATTENTION_* macros. The macros and the
// per-variant traits classes are unchanged from before; only the
// selection logic moved.
// 2. dispatch_variant<V>() + the final switch
// - dispatch_variant<V>() is a single function template that fans out
// over (dtype, mask) and forwards into the per-instance dispatch
// function generated by INST_UNIFIED_ATTENTION_DISPATCH.
// - The final switch maps KernelVariant -> dispatch_variant<V>.
//
// page_size is intentionally NOT part of this enum. The multi-page-tile
// fix in the pipeline made the compile-time tile-N (kBlockN) independent
// of the runtime page_blk_size, so every variant is correct for any page
// size. Selection is driven purely by Q-tile shape.
// page_size is intentionally NOT part of the config — the multi-page-tile
// pipeline fix made kBlockN independent of runtime page_blk_size, so every
// variant is correct for any page size.
// =============================================================================
enum class KernelVariant {
// d=128 MHA (num_queries_per_kv = 1)
prefill_d128_mha, // kBlockM=256, 8 warps, 32x32 mfma
decode_d128_mha_m128, // kBlockM=128, 4 warps, 32x32 mfma (kBlockQ=128)
// d=64 GQA-8 (num_queries_per_kv = 8)
prefill_d64_gqa8, // kBlockM=256, 8 warps, 32x32 mfma
decode_d64_gqa8_m128, // kBlockM=128, 4 warps, 32x32 mfma
decode_d64_gqa8_m64, // kBlockM=64, 2 warps, 32x32 mfma
decode_d64_gqa8_m16, // kBlockM=16, 1 warp, 16x16 mfma
};
struct KernelConfig {
struct KernelConfig
{
KernelVariant variant;
bool unsupported = false;
};
namespace {
// Internal tier classification — used only by select_config. The tier name is
// just shorthand for a kBlockM choice; with num_queries_per_kv=8 the tiers
// correspond to kBlockQ thresholds {2, 8, 16}.
enum class tile_tier { medium, small, tiny };
// Internal tile-tier classification — used only by select_config. The tier
// name is shorthand for a kBlockM choice; with num_queries_per_kv=8 the
// tiers correspond to kBlockQ thresholds {2, 8, 16}.
enum class tile_tier
{
medium,
small,
tiny
};
tile_tier select_tile_tier(const unified_attention_args& args)
{
@@ -80,8 +70,8 @@ tile_tier select_tile_tier(const unified_attention_args& args)
// many more tokens, fall back to the medium tier (1D grid with Q iteration).
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
if (avg_q <= kBlockQ_tiny && max_q <= kBlockQ_tiny) return tile_tier::tiny;
if (avg_q <= kBlockQ_small && max_q <= kBlockQ_small) return tile_tier::small;
if(avg_q <= kBlockQ_tiny && max_q <= kBlockQ_tiny) return tile_tier::tiny;
if(avg_q <= kBlockQ_small && max_q <= kBlockQ_small) return tile_tier::small;
return tile_tier::medium;
}
@@ -96,13 +86,13 @@ KernelConfig select_config(const unified_attention_args& args)
// both fit in 128. Cuts Q-tile waste roughly 2x vs prefill for
// short-Q workloads.
// * prefill_d128_mha : kBlockM=256, 8 warps. Everything else.
if (args.hdim == 128 && args.num_queries_per_kv == 1)
if(args.hdim == 128 && args.num_queries_per_kv == 1)
{
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs
: args.num_tokens;
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
if (avg_q <= 128 && max_q <= 128)
if(avg_q <= 128 && max_q <= 128)
cfg.variant = KernelVariant::decode_d128_mha_m128;
else
cfg.variant = KernelVariant::prefill_d128_mha;
@@ -110,9 +100,9 @@ KernelConfig select_config(const unified_attention_args& args)
}
// d=64 GQA-8 — pure tile-tier ladder. page_size has no influence here.
if (args.hdim == 64 && args.num_queries_per_kv == 8)
if(args.hdim == 64 && args.num_queries_per_kv == 8)
{
switch (select_tile_tier(args))
switch(select_tile_tier(args))
{
case tile_tier::tiny: cfg.variant = KernelVariant::decode_d64_gqa8_m16; break;
case tile_tier::small: cfg.variant = KernelVariant::decode_d64_gqa8_m64; break;
@@ -126,125 +116,36 @@ KernelConfig select_config(const unified_attention_args& args)
}
// -----------------------------------------------------------------------------
// Dispatch macros and per-variant dispatch helpers.
// dispatch_variant<V>
//
// Each DISPATCH_* macro instantiates one (traits, dtype, mask, ...) combo and
// returns. The per-variant helpers below pick the right macro family and fan
// out over (dtype, mask). They look repetitive on purpose: a follow-up commit
// will collapse the 4 traits classes into one templated `kernel_traits<V>`,
// at which point these helpers become one-liners.
// One function template. Fans out over (dtype, mask) and forwards into the
// per-instance dispatch generated by INST_UNIFIED_ATTENTION_DISPATCH. No
// per-variant boilerplate.
// -----------------------------------------------------------------------------
// Helper macro: dispatches based on DataType, IsMasking, HeadSize, BlockM, NumQPerKV.
#define DISPATCH_UNIFIED_ATTENTION(DType, IsMask, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
}
// Dispatch macros for three tile tiers (default block_size).
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType, IsMask, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
}
#define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType, IsMask, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}
#define DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType, IsMask, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_decode_tiny_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}
namespace {
using DType = unified_attention_args::data_type_enum;
std::pair<bool, float> dispatch_prefill_d128_mha(
const unified_attention_args& args, const stream_config& config)
template <KernelVariant V>
std::pair<bool, float> dispatch_variant(const unified_attention_args& args,
const stream_config& config)
{
using DT = unified_attention_args::data_type_enum;
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
if (args.data_type == DType::fp16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION(DType::fp16, false, 128, 256, 1)
else DISPATCH_UNIFIED_ATTENTION(DType::fp16, true, 128, 256, 1)
} else if (args.data_type == DType::bf16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION(DType::bf16, false, 128, 256, 1)
else DISPATCH_UNIFIED_ATTENTION(DType::bf16, true, 128, 256, 1)
if(args.data_type == DT::fp16)
{
if(is_mask)
return unified_attention_kernel_dispatch<
unified_attention_kernel_traits<V, DT::fp16, true>>(args, config);
return unified_attention_kernel_dispatch<
unified_attention_kernel_traits<V, DT::fp16, false>>(args, config);
}
return {false, -1.f};
}
std::pair<bool, float> dispatch_decode_d128_mha_m128(
const unified_attention_args& args, const stream_config& config)
{
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
if (args.data_type == DType::fp16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::fp16, false, 128, 128, 1)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::fp16, true, 128, 128, 1)
} else if (args.data_type == DType::bf16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::bf16, false, 128, 128, 1)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::bf16, true, 128, 128, 1)
}
return {false, -1.f};
}
std::pair<bool, float> dispatch_prefill_d64_gqa8(
const unified_attention_args& args, const stream_config& config)
{
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
if (args.data_type == DType::fp16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION(DType::fp16, false, 64, 256, 8)
else DISPATCH_UNIFIED_ATTENTION(DType::fp16, true, 64, 256, 8)
} else if (args.data_type == DType::bf16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION(DType::bf16, false, 64, 256, 8)
else DISPATCH_UNIFIED_ATTENTION(DType::bf16, true, 64, 256, 8)
}
return {false, -1.f};
}
std::pair<bool, float> dispatch_decode_d64_gqa8_m128(
const unified_attention_args& args, const stream_config& config)
{
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
if (args.data_type == DType::fp16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::fp16, false, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::fp16, true, 64, 128, 8)
} else if (args.data_type == DType::bf16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::bf16, false, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::bf16, true, 64, 128, 8)
}
return {false, -1.f};
}
std::pair<bool, float> dispatch_decode_d64_gqa8_m64(
const unified_attention_args& args, const stream_config& config)
{
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
if (args.data_type == DType::fp16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType::fp16, false, 64, 64, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType::fp16, true, 64, 64, 8)
} else if (args.data_type == DType::bf16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType::bf16, false, 64, 64, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType::bf16, true, 64, 64, 8)
}
return {false, -1.f};
}
std::pair<bool, float> dispatch_decode_d64_gqa8_m16(
const unified_attention_args& args, const stream_config& config)
{
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
if (args.data_type == DType::fp16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType::fp16, false, 64, 16, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType::fp16, true, 64, 16, 8)
} else if (args.data_type == DType::bf16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType::bf16, false, 64, 16, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType::bf16, true, 64, 16, 8)
if(args.data_type == DT::bf16)
{
if(is_mask)
return unified_attention_kernel_dispatch<
unified_attention_kernel_traits<V, DT::bf16, true>>(args, config);
return unified_attention_kernel_dispatch<
unified_attention_kernel_traits<V, DT::bf16, false>>(args, config);
}
return {false, -1.f};
}
@@ -256,30 +157,31 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
{
const auto cfg = select_config(args);
if (cfg.unsupported)
if(cfg.unsupported)
{
std::cerr << "unified_attention: no matching kernel instance for hdim=" << args.hdim
<< " num_queries_per_kv=" << args.num_queries_per_kv
<< " data_type=" << args.data_type
<< " mask_type=" << args.mask_type << std::endl;
<< " data_type=" << args.data_type << " mask_type=" << args.mask_type
<< std::endl;
return std::make_pair(false, -1.f);
}
switch (cfg.variant)
switch(cfg.variant)
{
case KernelVariant::prefill_d128_mha: return dispatch_prefill_d128_mha(args, config);
case KernelVariant::decode_d128_mha_m128: return dispatch_decode_d128_mha_m128(args, config);
case KernelVariant::prefill_d64_gqa8: return dispatch_prefill_d64_gqa8(args, config);
case KernelVariant::decode_d64_gqa8_m128: return dispatch_decode_d64_gqa8_m128(args, config);
case KernelVariant::decode_d64_gqa8_m64: return dispatch_decode_d64_gqa8_m64(args, config);
case KernelVariant::decode_d64_gqa8_m16: return dispatch_decode_d64_gqa8_m16(args, config);
case KernelVariant::prefill_d128_mha:
return dispatch_variant<KernelVariant::prefill_d128_mha>(args, config);
case KernelVariant::decode_d128_mha_m128:
return dispatch_variant<KernelVariant::decode_d128_mha_m128>(args, config);
case KernelVariant::prefill_d64_gqa8:
return dispatch_variant<KernelVariant::prefill_d64_gqa8>(args, config);
case KernelVariant::decode_d64_gqa8_m128:
return dispatch_variant<KernelVariant::decode_d64_gqa8_m128>(args, config);
case KernelVariant::decode_d64_gqa8_m64:
return dispatch_variant<KernelVariant::decode_d64_gqa8_m64>(args, config);
case KernelVariant::decode_d64_gqa8_m16:
return dispatch_variant<KernelVariant::decode_d64_gqa8_m16>(args, config);
}
return std::make_pair(false, -1.f);
}
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_TINY
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM
#undef DISPATCH_UNIFIED_ATTENTION
} // namespace ck_tile

View File

@@ -20,26 +20,36 @@
#include "unified_attention.hpp"
#include "mask.hpp"
#define INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) \
template <> \
std::pair<bool, float> unified_attention_kernel_dispatch<kernel_traits>( \
const unified_attention_args& args, const stream_config& config) \
{ \
return std::make_pair( \
true, unified_attention_kernel_launch<kernel_traits::kernel>(args, config)); \
}
#define INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) \
template <> \
std::pair<bool, float> unified_attention_kernel_dispatch_decode<kernel_traits>( \
const unified_attention_args& args, const stream_config& config) \
{ \
return std::make_pair( \
true, unified_attention_kernel_launch<kernel_traits::kernel, true>(args, config)); \
}
namespace ck_tile {
// =============================================================================
// KernelVariant
//
// Flat enum of every compiled kernel instance. Each variant fixes
// (kBlockM, warp count, MFMA shape, pipeline policy) via a variant_config<V>
// specialization below. This is the single source of truth for "what knobs
// differ between kernel instances".
//
// page_size is intentionally NOT part of this enum. The multi-page-tile fix
// in the pipeline decoupled kBlockN from page_blk_size, so every variant is
// correct for any page size.
// =============================================================================
enum class KernelVariant
{
// d=128 MHA (num_queries_per_kv = 1)
prefill_d128_mha, // kBlockM=256, 8 warps, 32x32 mfma
decode_d128_mha_m128, // kBlockM=128, 4 warps, 32x32 mfma (kBlockQ=128)
// d=64 GQA-8 (num_queries_per_kv = 8)
prefill_d64_gqa8, // kBlockM=256, 8 warps, 32x32 mfma
decode_d64_gqa8_m128, // kBlockM=128, 4 warps, 32x32 mfma
decode_d64_gqa8_m64, // kBlockM=64, 2 warps, 32x32 mfma (decode policy)
decode_d64_gqa8_m16, // kBlockM=16, 1 warp, 16x16 mfma (tiny-decode policy)
};
// -----------------------------------------------------------------------------
// Per-DataType problem element types.
// -----------------------------------------------------------------------------
template <unified_attention_args::data_type_enum DataType>
struct unified_attention_problem_traits;
@@ -61,261 +71,182 @@ struct unified_attention_problem_traits<unified_attention_args::data_type_enum::
using lse_dtype = float;
};
// Parameterized kernel traits: DataType, IsMasking, HeadSize, BlockM, NumQueriesPerKV, BlockSize
template <unified_attention_args::data_type_enum DataType,
bool IsMasking,
index_t HeadSize_ = 128,
index_t BlockM_ = 256,
index_t NumQPerKV_ = 1,
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32>
// =============================================================================
// variant_config<V>
//
// One specialization per KernelVariant. Each exposes the static knobs that
// distinguish that variant from the others:
//
// HeadSize : head dimension (compile-time)
// BlockM : Q-tile size along the M (token) axis
// NumQPerKV : 1 for MHA, 8 for GQA-8
// BlockSize : kBlockN — KV-tile size along the N axis
// BlockWarps : warp layout, sequence<M, N, K>
// WarpGemmShape : MFMA tile shape, sequence<M, N, K>
// Pipeline<P> : pipeline template (default vs decode vs tiny-decode policy)
// kUseDecodeGrid : selects 2D-by-seq grid (true) vs Q-block grid (false)
// =============================================================================
template <KernelVariant V>
struct variant_config;
template <>
struct variant_config<KernelVariant::prefill_d128_mha>
{
static constexpr index_t HeadSize = 128;
static constexpr index_t BlockM = 256;
static constexpr index_t NumQPerKV = 1;
static constexpr index_t BlockSize = 32;
using BlockWarps = sequence<8, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
template <typename Problem>
using Pipeline = UnifiedAttentionPipeline<Problem>;
static constexpr bool kUseDecodeGrid = false;
};
template <>
struct variant_config<KernelVariant::decode_d128_mha_m128>
{
static constexpr index_t HeadSize = 128;
static constexpr index_t BlockM = 128;
static constexpr index_t NumQPerKV = 1;
static constexpr index_t BlockSize = 32;
using BlockWarps = sequence<4, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
template <typename Problem>
using Pipeline = UnifiedAttentionPipeline<Problem>;
static constexpr bool kUseDecodeGrid = false;
};
template <>
struct variant_config<KernelVariant::prefill_d64_gqa8>
{
static constexpr index_t HeadSize = 64;
static constexpr index_t BlockM = 256;
static constexpr index_t NumQPerKV = 8;
static constexpr index_t BlockSize = 64;
using BlockWarps = sequence<8, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
template <typename Problem>
using Pipeline = UnifiedAttentionPipeline<Problem>;
static constexpr bool kUseDecodeGrid = false;
};
template <>
struct variant_config<KernelVariant::decode_d64_gqa8_m128>
{
static constexpr index_t HeadSize = 64;
static constexpr index_t BlockM = 128;
static constexpr index_t NumQPerKV = 8;
static constexpr index_t BlockSize = 64;
using BlockWarps = sequence<4, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
template <typename Problem>
using Pipeline = UnifiedAttentionPipeline<Problem>;
static constexpr bool kUseDecodeGrid = false;
};
template <>
struct variant_config<KernelVariant::decode_d64_gqa8_m64>
{
static constexpr index_t HeadSize = 64;
static constexpr index_t BlockM = 64;
static constexpr index_t NumQPerKV = 8;
static constexpr index_t BlockSize = 64;
using BlockWarps = sequence<2, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
template <typename Problem>
using Pipeline = UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineDecodePolicy>;
static constexpr bool kUseDecodeGrid = true;
};
template <>
struct variant_config<KernelVariant::decode_d64_gqa8_m16>
{
static constexpr index_t HeadSize = 64;
static constexpr index_t BlockM = 16;
static constexpr index_t NumQPerKV = 8;
static constexpr index_t BlockSize = 64;
using BlockWarps = sequence<1, 1, 1>;
using WarpGemmShape = sequence<16, 16, 32>;
template <typename Problem>
using Pipeline = UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineTinyDecodePolicy>;
static constexpr bool kUseDecodeGrid = true;
};
// =============================================================================
// unified_attention_kernel_traits<V, DataType, IsMasking>
//
// Single templated trait. Pulls per-variant knobs from variant_config<V> and
// per-dtype element types from unified_attention_problem_traits<DataType>.
// =============================================================================
template <KernelVariant V,
unified_attention_args::data_type_enum DataType,
bool IsMasking>
struct unified_attention_kernel_traits
{
static constexpr auto date_type = DataType;
static constexpr bool is_masking = IsMasking;
using cfg = variant_config<V>;
using dt = unified_attention_problem_traits<DataType>;
static constexpr index_t kBlockM = BlockM_;
static constexpr index_t HEAD_SIZE = HeadSize_;
static constexpr index_t BLOCK_SIZE = BlockSize_;
static constexpr auto date_type = DataType;
static constexpr bool is_masking = IsMasking;
static constexpr KernelVariant variant = V;
static constexpr index_t num_queries_per_kv = NumQPerKV_;
static constexpr index_t HEAD_SIZE = cfg::HeadSize;
static constexpr index_t kBlockM = cfg::BlockM;
static constexpr index_t BLOCK_SIZE = cfg::BlockSize;
static constexpr index_t num_queries_per_kv = cfg::NumQPerKV;
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
static constexpr bool kUseDecodeGrid = cfg::kUseDecodeGrid;
// kBlockM kBlockQ BLOCK_SIZE HEAD_SIZE
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
// 8 warps for warp specialization; kBlockM must be 8 * 32 = 256
using unified_attention_block_warps = sequence<8, 1, 1>;
using unified_attention_warp_gemm_shape = typename cfg::WarpGemmShape;
using unified_attention_block_warps = typename cfg::BlockWarps;
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
unified_attention_block_warps,
unified_attention_warp_gemm_shape,
unified_attention_block_warps,
unified_attention_warp_gemm_shape,
true // IsVLayoutRowMajor
>;
true>; // IsVLayoutRowMajor
using unified_attention_traits = TileUnifiedAttentionTraits<true, // kPadSeqLenQ_
false, // kPadHeadDimQ
-1 // kBlockPerCu
>;
-1>; // kBlockPerCu
using unified_attention_mask = GenericAttentionMask<IsMasking, /*IsLocal=*/false>;
using unified_attention_mask = GenericAttentionMask<IsMasking, /*IsLocal=*/false>;
using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem<
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::lse_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::o_dtype,
unified_attention_shape,
unified_attention_mask,
unified_attention_traits>;
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
using epilogue = Default2DEpilogue<
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::o_dtype,
true, // kPadM
true, // kPadM
true // UseRawStore
>>;
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
};
// Decode-tuned traits: 4 warps (1 warp group), kBlockM=128, serial pipeline.
// Uses the single-warp-group path in UnifiedAttentionPipeline.
template <unified_attention_args::data_type_enum DataType,
bool IsMasking,
index_t HeadSize_ = 128,
index_t BlockM_ = 128,
index_t NumQPerKV_ = 1,
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32>
struct unified_attention_decode_kernel_traits
{
static constexpr auto date_type = DataType;
static constexpr bool is_masking = IsMasking;
static constexpr index_t kBlockM = BlockM_;
static constexpr index_t HEAD_SIZE = HeadSize_;
static constexpr index_t BLOCK_SIZE = BlockSize_;
static constexpr index_t num_queries_per_kv = NumQPerKV_;
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
// kBlockM kBlockQ BLOCK_SIZE HEAD_SIZE
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
// 4 warps -> kBlockSize = 256 threads -> NumWarpGroups = 1
using unified_attention_block_warps = sequence<4, 1, 1>;
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
unified_attention_block_warps,
unified_attention_warp_gemm_shape,
unified_attention_block_warps,
unified_attention_warp_gemm_shape,
true>;
using unified_attention_traits = TileUnifiedAttentionTraits<true, false, -1>;
using unified_attention_mask = GenericAttentionMask<IsMasking, false>;
using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem<
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::lse_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::o_dtype,
unified_attention_shape,
unified_attention_mask,
unified_attention_traits>;
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
using epilogue = Default2DEpilogue<
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::o_dtype,
true, true, true>>;
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
};
// Small decode traits: 2 warps, kBlockM=64, decode policy (NumWarpPerGroup=2).
// Uses 1D warp layout (sequence<2,1,1>) so no softmax reduction changes needed.
template <unified_attention_args::data_type_enum DataType,
bool IsMasking,
index_t HeadSize_ = 64,
index_t BlockM_ = 64,
index_t NumQPerKV_ = 8,
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32>
struct unified_attention_decode_small_kernel_traits
{
static constexpr auto date_type = DataType;
static constexpr bool is_masking = IsMasking;
static constexpr index_t kBlockM = BlockM_;
static constexpr index_t HEAD_SIZE = HeadSize_;
static constexpr index_t BLOCK_SIZE = BlockSize_;
static constexpr index_t num_queries_per_kv = NumQPerKV_;
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
// 2 warps along M: kBlockM=2*32=64, kBlockSize=128, NumWarpGroups=1
using unified_attention_block_warps = sequence<2, 1, 1>;
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
unified_attention_block_warps,
unified_attention_warp_gemm_shape,
unified_attention_block_warps,
unified_attention_warp_gemm_shape,
true>;
using unified_attention_traits = TileUnifiedAttentionTraits<true, false, -1>;
using unified_attention_mask = GenericAttentionMask<IsMasking, false>;
using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem<
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::lse_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::o_dtype,
unified_attention_shape,
unified_attention_mask,
unified_attention_traits>;
using unified_attention_pipeline_problem =
UnifiedAttentionPipelineProblem<typename dt::qkvp_dtype,
typename dt::qkvp_dtype,
typename dt::qkvp_dtype,
typename dt::acc_dtype,
typename dt::acc_dtype,
typename dt::acc_dtype,
typename dt::lse_dtype,
typename dt::qkvp_dtype,
typename dt::acc_dtype,
typename dt::o_dtype,
unified_attention_shape,
unified_attention_mask,
unified_attention_traits>;
using unified_attention_pipeline =
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
UnifiedAttentionPipelineDecodePolicy>;
typename cfg::template Pipeline<unified_attention_pipeline_problem>;
using epilogue = Default2DEpilogue<
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::o_dtype,
true, true, true>>;
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
};
// Tiny decode traits: 1 warp, 16x16 MFMA, kBlockM=16, kBlockQ=2 for GQA-8.
// Matches Triton's BLOCK_M=16 / BLOCK_Q=2 decode configuration.
// Uses block_tile_reduce_sync instead of permlane32_swap for 16x16 MFMA.
template <unified_attention_args::data_type_enum DataType,
bool IsMasking,
index_t HeadSize_ = 64,
index_t BlockM_ = 16,
index_t NumQPerKV_ = 8,
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32>
struct unified_attention_decode_tiny_kernel_traits
{
static constexpr auto date_type = DataType;
static constexpr bool is_masking = IsMasking;
static constexpr index_t kBlockM = BlockM_;
static constexpr index_t HEAD_SIZE = HeadSize_;
static constexpr index_t BLOCK_SIZE = BlockSize_;
static constexpr index_t num_queries_per_kv = NumQPerKV_;
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
using unified_attention_warp_gemm_shape = sequence<16, 16, 32>;
// 1 warp: kBlockM=1*16=16, kBlockSize=64, NumWarpGroups=1
using unified_attention_block_warps = sequence<1, 1, 1>;
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
unified_attention_block_warps,
unified_attention_warp_gemm_shape,
unified_attention_block_warps,
unified_attention_warp_gemm_shape,
true>;
using unified_attention_traits = TileUnifiedAttentionTraits<true, false, -1>;
using unified_attention_mask = GenericAttentionMask<IsMasking, false>;
using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem<
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::lse_dtype,
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::o_dtype,
unified_attention_shape,
unified_attention_mask,
unified_attention_traits>;
using unified_attention_pipeline =
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
UnifiedAttentionPipelineTinyDecodePolicy>;
using epilogue = Default2DEpilogue<
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
typename unified_attention_problem_traits<date_type>::o_dtype,
true, true, true>>;
using epilogue =
Default2DEpilogue<Default2DEpilogueProblem<typename dt::acc_dtype,
typename dt::o_dtype,
true, // kPadM
true, // kPadN
true // UseRawStore
>>;
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
};
// =============================================================================
// Kernel launch — common helper. Picks the grid layout from
// Traits::kUseDecodeGrid; all other launch args are identical across variants.
// =============================================================================
template <typename Kernel, bool UseDecodeGrid = false>
float unified_attention_kernel_launch(const unified_attention_args& args,
const stream_config& config)
@@ -380,15 +311,33 @@ float unified_attention_kernel_launch(const unified_attention_args& args,
return launch_kernel(config, make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
// return value:
// first = whether the kernel was launched (true = launched, false = skipped)
// second = elapsed time (ms) of the kernel launch, valid only if first == true
template <typename KernelTraits>
// =============================================================================
// Per-instance dispatch. Each instance .cpp specializes this for its
// (V, DataType, IsMasking) tuple via INST_UNIFIED_ATTENTION_DISPATCH.
//
// Return: (launched?, elapsed_ms). elapsed_ms is valid only when launched.
// =============================================================================
template <typename Traits>
std::pair<bool, float> unified_attention_kernel_dispatch(const unified_attention_args& args,
const stream_config& config);
template <typename KernelTraits>
std::pair<bool, float> unified_attention_kernel_dispatch_decode(const unified_attention_args& args,
const stream_config& config);
} // namespace ck_tile
// One-line instantiation per (V, DataType, IsMasking) combination. Each
// instance .cpp consists of exactly one of these calls.
#define INST_UNIFIED_ATTENTION_DISPATCH(VARIANT_, DTYPE_, IS_MASK_) \
template <> \
std::pair<bool, float> unified_attention_kernel_dispatch< \
unified_attention_kernel_traits<KernelVariant::VARIANT_, \
unified_attention_args::data_type_enum::DTYPE_, \
IS_MASK_>>(const unified_attention_args& args, \
const stream_config& config) \
{ \
using Traits = unified_attention_kernel_traits< \
KernelVariant::VARIANT_, \
unified_attention_args::data_type_enum::DTYPE_, \
IS_MASK_>; \
return std::make_pair(true, \
unified_attention_kernel_launch<typename Traits::kernel, \
Traits::kUseDecodeGrid>(args, config)); \
}