diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8.cpp new file mode 100644 index 0000000000..0b6be68278 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8.cpp new file mode 100644 index 0000000000..6bd3dd6f58 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8.cpp new file mode 100644 index 0000000000..28ff9f22b1 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8.cpp new file mode 100644 index 0000000000..f4d83a06a0 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp index f9f88a7913..75a1629933 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -32,9 +32,10 @@ std::pair unified_attention(const unified_attention_args& args, const bool is_mask = (args.mask_type != static_cast(mask_enum::no_mask)); // Route based on (data_type, mask, hdim, num_queries_per_kv). - // Only d128 MHA (8 warps, kBlockM=256) instances available. // Decode-tuned instances require pipeline changes (NumWarpGroups must == 2, // which means exactly 8 warps; fewer warps are not supported). + + // d128, MHA (num_queries_per_kv == 1) if(args.hdim == 128 && args.num_queries_per_kv == 1) { if(args.data_type == unified_attention_args::data_type_enum::fp16) @@ -49,6 +50,21 @@ std::pair unified_attention(const unified_attention_args& args, } } + // d64, GQA-8 (num_queries_per_kv == 8) + if(args.hdim == 64 && args.num_queries_per_kv == 8) + { + if(args.data_type == unified_attention_args::data_type_enum::fp16) + { + if(!is_mask) DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::fp16, false, 64, 256, 8) + else DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::fp16, true, 64, 256, 8) + } + else if(args.data_type == unified_attention_args::data_type_enum::bf16) + { + if(!is_mask) DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::bf16, false, 64, 256, 8) + else DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::bf16, true, 64, 256, 8) + } + } + std::cerr << "unified_attention: no matching kernel instance for hdim=" << args.hdim << " num_queries_per_kv=" << args.num_queries_per_kv << " data_type=" << args.data_type << " mask_type=" << args.mask_type << std::endl; diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index bbd5ffb912..77c97ce109 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -64,8 +64,10 @@ struct unified_attention_kernel_traits static constexpr bool is_masking = IsMasking; static constexpr index_t kBlockM = BlockM_; - static constexpr index_t BLOCK_SIZE = 32; static constexpr index_t HEAD_SIZE = HeadSize_; + // On gfx950 with 128-bit loads (KVector=8), NumIssues = kPageBlockSize*HeadSize/4096. + // For HeadSize<=64 we need kPageBlockSize>=64 to keep NumIssues>=1. + static constexpr index_t BLOCK_SIZE = (HEAD_SIZE <= 64) ? 64 : 32; static constexpr index_t num_queries_per_kv = NumQPerKV_; static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;