From bbc748defe44594e3af111dafdd8030b6291e95c Mon Sep 17 00:00:00 2001 From: Amir Ghamarian Date: Fri, 27 Mar 2026 09:41:10 -0500 Subject: [PATCH] Add unified attention d64/GQA-8 kernel instances and fix BLOCK_SIZE for small head dims The unified attention kernel previously only supported head_size=128 with MHA (NumQPerKV=1). This change adds support for head_size=64 with GQA-8 (NumQPerKV=8), which is the configuration used by models like DeepSeek-V3/R1 (64 query heads, 8 KV heads, head_dim=64). Changes: - Add 4 new kernel instance files for d64/GQA-8: unified_attention_d64_{bf16,fp16}_{nmask,mask}_gqa8.cpp - Add d64/GQA-8 dispatch path in unified_attention.cpp - Fix BLOCK_SIZE (kPageBlockSize) in unified_attention_kernel_traits: compute from HEAD_SIZE instead of hardcoding 32. For HeadSize<=64, BLOCK_SIZE must be 64 to guarantee NumIssues>=1 on gfx950. With 128-bit vector loads (KVector=8), LaneGroups*NumWarps=128 exceeds kPageBlockSize=32 when HeadSize=64, causing a division-by-zero in the LDS tile descriptor constexpr evaluation. --- .../unified_attention_d64_bf16_mask_gqa8.cpp | 14 ++++++++++++++ .../unified_attention_d64_bf16_nmask_gqa8.cpp | 14 ++++++++++++++ .../unified_attention_d64_fp16_mask_gqa8.cpp | 14 ++++++++++++++ .../unified_attention_d64_fp16_nmask_gqa8.cpp | 14 ++++++++++++++ .../42_unified_attention/unified_attention.cpp | 18 +++++++++++++++++- .../unified_attention_impl.hpp | 4 +++- 6 files changed, 76 insertions(+), 2 deletions(-) create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8.cpp create mode 100644 example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8.cpp diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8.cpp new file mode 100644 index 0000000000..0b6be68278 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8.cpp new file mode 100644 index 0000000000..6bd3dd6f58 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8.cpp new file mode 100644 index 0000000000..28ff9f22b1 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8.cpp new file mode 100644 index 0000000000..f4d83a06a0 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp index f9f88a7913..75a1629933 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -32,9 +32,10 @@ std::pair unified_attention(const unified_attention_args& args, const bool is_mask = (args.mask_type != static_cast(mask_enum::no_mask)); // Route based on (data_type, mask, hdim, num_queries_per_kv). - // Only d128 MHA (8 warps, kBlockM=256) instances available. // Decode-tuned instances require pipeline changes (NumWarpGroups must == 2, // which means exactly 8 warps; fewer warps are not supported). + + // d128, MHA (num_queries_per_kv == 1) if(args.hdim == 128 && args.num_queries_per_kv == 1) { if(args.data_type == unified_attention_args::data_type_enum::fp16) @@ -49,6 +50,21 @@ std::pair unified_attention(const unified_attention_args& args, } } + // d64, GQA-8 (num_queries_per_kv == 8) + if(args.hdim == 64 && args.num_queries_per_kv == 8) + { + if(args.data_type == unified_attention_args::data_type_enum::fp16) + { + if(!is_mask) DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::fp16, false, 64, 256, 8) + else DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::fp16, true, 64, 256, 8) + } + else if(args.data_type == unified_attention_args::data_type_enum::bf16) + { + if(!is_mask) DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::bf16, false, 64, 256, 8) + else DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::bf16, true, 64, 256, 8) + } + } + std::cerr << "unified_attention: no matching kernel instance for hdim=" << args.hdim << " num_queries_per_kv=" << args.num_queries_per_kv << " data_type=" << args.data_type << " mask_type=" << args.mask_type << std::endl; diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index bbd5ffb912..77c97ce109 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -64,8 +64,10 @@ struct unified_attention_kernel_traits static constexpr bool is_masking = IsMasking; static constexpr index_t kBlockM = BlockM_; - static constexpr index_t BLOCK_SIZE = 32; static constexpr index_t HEAD_SIZE = HeadSize_; + // On gfx950 with 128-bit loads (KVector=8), NumIssues = kPageBlockSize*HeadSize/4096. + // For HeadSize<=64 we need kPageBlockSize>=64 to keep NumIssues>=1. + static constexpr index_t BLOCK_SIZE = (HEAD_SIZE <= 64) ? 64 : 32; static constexpr index_t num_queries_per_kv = NumQPerKV_; static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;