mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Collapse CK-UA traits into single kernel_traits<V, DType, IsMask> template
Replace 4 near-identical *_kernel_traits classes (~400 lines of repeated shape/policy plumbing) with one templated `unified_attention_kernel_traits` parameterized by `KernelVariant V`. The 6 dispatch_<variant> helpers in unified_attention.cpp collapse into a single `dispatch_variant<V>` function template that fans out over (dtype, mask). Per-variant compile-time knobs (BlockM, BlockSize, warp count, MFMA shape, pipeline policy, decode-grid flag) now live in one variant_config<V> specialization each. "What's different between variants" is readable top-to-bottom in a single block of code, and each instance .cpp shrinks to a one-line `INST_UNIFIED_ATTENTION_DISPATCH(V, DTYPE, IS_MASK)` macro. No behavior change. Correctness suite: 236/240 (same 4 known num_blocks=32768 + d=128 MHA int32-overflow failures as baseline). Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, true>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 128, 128, 1>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, false>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 128, 128, 1>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, true>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 128, 128, 1>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, false>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 128, 128, 1>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 256, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 128, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 64, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 16, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 256, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 128, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 64, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 16, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 256, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 128, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 64, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 16, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 256, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 128, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 64, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 16, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -23,50 +23,40 @@ std::ostream& operator<<(std::ostream& stream,
|
||||
//
|
||||
// The job is split in two halves so each is small enough to read in one sitting:
|
||||
//
|
||||
// 1. KernelVariant + select_config(args)
|
||||
// - KernelVariant is a flat enum of every compiled kernel instance the
|
||||
// module knows about. Each entry fixes the static knobs (kBlockM,
|
||||
// warp count, MFMA shape, pipeline policy).
|
||||
// - select_config() is the ONLY place where shape-based runtime
|
||||
// decisions live. It reads (hdim, num_queries_per_kv, avg_q,
|
||||
// max_seqlen_q) and emits a KernelConfig.
|
||||
// 1. select_config(args)
|
||||
// - Reads shape (hdim, num_queries_per_kv, avg_q, max_seqlen_q) and
|
||||
// picks one of the KernelVariants defined in unified_attention_impl.hpp.
|
||||
// KernelVariant is the only place where compile-time knobs live —
|
||||
// changing a knob means adding/editing a variant_config<V>.
|
||||
//
|
||||
// 2. dispatch_<variant>() helpers + the final switch
|
||||
// - Each KernelVariant has a tiny helper that fans out over the
|
||||
// (dtype, mask) cross-product and calls into the existing
|
||||
// DISPATCH_UNIFIED_ATTENTION_* macros. The macros and the
|
||||
// per-variant traits classes are unchanged from before; only the
|
||||
// selection logic moved.
|
||||
// 2. dispatch_variant<V>() + the final switch
|
||||
// - dispatch_variant<V>() is a single function template that fans out
|
||||
// over (dtype, mask) and forwards into the per-instance dispatch
|
||||
// function generated by INST_UNIFIED_ATTENTION_DISPATCH.
|
||||
// - The final switch maps KernelVariant -> dispatch_variant<V>.
|
||||
//
|
||||
// page_size is intentionally NOT part of this enum. The multi-page-tile
|
||||
// fix in the pipeline made the compile-time tile-N (kBlockN) independent
|
||||
// of the runtime page_blk_size, so every variant is correct for any page
|
||||
// size. Selection is driven purely by Q-tile shape.
|
||||
// page_size is intentionally NOT part of the config — the multi-page-tile
|
||||
// pipeline fix made kBlockN independent of runtime page_blk_size, so every
|
||||
// variant is correct for any page size.
|
||||
// =============================================================================
|
||||
|
||||
enum class KernelVariant {
|
||||
// d=128 MHA (num_queries_per_kv = 1)
|
||||
prefill_d128_mha, // kBlockM=256, 8 warps, 32x32 mfma
|
||||
decode_d128_mha_m128, // kBlockM=128, 4 warps, 32x32 mfma (kBlockQ=128)
|
||||
|
||||
// d=64 GQA-8 (num_queries_per_kv = 8)
|
||||
prefill_d64_gqa8, // kBlockM=256, 8 warps, 32x32 mfma
|
||||
decode_d64_gqa8_m128, // kBlockM=128, 4 warps, 32x32 mfma
|
||||
decode_d64_gqa8_m64, // kBlockM=64, 2 warps, 32x32 mfma
|
||||
decode_d64_gqa8_m16, // kBlockM=16, 1 warp, 16x16 mfma
|
||||
};
|
||||
|
||||
struct KernelConfig {
|
||||
struct KernelConfig
|
||||
{
|
||||
KernelVariant variant;
|
||||
bool unsupported = false;
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
// Internal tier classification — used only by select_config. The tier name is
|
||||
// just shorthand for a kBlockM choice; with num_queries_per_kv=8 the tiers
|
||||
// correspond to kBlockQ thresholds {2, 8, 16}.
|
||||
enum class tile_tier { medium, small, tiny };
|
||||
// Internal tile-tier classification — used only by select_config. The tier
|
||||
// name is shorthand for a kBlockM choice; with num_queries_per_kv=8 the
|
||||
// tiers correspond to kBlockQ thresholds {2, 8, 16}.
|
||||
enum class tile_tier
|
||||
{
|
||||
medium,
|
||||
small,
|
||||
tiny
|
||||
};
|
||||
|
||||
tile_tier select_tile_tier(const unified_attention_args& args)
|
||||
{
|
||||
@@ -80,8 +70,8 @@ tile_tier select_tile_tier(const unified_attention_args& args)
|
||||
// many more tokens, fall back to the medium tier (1D grid with Q iteration).
|
||||
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
|
||||
|
||||
if (avg_q <= kBlockQ_tiny && max_q <= kBlockQ_tiny) return tile_tier::tiny;
|
||||
if (avg_q <= kBlockQ_small && max_q <= kBlockQ_small) return tile_tier::small;
|
||||
if(avg_q <= kBlockQ_tiny && max_q <= kBlockQ_tiny) return tile_tier::tiny;
|
||||
if(avg_q <= kBlockQ_small && max_q <= kBlockQ_small) return tile_tier::small;
|
||||
return tile_tier::medium;
|
||||
}
|
||||
|
||||
@@ -96,13 +86,13 @@ KernelConfig select_config(const unified_attention_args& args)
|
||||
// both fit in 128. Cuts Q-tile waste roughly 2x vs prefill for
|
||||
// short-Q workloads.
|
||||
// * prefill_d128_mha : kBlockM=256, 8 warps. Everything else.
|
||||
if (args.hdim == 128 && args.num_queries_per_kv == 1)
|
||||
if(args.hdim == 128 && args.num_queries_per_kv == 1)
|
||||
{
|
||||
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs
|
||||
: args.num_tokens;
|
||||
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
|
||||
|
||||
if (avg_q <= 128 && max_q <= 128)
|
||||
if(avg_q <= 128 && max_q <= 128)
|
||||
cfg.variant = KernelVariant::decode_d128_mha_m128;
|
||||
else
|
||||
cfg.variant = KernelVariant::prefill_d128_mha;
|
||||
@@ -110,9 +100,9 @@ KernelConfig select_config(const unified_attention_args& args)
|
||||
}
|
||||
|
||||
// d=64 GQA-8 — pure tile-tier ladder. page_size has no influence here.
|
||||
if (args.hdim == 64 && args.num_queries_per_kv == 8)
|
||||
if(args.hdim == 64 && args.num_queries_per_kv == 8)
|
||||
{
|
||||
switch (select_tile_tier(args))
|
||||
switch(select_tile_tier(args))
|
||||
{
|
||||
case tile_tier::tiny: cfg.variant = KernelVariant::decode_d64_gqa8_m16; break;
|
||||
case tile_tier::small: cfg.variant = KernelVariant::decode_d64_gqa8_m64; break;
|
||||
@@ -126,125 +116,36 @@ KernelConfig select_config(const unified_attention_args& args)
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Dispatch macros and per-variant dispatch helpers.
|
||||
// dispatch_variant<V>
|
||||
//
|
||||
// Each DISPATCH_* macro instantiates one (traits, dtype, mask, ...) combo and
|
||||
// returns. The per-variant helpers below pick the right macro family and fan
|
||||
// out over (dtype, mask). They look repetitive on purpose: a follow-up commit
|
||||
// will collapse the 4 traits classes into one templated `kernel_traits<V>`,
|
||||
// at which point these helpers become one-liners.
|
||||
// One function template. Fans out over (dtype, mask) and forwards into the
|
||||
// per-instance dispatch generated by INST_UNIFIED_ATTENTION_DISPATCH. No
|
||||
// per-variant boilerplate.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// Helper macro: dispatches based on DataType, IsMasking, HeadSize, BlockM, NumQPerKV.
|
||||
#define DISPATCH_UNIFIED_ATTENTION(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
// Dispatch macros for three tile tiers (default block_size).
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_tiny_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
using DType = unified_attention_args::data_type_enum;
|
||||
|
||||
std::pair<bool, float> dispatch_prefill_d128_mha(
|
||||
const unified_attention_args& args, const stream_config& config)
|
||||
template <KernelVariant V>
|
||||
std::pair<bool, float> dispatch_variant(const unified_attention_args& args,
|
||||
const stream_config& config)
|
||||
{
|
||||
using DT = unified_attention_args::data_type_enum;
|
||||
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
|
||||
if (args.data_type == DType::fp16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION(DType::fp16, false, 128, 256, 1)
|
||||
else DISPATCH_UNIFIED_ATTENTION(DType::fp16, true, 128, 256, 1)
|
||||
} else if (args.data_type == DType::bf16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION(DType::bf16, false, 128, 256, 1)
|
||||
else DISPATCH_UNIFIED_ATTENTION(DType::bf16, true, 128, 256, 1)
|
||||
|
||||
if(args.data_type == DT::fp16)
|
||||
{
|
||||
if(is_mask)
|
||||
return unified_attention_kernel_dispatch<
|
||||
unified_attention_kernel_traits<V, DT::fp16, true>>(args, config);
|
||||
return unified_attention_kernel_dispatch<
|
||||
unified_attention_kernel_traits<V, DT::fp16, false>>(args, config);
|
||||
}
|
||||
return {false, -1.f};
|
||||
}
|
||||
|
||||
std::pair<bool, float> dispatch_decode_d128_mha_m128(
|
||||
const unified_attention_args& args, const stream_config& config)
|
||||
{
|
||||
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
|
||||
if (args.data_type == DType::fp16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::fp16, false, 128, 128, 1)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::fp16, true, 128, 128, 1)
|
||||
} else if (args.data_type == DType::bf16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::bf16, false, 128, 128, 1)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::bf16, true, 128, 128, 1)
|
||||
}
|
||||
return {false, -1.f};
|
||||
}
|
||||
|
||||
std::pair<bool, float> dispatch_prefill_d64_gqa8(
|
||||
const unified_attention_args& args, const stream_config& config)
|
||||
{
|
||||
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
|
||||
if (args.data_type == DType::fp16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION(DType::fp16, false, 64, 256, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION(DType::fp16, true, 64, 256, 8)
|
||||
} else if (args.data_type == DType::bf16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION(DType::bf16, false, 64, 256, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION(DType::bf16, true, 64, 256, 8)
|
||||
}
|
||||
return {false, -1.f};
|
||||
}
|
||||
|
||||
std::pair<bool, float> dispatch_decode_d64_gqa8_m128(
|
||||
const unified_attention_args& args, const stream_config& config)
|
||||
{
|
||||
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
|
||||
if (args.data_type == DType::fp16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::fp16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::fp16, true, 64, 128, 8)
|
||||
} else if (args.data_type == DType::bf16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::bf16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::bf16, true, 64, 128, 8)
|
||||
}
|
||||
return {false, -1.f};
|
||||
}
|
||||
|
||||
std::pair<bool, float> dispatch_decode_d64_gqa8_m64(
|
||||
const unified_attention_args& args, const stream_config& config)
|
||||
{
|
||||
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
|
||||
if (args.data_type == DType::fp16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType::fp16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType::fp16, true, 64, 64, 8)
|
||||
} else if (args.data_type == DType::bf16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType::bf16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType::bf16, true, 64, 64, 8)
|
||||
}
|
||||
return {false, -1.f};
|
||||
}
|
||||
|
||||
std::pair<bool, float> dispatch_decode_d64_gqa8_m16(
|
||||
const unified_attention_args& args, const stream_config& config)
|
||||
{
|
||||
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
|
||||
if (args.data_type == DType::fp16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType::fp16, false, 64, 16, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType::fp16, true, 64, 16, 8)
|
||||
} else if (args.data_type == DType::bf16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType::bf16, false, 64, 16, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType::bf16, true, 64, 16, 8)
|
||||
if(args.data_type == DT::bf16)
|
||||
{
|
||||
if(is_mask)
|
||||
return unified_attention_kernel_dispatch<
|
||||
unified_attention_kernel_traits<V, DT::bf16, true>>(args, config);
|
||||
return unified_attention_kernel_dispatch<
|
||||
unified_attention_kernel_traits<V, DT::bf16, false>>(args, config);
|
||||
}
|
||||
return {false, -1.f};
|
||||
}
|
||||
@@ -256,30 +157,31 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
{
|
||||
const auto cfg = select_config(args);
|
||||
|
||||
if (cfg.unsupported)
|
||||
if(cfg.unsupported)
|
||||
{
|
||||
std::cerr << "unified_attention: no matching kernel instance for hdim=" << args.hdim
|
||||
<< " num_queries_per_kv=" << args.num_queries_per_kv
|
||||
<< " data_type=" << args.data_type
|
||||
<< " mask_type=" << args.mask_type << std::endl;
|
||||
<< " data_type=" << args.data_type << " mask_type=" << args.mask_type
|
||||
<< std::endl;
|
||||
return std::make_pair(false, -1.f);
|
||||
}
|
||||
|
||||
switch (cfg.variant)
|
||||
switch(cfg.variant)
|
||||
{
|
||||
case KernelVariant::prefill_d128_mha: return dispatch_prefill_d128_mha(args, config);
|
||||
case KernelVariant::decode_d128_mha_m128: return dispatch_decode_d128_mha_m128(args, config);
|
||||
case KernelVariant::prefill_d64_gqa8: return dispatch_prefill_d64_gqa8(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m128: return dispatch_decode_d64_gqa8_m128(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m64: return dispatch_decode_d64_gqa8_m64(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m16: return dispatch_decode_d64_gqa8_m16(args, config);
|
||||
case KernelVariant::prefill_d128_mha:
|
||||
return dispatch_variant<KernelVariant::prefill_d128_mha>(args, config);
|
||||
case KernelVariant::decode_d128_mha_m128:
|
||||
return dispatch_variant<KernelVariant::decode_d128_mha_m128>(args, config);
|
||||
case KernelVariant::prefill_d64_gqa8:
|
||||
return dispatch_variant<KernelVariant::prefill_d64_gqa8>(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m128:
|
||||
return dispatch_variant<KernelVariant::decode_d64_gqa8_m128>(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m64:
|
||||
return dispatch_variant<KernelVariant::decode_d64_gqa8_m64>(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m16:
|
||||
return dispatch_variant<KernelVariant::decode_d64_gqa8_m16>(args, config);
|
||||
}
|
||||
return std::make_pair(false, -1.f);
|
||||
}
|
||||
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_TINY
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM
|
||||
#undef DISPATCH_UNIFIED_ATTENTION
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -20,26 +20,36 @@
|
||||
#include "unified_attention.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) \
|
||||
template <> \
|
||||
std::pair<bool, float> unified_attention_kernel_dispatch<kernel_traits>( \
|
||||
const unified_attention_args& args, const stream_config& config) \
|
||||
{ \
|
||||
return std::make_pair( \
|
||||
true, unified_attention_kernel_launch<kernel_traits::kernel>(args, config)); \
|
||||
}
|
||||
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) \
|
||||
template <> \
|
||||
std::pair<bool, float> unified_attention_kernel_dispatch_decode<kernel_traits>( \
|
||||
const unified_attention_args& args, const stream_config& config) \
|
||||
{ \
|
||||
return std::make_pair( \
|
||||
true, unified_attention_kernel_launch<kernel_traits::kernel, true>(args, config)); \
|
||||
}
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// =============================================================================
|
||||
// KernelVariant
|
||||
//
|
||||
// Flat enum of every compiled kernel instance. Each variant fixes
|
||||
// (kBlockM, warp count, MFMA shape, pipeline policy) via a variant_config<V>
|
||||
// specialization below. This is the single source of truth for "what knobs
|
||||
// differ between kernel instances".
|
||||
//
|
||||
// page_size is intentionally NOT part of this enum. The multi-page-tile fix
|
||||
// in the pipeline decoupled kBlockN from page_blk_size, so every variant is
|
||||
// correct for any page size.
|
||||
// =============================================================================
|
||||
enum class KernelVariant
|
||||
{
|
||||
// d=128 MHA (num_queries_per_kv = 1)
|
||||
prefill_d128_mha, // kBlockM=256, 8 warps, 32x32 mfma
|
||||
decode_d128_mha_m128, // kBlockM=128, 4 warps, 32x32 mfma (kBlockQ=128)
|
||||
|
||||
// d=64 GQA-8 (num_queries_per_kv = 8)
|
||||
prefill_d64_gqa8, // kBlockM=256, 8 warps, 32x32 mfma
|
||||
decode_d64_gqa8_m128, // kBlockM=128, 4 warps, 32x32 mfma
|
||||
decode_d64_gqa8_m64, // kBlockM=64, 2 warps, 32x32 mfma (decode policy)
|
||||
decode_d64_gqa8_m16, // kBlockM=16, 1 warp, 16x16 mfma (tiny-decode policy)
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Per-DataType problem element types.
|
||||
// -----------------------------------------------------------------------------
|
||||
template <unified_attention_args::data_type_enum DataType>
|
||||
struct unified_attention_problem_traits;
|
||||
|
||||
@@ -61,261 +71,182 @@ struct unified_attention_problem_traits<unified_attention_args::data_type_enum::
|
||||
using lse_dtype = float;
|
||||
};
|
||||
|
||||
// Parameterized kernel traits: DataType, IsMasking, HeadSize, BlockM, NumQueriesPerKV, BlockSize
|
||||
template <unified_attention_args::data_type_enum DataType,
|
||||
bool IsMasking,
|
||||
index_t HeadSize_ = 128,
|
||||
index_t BlockM_ = 256,
|
||||
index_t NumQPerKV_ = 1,
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32>
|
||||
// =============================================================================
|
||||
// variant_config<V>
|
||||
//
|
||||
// One specialization per KernelVariant. Each exposes the static knobs that
|
||||
// distinguish that variant from the others:
|
||||
//
|
||||
// HeadSize : head dimension (compile-time)
|
||||
// BlockM : Q-tile size along the M (token) axis
|
||||
// NumQPerKV : 1 for MHA, 8 for GQA-8
|
||||
// BlockSize : kBlockN — KV-tile size along the N axis
|
||||
// BlockWarps : warp layout, sequence<M, N, K>
|
||||
// WarpGemmShape : MFMA tile shape, sequence<M, N, K>
|
||||
// Pipeline<P> : pipeline template (default vs decode vs tiny-decode policy)
|
||||
// kUseDecodeGrid : selects 2D-by-seq grid (true) vs Q-block grid (false)
|
||||
// =============================================================================
|
||||
template <KernelVariant V>
|
||||
struct variant_config;
|
||||
|
||||
template <>
|
||||
struct variant_config<KernelVariant::prefill_d128_mha>
|
||||
{
|
||||
static constexpr index_t HeadSize = 128;
|
||||
static constexpr index_t BlockM = 256;
|
||||
static constexpr index_t NumQPerKV = 1;
|
||||
static constexpr index_t BlockSize = 32;
|
||||
using BlockWarps = sequence<8, 1, 1>;
|
||||
using WarpGemmShape = sequence<32, 32, 16>;
|
||||
template <typename Problem>
|
||||
using Pipeline = UnifiedAttentionPipeline<Problem>;
|
||||
static constexpr bool kUseDecodeGrid = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct variant_config<KernelVariant::decode_d128_mha_m128>
|
||||
{
|
||||
static constexpr index_t HeadSize = 128;
|
||||
static constexpr index_t BlockM = 128;
|
||||
static constexpr index_t NumQPerKV = 1;
|
||||
static constexpr index_t BlockSize = 32;
|
||||
using BlockWarps = sequence<4, 1, 1>;
|
||||
using WarpGemmShape = sequence<32, 32, 16>;
|
||||
template <typename Problem>
|
||||
using Pipeline = UnifiedAttentionPipeline<Problem>;
|
||||
static constexpr bool kUseDecodeGrid = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct variant_config<KernelVariant::prefill_d64_gqa8>
|
||||
{
|
||||
static constexpr index_t HeadSize = 64;
|
||||
static constexpr index_t BlockM = 256;
|
||||
static constexpr index_t NumQPerKV = 8;
|
||||
static constexpr index_t BlockSize = 64;
|
||||
using BlockWarps = sequence<8, 1, 1>;
|
||||
using WarpGemmShape = sequence<32, 32, 16>;
|
||||
template <typename Problem>
|
||||
using Pipeline = UnifiedAttentionPipeline<Problem>;
|
||||
static constexpr bool kUseDecodeGrid = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct variant_config<KernelVariant::decode_d64_gqa8_m128>
|
||||
{
|
||||
static constexpr index_t HeadSize = 64;
|
||||
static constexpr index_t BlockM = 128;
|
||||
static constexpr index_t NumQPerKV = 8;
|
||||
static constexpr index_t BlockSize = 64;
|
||||
using BlockWarps = sequence<4, 1, 1>;
|
||||
using WarpGemmShape = sequence<32, 32, 16>;
|
||||
template <typename Problem>
|
||||
using Pipeline = UnifiedAttentionPipeline<Problem>;
|
||||
static constexpr bool kUseDecodeGrid = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct variant_config<KernelVariant::decode_d64_gqa8_m64>
|
||||
{
|
||||
static constexpr index_t HeadSize = 64;
|
||||
static constexpr index_t BlockM = 64;
|
||||
static constexpr index_t NumQPerKV = 8;
|
||||
static constexpr index_t BlockSize = 64;
|
||||
using BlockWarps = sequence<2, 1, 1>;
|
||||
using WarpGemmShape = sequence<32, 32, 16>;
|
||||
template <typename Problem>
|
||||
using Pipeline = UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineDecodePolicy>;
|
||||
static constexpr bool kUseDecodeGrid = true;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct variant_config<KernelVariant::decode_d64_gqa8_m16>
|
||||
{
|
||||
static constexpr index_t HeadSize = 64;
|
||||
static constexpr index_t BlockM = 16;
|
||||
static constexpr index_t NumQPerKV = 8;
|
||||
static constexpr index_t BlockSize = 64;
|
||||
using BlockWarps = sequence<1, 1, 1>;
|
||||
using WarpGemmShape = sequence<16, 16, 32>;
|
||||
template <typename Problem>
|
||||
using Pipeline = UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineTinyDecodePolicy>;
|
||||
static constexpr bool kUseDecodeGrid = true;
|
||||
};
|
||||
|
||||
// =============================================================================
|
||||
// unified_attention_kernel_traits<V, DataType, IsMasking>
|
||||
//
|
||||
// Single templated trait. Pulls per-variant knobs from variant_config<V> and
|
||||
// per-dtype element types from unified_attention_problem_traits<DataType>.
|
||||
// =============================================================================
|
||||
template <KernelVariant V,
|
||||
unified_attention_args::data_type_enum DataType,
|
||||
bool IsMasking>
|
||||
struct unified_attention_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
using cfg = variant_config<V>;
|
||||
using dt = unified_attention_problem_traits<DataType>;
|
||||
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
static constexpr KernelVariant variant = V;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t HEAD_SIZE = cfg::HeadSize;
|
||||
static constexpr index_t kBlockM = cfg::BlockM;
|
||||
static constexpr index_t BLOCK_SIZE = cfg::BlockSize;
|
||||
static constexpr index_t num_queries_per_kv = cfg::NumQPerKV;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
static constexpr bool kUseDecodeGrid = cfg::kUseDecodeGrid;
|
||||
|
||||
// kBlockM kBlockQ BLOCK_SIZE HEAD_SIZE
|
||||
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
|
||||
|
||||
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
|
||||
// 8 warps for warp specialization; kBlockM must be 8 * 32 = 256
|
||||
using unified_attention_block_warps = sequence<8, 1, 1>;
|
||||
using unified_attention_warp_gemm_shape = typename cfg::WarpGemmShape;
|
||||
using unified_attention_block_warps = typename cfg::BlockWarps;
|
||||
|
||||
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
true // IsVLayoutRowMajor
|
||||
>;
|
||||
true>; // IsVLayoutRowMajor
|
||||
|
||||
using unified_attention_traits = TileUnifiedAttentionTraits<true, // kPadSeqLenQ_
|
||||
false, // kPadHeadDimQ
|
||||
-1 // kBlockPerCu
|
||||
>;
|
||||
-1>; // kBlockPerCu
|
||||
using unified_attention_mask = GenericAttentionMask<IsMasking, /*IsLocal=*/false>;
|
||||
|
||||
using unified_attention_mask = GenericAttentionMask<IsMasking, /*IsLocal=*/false>;
|
||||
|
||||
using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem<
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::lse_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
|
||||
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
|
||||
|
||||
using epilogue = Default2DEpilogue<
|
||||
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
true, // kPadM
|
||||
true, // kPadM
|
||||
true // UseRawStore
|
||||
>>;
|
||||
|
||||
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
|
||||
};
|
||||
|
||||
// Decode-tuned traits: 4 warps (1 warp group), kBlockM=128, serial pipeline.
|
||||
// Uses the single-warp-group path in UnifiedAttentionPipeline.
|
||||
template <unified_attention_args::data_type_enum DataType,
|
||||
bool IsMasking,
|
||||
index_t HeadSize_ = 128,
|
||||
index_t BlockM_ = 128,
|
||||
index_t NumQPerKV_ = 1,
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32>
|
||||
struct unified_attention_decode_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
|
||||
// kBlockM kBlockQ BLOCK_SIZE HEAD_SIZE
|
||||
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
|
||||
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
|
||||
// 4 warps -> kBlockSize = 256 threads -> NumWarpGroups = 1
|
||||
using unified_attention_block_warps = sequence<4, 1, 1>;
|
||||
|
||||
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
true>;
|
||||
|
||||
using unified_attention_traits = TileUnifiedAttentionTraits<true, false, -1>;
|
||||
using unified_attention_mask = GenericAttentionMask<IsMasking, false>;
|
||||
|
||||
using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem<
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::lse_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
|
||||
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
|
||||
|
||||
using epilogue = Default2DEpilogue<
|
||||
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
true, true, true>>;
|
||||
|
||||
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
|
||||
};
|
||||
|
||||
// Small decode traits: 2 warps, kBlockM=64, decode policy (NumWarpPerGroup=2).
|
||||
// Uses 1D warp layout (sequence<2,1,1>) so no softmax reduction changes needed.
|
||||
template <unified_attention_args::data_type_enum DataType,
|
||||
bool IsMasking,
|
||||
index_t HeadSize_ = 64,
|
||||
index_t BlockM_ = 64,
|
||||
index_t NumQPerKV_ = 8,
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32>
|
||||
struct unified_attention_decode_small_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
|
||||
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
|
||||
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
|
||||
// 2 warps along M: kBlockM=2*32=64, kBlockSize=128, NumWarpGroups=1
|
||||
using unified_attention_block_warps = sequence<2, 1, 1>;
|
||||
|
||||
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
true>;
|
||||
|
||||
using unified_attention_traits = TileUnifiedAttentionTraits<true, false, -1>;
|
||||
using unified_attention_mask = GenericAttentionMask<IsMasking, false>;
|
||||
|
||||
using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem<
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::lse_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
using unified_attention_pipeline_problem =
|
||||
UnifiedAttentionPipelineProblem<typename dt::qkvp_dtype,
|
||||
typename dt::qkvp_dtype,
|
||||
typename dt::qkvp_dtype,
|
||||
typename dt::acc_dtype,
|
||||
typename dt::acc_dtype,
|
||||
typename dt::acc_dtype,
|
||||
typename dt::lse_dtype,
|
||||
typename dt::qkvp_dtype,
|
||||
typename dt::acc_dtype,
|
||||
typename dt::o_dtype,
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
|
||||
using unified_attention_pipeline =
|
||||
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
|
||||
UnifiedAttentionPipelineDecodePolicy>;
|
||||
typename cfg::template Pipeline<unified_attention_pipeline_problem>;
|
||||
|
||||
using epilogue = Default2DEpilogue<
|
||||
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
true, true, true>>;
|
||||
|
||||
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
|
||||
};
|
||||
|
||||
// Tiny decode traits: 1 warp, 16x16 MFMA, kBlockM=16, kBlockQ=2 for GQA-8.
|
||||
// Matches Triton's BLOCK_M=16 / BLOCK_Q=2 decode configuration.
|
||||
// Uses block_tile_reduce_sync instead of permlane32_swap for 16x16 MFMA.
|
||||
template <unified_attention_args::data_type_enum DataType,
|
||||
bool IsMasking,
|
||||
index_t HeadSize_ = 64,
|
||||
index_t BlockM_ = 16,
|
||||
index_t NumQPerKV_ = 8,
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32>
|
||||
struct unified_attention_decode_tiny_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
|
||||
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
|
||||
using unified_attention_warp_gemm_shape = sequence<16, 16, 32>;
|
||||
// 1 warp: kBlockM=1*16=16, kBlockSize=64, NumWarpGroups=1
|
||||
using unified_attention_block_warps = sequence<1, 1, 1>;
|
||||
|
||||
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
true>;
|
||||
|
||||
using unified_attention_traits = TileUnifiedAttentionTraits<true, false, -1>;
|
||||
using unified_attention_mask = GenericAttentionMask<IsMasking, false>;
|
||||
|
||||
using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem<
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::lse_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
|
||||
using unified_attention_pipeline =
|
||||
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
|
||||
UnifiedAttentionPipelineTinyDecodePolicy>;
|
||||
|
||||
using epilogue = Default2DEpilogue<
|
||||
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
true, true, true>>;
|
||||
using epilogue =
|
||||
Default2DEpilogue<Default2DEpilogueProblem<typename dt::acc_dtype,
|
||||
typename dt::o_dtype,
|
||||
true, // kPadM
|
||||
true, // kPadN
|
||||
true // UseRawStore
|
||||
>>;
|
||||
|
||||
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
|
||||
};
|
||||
|
||||
// =============================================================================
|
||||
// Kernel launch — common helper. Picks the grid layout from
|
||||
// Traits::kUseDecodeGrid; all other launch args are identical across variants.
|
||||
// =============================================================================
|
||||
template <typename Kernel, bool UseDecodeGrid = false>
|
||||
float unified_attention_kernel_launch(const unified_attention_args& args,
|
||||
const stream_config& config)
|
||||
@@ -380,15 +311,33 @@ float unified_attention_kernel_launch(const unified_attention_args& args,
|
||||
return launch_kernel(config, make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
// return value:
|
||||
// first = whether the kernel was launched (true = launched, false = skipped)
|
||||
// second = elapsed time (ms) of the kernel launch, valid only if first == true
|
||||
template <typename KernelTraits>
|
||||
// =============================================================================
|
||||
// Per-instance dispatch. Each instance .cpp specializes this for its
|
||||
// (V, DataType, IsMasking) tuple via INST_UNIFIED_ATTENTION_DISPATCH.
|
||||
//
|
||||
// Return: (launched?, elapsed_ms). elapsed_ms is valid only when launched.
|
||||
// =============================================================================
|
||||
template <typename Traits>
|
||||
std::pair<bool, float> unified_attention_kernel_dispatch(const unified_attention_args& args,
|
||||
const stream_config& config);
|
||||
|
||||
template <typename KernelTraits>
|
||||
std::pair<bool, float> unified_attention_kernel_dispatch_decode(const unified_attention_args& args,
|
||||
const stream_config& config);
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
// One-line instantiation per (V, DataType, IsMasking) combination. Each
|
||||
// instance .cpp consists of exactly one of these calls.
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH(VARIANT_, DTYPE_, IS_MASK_) \
|
||||
template <> \
|
||||
std::pair<bool, float> unified_attention_kernel_dispatch< \
|
||||
unified_attention_kernel_traits<KernelVariant::VARIANT_, \
|
||||
unified_attention_args::data_type_enum::DTYPE_, \
|
||||
IS_MASK_>>(const unified_attention_args& args, \
|
||||
const stream_config& config) \
|
||||
{ \
|
||||
using Traits = unified_attention_kernel_traits< \
|
||||
KernelVariant::VARIANT_, \
|
||||
unified_attention_args::data_type_enum::DTYPE_, \
|
||||
IS_MASK_>; \
|
||||
return std::make_pair(true, \
|
||||
unified_attention_kernel_launch<typename Traits::kernel, \
|
||||
Traits::kUseDecodeGrid>(args, config)); \
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user