CK-UA int32 overflow protection: add explicit template parameter for large caches

Add CachePtrInt32OverflowPossible template parameter (default false) to all
unified attention kernel traits. This enables dual kernel variants:
- Small cache (false): compile-time elimination of overflow checks for <100K blocks
- Large cache (true): runtime overflow checking with pointer rebasing for >=100K blocks

Key changes:
- Add CachePtrInt32OverflowPossible as 14th template parameter to UnifiedAttentionPipelineProblem
- Pass parameter through all kernel traits: decode, decode_small, decode_tiny, decode_bs32
- Implement overflow checking in pipeline with if constexpr for zero overhead when disabled
- Update dispatch macros with _SMALL_CACHE and _LARGE_CACHE variants
- Create instance files for both small and large cache variants (narrow, _s, _m tiers)
- Remove old MAX_NUM_BLOCKS inference logic (num_kv_heads is runtime, cannot infer)

Python calculates overflow possibility based on actual cache size and passes
it explicitly via cache_ptr_int32_overflow_possible parameter.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
This commit is contained in:
juuso-oskari
2026-05-08 10:15:38 +00:00
parent e9cf036a81
commit 397febf42c
33 changed files with 262 additions and 216 deletions

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 128, 8, 32>;
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 128, 8, 32, true>; // Large cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)

View File

