Add d=128 MHA decode variant (decode_d128_mha_m128)

Until now every d=128 MHA workload took the 8-warp prefill kernel
(kBlockM=256, kBlockQ=256), wasting 255/256 Q rows on pure-decode
shapes where Q is 1. Add a dedicated 4-warp decode variant with
kBlockM=128 (kBlockQ=128) that cuts the Q-tile waste roughly in half.

  * Four new instance files at instances/unified_attention_d128_*_decode.cpp,
    each instantiating unified_attention_decode_kernel_traits<dt, mask, 128, 128, 1>.
  * KernelVariant::decode_d128_mha_m128 wired into select_config: chosen
    when both avg_q and max_seqlen_q fit in 128, else fall back to prefill.

Tests: ua-test-scripts/test_unified_attention_ck_correctness.py stays at
236/240 -- the pure-decode seq_lens pattern in head_config=(16,16,128)
now routes to the new variant and matches the torch reference. The 4
remaining failures are the pre-existing int32-overflow case (orthogonal).

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-05-12 09:34:52 +00:00
parent 3ab4df37e2
commit fddb0d21cd
5 changed files with 85 additions and 4 deletions

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 128, 128, 1>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 128, 128, 1>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 128, 128, 1>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 128, 128, 1>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -49,6 +49,7 @@ std::ostream& operator<<(std::ostream& stream,
enum class KernelVariant {
// d=128 MHA (num_queries_per_kv = 1)
prefill_d128_mha, // kBlockM=256, 8 warps, 32x32 mfma
decode_d128_mha_m128, // kBlockM=128, 4 warps, 32x32 mfma (kBlockQ=128)
// d=64 GQA-8 (num_queries_per_kv = 8)
prefill_d64_gqa8, // kBlockM=256, 8 warps, 32x32 mfma
@@ -95,12 +96,21 @@ KernelConfig select_config(const unified_attention_args& args)
{
KernelConfig cfg;
// d=128 MHA — only the 8-warp prefill kernel exists today. A dedicated
// d=128 decode variant is the next commit; until then all d=128 traffic
// takes the prefill kernel (Q-tile waste for short Q, but correct).
// d=128 MHA — two variants today:
// * decode_d128_mha_m128 : kBlockM=128, 4 warps. Used when avg/max Q
// both fit in 128. Cuts Q-tile waste roughly 2x vs prefill for
// short-Q workloads.
// * prefill_d128_mha : kBlockM=256, 8 warps. Everything else.
if (args.hdim == 128 && args.num_queries_per_kv == 1)
{
cfg.variant = KernelVariant::prefill_d128_mha;
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs
: args.num_tokens;
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
if (avg_q <= 128 && max_q <= 128)
cfg.variant = KernelVariant::decode_d128_mha_m128;
else
cfg.variant = KernelVariant::prefill_d128_mha;
return cfg;
}
@@ -207,6 +217,20 @@ std::pair<bool, float> dispatch_prefill_d128_mha(
return {false, -1.f};
}
std::pair<bool, float> dispatch_decode_d128_mha_m128(
const unified_attention_args& args, const stream_config& config)
{
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
if (args.data_type == DType::fp16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::fp16, false, 128, 128, 1)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::fp16, true, 128, 128, 1)
} else if (args.data_type == DType::bf16) {
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::bf16, false, 128, 128, 1)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::bf16, true, 128, 128, 1)
}
return {false, -1.f};
}
std::pair<bool, float> dispatch_prefill_d64_gqa8(
const unified_attention_args& args, const stream_config& config)
{
@@ -324,6 +348,7 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
switch (cfg.variant)
{
case KernelVariant::prefill_d128_mha: return dispatch_prefill_d128_mha(args, config);
case KernelVariant::decode_d128_mha_m128: return dispatch_decode_d128_mha_m128(args, config);
case KernelVariant::prefill_d64_gqa8: return dispatch_prefill_d64_gqa8(args, config);
case KernelVariant::decode_d64_gqa8_m128: return dispatch_decode_d64_gqa8_m128(args, config);
case KernelVariant::decode_d64_gqa8_m128_p32: return dispatch_decode_d64_gqa8_m128_p32(args, config);