CK-UA: collapse MHA/GQA variants -- one binary per (head_dim, kBlockM)

After moving kBlockQ to runtime in the previous commit, the static
NumQPerKV in `variant_config<V>` and the runtime-vs-static assert in
the kernel became the only things still tying a compiled binary to a
specific num_queries_per_kv. Drop both and the existing instances now
serve every num_qpkv that divides kBlockM evenly.

Concretely:
  * `variant_config<V>` -- remove the NumQPerKV field from every
    specialization.
  * `unified_attention_kernel_traits` -- remove the `num_queries_per_kv`
    / `kBlockQ = kBlockM / num_qpkv` derivation. The BlockTile's 2nd
    entry (the static `kBlockQ` exposed via UnifiedAttentionShape) is
    anchored at kBlockM so it describes the "num_qpkv == 1" fallback;
    the actual kBlockQ is always the runtime value.
  * `unified_attention_kernel_launch` -- recompute kBlockQ at host time
    from `args.num_queries_per_kv` for the total_num_q_blocks math.
  * `unified_attention_kernel.hpp` -- drop the
    `assert(kBlockQ_dyn == kBlockQ)` (it enforced the very coupling we
    just removed).
  * `unified_attention.cpp::select_config` -- collapse the two
    per-num_qpkv code paths into a single (head_dim, avg_rows,
    max_rows) ladder, where avg_rows = avg_q * num_qpkv.

Variant renames (8 variants):
  prefill_d128_mha       -> prefill_d128
  decode_d128_mha_m128   -> decode_d128_m128
  decode_d128_mha_m32    -> decode_d128_m32
  decode_d128_mha_m16    -> decode_d128_m16
  prefill_d64_gqa8       -> prefill_d64
  decode_d64_gqa8_m128   -> decode_d64_m128
  decode_d64_gqa8_m64    -> decode_d64_m64
  decode_d64_gqa8_m16    -> decode_d64_m16

The 16 d=64 instance files lose their `_gqa8` infix to match the
d=128 naming (file count unchanged: 16 dtypes x mask combos per
head_dim).

Validation:
  * Correctness suite: 241/245 (same 4 pre-existing int32-overflow
    failures in the prefill rebased-pointer path).
  * d=128 GQA-8 (a NEW combo we never had a binary for) -- runs
    correctly on the existing decode_d128_m* binaries with num_qpkv=8
    at runtime. max abs diff <= 1e-2 vs the torch reference at ql in
    {1, 4, 16}.
  * d=64 MHA (also a new combo) -- runs correctly on the existing
    decode_d64_m* binaries with num_qpkv=1. Same tolerance.
  * Perf sweep (b=4..256, sk=120000, MI300):
      d=64  GQA-8: speedups 1.28x..1.84x vs Triton (within 0.6%
                   of baseline).
      d=128 MHA:   speedups 0.98x..1.14x vs Triton (within 0.3%
                   of baseline).

Unlocked: adding new (head_dim, num_qpkv) combos no longer requires
new kernel binaries -- just a host-side heuristic update mapping the
combo to the appropriate (kBlockM, BlockWarps) ladder.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-05-12 12:15:55 +00:00
parent 614afea7eb
commit d77f0bea63
47 changed files with 254 additions and 265 deletions

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, bf16, true)
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128, bf16, true)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, bf16, true)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m128, bf16, true)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, bf16, true)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m32, bf16, true)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, bf16, true)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m16, bf16, true)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, bf16, false)
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128, bf16, false)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, bf16, false)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m128, bf16, false)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, bf16, false)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m32, bf16, false)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, bf16, false)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m16, bf16, false)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, fp16, true)
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128, fp16, true)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, fp16, true)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m128, fp16, true)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, fp16, true)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m32, fp16, true)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, fp16, true)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m16, fp16, true)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128_mha, fp16, false)
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d128, fp16, false)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m128, fp16, false)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m128, fp16, false)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, fp16, false)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m32, fp16, false)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, fp16, false)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_m16, fp16, false)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, bf16, true)
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64, bf16, true)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, fp16, true)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m128, bf16, true)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, bf16, false)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m64, bf16, true)
} // namespace ck_tile