@@ -6,9 +6,9 @@
namespace ck_tile {
// Medium-tier small-cache optimized variant: MaxNumBlocks=100000 (zero rebasing overhead)
// Medium-tier small-cache optimized variant: MaxNumBlocks=false (zero rebasing overhead)
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 128, 8, 32, 100000>;
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 128, 8, 32, false>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 64, 8, 32>;
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 64, 8, 32, true>; // Large cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)

View File

@@ -6,9 +6,9 @@
namespace ck_tile {
// Small-cache optimized variant: MaxNumBlocks=100000 (zero rebasing overhead)
// Small-cache optimized variant: MaxNumBlocks=false (zero rebasing overhead)
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 64, 8, 32, 100000>;
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 64, 8, 32, false>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 16, 8, 32>;
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 16, 8, 32, true>; // Large cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 32, 8, 32>;
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 32, 8, 32, true>; // Large cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 32, 8, 32, false>; // Small cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 128, 8, 32>;
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 128, 8, 32, true>; // Large cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)

View File

@@ -6,9 +6,9 @@
namespace ck_tile {
// Medium-tier small-cache optimized variant: MaxNumBlocks=100000 (zero rebasing overhead)
// Medium-tier small-cache optimized variant: MaxNumBlocks=false (zero rebasing overhead)
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 128, 8, 32, 100000>;
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 128, 8, 32, false>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 64, 8, 32>;
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 64, 8, 32, true>; // Large cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)

View File

@@ -6,9 +6,9 @@
namespace ck_tile {
// Small-cache optimized variant: MaxNumBlocks=100000 (zero rebasing overhead)
// Small-cache optimized variant: CachePtrInt32OverflowPossible=false (no overflow checks)
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 64, 8, 32, 100000>;
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 64, 8, 32, false>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 16, 8, 32>;
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 16, 8, 32, true>; // Large cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 32, 8, 32>;
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 32, 8, 32, true>; // Large cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 32, 8, 32, false>; // Small cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 128, 8, 32>;
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 128, 8, 32, true>; // Large cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)

View File

@@ -0,0 +1,15 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
// Medium-tier small-cache optimized variant: MaxNumBlocks=false (zero rebasing overhead)
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 128, 8, 32, false>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 64, 8, 32>;
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 64, 8, 32, true>; // Large cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)

View File

@@ -0,0 +1,15 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
// Small-cache optimized variant: MaxNumBlocks=false (zero rebasing overhead)
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 64, 8, 32, false>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 16, 8, 32>;
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 16, 8, 32, true>; // Large cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 32, 8, 32>;
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 32, 8, 32, true>; // Large cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 32, 8, 32, false>; // Small cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 128, 8, 32>;
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 128, 8, 32, true>; // Large cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)

View File

@@ -0,0 +1,15 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
// Medium-tier small-cache optimized variant: MaxNumBlocks=false (zero rebasing overhead)
using kernel_traits =
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 128, 8, 32, false>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 64, 8, 32>;
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 64, 8, 32, true>; // Large cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)

View File

@@ -0,0 +1,15 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
// Small-cache optimized variant: CachePtrInt32OverflowPossible=false (no overflow checks)
using kernel_traits =
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 64, 8, 32, false>;
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 16, 8, 32>;
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 16, 8, 32, true>; // Large cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 32, 8, 32>;
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 32, 8, 32, true>; // Large cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
using kernel_traits =
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 32, 8, 32, false>; // Small cache: overflow checks enabled
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
} // namespace ck_tile

View File

@@ -64,24 +64,43 @@ std::ostream& operator<<(std::ostream& stream,
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}
// Small-cache variants (7th template arg = MaxNumBlocks for compile-time overflow elimination).
// For d64/GQA-8/bs32: overflow threshold = 2^31 / (32 * 64) = 1,048,575 blocks.
// Set MaxNumBlocks = 100,000 (conservative, safe for ~98K blocks) to guarantee no overflow.
// Small-cache variants (7th template arg = CachePtrInt32OverflowPossible=false).
// For small caches (<100K blocks), we can guarantee no int32 overflow, so compile-time eliminate overflow checks.
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_SMALL_CACHE(DType, IsMask, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, 100000>; \
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, false>; \
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
}
#define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE(DType, IsMask, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, 100000>; \
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, false>; \
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}
#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_SMALL_CACHE(DType, IsMask, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, 100000>; \
using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, false>; \
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}
// Large-cache variants (7th template arg = CachePtrInt32OverflowPossible=true).
// For large caches (>=100K blocks), enable runtime overflow checking with pointer rebasing.
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LARGE_CACHE(DType, IsMask, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, true>; \
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
}
#define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_LARGE_CACHE(DType, IsMask, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, true>; \
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}
#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LARGE_CACHE(DType, IsMask, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, true>; \
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}
@@ -108,26 +127,14 @@ static tile_tier select_tile_tier(const unified_attention_args& args)
return tile_tier::medium;
}
// Select between small-cache (compile-time overflow elimination) and large-cache variants.
// For d64/bs32: overflow threshold = 2^31 / (32 * 64) = 1,048,575 blocks
// We use 100,000 as the small-cache limit (conservative, safe for ~98K blocks)
static bool use_small_cache_variant(const unified_attention_args& args)
{
// Only optimize for d64 with block_size < 64 (bs32 variants)
if(args.hdim != 64 || args.page_blk_size >= 64)
return false;
// Conservative threshold: 100,000 blocks (~98K)
// This guarantees no int32 overflow for d64/bs32
constexpr index_t kSmallCacheThreshold = 100000;
return args.num_blks <= kSmallCacheThreshold;
}
std::pair<bool, float> unified_attention(const unified_attention_args& args,
const stream_config& config)
{
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
const auto tier = select_tile_tier(args);
// Python calculates overflow possibility and passes it directly
const bool use_small_cache = !args.cache_ptr_int32_overflow_possible;
// d128, MHA (num_queries_per_kv == 1)
if(args.hdim == 128 && args.num_queries_per_kv == 1)
@@ -148,7 +155,6 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
if(args.hdim == 64 && args.num_queries_per_kv == 8)
{
const bool use_bs32 = (args.page_blk_size < 64);
const bool use_small_cache = use_small_cache_variant(args);
if(tier == tile_tier::tiny)
{
@@ -157,13 +163,23 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
// Avoids 1-warp race condition; 2x less waste than small tier.
if(args.data_type == unified_attention_args::data_type_enum::fp16)
{
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, false, 64, 32, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, true, 64, 32, 8)
if(use_small_cache) {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_SMALL_CACHE(unified_attention_args::data_type_enum::fp16, false, 64, 32, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_SMALL_CACHE(unified_attention_args::data_type_enum::fp16, true, 64, 32, 8)
} else {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LARGE_CACHE(unified_attention_args::data_type_enum::fp16, false, 64, 32, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LARGE_CACHE(unified_attention_args::data_type_enum::fp16, true, 64, 32, 8)
}
}
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
{
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, false, 64, 32, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, true, 64, 32, 8)
if(use_small_cache) {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 32, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 32, 8)
} else {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LARGE_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 32, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LARGE_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 32, 8)
}
}
} else {
// bs64 tiny: 1 warp, 16x16 MFMA, kBlockM=16, kBlockQ=2.
@@ -184,8 +200,13 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
if(args.data_type == unified_attention_args::data_type_enum::fp16)
{
if(use_bs32) {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::fp16, false, 64, 64, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::fp16, true, 64, 64, 8)
if(use_small_cache) {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::fp16, false, 64, 64, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::fp16, true, 64, 64, 8)
} else {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::fp16, false, 64, 64, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::fp16, true, 64, 64, 8)
}
} else {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(unified_attention_args::data_type_enum::fp16, false, 64, 64, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(unified_attention_args::data_type_enum::fp16, true, 64, 64, 8)
@@ -198,8 +219,8 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 64, 8)
} else {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::bf16, true, 64, 64, 8)
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 64, 8)
}
} else {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8)
@@ -212,8 +233,13 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
if(args.data_type == unified_attention_args::data_type_enum::fp16)
{
if(use_bs32) {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8)
if(use_small_cache) {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8)
} else {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8)
}
} else {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8)
@@ -222,8 +248,13 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
{
if(use_bs32) {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8)
if(use_small_cache) {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8)
} else {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8)
}
} else {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8)

View File

