Files
composable_kernel/example/ck_tile/42_unified_attention/unified_attention.cpp
juuso-oskari f5beedb2e9 Add CK-UA decode_d128_mha_m32 / _m16 small-Q tiers
For pure-decode workloads (sq=1) at d=128 the m128 tile wastes most of
its 128 query rows, capping CK below Triton on every batch size in our
sweep (4..256). Add two small-Q tiers that mirror the d=64 GQA-8 ladder:

  * decode_d128_mha_m16 : kBlockM=16, 1 warp, 16x16 MFMA  (tiny-decode)
  * decode_d128_mha_m32 : kBlockM=32, 1 warp, 32x32 MFMA  (tiny-decode)

select_config now ladders by (avg_q, max_q): m16 -> m32 -> m128 -> prefill.

d=128 MHA, hq=16/hk=16, sq=1, sk=120k, num_blocks=60k:
  batch  before    after    CK BW
      4  ~0.95x   0.98x   4.76 TB/s
      8  ~0.85x   1.29x   5.00 TB/s
     32  ~0.85x   1.14x   5.29 TB/s
     64  ~0.75x   0.93x   5.35 TB/s
    128  ~1.00x   1.09x   5.39 TB/s
    256  ~1.03x   1.02x   5.41 TB/s

Correctness suite stays at 241/245 (same 4 known int32-overflow
failures in the prefill path).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-12 11:48:19 +00:00