View File

@@ -6,6 +6,6 @@
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64_gqa8, fp16, false)
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m16, bf16, true)
} // namespace ck_tile

View File

@@ -1,11 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, bf16, true)
} // namespace ck_tile

View File

@@ -1,11 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, bf16, true)
} // namespace ck_tile

View File

@@ -1,11 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, bf16, true)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64, bf16, false)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m128, bf16, false)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m64, bf16, false)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m16, bf16, false)
} // namespace ck_tile

View File

@@ -1,11 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, bf16, false)
} // namespace ck_tile

View File

@@ -1,11 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, bf16, false)
} // namespace ck_tile

View File

@@ -1,11 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, bf16, false)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64, fp16, true)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m128, fp16, true)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m64, fp16, true)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m16, fp16, true)
} // namespace ck_tile

View File

@@ -1,11 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, fp16, true)
} // namespace ck_tile

View File

@@ -1,11 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, fp16, true)
} // namespace ck_tile

View File

@@ -1,11 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, fp16, true)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(prefill_d64, fp16, false)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m128, fp16, false)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m64, fp16, false)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_m16, fp16, false)
} // namespace ck_tile

View File

@@ -1,11 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m128, fp16, false)
} // namespace ck_tile

View File

@@ -1,11 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m64, fp16, false)
} // namespace ck_tile

View File

@@ -1,11 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d64_gqa8_m16, fp16, false)
} // namespace ck_tile

View File

