mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Add d=128 MHA decode variant (decode_d128_mha_m128)
Until now every d=128 MHA workload took the 8-warp prefill kernel
(kBlockM=256, kBlockQ=256), wasting 255/256 Q rows on pure-decode
shapes where Q is 1. Add a dedicated 4-warp decode variant with
kBlockM=128 (kBlockQ=128) that cuts the Q-tile waste roughly in half.
* Four new instance files at instances/unified_attention_d128_*_decode.cpp,
each instantiating unified_attention_decode_kernel_traits<dt, mask, 128, 128, 1>.
* KernelVariant::decode_d128_mha_m128 wired into select_config: chosen
when both avg_q and max_seqlen_q fit in 128, else fall back to prefill.
Tests: ua-test-scripts/test_unified_attention_ck_correctness.py stays at
236/240 -- the pure-decode seq_lens pattern in head_config=(16,16,128)
now routes to the new variant and matches the torch reference. The 4
remaining failures are the pre-existing int32-overflow case (orthogonal).
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 128, 128, 1>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 128, 128, 1>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 128, 128, 1>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 128, 128, 1>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -49,6 +49,7 @@ std::ostream& operator<<(std::ostream& stream,
|
||||
enum class KernelVariant {
|
||||
// d=128 MHA (num_queries_per_kv = 1)
|
||||
prefill_d128_mha, // kBlockM=256, 8 warps, 32x32 mfma
|
||||
decode_d128_mha_m128, // kBlockM=128, 4 warps, 32x32 mfma (kBlockQ=128)
|
||||
|
||||
// d=64 GQA-8 (num_queries_per_kv = 8)
|
||||
prefill_d64_gqa8, // kBlockM=256, 8 warps, 32x32 mfma
|
||||
@@ -95,12 +96,21 @@ KernelConfig select_config(const unified_attention_args& args)
|
||||
{
|
||||
KernelConfig cfg;
|
||||
|
||||
// d=128 MHA — only the 8-warp prefill kernel exists today. A dedicated
|
||||
// d=128 decode variant is the next commit; until then all d=128 traffic
|
||||
// takes the prefill kernel (Q-tile waste for short Q, but correct).
|
||||
// d=128 MHA — two variants today:
|
||||
// * decode_d128_mha_m128 : kBlockM=128, 4 warps. Used when avg/max Q
|
||||
// both fit in 128. Cuts Q-tile waste roughly 2x vs prefill for
|
||||
// short-Q workloads.
|
||||
// * prefill_d128_mha : kBlockM=256, 8 warps. Everything else.
|
||||
if (args.hdim == 128 && args.num_queries_per_kv == 1)
|
||||
{
|
||||
cfg.variant = KernelVariant::prefill_d128_mha;
|
||||
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs
|
||||
: args.num_tokens;
|
||||
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
|
||||
|
||||
if (avg_q <= 128 && max_q <= 128)
|
||||
cfg.variant = KernelVariant::decode_d128_mha_m128;
|
||||
else
|
||||
cfg.variant = KernelVariant::prefill_d128_mha;
|
||||
return cfg;
|
||||
}
|
||||
|
||||
@@ -207,6 +217,20 @@ std::pair<bool, float> dispatch_prefill_d128_mha(
|
||||
return {false, -1.f};
|
||||
}
|
||||
|
||||
std::pair<bool, float> dispatch_decode_d128_mha_m128(
|
||||
const unified_attention_args& args, const stream_config& config)
|
||||
{
|
||||
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
|
||||
if (args.data_type == DType::fp16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::fp16, false, 128, 128, 1)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::fp16, true, 128, 128, 1)
|
||||
} else if (args.data_type == DType::bf16) {
|
||||
if (!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::bf16, false, 128, 128, 1)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType::bf16, true, 128, 128, 1)
|
||||
}
|
||||
return {false, -1.f};
|
||||
}
|
||||
|
||||
std::pair<bool, float> dispatch_prefill_d64_gqa8(
|
||||
const unified_attention_args& args, const stream_config& config)
|
||||
{
|
||||
@@ -324,6 +348,7 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
switch (cfg.variant)
|
||||
{
|
||||
case KernelVariant::prefill_d128_mha: return dispatch_prefill_d128_mha(args, config);
|
||||
case KernelVariant::decode_d128_mha_m128: return dispatch_decode_d128_mha_m128(args, config);
|
||||
case KernelVariant::prefill_d64_gqa8: return dispatch_prefill_d64_gqa8(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m128: return dispatch_decode_d64_gqa8_m128(args, config);
|
||||
case KernelVariant::decode_d64_gqa8_m128_p32: return dispatch_decode_d64_gqa8_m128_p32(args, config);
|
||||
|
||||
Reference in New Issue
Block a user