mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
CK-UA: collapse MHA/GQA variants -- one binary per (head_dim, kBlockM)
After moving kBlockQ to runtime in the previous commit, the static
NumQPerKV in `variant_config<V>` and the runtime-vs-static assert in
the kernel became the only things still tying a compiled binary to a
specific num_queries_per_kv. Drop both and the existing instances now
serve every num_qpkv that divides kBlockM evenly.
Concretely:
* `variant_config<V>` -- remove the NumQPerKV field from every
specialization.
* `unified_attention_kernel_traits` -- remove the `num_queries_per_kv`
/ `kBlockQ = kBlockM / num_qpkv` derivation. The BlockTile's 2nd
entry (the static `kBlockQ` exposed via UnifiedAttentionShape) is
anchored at kBlockM so it describes the "num_qpkv == 1" fallback;
the actual kBlockQ is always the runtime value.
* `unified_attention_kernel_launch` -- recompute kBlockQ at host time
from `args.num_queries_per_kv` for the total_num_q_blocks math.
* `unified_attention_kernel.hpp` -- drop the
`assert(kBlockQ_dyn == kBlockQ)` (it enforced the very coupling we
just removed).
* `unified_attention.cpp::select_config` -- collapse the two
per-num_qpkv code paths into a single (head_dim, avg_rows,
max_rows) ladder, where avg_rows = avg_q * num_qpkv.
Variant renames (8 variants):
prefill_d128_mha -> prefill_d128
decode_d128_mha_m128 -> decode_d128_m128
decode_d128_mha_m32 -> decode_d128_m32
decode_d128_mha_m16 -> decode_d128_m16
prefill_d64_gqa8 -> prefill_d64
decode_d64_gqa8_m128 -> decode_d64_m128
decode_d64_gqa8_m64 -> decode_d64_m64
decode_d64_gqa8_m16 -> decode_d64_m16
The 16 d=64 instance files lose their `_gqa8` infix to match the
d=128 naming (file count unchanged: 16 dtypes x mask combos per
head_dim).
Validation:
* Correctness suite: 241/245 (same 4 pre-existing int32-overflow
failures in the prefill rebased-pointer path).
* d=128 GQA-8 (a NEW combo we never had a binary for) -- runs
correctly on the existing decode_d128_m* binaries with num_qpkv=8
at runtime. max abs diff <= 1e-2 vs the torch reference at ql in
{1, 4, 16}.
* d=64 MHA (also a new combo) -- runs correctly on the existing
decode_d64_m* binaries with num_qpkv=1. Same tolerance.
* Perf sweep (b=4..256, sk=120000, MI300):
d=64 GQA-8: speedups 1.28x..1.84x vs Triton (within 0.6%
of baseline).
d=128 MHA: speedups 0.98x..1.14x vs Triton (within 0.3%
of baseline).
Unlocked: adding new (head_dim, num_qpkv) combos no longer requires
new kernel binaries -- just a host-side heuristic update mapping the
combo to the appropriate (kBlockM, BlockWarps) ladder.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, bf16, true)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, bf16, true)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m128, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, bf16, true)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m32, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, bf16, true)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m16, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, bf16, false)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, bf16, false)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m128, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, bf16, false)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m32, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, bf16, false)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m16, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, fp16, true)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, fp16, true)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m128, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, fp16, true)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m32, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, fp16, true)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m16, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, fp16, false)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, fp16, false)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m128, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, fp16, false)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m32, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, fp16, false)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m16, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, bf16, true)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, fp16, true)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m128, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, bf16, false)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m64, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,6 +6,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, fp16, false)
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m16, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,11 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,11 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,11 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m128, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m64, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m16, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,11 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,11 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,11 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m128, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m64, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m16, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,11 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,11 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,11 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m128, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m64, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m16, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,11 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,11 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,11 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -46,72 +46,52 @@ struct KernelConfig
|
||||
bool unsupported = false;
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
// Internal tile-tier classification — used only by select_config. The tier
|
||||
// name is shorthand for a kBlockM choice; with num_queries_per_kv=8 the
|
||||
// tiers correspond to kBlockQ thresholds {2, 8, 16}.
|
||||
enum class tile_tier
|
||||
{
|
||||
medium,
|
||||
small,
|
||||
tiny
|
||||
};
|
||||
|
||||
tile_tier select_tile_tier(const unified_attention_args& args)
|
||||
{
|
||||
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs
|
||||
: args.num_tokens;
|
||||
const index_t kBlockQ_tiny = 16 / args.num_queries_per_kv;
|
||||
const index_t kBlockQ_small = 64 / args.num_queries_per_kv;
|
||||
|
||||
// Decode tiers use a 2D grid (num_kv_heads, num_seqs) that assumes each
|
||||
// seq has at most kBlockQ tokens. For mixed batches where some seqs have
|
||||
// many more tokens, fall back to the medium tier (1D grid with Q iteration).
|
||||
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
|
||||
|
||||
if(avg_q <= kBlockQ_tiny && max_q <= kBlockQ_tiny) return tile_tier::tiny;
|
||||
if(avg_q <= kBlockQ_small && max_q <= kBlockQ_small) return tile_tier::small;
|
||||
return tile_tier::medium;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
KernelConfig select_config(const unified_attention_args& args)
|
||||
{
|
||||
KernelConfig cfg;
|
||||
|
||||
// d=128 MHA — tile-tier ladder by (avg_q, max_q):
|
||||
// * decode_d128_mha_m16 : kBlockM=16, 1 warp, 16x16 mfma (tiny-decode)
|
||||
// * decode_d128_mha_m32 : kBlockM=32, 1 warp, 32x32 mfma (tiny-decode)
|
||||
// * decode_d128_mha_m128 : kBlockM=128, 4 warps, 32x32 mfma (default)
|
||||
// * prefill_d128_mha : kBlockM=256, 8 warps, 32x32 mfma
|
||||
if(args.hdim == 128 && args.num_queries_per_kv == 1)
|
||||
{
|
||||
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs
|
||||
// The variants are now num_queries_per_kv-agnostic (kBlockQ is runtime
|
||||
// inside the kernel) -- we just have to pick a kBlockM that holds enough
|
||||
// rows for `num_qpkv * max_q` and that num_qpkv divides cleanly.
|
||||
//
|
||||
// `avg_q * num_qpkv` is the *effective* per-CTA tile occupancy; e.g.
|
||||
// GQA-8 with sq=1 produces 8 rows per Q tile, the same as MHA with sq=8.
|
||||
// Tiering on that quantity lets one variant ladder serve both regimes.
|
||||
const index_t num_qpkv = args.num_queries_per_kv;
|
||||
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs
|
||||
: args.num_tokens;
|
||||
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
|
||||
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
|
||||
const index_t avg_rows = avg_q * num_qpkv; // effective rows per Q tile
|
||||
const index_t max_rows = max_q * num_qpkv;
|
||||
|
||||
if(avg_q <= 16 && max_q <= 16)
|
||||
cfg.variant = KernelVariant::decode_d128_mha_m16;
|
||||
else if(avg_q <= 32 && max_q <= 32)
|
||||
cfg.variant = KernelVariant::decode_d128_mha_m32;
|
||||
else if(avg_q <= 128 && max_q <= 128)
|
||||
cfg.variant = KernelVariant::decode_d128_mha_m128;
|
||||
if(args.hdim == 128)
|
||||
{
|
||||
// d=128 ladder: m16 / m32 / m128 / prefill (m256). Requires
|
||||
// num_qpkv to divide the chosen kBlockM, which is automatic for
|
||||
// num_qpkv in {1, 2, 4, 8, 16} and any of the kBlockM's below.
|
||||
if(avg_rows <= 16 && max_rows <= 16)
|
||||
cfg.variant = KernelVariant::decode_d128_m16;
|
||||
else if(avg_rows <= 32 && max_rows <= 32)
|
||||
cfg.variant = KernelVariant::decode_d128_m32;
|
||||
else if(avg_rows <= 128 && max_rows <= 128)
|
||||
cfg.variant = KernelVariant::decode_d128_m128;
|
||||
else
|
||||
cfg.variant = KernelVariant::prefill_d128_mha;
|
||||
cfg.variant = KernelVariant::prefill_d128;
|
||||
return cfg;
|
||||
}
|
||||
|
||||
// d=64 GQA-8 — pure tile-tier ladder. page_size has no influence here.
|
||||
if(args.hdim == 64 && args.num_queries_per_kv == 8)
|
||||
if(args.hdim == 64)
|
||||
{
|
||||
switch(select_tile_tier(args))
|
||||
{
|
||||
case tile_tier::tiny: cfg.variant = KernelVariant::decode_d64_gqa8_m16; break;
|
||||
case tile_tier::small: cfg.variant = KernelVariant::decode_d64_gqa8_m64; break;
|
||||
case tile_tier::medium: cfg.variant = KernelVariant::decode_d64_gqa8_m128; break;
|
||||
}
|
||||
// d=64 ladder: m16 / m64 / m128 / prefill (m256). Same shape
|
||||
// selection logic as d=128; the variant's kBlockN is just bigger.
|
||||
if(avg_rows <= 16 && max_rows <= 16)
|
||||
cfg.variant = KernelVariant::decode_d64_m16;
|
||||
else if(avg_rows <= 64 && max_rows <= 64)
|
||||
cfg.variant = KernelVariant::decode_d64_m64;
|
||||
else if(avg_rows <= 128 && max_rows <= 128)
|
||||
cfg.variant = KernelVariant::decode_d64_m128;
|
||||
else
|
||||
cfg.variant = KernelVariant::prefill_d64;
|
||||
return cfg;
|
||||
}
|
||||
|
||||
@@ -172,22 +152,22 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
|
||||
switch(cfg.variant)
|
||||
{
|
||||
case KernelVariant::prefill_d128_mha:
|
||||
return dispatch_variant<KernelVariant::prefill_d128_mha>(args, config);
|
||||
case KernelVariant::decode_d128_mha_m128:
|
||||
return dispatch_variant<KernelVariant::decode_d128_mha_m128>(args, config);
|
||||
case KernelVariant::decode_d128_mha_m32:
|
||||
return dispatch_variant<KernelVariant::decode_d128_mha_m32>(args, config);
|
||||
case KernelVariant::decode_d128_mha_m16:
|
||||
return dispatch_variant<KernelVariant::decode_d128_mha_m16>(args, config);
|
||||
case KernelVariant::prefill_d64_gqa8:
|
||||
return dispatch_variant<KernelVariant::prefill_d64_gqa8>(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m128:
|
||||
return dispatch_variant<KernelVariant::decode_d64_gqa8_m128>(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m64:
|
||||
return dispatch_variant<KernelVariant::decode_d64_gqa8_m64>(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m16:
|
||||
return dispatch_variant<KernelVariant::decode_d64_gqa8_m16>(args, config);
|
||||
case KernelVariant::prefill_d128:
|
||||
return dispatch_variant<KernelVariant::prefill_d128>(args, config);
|
||||
case KernelVariant::decode_d128_m128:
|
||||
return dispatch_variant<KernelVariant::decode_d128_m128>(args, config);
|
||||
case KernelVariant::decode_d128_m32:
|
||||
return dispatch_variant<KernelVariant::decode_d128_m32>(args, config);
|
||||
case KernelVariant::decode_d128_m16:
|
||||
return dispatch_variant<KernelVariant::decode_d128_m16>(args, config);
|
||||
case KernelVariant::prefill_d64:
|
||||
return dispatch_variant<KernelVariant::prefill_d64>(args, config);
|
||||
case KernelVariant::decode_d64_m128:
|
||||
return dispatch_variant<KernelVariant::decode_d64_m128>(args, config);
|
||||
case KernelVariant::decode_d64_m64:
|
||||
return dispatch_variant<KernelVariant::decode_d64_m64>(args, config);
|
||||
case KernelVariant::decode_d64_m16:
|
||||
return dispatch_variant<KernelVariant::decode_d64_m16>(args, config);
|
||||
}
|
||||
return std::make_pair(false, -1.f);
|
||||
}
|
||||
|
||||
@@ -36,17 +36,20 @@ namespace ck_tile {
|
||||
// =============================================================================
|
||||
enum class KernelVariant
|
||||
{
|
||||
// d=128 MHA (num_queries_per_kv = 1)
|
||||
prefill_d128_mha, // kBlockM=256, 8 warps, 32x32 mfma
|
||||
decode_d128_mha_m128, // kBlockM=128, 4 warps, 32x32 mfma (kBlockQ=128)
|
||||
decode_d128_mha_m32, // kBlockM=32, 1 warp, 32x32 mfma (tiny-decode policy)
|
||||
decode_d128_mha_m16, // kBlockM=16, 1 warp, 16x16 mfma (tiny-decode policy)
|
||||
// d=128 (num_queries_per_kv chosen at *runtime* — same binary serves both
|
||||
// MHA and GQA-N as long as num_qpkv divides kBlockM). kBlockM is the only
|
||||
// structural compile-time knob; pick the tier by max_q after multiplying
|
||||
// by num_qpkv in select_config.
|
||||
prefill_d128, // kBlockM=256, 8 warps, 32x32 mfma
|
||||
decode_d128_m128, // kBlockM=128, 4 warps, 32x32 mfma
|
||||
decode_d128_m32, // kBlockM=32, 1 warp, 32x32 mfma (tiny-decode policy)
|
||||
decode_d128_m16, // kBlockM=16, 1 warp, 16x16 mfma (tiny-decode policy)
|
||||
|
||||
// d=64 GQA-8 (num_queries_per_kv = 8)
|
||||
prefill_d64_gqa8, // kBlockM=256, 8 warps, 32x32 mfma
|
||||
decode_d64_gqa8_m128, // kBlockM=128, 4 warps, 32x32 mfma
|
||||
decode_d64_gqa8_m64, // kBlockM=64, 2 warps, 32x32 mfma (decode policy)
|
||||
decode_d64_gqa8_m16, // kBlockM=16, 1 warp, 16x16 mfma (tiny-decode policy)
|
||||
// d=64.
|
||||
prefill_d64, // kBlockM=256, 8 warps, 32x32 mfma
|
||||
decode_d64_m128, // kBlockM=128, 4 warps, 32x32 mfma
|
||||
decode_d64_m64, // kBlockM=64, 2 warps, 32x32 mfma (decode policy)
|
||||
decode_d64_m16, // kBlockM=16, 1 warp, 16x16 mfma (tiny-decode policy)
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
@@ -81,22 +84,25 @@ struct unified_attention_problem_traits<unified_attention_args::data_type_enum::
|
||||
//
|
||||
// HeadSize : head dimension (compile-time)
|
||||
// BlockM : Q-tile size along the M (token) axis
|
||||
// NumQPerKV : 1 for MHA, 8 for GQA-8
|
||||
// BlockSize : kBlockN — KV-tile size along the N axis
|
||||
// BlockWarps : warp layout, sequence<M, N, K>
|
||||
// WarpGemmShape : MFMA tile shape, sequence<M, N, K>
|
||||
// Pipeline<P> : pipeline template (default vs decode vs tiny-decode policy)
|
||||
// kUseDecodeGrid : selects 2D-by-seq grid (true) vs Q-block grid (false)
|
||||
//
|
||||
// num_queries_per_kv is *not* a compile-time knob: kBlockQ = kBlockM /
|
||||
// num_qpkv is computed at runtime inside the kernel and pipeline. The only
|
||||
// constraint is `kBlockM % num_qpkv == 0` (host-side select_config makes sure
|
||||
// of this).
|
||||
// =============================================================================
|
||||
template <KernelVariant V>
|
||||
struct variant_config;
|
||||
|
||||
template <>
|
||||
struct variant_config<KernelVariant::prefill_d128_mha>
|
||||
struct variant_config<KernelVariant::prefill_d128>
|
||||
{
|
||||
static constexpr index_t HeadSize = 128;
|
||||
static constexpr index_t BlockM = 256;
|
||||
static constexpr index_t NumQPerKV = 1;
|
||||
static constexpr index_t BlockSize = 32;
|
||||
using BlockWarps = sequence<8, 1, 1>;
|
||||
using WarpGemmShape = sequence<32, 32, 16>;
|
||||
@@ -106,11 +112,10 @@ struct variant_config<KernelVariant::prefill_d128_mha>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct variant_config<KernelVariant::decode_d128_mha_m128>
|
||||
struct variant_config<KernelVariant::decode_d128_m128>
|
||||
{
|
||||
static constexpr index_t HeadSize = 128;
|
||||
static constexpr index_t BlockM = 128;
|
||||
static constexpr index_t NumQPerKV = 1;
|
||||
static constexpr index_t BlockSize = 32;
|
||||
using BlockWarps = sequence<4, 1, 1>;
|
||||
using WarpGemmShape = sequence<32, 32, 16>;
|
||||
@@ -120,11 +125,10 @@ struct variant_config<KernelVariant::decode_d128_mha_m128>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct variant_config<KernelVariant::decode_d128_mha_m32>
|
||||
struct variant_config<KernelVariant::decode_d128_m32>
|
||||
{
|
||||
static constexpr index_t HeadSize = 128;
|
||||
static constexpr index_t BlockM = 32;
|
||||
static constexpr index_t NumQPerKV = 1;
|
||||
static constexpr index_t BlockSize = 32;
|
||||
using BlockWarps = sequence<1, 1, 1>;
|
||||
using WarpGemmShape = sequence<32, 32, 16>;
|
||||
@@ -134,11 +138,10 @@ struct variant_config<KernelVariant::decode_d128_mha_m32>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct variant_config<KernelVariant::decode_d128_mha_m16>
|
||||
struct variant_config<KernelVariant::decode_d128_m16>
|
||||
{
|
||||
static constexpr index_t HeadSize = 128;
|
||||
static constexpr index_t BlockM = 16;
|
||||
static constexpr index_t NumQPerKV = 1;
|
||||
static constexpr index_t BlockSize = 32;
|
||||
using BlockWarps = sequence<1, 1, 1>;
|
||||
using WarpGemmShape = sequence<16, 16, 32>;
|
||||
@@ -148,11 +151,10 @@ struct variant_config<KernelVariant::decode_d128_mha_m16>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct variant_config<KernelVariant::prefill_d64_gqa8>
|
||||
struct variant_config<KernelVariant::prefill_d64>
|
||||
{
|
||||
static constexpr index_t HeadSize = 64;
|
||||
static constexpr index_t BlockM = 256;
|
||||
static constexpr index_t NumQPerKV = 8;
|
||||
static constexpr index_t BlockSize = 64;
|
||||
using BlockWarps = sequence<8, 1, 1>;
|
||||
using WarpGemmShape = sequence<32, 32, 16>;
|
||||
@@ -162,11 +164,10 @@ struct variant_config<KernelVariant::prefill_d64_gqa8>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct variant_config<KernelVariant::decode_d64_gqa8_m128>
|
||||
struct variant_config<KernelVariant::decode_d64_m128>
|
||||
{
|
||||
static constexpr index_t HeadSize = 64;
|
||||
static constexpr index_t BlockM = 128;
|
||||
static constexpr index_t NumQPerKV = 8;
|
||||
static constexpr index_t BlockSize = 64;
|
||||
using BlockWarps = sequence<4, 1, 1>;
|
||||
using WarpGemmShape = sequence<32, 32, 16>;
|
||||
@@ -176,11 +177,10 @@ struct variant_config<KernelVariant::decode_d64_gqa8_m128>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct variant_config<KernelVariant::decode_d64_gqa8_m64>
|
||||
struct variant_config<KernelVariant::decode_d64_m64>
|
||||
{
|
||||
static constexpr index_t HeadSize = 64;
|
||||
static constexpr index_t BlockM = 64;
|
||||
static constexpr index_t NumQPerKV = 8;
|
||||
static constexpr index_t BlockSize = 64;
|
||||
using BlockWarps = sequence<2, 1, 1>;
|
||||
using WarpGemmShape = sequence<32, 32, 16>;
|
||||
@@ -190,11 +190,10 @@ struct variant_config<KernelVariant::decode_d64_gqa8_m64>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct variant_config<KernelVariant::decode_d64_gqa8_m16>
|
||||
struct variant_config<KernelVariant::decode_d64_m16>
|
||||
{
|
||||
static constexpr index_t HeadSize = 64;
|
||||
static constexpr index_t BlockM = 16;
|
||||
static constexpr index_t NumQPerKV = 8;
|
||||
static constexpr index_t BlockSize = 64;
|
||||
using BlockWarps = sequence<1, 1, 1>;
|
||||
using WarpGemmShape = sequence<16, 16, 32>;
|
||||
@@ -221,14 +220,19 @@ struct unified_attention_kernel_traits
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
static constexpr KernelVariant variant = V;
|
||||
|
||||
static constexpr index_t HEAD_SIZE = cfg::HeadSize;
|
||||
static constexpr index_t kBlockM = cfg::BlockM;
|
||||
static constexpr index_t BLOCK_SIZE = cfg::BlockSize;
|
||||
static constexpr index_t num_queries_per_kv = cfg::NumQPerKV;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
static constexpr bool kUseDecodeGrid = cfg::kUseDecodeGrid;
|
||||
static constexpr index_t HEAD_SIZE = cfg::HeadSize;
|
||||
static constexpr index_t kBlockM = cfg::BlockM;
|
||||
static constexpr index_t BLOCK_SIZE = cfg::BlockSize;
|
||||
static constexpr bool kUseDecodeGrid = cfg::kUseDecodeGrid;
|
||||
|
||||
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
|
||||
// The 2nd entry of the BlockTile is the static `kBlockQ` exposed via
|
||||
// `UnifiedAttentionShape::kBlockQ`. Now that the kernel always reads
|
||||
// kBlockQ from `args.num_queries_per_kv` at runtime, this static value
|
||||
// is only the fallback when no num_qpkv was plumbed through (which never
|
||||
// happens in practice). Anchor it at kBlockM so the static "looks like
|
||||
// num_qpkv == 1" and any (kBlockM, num_qpkv) such that kBlockM % num_qpkv
|
||||
// == 0 works without touching this trait.
|
||||
using unified_attention_block_tile = sequence<kBlockM, kBlockM, BLOCK_SIZE, HEAD_SIZE>;
|
||||
using unified_attention_warp_gemm_shape = typename cfg::WarpGemmShape;
|
||||
using unified_attention_block_warps = typename cfg::BlockWarps;
|
||||
|
||||
@@ -281,8 +285,12 @@ template <typename Kernel, bool UseDecodeGrid = false>
|
||||
float unified_attention_kernel_launch(const unified_attention_args& args,
|
||||
const stream_config& config)
|
||||
{
|
||||
constexpr index_t kBlockQ = Kernel::kBlockQ;
|
||||
index_t total_num_q_blocks = args.num_tokens / kBlockQ + args.num_seqs;
|
||||
// kBlockQ is derived from the runtime num_queries_per_kv now -- the
|
||||
// static `Kernel::kBlockQ` is anchored at kBlockM and would over-count
|
||||
// tiles for GQA workloads. We assert kBlockM % num_qpkv == 0 in
|
||||
// select_config so this integer divide is always exact.
|
||||
const index_t kBlockQ = Kernel::kBlockM / args.num_queries_per_kv;
|
||||
const index_t total_num_q_blocks = args.num_tokens / kBlockQ + args.num_seqs;
|
||||
auto kargs = Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
|
||||
@@ -263,13 +263,14 @@ struct UnifiedAttentionKernel
|
||||
|
||||
const index_t num_queries_per_kv = kargs.num_queries_per_kv;
|
||||
|
||||
// kBlockQ derived at runtime from num_queries_per_kv. For the variants
|
||||
// we ship today this matches the compile-time `kBlockQ` from the
|
||||
// pipeline trait (the assert below catches any disagreement); the
|
||||
// explicit runtime form is what eventually lets a single kernel
|
||||
// instantiation cover multiple num_queries_per_kv values.
|
||||
// kBlockQ derived at runtime from num_queries_per_kv. The static
|
||||
// `kBlockQ` from the pipeline trait is anchored at kBlockM (i.e. it
|
||||
// describes num_qpkv == 1) so the same compiled binary serves every
|
||||
// num_qpkv that divides kBlockM evenly -- e.g. the d=128 variants
|
||||
// can run both MHA and GQA-N at runtime with no recompile. The host
|
||||
// side (select_config) is responsible for enforcing kBlockM %
|
||||
// num_queries_per_kv == 0.
|
||||
const index_t kBlockQ_dyn = kBlockM / num_queries_per_kv;
|
||||
assert(kBlockQ_dyn == kBlockQ);
|
||||
|
||||
// Split-KV: each CTA handles one (kv_head, q_block, split) tuple. The
|
||||
// split index lives in z — when num_splits == 1 (the only z value)
|
||||
|
||||
Reference in New Issue
Block a user