Add unified attention d64/GQA-8 kernel instances and fix BLOCK_SIZE for small head dims

The unified attention kernel previously only supported head_size=128
with MHA (NumQPerKV=1). This change adds support for head_size=64 with
GQA-8 (NumQPerKV=8), which is the configuration used by models like
DeepSeek-V3/R1 (64 query heads, 8 KV heads, head_dim=64).
Changes:
- Add 4 new kernel instance files for d64/GQA-8:
  unified_attention_d64_{bf16,fp16}_{nmask,mask}_gqa8.cpp
- Add d64/GQA-8 dispatch path in unified_attention.cpp
- Fix BLOCK_SIZE (kPageBlockSize) in unified_attention_kernel_traits:
  compute from HEAD_SIZE instead of hardcoding 32. For HeadSize<=64,
  BLOCK_SIZE must be 64 to guarantee NumIssues>=1 on gfx950. With
  128-bit vector loads (KVector=8), LaneGroups*NumWarps=128 exceeds
  kPageBlockSize=32 when HeadSize=64, causing a division-by-zero in
  the LDS tile descriptor constexpr evaluation.
This commit is contained in:
Amir Ghamarian
2026-03-27 09:41:10 -05:00
parent 5cd4b441ab
commit bbc748defe
6 changed files with 76 additions and 2 deletions

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 256, 8>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 256, 8>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 256, 8>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 256, 8>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -32,9 +32,10 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
// Route based on (data_type, mask, hdim, num_queries_per_kv).
// Only d128 MHA (8 warps, kBlockM=256) instances available.
// Decode-tuned instances require pipeline changes (NumWarpGroups must == 2,
// which means exactly 8 warps; fewer warps are not supported).
// d128, MHA (num_queries_per_kv == 1)
if(args.hdim == 128 && args.num_queries_per_kv == 1)
{
if(args.data_type == unified_attention_args::data_type_enum::fp16)
@@ -49,6 +50,21 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
}
}
// d64, GQA-8 (num_queries_per_kv == 8)
if(args.hdim == 64 && args.num_queries_per_kv == 8)
{
if(args.data_type == unified_attention_args::data_type_enum::fp16)
{
if(!is_mask) DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::fp16, false, 64, 256, 8)
else DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::fp16, true, 64, 256, 8)
}
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
{
if(!is_mask) DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::bf16, false, 64, 256, 8)
else DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::bf16, true, 64, 256, 8)
}
}
std::cerr << "unified_attention: no matching kernel instance for hdim=" << args.hdim
<< " num_queries_per_kv=" << args.num_queries_per_kv
<< " data_type=" << args.data_type << " mask_type=" << args.mask_type << std::endl;

View File

@@ -64,8 +64,10 @@ struct unified_attention_kernel_traits
static constexpr bool is_masking = IsMasking;
static constexpr index_t kBlockM = BlockM_;
static constexpr index_t BLOCK_SIZE = 32;
static constexpr index_t HEAD_SIZE = HeadSize_;
// On gfx950 with 128-bit loads (KVector=8), NumIssues = kPageBlockSize*HeadSize/4096.
// For HeadSize<=64 we need kPageBlockSize>=64 to keep NumIssues>=1.
static constexpr index_t BLOCK_SIZE = (HEAD_SIZE <= 64) ? 64 : 32;
static constexpr index_t num_queries_per_kv = NumQPerKV_;
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;