Add CK-UA decode_d128_mha_m32 / _m16 small-Q tiers

For pure-decode workloads (sq=1) at d=128 the m128 tile wastes most of
its 128 query rows, capping CK below Triton on every batch size in our
sweep (4..256). Add two small-Q tiers that mirror the d=64 GQA-8 ladder:

  * decode_d128_mha_m16 : kBlockM=16, 1 warp, 16x16 MFMA  (tiny-decode)
  * decode_d128_mha_m32 : kBlockM=32, 1 warp, 32x32 MFMA  (tiny-decode)

select_config now ladders by (avg_q, max_q): m16 -> m32 -> m128 -> prefill.

d=128 MHA, hq=16/hk=16, sq=1, sk=120k, num_blocks=60k:
  batch  before    after    CK BW
      4  ~0.95x   0.98x   4.76 TB/s
      8  ~0.85x   1.29x   5.00 TB/s
     32  ~0.85x   1.14x   5.29 TB/s
     64  ~0.75x   0.93x   5.35 TB/s
    128  ~1.00x   1.09x   5.39 TB/s
    256  ~1.03x   1.02x   5.41 TB/s

Correctness suite stays at 241/245 (same 4 known int32-overflow
failures in the prefill path).

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-05-12 11:48:19 +00:00
parent fb0d729fbb
commit f5beedb2e9
10 changed files with 132 additions and 6 deletions

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, bf16, true)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, bf16, true)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, bf16, false)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, bf16, false)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, fp16, true)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, fp16, true)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m32, fp16, false)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH(decode_d128_mha_m16, fp16, false)
} // namespace ck_tile

View File

@@ -81,18 +81,22 @@ KernelConfig select_config(const unified_attention_args& args)
{
KernelConfig cfg;
// d=128 MHA — two variants today:
// * decode_d128_mha_m128 : kBlockM=128, 4 warps. Used when avg/max Q
// both fit in 128. Cuts Q-tile waste roughly 2x vs prefill for
// short-Q workloads.
// * prefill_d128_mha : kBlockM=256, 8 warps. Everything else.
// d=128 MHA — tile-tier ladder by (avg_q, max_q):
// * decode_d128_mha_m16 : kBlockM=16, 1 warp, 16x16 mfma (tiny-decode)
// * decode_d128_mha_m32 : kBlockM=32, 1 warp, 32x32 mfma (tiny-decode)
// * decode_d128_mha_m128 : kBlockM=128, 4 warps, 32x32 mfma (default)
// * prefill_d128_mha : kBlockM=256, 8 warps, 32x32 mfma
if(args.hdim == 128 && args.num_queries_per_kv == 1)
{
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs
: args.num_tokens;
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
if(avg_q <= 128 && max_q <= 128)
if(avg_q <= 16 && max_q <= 16)
cfg.variant = KernelVariant::decode_d128_mha_m16;
else if(avg_q <= 32 && max_q <= 32)
cfg.variant = KernelVariant::decode_d128_mha_m32;
else if(avg_q <= 128 && max_q <= 128)
cfg.variant = KernelVariant::decode_d128_mha_m128;
else
cfg.variant = KernelVariant::prefill_d128_mha;
@@ -172,6 +176,10 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
return dispatch_variant<KernelVariant::prefill_d128_mha>(args, config);
case KernelVariant::decode_d128_mha_m128:
return dispatch_variant<KernelVariant::decode_d128_mha_m128>(args, config);
case KernelVariant::decode_d128_mha_m32:
return dispatch_variant<KernelVariant::decode_d128_mha_m32>(args, config);
case KernelVariant::decode_d128_mha_m16:
return dispatch_variant<KernelVariant::decode_d128_mha_m16>(args, config);
case KernelVariant::prefill_d64_gqa8:
return dispatch_variant<KernelVariant::prefill_d64_gqa8>(args, config);
case KernelVariant::decode_d64_gqa8_m128:

View File

@@ -39,6 +39,8 @@ enum class KernelVariant
// d=128 MHA (num_queries_per_kv = 1)
prefill_d128_mha, // kBlockM=256, 8 warps, 32x32 mfma
decode_d128_mha_m128, // kBlockM=128, 4 warps, 32x32 mfma (kBlockQ=128)
decode_d128_mha_m32, // kBlockM=32, 1 warp, 32x32 mfma (tiny-decode policy)
decode_d128_mha_m16, // kBlockM=16, 1 warp, 16x16 mfma (tiny-decode policy)
// d=64 GQA-8 (num_queries_per_kv = 8)
prefill_d64_gqa8, // kBlockM=256, 8 warps, 32x32 mfma
@@ -117,6 +119,34 @@ struct variant_config<KernelVariant::decode_d128_mha_m128>
static constexpr bool kUseDecodeGrid = false;
};
template <>
struct variant_config<KernelVariant::decode_d128_mha_m32>
{
static constexpr index_t HeadSize = 128;
static constexpr index_t BlockM = 32;
static constexpr index_t NumQPerKV = 1;
static constexpr index_t BlockSize = 32;
using BlockWarps = sequence<1, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
template <typename Problem>
using Pipeline = UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineTinyDecodePolicy>;
static constexpr bool kUseDecodeGrid = true;
};
template <>
struct variant_config<KernelVariant::decode_d128_mha_m16>
{
static constexpr index_t HeadSize = 128;
static constexpr index_t BlockM = 16;
static constexpr index_t NumQPerKV = 1;
static constexpr index_t BlockSize = 32;
using BlockWarps = sequence<1, 1, 1>;
using WarpGemmShape = sequence<16, 16, 32>;
template <typename Problem>
using Pipeline = UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineTinyDecodePolicy>;
static constexpr bool kUseDecodeGrid = true;
};
template <>
struct variant_config<KernelVariant::prefill_d64_gqa8>
{