mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 14:54:47 +00:00
For pure-decode workloads (sq=1) at d=128 the m128 tile wastes most of
its 128 query rows, capping CK below Triton on every batch size in our
sweep (4..256). Add two small-Q tiers that mirror the d=64 GQA-8 ladder:
* decode_d128_mha_m16 : kBlockM=16, 1 warp, 16x16 MFMA (tiny-decode)
* decode_d128_mha_m32 : kBlockM=32, 1 warp, 32x32 MFMA (tiny-decode)
select_config now ladders by (avg_q, max_q): m16 -> m32 -> m128 -> prefill.
d=128 MHA, hq=16/hk=16, sq=1, sk=120k, num_blocks=60k:
batch before after CK BW
4 ~0.95x 0.98x 4.76 TB/s
8 ~0.85x 1.29x 5.00 TB/s
32 ~0.85x 1.14x 5.29 TB/s
64 ~0.75x 0.93x 5.35 TB/s
128 ~1.00x 1.09x 5.39 TB/s
256 ~1.03x 1.02x 5.41 TB/s
Correctness suite stays at 241/245 (same 4 known int32-overflow
failures in the prefill path).
Co-authored-by: Cursor <cursoragent@cursor.com>
196 lines
7.6 KiB
C++
196 lines
7.6 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include "unified_attention.hpp"
|
|
#include "unified_attention_impl.hpp"
|
|
#include "mask.hpp"
|
|
|
|
namespace ck_tile {
|
|
|
|
std::ostream& operator<<(std::ostream& stream,
|
|
const unified_attention_args::data_type_enum& data_type)
|
|
{
|
|
switch(data_type)
|
|
{
|
|
case unified_attention_args::data_type_enum::fp16: return stream << "fp16";
|
|
case unified_attention_args::data_type_enum::bf16: return stream << "bf16";
|
|
default: return stream << "unknown";
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Config selection
|
|
//
|
|
// The job is split in two halves so each is small enough to read in one sitting:
|
|
//
|
|
// 1. select_config(args)
|
|
// - Reads shape (hdim, num_queries_per_kv, avg_q, max_seqlen_q) and
|
|
// picks one of the KernelVariants defined in unified_attention_impl.hpp.
|
|
// KernelVariant is the only place where compile-time knobs live —
|
|
// changing a knob means adding/editing a variant_config<V>.
|
|
//
|
|
// 2. dispatch_variant<V>() + the final switch
|
|
// - dispatch_variant<V>() is a single function template that fans out
|
|
// over (dtype, mask) and forwards into the per-instance dispatch
|
|
// function generated by INST_UNIFIED_ATTENTION_DISPATCH.
|
|
// - The final switch maps KernelVariant -> dispatch_variant<V>.
|
|
//
|
|
// page_size is intentionally NOT part of the config — the multi-page-tile
|
|
// pipeline fix made kBlockN independent of runtime page_blk_size, so every
|
|
// variant is correct for any page size.
|
|
// =============================================================================
|
|
|
|
struct KernelConfig
|
|
{
|
|
KernelVariant variant;
|
|
bool unsupported = false;
|
|
};
|
|
|
|
namespace {
|
|
|
|
// Internal tile-tier classification — used only by select_config. The tier
|
|
// name is shorthand for a kBlockM choice; with num_queries_per_kv=8 the
|
|
// tiers correspond to kBlockQ thresholds {2, 8, 16}.
|
|
enum class tile_tier
|
|
{
|
|
medium,
|
|
small,
|
|
tiny
|
|
};
|
|
|
|
tile_tier select_tile_tier(const unified_attention_args& args)
|
|
{
|
|
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs
|
|
: args.num_tokens;
|
|
const index_t kBlockQ_tiny = 16 / args.num_queries_per_kv;
|
|
const index_t kBlockQ_small = 64 / args.num_queries_per_kv;
|
|
|
|
// Decode tiers use a 2D grid (num_kv_heads, num_seqs) that assumes each
|
|
// seq has at most kBlockQ tokens. For mixed batches where some seqs have
|
|
// many more tokens, fall back to the medium tier (1D grid with Q iteration).
|
|
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
|
|
|
|
if(avg_q <= kBlockQ_tiny && max_q <= kBlockQ_tiny) return tile_tier::tiny;
|
|
if(avg_q <= kBlockQ_small && max_q <= kBlockQ_small) return tile_tier::small;
|
|
return tile_tier::medium;
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
KernelConfig select_config(const unified_attention_args& args)
|
|
{
|
|
KernelConfig cfg;
|
|
|
|
// d=128 MHA — tile-tier ladder by (avg_q, max_q):
|
|
// * decode_d128_mha_m16 : kBlockM=16, 1 warp, 16x16 mfma (tiny-decode)
|
|
// * decode_d128_mha_m32 : kBlockM=32, 1 warp, 32x32 mfma (tiny-decode)
|
|
// * decode_d128_mha_m128 : kBlockM=128, 4 warps, 32x32 mfma (default)
|
|
// * prefill_d128_mha : kBlockM=256, 8 warps, 32x32 mfma
|
|
if(args.hdim == 128 && args.num_queries_per_kv == 1)
|
|
{
|
|
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs
|
|
: args.num_tokens;
|
|
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
|
|
|
|
if(avg_q <= 16 && max_q <= 16)
|
|
cfg.variant = KernelVariant::decode_d128_mha_m16;
|
|
else if(avg_q <= 32 && max_q <= 32)
|
|
cfg.variant = KernelVariant::decode_d128_mha_m32;
|
|
else if(avg_q <= 128 && max_q <= 128)
|
|
cfg.variant = KernelVariant::decode_d128_mha_m128;
|
|
else
|
|
cfg.variant = KernelVariant::prefill_d128_mha;
|
|
return cfg;
|
|
}
|
|
|
|
// d=64 GQA-8 — pure tile-tier ladder. page_size has no influence here.
|
|
if(args.hdim == 64 && args.num_queries_per_kv == 8)
|
|
{
|
|
switch(select_tile_tier(args))
|
|
{
|
|
case tile_tier::tiny: cfg.variant = KernelVariant::decode_d64_gqa8_m16; break;
|
|
case tile_tier::small: cfg.variant = KernelVariant::decode_d64_gqa8_m64; break;
|
|
case tile_tier::medium: cfg.variant = KernelVariant::decode_d64_gqa8_m128; break;
|
|
}
|
|
return cfg;
|
|
}
|
|
|
|
cfg.unsupported = true;
|
|
return cfg;
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------
|
|
// dispatch_variant<V>
|
|
//
|
|
// One function template. Fans out over (dtype, mask) and forwards into the
|
|
// per-instance dispatch generated by INST_UNIFIED_ATTENTION_DISPATCH. No
|
|
// per-variant boilerplate.
|
|
// -----------------------------------------------------------------------------
|
|
namespace {
|
|
|
|
template <KernelVariant V>
|
|
std::pair<bool, float> dispatch_variant(const unified_attention_args& args,
|
|
const stream_config& config)
|
|
{
|
|
using DT = unified_attention_args::data_type_enum;
|
|
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
|
|
|
|
if(args.data_type == DT::fp16)
|
|
{
|
|
if(is_mask)
|
|
return unified_attention_kernel_dispatch<
|
|
unified_attention_kernel_traits<V, DT::fp16, true>>(args, config);
|
|
return unified_attention_kernel_dispatch<
|
|
unified_attention_kernel_traits<V, DT::fp16, false>>(args, config);
|
|
}
|
|
if(args.data_type == DT::bf16)
|
|
{
|
|
if(is_mask)
|
|
return unified_attention_kernel_dispatch<
|
|
unified_attention_kernel_traits<V, DT::bf16, true>>(args, config);
|
|
return unified_attention_kernel_dispatch<
|
|
unified_attention_kernel_traits<V, DT::bf16, false>>(args, config);
|
|
}
|
|
return {false, -1.f};
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
|
const stream_config& config)
|
|
{
|
|
const auto cfg = select_config(args);
|
|
|
|
if(cfg.unsupported)
|
|
{
|
|
std::cerr << "unified_attention: no matching kernel instance for hdim=" << args.hdim
|
|
<< " num_queries_per_kv=" << args.num_queries_per_kv
|
|
<< " data_type=" << args.data_type << " mask_type=" << args.mask_type
|
|
<< std::endl;
|
|
return std::make_pair(false, -1.f);
|
|
}
|
|
|
|
switch(cfg.variant)
|
|
{
|
|
case KernelVariant::prefill_d128_mha:
|
|
return dispatch_variant<KernelVariant::prefill_d128_mha>(args, config);
|
|
case KernelVariant::decode_d128_mha_m128:
|
|
return dispatch_variant<KernelVariant::decode_d128_mha_m128>(args, config);
|
|
case KernelVariant::decode_d128_mha_m32:
|
|
return dispatch_variant<KernelVariant::decode_d128_mha_m32>(args, config);
|
|
case KernelVariant::decode_d128_mha_m16:
|
|
return dispatch_variant<KernelVariant::decode_d128_mha_m16>(args, config);
|
|
case KernelVariant::prefill_d64_gqa8:
|
|
return dispatch_variant<KernelVariant::prefill_d64_gqa8>(args, config);
|
|
case KernelVariant::decode_d64_gqa8_m128:
|
|
return dispatch_variant<KernelVariant::decode_d64_gqa8_m128>(args, config);
|
|
case KernelVariant::decode_d64_gqa8_m64:
|
|
return dispatch_variant<KernelVariant::decode_d64_gqa8_m64>(args, config);
|
|
case KernelVariant::decode_d64_gqa8_m16:
|
|
return dispatch_variant<KernelVariant::decode_d64_gqa8_m16>(args, config);
|
|
}
|
|
return std::make_pair(false, -1.f);
|
|
}
|
|
|
|
} // namespace ck_tile
|