@@ -67,6 +67,8 @@ struct unified_attention_args
index_t num_seqs; // number of batches for q
index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown)
bool cache_ptr_int32_overflow_possible = false; // true = use large cache variant with overflow checks
};
std::ostream& operator<<(std::ostream& stream,

View File

@@ -67,7 +67,8 @@ template <unified_attention_args::data_type_enum DataType,
index_t HeadSize_ = 128,
index_t BlockM_ = 256,
index_t NumQPerKV_ = 1,
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32>
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32,
bool CachePtrInt32OverflowPossible_ = false>
struct unified_attention_kernel_traits
{
static constexpr auto date_type = DataType;
@@ -116,7 +117,7 @@ struct unified_attention_kernel_traits
unified_attention_shape,
unified_attention_mask,
unified_attention_traits,
-1>; // MaxNumBlocks = -1 (runtime check) for prefill/large tiles
CachePtrInt32OverflowPossible_>;
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
@@ -139,7 +140,7 @@ template <unified_attention_args::data_type_enum DataType,
index_t BlockM_ = 128,
index_t NumQPerKV_ = 1,
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32,
index_t MaxNumBlocks_ = -1> // -1 means no compile-time limit (runtime check)
bool CachePtrInt32OverflowPossible_ = false> // Default false = no overflow expected
struct unified_attention_decode_kernel_traits
{
static constexpr auto date_type = DataType;
@@ -148,7 +149,6 @@ struct unified_attention_decode_kernel_traits
static constexpr index_t kBlockM = BlockM_;
static constexpr index_t HEAD_SIZE = HeadSize_;
static constexpr index_t BLOCK_SIZE = BlockSize_;
static constexpr index_t MAX_NUM_BLOCKS = MaxNumBlocks_;
static constexpr index_t num_queries_per_kv = NumQPerKV_;
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
@@ -183,7 +183,7 @@ struct unified_attention_decode_kernel_traits
unified_attention_shape,
unified_attention_mask,
unified_attention_traits,
-1>; // MaxNumBlocks = -1 (runtime check) for prefill/large tiles
CachePtrInt32OverflowPossible_>;
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
@@ -203,7 +203,7 @@ template <unified_attention_args::data_type_enum DataType,
index_t BlockM_ = 64,
index_t NumQPerKV_ = 8,
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32,
index_t MaxNumBlocks_ = -1>
bool CachePtrInt32OverflowPossible_ = false>
struct unified_attention_decode_small_kernel_traits
{
static constexpr auto date_type = DataType;
@@ -212,7 +212,6 @@ struct unified_attention_decode_small_kernel_traits
static constexpr index_t kBlockM = BlockM_;
static constexpr index_t HEAD_SIZE = HeadSize_;
static constexpr index_t BLOCK_SIZE = BlockSize_;
static constexpr index_t MAX_NUM_BLOCKS = MaxNumBlocks_;
static constexpr index_t num_queries_per_kv = NumQPerKV_;
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
@@ -246,7 +245,7 @@ struct unified_attention_decode_small_kernel_traits
unified_attention_shape,
unified_attention_mask,
unified_attention_traits,
MAX_NUM_BLOCKS>;
CachePtrInt32OverflowPossible_>;
using unified_attention_pipeline =
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
@@ -269,7 +268,7 @@ template <unified_attention_args::data_type_enum DataType,
index_t BlockM_ = 16,
index_t NumQPerKV_ = 8,
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32,
index_t MaxNumBlocks_ = -1> // -1 means no compile-time limit (runtime check)
bool CachePtrInt32OverflowPossible_ = false>
struct unified_attention_decode_tiny_kernel_traits
{
static constexpr auto date_type = DataType;
@@ -278,7 +277,6 @@ struct unified_attention_decode_tiny_kernel_traits
static constexpr index_t kBlockM = BlockM_;
static constexpr index_t HEAD_SIZE = HeadSize_;
static constexpr index_t BLOCK_SIZE = BlockSize_;
static constexpr index_t MAX_NUM_BLOCKS = MaxNumBlocks_;
static constexpr index_t num_queries_per_kv = NumQPerKV_;
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
@@ -312,7 +310,7 @@ struct unified_attention_decode_tiny_kernel_traits
unified_attention_shape,
unified_attention_mask,
unified_attention_traits,
MAX_NUM_BLOCKS>;
CachePtrInt32OverflowPossible_>;
using unified_attention_pipeline =
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
@@ -335,7 +333,7 @@ template <unified_attention_args::data_type_enum DataType,
index_t BlockM_ = 32,
index_t NumQPerKV_ = 8,
index_t BlockSize_ = 32,
index_t MaxNumBlocks_ = -1> // -1 means no compile-time limit (runtime check)
bool CachePtrInt32OverflowPossible_ = false>
struct unified_attention_decode_bs32_kernel_traits
{
static constexpr auto date_type = DataType;
@@ -344,7 +342,6 @@ struct unified_attention_decode_bs32_kernel_traits
static constexpr index_t kBlockM = BlockM_;
static constexpr index_t HEAD_SIZE = HeadSize_;
static constexpr index_t BLOCK_SIZE = BlockSize_;
static constexpr index_t MAX_NUM_BLOCKS = MaxNumBlocks_;
static constexpr index_t num_queries_per_kv = NumQPerKV_;
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
@@ -377,7 +374,7 @@ struct unified_attention_decode_bs32_kernel_traits
unified_attention_shape,
unified_attention_mask,
unified_attention_traits,
MAX_NUM_BLOCKS>;
CachePtrInt32OverflowPossible_>;
using unified_attention_pipeline =
UnifiedAttentionPipeline<unified_attention_pipeline_problem,