196 lines
7.6 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
#include "mask.hpp"
namespace ck_tile {
std::ostream& operator<<(std::ostream& stream,
const unified_attention_args::data_type_enum& data_type)
{
switch(data_type)
{
case unified_attention_args::data_type_enum::fp16: return stream << "fp16";
case unified_attention_args::data_type_enum::bf16: return stream << "bf16";
default: return stream << "unknown";
}
}
// =============================================================================
// Config selection
//
// The job is split in two halves so each is small enough to read in one sitting:
//
// 1. select_config(args)
// - Reads shape (hdim, num_queries_per_kv, avg_q, max_seqlen_q) and
// picks one of the KernelVariants defined in unified_attention_impl.hpp.
// KernelVariant is the only place where compile-time knobs live —
// changing a knob means adding/editing a variant_config<V>.
//
// 2. dispatch_variant<V>() + the final switch
// - dispatch_variant<V>() is a single function template that fans out
// over (dtype, mask) and forwards into the per-instance dispatch
// function generated by INST_UNIFIED_ATTENTION_DISPATCH.
// - The final switch maps KernelVariant -> dispatch_variant<V>.
//
// page_size is intentionally NOT part of the config — the multi-page-tile
// pipeline fix made kBlockN independent of runtime page_blk_size, so every
// variant is correct for any page size.
// =============================================================================
struct KernelConfig
{
KernelVariant variant;
bool unsupported = false;
};
namespace {
// Internal tile-tier classification — used only by select_config. The tier
// name is shorthand for a kBlockM choice; with num_queries_per_kv=8 the
// tiers correspond to kBlockQ thresholds {2, 8, 16}.
enum class tile_tier
{
medium,
small,
tiny
};
tile_tier select_tile_tier(const unified_attention_args& args)
{
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs
: args.num_tokens;
const index_t kBlockQ_tiny = 16 / args.num_queries_per_kv;
const index_t kBlockQ_small = 64 / args.num_queries_per_kv;
// Decode tiers use a 2D grid (num_kv_heads, num_seqs) that assumes each
// seq has at most kBlockQ tokens. For mixed batches where some seqs have
// many more tokens, fall back to the medium tier (1D grid with Q iteration).
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
if(avg_q <= kBlockQ_tiny && max_q <= kBlockQ_tiny) return tile_tier::tiny;
if(avg_q <= kBlockQ_small && max_q <= kBlockQ_small) return tile_tier::small;
return tile_tier::medium;
}
} // anonymous namespace
KernelConfig select_config(const unified_attention_args& args)
{
KernelConfig cfg;
// d=128 MHA — tile-tier ladder by (avg_q, max_q):
// * decode_d128_mha_m16 : kBlockM=16, 1 warp, 16x16 mfma (tiny-decode)
// * decode_d128_mha_m32 : kBlockM=32, 1 warp, 32x32 mfma (tiny-decode)
// * decode_d128_mha_m128 : kBlockM=128, 4 warps, 32x32 mfma (default)
// * prefill_d128_mha : kBlockM=256, 8 warps, 32x32 mfma
if(args.hdim == 128 && args.num_queries_per_kv == 1)
{
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs
: args.num_tokens;
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
if(avg_q <= 16 && max_q <= 16)
cfg.variant = KernelVariant::decode_d128_mha_m16;
else if(avg_q <= 32 && max_q <= 32)
cfg.variant = KernelVariant::decode_d128_mha_m32;
else if(avg_q <= 128 && max_q <= 128)
cfg.variant = KernelVariant::decode_d128_mha_m128;
else
cfg.variant = KernelVariant::prefill_d128_mha;
return cfg;
}
// d=64 GQA-8 — pure tile-tier ladder. page_size has no influence here.
if(args.hdim == 64 && args.num_queries_per_kv == 8)
{
switch(select_tile_tier(args))
{
case tile_tier::tiny: cfg.variant = KernelVariant::decode_d64_gqa8_m16; break;
case tile_tier::small: cfg.variant = KernelVariant::decode_d64_gqa8_m64; break;
case tile_tier::medium: cfg.variant = KernelVariant::decode_d64_gqa8_m128; break;
}
return cfg;
}
cfg.unsupported = true;
return cfg;
}
// -----------------------------------------------------------------------------
// dispatch_variant<V>
//
// One function template. Fans out over (dtype, mask) and forwards into the
// per-instance dispatch generated by INST_UNIFIED_ATTENTION_DISPATCH. No
// per-variant boilerplate.
// -----------------------------------------------------------------------------
namespace {
template <KernelVariant V>
std::pair<bool, float> dispatch_variant(const unified_attention_args& args,
const stream_config& config)
{
using DT = unified_attention_args::data_type_enum;
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
if(args.data_type == DT::fp16)
{
if(is_mask)
return unified_attention_kernel_dispatch<
unified_attention_kernel_traits<V, DT::fp16, true>>(args, config);
return unified_attention_kernel_dispatch<
unified_attention_kernel_traits<V, DT::fp16, false>>(args, config);
}
if(args.data_type == DT::bf16)
{
if(is_mask)
return unified_attention_kernel_dispatch<
unified_attention_kernel_traits<V, DT::bf16, true>>(args, config);
return unified_attention_kernel_dispatch<
unified_attention_kernel_traits<V, DT::bf16, false>>(args, config);
}
return {false, -1.f};
}
} // anonymous namespace
std::pair<bool, float> unified_attention(const unified_attention_args& args,
const stream_config& config)
{
const auto cfg = select_config(args);
if(cfg.unsupported)
{
std::cerr << "unified_attention: no matching kernel instance for hdim=" << args.hdim
<< " num_queries_per_kv=" << args.num_queries_per_kv
<< " data_type=" << args.data_type << " mask_type=" << args.mask_type
<< std::endl;
return std::make_pair(false, -1.f);
}
switch(cfg.variant)
{
case KernelVariant::prefill_d128_mha:
return dispatch_variant<KernelVariant::prefill_d128_mha>(args, config);
case KernelVariant::decode_d128_mha_m128:
return dispatch_variant<KernelVariant::decode_d128_mha_m128>(args, config);
case KernelVariant::decode_d128_mha_m32:
return dispatch_variant<KernelVariant::decode_d128_mha_m32>(args, config);
case KernelVariant::decode_d128_mha_m16:
return dispatch_variant<KernelVariant::decode_d128_mha_m16>(args, config);
case KernelVariant::prefill_d64_gqa8:
return dispatch_variant<KernelVariant::prefill_d64_gqa8>(args, config);
case KernelVariant::decode_d64_gqa8_m128:
return dispatch_variant<KernelVariant::decode_d64_gqa8_m128>(args, config);
case KernelVariant::decode_d64_gqa8_m64:
return dispatch_variant<KernelVariant::decode_d64_gqa8_m64>(args, config);
case KernelVariant::decode_d64_gqa8_m16:
return dispatch_variant<KernelVariant::decode_d64_gqa8_m16>(args, config);
}
return std::make_pair(false, -1.f);
}
} // namespace ck_tile