mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Add unified attention d64/GQA-8 kernel instances and fix BLOCK_SIZE for small head dims
The unified attention kernel previously only supported head_size=128
with MHA (NumQPerKV=1). This change adds support for head_size=64 with
GQA-8 (NumQPerKV=8), which is the configuration used by models like
DeepSeek-V3/R1 (64 query heads, 8 KV heads, head_dim=64).
Changes:
- Add 4 new kernel instance files for d64/GQA-8:
unified_attention_d64_{bf16,fp16}_{nmask,mask}_gqa8.cpp
- Add d64/GQA-8 dispatch path in unified_attention.cpp
- Fix BLOCK_SIZE (kPageBlockSize) in unified_attention_kernel_traits:
compute from HEAD_SIZE instead of hardcoding 32. For HeadSize<=64,
BLOCK_SIZE must be 64 to guarantee NumIssues>=1 on gfx950. With
128-bit vector loads (KVector=8), LaneGroups*NumWarps=128 exceeds
kPageBlockSize=32 when HeadSize=64, causing a division-by-zero in
the LDS tile descriptor constexpr evaluation.
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 256, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 256, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 256, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 256, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -32,9 +32,10 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
|
||||
|
||||
// Route based on (data_type, mask, hdim, num_queries_per_kv).
|
||||
// Only d128 MHA (8 warps, kBlockM=256) instances available.
|
||||
// Decode-tuned instances require pipeline changes (NumWarpGroups must == 2,
|
||||
// which means exactly 8 warps; fewer warps are not supported).
|
||||
|
||||
// d128, MHA (num_queries_per_kv == 1)
|
||||
if(args.hdim == 128 && args.num_queries_per_kv == 1)
|
||||
{
|
||||
if(args.data_type == unified_attention_args::data_type_enum::fp16)
|
||||
@@ -49,6 +50,21 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
}
|
||||
}
|
||||
|
||||
// d64, GQA-8 (num_queries_per_kv == 8)
|
||||
if(args.hdim == 64 && args.num_queries_per_kv == 8)
|
||||
{
|
||||
if(args.data_type == unified_attention_args::data_type_enum::fp16)
|
||||
{
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::fp16, false, 64, 256, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::fp16, true, 64, 256, 8)
|
||||
}
|
||||
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
|
||||
{
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::bf16, false, 64, 256, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::bf16, true, 64, 256, 8)
|
||||
}
|
||||
}
|
||||
|
||||
std::cerr << "unified_attention: no matching kernel instance for hdim=" << args.hdim
|
||||
<< " num_queries_per_kv=" << args.num_queries_per_kv
|
||||
<< " data_type=" << args.data_type << " mask_type=" << args.mask_type << std::endl;
|
||||
|
||||
@@ -64,8 +64,10 @@ struct unified_attention_kernel_traits
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t BLOCK_SIZE = 32;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
// On gfx950 with 128-bit loads (KVector=8), NumIssues = kPageBlockSize*HeadSize/4096.
|
||||
// For HeadSize<=64 we need kPageBlockSize>=64 to keep NumIssues>=1.
|
||||
static constexpr index_t BLOCK_SIZE = (HEAD_SIZE <= 64) ? 64 : 32;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
|
||||
Reference in New Issue
Block a user