@@ -46,72 +46,52 @@ struct KernelConfig
bool unsupported = false;
};
namespace {
// Internal tile-tier classification — used only by select_config. The tier
// name is shorthand for a kBlockM choice; with num_queries_per_kv=8 the
// tiers correspond to kBlockQ thresholds {2, 8, 16}.
enum class tile_tier
{
medium,
small,
tiny
};
tile_tier select_tile_tier(const unified_attention_args& args)
{
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs
: args.num_tokens;
const index_t kBlockQ_tiny = 16 / args.num_queries_per_kv;
const index_t kBlockQ_small = 64 / args.num_queries_per_kv;
// Decode tiers use a 2D grid (num_kv_heads, num_seqs) that assumes each
// seq has at most kBlockQ tokens. For mixed batches where some seqs have
// many more tokens, fall back to the medium tier (1D grid with Q iteration).
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
if(avg_q <= kBlockQ_tiny && max_q <= kBlockQ_tiny) return tile_tier::tiny;
if(avg_q <= kBlockQ_small && max_q <= kBlockQ_small) return tile_tier::small;
return tile_tier::medium;
}
} // anonymous namespace
KernelConfig select_config(const unified_attention_args& args)
{
KernelConfig cfg;
// d=128 MHA — tile-tier ladder by (avg_q, max_q):
// * decode_d128_mha_m16 : kBlockM=16, 1 warp, 16x16 mfma (tiny-decode)
// * decode_d128_mha_m32 : kBlockM=32, 1 warp, 32x32 mfma (tiny-decode)
// * decode_d128_mha_m128 : kBlockM=128, 4 warps, 32x32 mfma (default)
// * prefill_d128_mha : kBlockM=256, 8 warps, 32x32 mfma
if(args.hdim == 128 && args.num_queries_per_kv == 1)
{
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs
// The variants are now num_queries_per_kv-agnostic (kBlockQ is runtime
// inside the kernel) -- we just have to pick a kBlockM that holds enough
// rows for `num_qpkv * max_q` and that num_qpkv divides cleanly.
//
// `avg_q * num_qpkv` is the *effective* per-CTA tile occupancy; e.g.
// GQA-8 with sq=1 produces 8 rows per Q tile, the same as MHA with sq=8.
// Tiering on that quantity lets one variant ladder serve both regimes.
const index_t num_qpkv = args.num_queries_per_kv;
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs
: args.num_tokens;
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
const index_t avg_rows = avg_q * num_qpkv; // effective rows per Q tile
const index_t max_rows = max_q * num_qpkv;
if(avg_q <= 16 && max_q <= 16)
cfg.variant = KernelVariant::decode_d128_mha_m16;
else if(avg_q <= 32 && max_q <= 32)
cfg.variant = KernelVariant::decode_d128_mha_m32;
else if(avg_q <= 128 && max_q <= 128)
cfg.variant = KernelVariant::decode_d128_mha_m128;
if(args.hdim == 128)
{
// d=128 ladder: m16 / m32 / m128 / prefill (m256). Requires
// num_qpkv to divide the chosen kBlockM, which is automatic for
// num_qpkv in {1, 2, 4, 8, 16} and any of the kBlockM's below.
if(avg_rows <= 16 && max_rows <= 16)
cfg.variant = KernelVariant::decode_d128_m16;
else if(avg_rows <= 32 && max_rows <= 32)
cfg.variant = KernelVariant::decode_d128_m32;
else if(avg_rows <= 128 && max_rows <= 128)
cfg.variant = KernelVariant::decode_d128_m128;
else
cfg.variant = KernelVariant::prefill_d128_mha;
cfg.variant = KernelVariant::prefill_d128;
return cfg;
}
// d=64 GQA-8 — pure tile-tier ladder. page_size has no influence here.
if(args.hdim == 64 && args.num_queries_per_kv == 8)
if(args.hdim == 64)
{
switch(select_tile_tier(args))
{
case tile_tier::tiny: cfg.variant = KernelVariant::decode_d64_gqa8_m16; break;
case tile_tier::small: cfg.variant = KernelVariant::decode_d64_gqa8_m64; break;
case tile_tier::medium: cfg.variant = KernelVariant::decode_d64_gqa8_m128; break;
}
// d=64 ladder: m16 / m64 / m128 / prefill (m256). Same shape
// selection logic as d=128; the variant's kBlockN is just bigger.
if(avg_rows <= 16 && max_rows <= 16)
cfg.variant = KernelVariant::decode_d64_m16;
else if(avg_rows <= 64 && max_rows <= 64)
cfg.variant = KernelVariant::decode_d64_m64;
else if(avg_rows <= 128 && max_rows <= 128)
cfg.variant = KernelVariant::decode_d64_m128;
else
cfg.variant = KernelVariant::prefill_d64;
return cfg;
}
@@ -172,22 +152,22 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
switch(cfg.variant)
{
case KernelVariant::prefill_d128_mha:
return dispatch_variant<KernelVariant::prefill_d128_mha>(args, config);
case KernelVariant::decode_d128_mha_m128:
return dispatch_variant<KernelVariant::decode_d128_mha_m128>(args, config);
case KernelVariant::decode_d128_mha_m32:
return dispatch_variant<KernelVariant::decode_d128_mha_m32>(args, config);
case KernelVariant::decode_d128_mha_m16:
return dispatch_variant<KernelVariant::decode_d128_mha_m16>(args, config);
case KernelVariant::prefill_d64_gqa8:
return dispatch_variant<KernelVariant::prefill_d64_gqa8>(args, config);
case KernelVariant::decode_d64_gqa8_m128:
return dispatch_variant<KernelVariant::decode_d64_gqa8_m128>(args, config);
case KernelVariant::decode_d64_gqa8_m64:
return dispatch_variant<KernelVariant::decode_d64_gqa8_m64>(args, config);
case KernelVariant::decode_d64_gqa8_m16:
return dispatch_variant<KernelVariant::decode_d64_gqa8_m16>(args, config);
case KernelVariant::prefill_d128:
return dispatch_variant<KernelVariant::prefill_d128>(args, config);
case KernelVariant::decode_d128_m128:
return dispatch_variant<KernelVariant::decode_d128_m128>(args, config);
case KernelVariant::decode_d128_m32:
return dispatch_variant<KernelVariant::decode_d128_m32>(args, config);
case KernelVariant::decode_d128_m16:
return dispatch_variant<KernelVariant::decode_d128_m16>(args, config);
case KernelVariant::prefill_d64:
return dispatch_variant<KernelVariant::prefill_d64>(args, config);
case KernelVariant::decode_d64_m128:
return dispatch_variant<KernelVariant::decode_d64_m128>(args, config);
case KernelVariant::decode_d64_m64:
return dispatch_variant<KernelVariant::decode_d64_m64>(args, config);
case KernelVariant::decode_d64_m16:
return dispatch_variant<KernelVariant::decode_d64_m16>(args, config);
}
return std::make_pair(false, -1.f);
}

View File

@@ -36,17 +36,20 @@ namespace ck_tile {
// =============================================================================
enum class KernelVariant
{
// d=128 MHA (num_queries_per_kv = 1)
prefill_d128_mha, // kBlockM=256, 8 warps, 32x32 mfma
decode_d128_mha_m128, // kBlockM=128, 4 warps, 32x32 mfma (kBlockQ=128)
decode_d128_mha_m32, // kBlockM=32, 1 warp, 32x32 mfma (tiny-decode policy)
decode_d128_mha_m16, // kBlockM=16, 1 warp, 16x16 mfma (tiny-decode policy)
// d=128 (num_queries_per_kv chosen at *runtime* — same binary serves both
// MHA and GQA-N as long as num_qpkv divides kBlockM). kBlockM is the only
// structural compile-time knob; pick the tier by max_q after multiplying
// by num_qpkv in select_config.
prefill_d128, // kBlockM=256, 8 warps, 32x32 mfma
decode_d128_m128, // kBlockM=128, 4 warps, 32x32 mfma
decode_d128_m32, // kBlockM=32, 1 warp, 32x32 mfma (tiny-decode policy)
decode_d128_m16, // kBlockM=16, 1 warp, 16x16 mfma (tiny-decode policy)
// d=64 GQA-8 (num_queries_per_kv = 8)
prefill_d64_gqa8, // kBlockM=256, 8 warps, 32x32 mfma
decode_d64_gqa8_m128, // kBlockM=128, 4 warps, 32x32 mfma
decode_d64_gqa8_m64, // kBlockM=64, 2 warps, 32x32 mfma (decode policy)
decode_d64_gqa8_m16, // kBlockM=16, 1 warp, 16x16 mfma (tiny-decode policy)
// d=64.
prefill_d64, // kBlockM=256, 8 warps, 32x32 mfma
decode_d64_m128, // kBlockM=128, 4 warps, 32x32 mfma
decode_d64_m64, // kBlockM=64, 2 warps, 32x32 mfma (decode policy)
decode_d64_m16, // kBlockM=16, 1 warp, 16x16 mfma (tiny-decode policy)
};
// -----------------------------------------------------------------------------
@@ -81,22 +84,25 @@ struct unified_attention_problem_traits<unified_attention_args::data_type_enum::
//
// HeadSize : head dimension (compile-time)
// BlockM : Q-tile size along the M (token) axis
// NumQPerKV : 1 for MHA, 8 for GQA-8
// BlockSize : kBlockN — KV-tile size along the N axis
// BlockWarps : warp layout, sequence<M, N, K>
// WarpGemmShape : MFMA tile shape, sequence<M, N, K>
// Pipeline<P> : pipeline template (default vs decode vs tiny-decode policy)
// kUseDecodeGrid : selects 2D-by-seq grid (true) vs Q-block grid (false)
//
// num_queries_per_kv is *not* a compile-time knob: kBlockQ = kBlockM /
// num_qpkv is computed at runtime inside the kernel and pipeline. The only
// constraint is `kBlockM % num_qpkv == 0` (host-side select_config makes sure
// of this).
// =============================================================================
template <KernelVariant V>
struct variant_config;
template <>
struct variant_config<KernelVariant::prefill_d128_mha>
struct variant_config<KernelVariant::prefill_d128>
{
static constexpr index_t HeadSize = 128;
static constexpr index_t BlockM = 256;
static constexpr index_t NumQPerKV = 1;
static constexpr index_t BlockSize = 32;
using BlockWarps = sequence<8, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
@@ -106,11 +112,10 @@ struct variant_config<KernelVariant::prefill_d128_mha>
};
template <>
struct variant_config<KernelVariant::decode_d128_mha_m128>
struct variant_config<KernelVariant::decode_d128_m128>
{
static constexpr index_t HeadSize = 128;
static constexpr index_t BlockM = 128;
static constexpr index_t NumQPerKV = 1;
static constexpr index_t BlockSize = 32;
using BlockWarps = sequence<4, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
@@ -120,11 +125,10 @@ struct variant_config<KernelVariant::decode_d128_mha_m128>
};
template <>
struct variant_config<KernelVariant::decode_d128_mha_m32>
struct variant_config<KernelVariant::decode_d128_m32>
{
static constexpr index_t HeadSize = 128;
static constexpr index_t BlockM = 32;
static constexpr index_t NumQPerKV = 1;
static constexpr index_t BlockSize = 32;
using BlockWarps = sequence<1, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
@@ -134,11 +138,10 @@ struct variant_config<KernelVariant::decode_d128_mha_m32>
};
template <>
struct variant_config<KernelVariant::decode_d128_mha_m16>
struct variant_config<KernelVariant::decode_d128_m16>
{
static constexpr index_t HeadSize = 128;
static constexpr index_t BlockM = 16;
static constexpr index_t NumQPerKV = 1;
static constexpr index_t BlockSize = 32;
using BlockWarps = sequence<1, 1, 1>;
using WarpGemmShape = sequence<16, 16, 32>;
@@ -148,11 +151,10 @@ struct variant_config<KernelVariant::decode_d128_mha_m16>
};
template <>
struct variant_config<KernelVariant::prefill_d64_gqa8>
struct variant_config<KernelVariant::prefill_d64>
{
static constexpr index_t HeadSize = 64;
static constexpr index_t BlockM = 256;
static constexpr index_t NumQPerKV = 8;
static constexpr index_t BlockSize = 64;
using BlockWarps = sequence<8, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
@@ -162,11 +164,10 @@ struct variant_config<KernelVariant::prefill_d64_gqa8>
};
template <>
struct variant_config<KernelVariant::decode_d64_gqa8_m128>
struct variant_config<KernelVariant::decode_d64_m128>
{
static constexpr index_t HeadSize = 64;
static constexpr index_t BlockM = 128;
static constexpr index_t NumQPerKV = 8;
static constexpr index_t BlockSize = 64;
using BlockWarps = sequence<4, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
@@ -176,11 +177,10 @@ struct variant_config<KernelVariant::decode_d64_gqa8_m128>
};
template <>
struct variant_config<KernelVariant::decode_d64_gqa8_m64>
struct variant_config<KernelVariant::decode_d64_m64>
{
static constexpr index_t HeadSize = 64;
static constexpr index_t BlockM = 64;
static constexpr index_t NumQPerKV = 8;
static constexpr index_t BlockSize = 64;
using BlockWarps = sequence<2, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
@@ -190,11 +190,10 @@ struct variant_config<KernelVariant::decode_d64_gqa8_m64>
};
template <>
struct variant_config<KernelVariant::decode_d64_gqa8_m16>
struct variant_config<KernelVariant::decode_d64_m16>
{
static constexpr index_t HeadSize = 64;
static constexpr index_t BlockM = 16;
static constexpr index_t NumQPerKV = 8;
static constexpr index_t BlockSize = 64;
using BlockWarps = sequence<1, 1, 1>;
using WarpGemmShape = sequence<16, 16, 32>;
@@ -221,14 +220,19 @@ struct unified_attention_kernel_traits
static constexpr bool is_masking = IsMasking;
static constexpr KernelVariant variant = V;
static constexpr index_t HEAD_SIZE = cfg::HeadSize;
static constexpr index_t kBlockM = cfg::BlockM;
static constexpr index_t BLOCK_SIZE = cfg::BlockSize;
static constexpr index_t num_queries_per_kv = cfg::NumQPerKV;
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
static constexpr bool kUseDecodeGrid = cfg::kUseDecodeGrid;
static constexpr index_t HEAD_SIZE = cfg::HeadSize;
static constexpr index_t kBlockM = cfg::BlockM;
static constexpr index_t BLOCK_SIZE = cfg::BlockSize;
static constexpr bool kUseDecodeGrid = cfg::kUseDecodeGrid;
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
// The 2nd entry of the BlockTile is the static `kBlockQ` exposed via
// `UnifiedAttentionShape::kBlockQ`. Now that the kernel always reads
// kBlockQ from `args.num_queries_per_kv` at runtime, this static value
// is only the fallback when no num_qpkv was plumbed through (which never
// happens in practice). Anchor it at kBlockM so the static "looks like
// num_qpkv == 1" and any (kBlockM, num_qpkv) such that kBlockM % num_qpkv
// == 0 works without touching this trait.
using unified_attention_block_tile = sequence<kBlockM, kBlockM, BLOCK_SIZE, HEAD_SIZE>;
using unified_attention_warp_gemm_shape = typename cfg::WarpGemmShape;
using unified_attention_block_warps = typename cfg::BlockWarps;
@@ -281,8 +285,12 @@ template <typename Kernel, bool UseDecodeGrid = false>
float unified_attention_kernel_launch(const unified_attention_args& args,
const stream_config& config)
{
constexpr index_t kBlockQ = Kernel::kBlockQ;
index_t total_num_q_blocks = args.num_tokens / kBlockQ + args.num_seqs;
// kBlockQ is derived from the runtime num_queries_per_kv now -- the
// static `Kernel::kBlockQ` is anchored at kBlockM and would over-count
// tiles for GQA workloads. We assert kBlockM % num_qpkv == 0 in
// select_config so this integer divide is always exact.
const index_t kBlockQ = Kernel::kBlockM / args.num_queries_per_kv;
const index_t total_num_q_blocks = args.num_tokens / kBlockQ + args.num_seqs;
auto kargs = Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,

View File

@@ -263,13 +263,14 @@ struct UnifiedAttentionKernel
const index_t num_queries_per_kv = kargs.num_queries_per_kv;
// kBlockQ derived at runtime from num_queries_per_kv. For the variants
// we ship today this matches the compile-time `kBlockQ` from the
// pipeline trait (the assert below catches any disagreement); the
// explicit runtime form is what eventually lets a single kernel
// instantiation cover multiple num_queries_per_kv values.
// kBlockQ derived at runtime from num_queries_per_kv. The static
// `kBlockQ` from the pipeline trait is anchored at kBlockM (i.e. it
// describes num_qpkv == 1) so the same compiled binary serves every
// num_qpkv that divides kBlockM evenly -- e.g. the d=128 variants
// can run both MHA and GQA-N at runtime with no recompile. The host
// side (select_config) is responsible for enforcing kBlockM %
// num_queries_per_kv == 0.
const index_t kBlockQ_dyn = kBlockM / num_queries_per_kv;
assert(kBlockQ_dyn == kBlockQ);
// Split-KV: each CTA handles one (kv_head, q_block, split) tuple. The
// split index lives in z — when num_splits == 1 (the only z value)