mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
CK-UA int32 overflow protection: add explicit template parameter for large caches
Add CachePtrInt32OverflowPossible template parameter (default false) to all unified attention kernel traits. This enables dual kernel variants: - Small cache (false): compile-time elimination of overflow checks for <100K blocks - Large cache (true): runtime overflow checking with pointer rebasing for >=100K blocks Key changes: - Add CachePtrInt32OverflowPossible as 14th template parameter to UnifiedAttentionPipelineProblem - Pass parameter through all kernel traits: decode, decode_small, decode_tiny, decode_bs32 - Implement overflow checking in pipeline with if constexpr for zero overhead when disabled - Update dispatch macros with _SMALL_CACHE and _LARGE_CACHE variants - Create instance files for both small and large cache variants (narrow, _s, _m tiers) - Remove old MAX_NUM_BLOCKS inference logic (num_kv_heads is runtime, cannot infer) Python calculates overflow possibility based on actual cache size and passes it explicitly via cache_ptr_int32_overflow_possible parameter. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 128, 8, 32>;
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 128, 8, 32, true>; // Large cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
|
||||
@@ -6,9 +6,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Medium-tier small-cache optimized variant: MaxNumBlocks=100000 (zero rebasing overhead)
|
||||
// Medium-tier small-cache optimized variant: MaxNumBlocks=false (zero rebasing overhead)
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 128, 8, 32, 100000>;
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 128, 8, 32, false>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 64, 8, 32>;
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 64, 8, 32, true>; // Large cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
|
||||
@@ -6,9 +6,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Small-cache optimized variant: MaxNumBlocks=100000 (zero rebasing overhead)
|
||||
// Small-cache optimized variant: MaxNumBlocks=false (zero rebasing overhead)
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 64, 8, 32, 100000>;
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 64, 8, 32, false>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 16, 8, 32>;
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 16, 8, 32, true>; // Large cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 32, 8, 32>;
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 32, 8, 32, true>; // Large cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 32, 8, 32, false>; // Small cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 128, 8, 32>;
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 128, 8, 32, true>; // Large cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
|
||||
@@ -6,9 +6,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Medium-tier small-cache optimized variant: MaxNumBlocks=100000 (zero rebasing overhead)
|
||||
// Medium-tier small-cache optimized variant: MaxNumBlocks=false (zero rebasing overhead)
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 128, 8, 32, 100000>;
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 128, 8, 32, false>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 64, 8, 32>;
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 64, 8, 32, true>; // Large cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
|
||||
@@ -6,9 +6,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Small-cache optimized variant: MaxNumBlocks=100000 (zero rebasing overhead)
|
||||
// Small-cache optimized variant: CachePtrInt32OverflowPossible=false (no overflow checks)
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 64, 8, 32, 100000>;
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 64, 8, 32, false>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 16, 8, 32>;
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 16, 8, 32, true>; // Large cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 32, 8, 32>;
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 32, 8, 32, true>; // Large cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 32, 8, 32, false>; // Small cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 128, 8, 32>;
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 128, 8, 32, true>; // Large cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Medium-tier small-cache optimized variant: MaxNumBlocks=false (zero rebasing overhead)
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 128, 8, 32, false>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 64, 8, 32>;
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 64, 8, 32, true>; // Large cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Small-cache optimized variant: MaxNumBlocks=false (zero rebasing overhead)
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 64, 8, 32, false>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 16, 8, 32>;
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 16, 8, 32, true>; // Large cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 32, 8, 32>;
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 32, 8, 32, true>; // Large cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 32, 8, 32, false>; // Small cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 128, 8, 32>;
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 128, 8, 32, true>; // Large cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Medium-tier small-cache optimized variant: MaxNumBlocks=false (zero rebasing overhead)
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 128, 8, 32, false>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 64, 8, 32>;
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 64, 8, 32, true>; // Large cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Small-cache optimized variant: CachePtrInt32OverflowPossible=false (no overflow checks)
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 64, 8, 32, false>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 16, 8, 32>;
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 16, 8, 32, true>; // Large cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 32, 8, 32>;
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 32, 8, 32, true>; // Large cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 32, 8, 32, false>; // Small cache: overflow checks enabled
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -64,24 +64,43 @@ std::ostream& operator<<(std::ostream& stream,
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
// Small-cache variants (7th template arg = MaxNumBlocks for compile-time overflow elimination).
|
||||
// For d64/GQA-8/bs32: overflow threshold = 2^31 / (32 * 64) = 1,048,575 blocks.
|
||||
// Set MaxNumBlocks = 100,000 (conservative, safe for ~98K blocks) to guarantee no overflow.
|
||||
// Small-cache variants (7th template arg = CachePtrInt32OverflowPossible=false).
|
||||
// For small caches (<100K blocks), we can guarantee no int32 overflow, so compile-time eliminate overflow checks.
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_SMALL_CACHE(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, 100000>; \
|
||||
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, false>; \
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, 100000>; \
|
||||
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, false>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_SMALL_CACHE(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, 100000>; \
|
||||
using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, false>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
// Large-cache variants (7th template arg = CachePtrInt32OverflowPossible=true).
|
||||
// For large caches (>=100K blocks), enable runtime overflow checking with pointer rebasing.
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LARGE_CACHE(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, true>; \
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_LARGE_CACHE(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, true>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LARGE_CACHE(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32, true>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
@@ -108,26 +127,14 @@ static tile_tier select_tile_tier(const unified_attention_args& args)
|
||||
return tile_tier::medium;
|
||||
}
|
||||
|
||||
// Select between small-cache (compile-time overflow elimination) and large-cache variants.
|
||||
// For d64/bs32: overflow threshold = 2^31 / (32 * 64) = 1,048,575 blocks
|
||||
// We use 100,000 as the small-cache limit (conservative, safe for ~98K blocks)
|
||||
static bool use_small_cache_variant(const unified_attention_args& args)
|
||||
{
|
||||
// Only optimize for d64 with block_size < 64 (bs32 variants)
|
||||
if(args.hdim != 64 || args.page_blk_size >= 64)
|
||||
return false;
|
||||
|
||||
// Conservative threshold: 100,000 blocks (~98K)
|
||||
// This guarantees no int32 overflow for d64/bs32
|
||||
constexpr index_t kSmallCacheThreshold = 100000;
|
||||
return args.num_blks <= kSmallCacheThreshold;
|
||||
}
|
||||
|
||||
std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
const stream_config& config)
|
||||
{
|
||||
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
|
||||
const auto tier = select_tile_tier(args);
|
||||
// Python calculates overflow possibility and passes it directly
|
||||
const bool use_small_cache = !args.cache_ptr_int32_overflow_possible;
|
||||
|
||||
// d128, MHA (num_queries_per_kv == 1)
|
||||
if(args.hdim == 128 && args.num_queries_per_kv == 1)
|
||||
@@ -148,7 +155,6 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
if(args.hdim == 64 && args.num_queries_per_kv == 8)
|
||||
{
|
||||
const bool use_bs32 = (args.page_blk_size < 64);
|
||||
const bool use_small_cache = use_small_cache_variant(args);
|
||||
|
||||
if(tier == tile_tier::tiny)
|
||||
{
|
||||
@@ -157,13 +163,23 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
// Avoids 1-warp race condition; 2x less waste than small tier.
|
||||
if(args.data_type == unified_attention_args::data_type_enum::fp16)
|
||||
{
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, false, 64, 32, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, true, 64, 32, 8)
|
||||
if(use_small_cache) {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_SMALL_CACHE(unified_attention_args::data_type_enum::fp16, false, 64, 32, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_SMALL_CACHE(unified_attention_args::data_type_enum::fp16, true, 64, 32, 8)
|
||||
} else {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LARGE_CACHE(unified_attention_args::data_type_enum::fp16, false, 64, 32, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LARGE_CACHE(unified_attention_args::data_type_enum::fp16, true, 64, 32, 8)
|
||||
}
|
||||
}
|
||||
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
|
||||
{
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, false, 64, 32, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, true, 64, 32, 8)
|
||||
if(use_small_cache) {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 32, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 32, 8)
|
||||
} else {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LARGE_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 32, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LARGE_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 32, 8)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// bs64 tiny: 1 warp, 16x16 MFMA, kBlockM=16, kBlockQ=2.
|
||||
@@ -184,8 +200,13 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
if(args.data_type == unified_attention_args::data_type_enum::fp16)
|
||||
{
|
||||
if(use_bs32) {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::fp16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::fp16, true, 64, 64, 8)
|
||||
if(use_small_cache) {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::fp16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::fp16, true, 64, 64, 8)
|
||||
} else {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::fp16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::fp16, true, 64, 64, 8)
|
||||
}
|
||||
} else {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(unified_attention_args::data_type_enum::fp16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(unified_attention_args::data_type_enum::fp16, true, 64, 64, 8)
|
||||
@@ -198,8 +219,8 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 64, 8)
|
||||
} else {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::bf16, true, 64, 64, 8)
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 64, 8)
|
||||
}
|
||||
} else {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8)
|
||||
@@ -212,8 +233,13 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
if(args.data_type == unified_attention_args::data_type_enum::fp16)
|
||||
{
|
||||
if(use_bs32) {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8)
|
||||
if(use_small_cache) {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8)
|
||||
} else {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8)
|
||||
}
|
||||
} else {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8)
|
||||
@@ -222,8 +248,13 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
|
||||
{
|
||||
if(use_bs32) {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8)
|
||||
if(use_small_cache) {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8)
|
||||
} else {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8)
|
||||
}
|
||||
} else {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8)
|
||||
|
||||
@@ -67,6 +67,8 @@ struct unified_attention_args
|
||||
|
||||
index_t num_seqs; // number of batches for q
|
||||
index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown)
|
||||
|
||||
bool cache_ptr_int32_overflow_possible = false; // true = use large cache variant with overflow checks
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream,
|
||||
|
||||
@@ -67,7 +67,8 @@ template <unified_attention_args::data_type_enum DataType,
|
||||
index_t HeadSize_ = 128,
|
||||
index_t BlockM_ = 256,
|
||||
index_t NumQPerKV_ = 1,
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32>
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32,
|
||||
bool CachePtrInt32OverflowPossible_ = false>
|
||||
struct unified_attention_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
@@ -116,7 +117,7 @@ struct unified_attention_kernel_traits
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits,
|
||||
-1>; // MaxNumBlocks = -1 (runtime check) for prefill/large tiles
|
||||
CachePtrInt32OverflowPossible_>;
|
||||
|
||||
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
|
||||
|
||||
@@ -139,7 +140,7 @@ template <unified_attention_args::data_type_enum DataType,
|
||||
index_t BlockM_ = 128,
|
||||
index_t NumQPerKV_ = 1,
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32,
|
||||
index_t MaxNumBlocks_ = -1> // -1 means no compile-time limit (runtime check)
|
||||
bool CachePtrInt32OverflowPossible_ = false> // Default false = no overflow expected
|
||||
struct unified_attention_decode_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
@@ -148,7 +149,6 @@ struct unified_attention_decode_kernel_traits
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
static constexpr index_t MAX_NUM_BLOCKS = MaxNumBlocks_;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
@@ -183,7 +183,7 @@ struct unified_attention_decode_kernel_traits
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits,
|
||||
-1>; // MaxNumBlocks = -1 (runtime check) for prefill/large tiles
|
||||
CachePtrInt32OverflowPossible_>;
|
||||
|
||||
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
|
||||
|
||||
@@ -203,7 +203,7 @@ template <unified_attention_args::data_type_enum DataType,
|
||||
index_t BlockM_ = 64,
|
||||
index_t NumQPerKV_ = 8,
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32,
|
||||
index_t MaxNumBlocks_ = -1>
|
||||
bool CachePtrInt32OverflowPossible_ = false>
|
||||
struct unified_attention_decode_small_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
@@ -212,7 +212,6 @@ struct unified_attention_decode_small_kernel_traits
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
static constexpr index_t MAX_NUM_BLOCKS = MaxNumBlocks_;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
@@ -246,7 +245,7 @@ struct unified_attention_decode_small_kernel_traits
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits,
|
||||
MAX_NUM_BLOCKS>;
|
||||
CachePtrInt32OverflowPossible_>;
|
||||
|
||||
using unified_attention_pipeline =
|
||||
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
|
||||
@@ -269,7 +268,7 @@ template <unified_attention_args::data_type_enum DataType,
|
||||
index_t BlockM_ = 16,
|
||||
index_t NumQPerKV_ = 8,
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32,
|
||||
index_t MaxNumBlocks_ = -1> // -1 means no compile-time limit (runtime check)
|
||||
bool CachePtrInt32OverflowPossible_ = false>
|
||||
struct unified_attention_decode_tiny_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
@@ -278,7 +277,6 @@ struct unified_attention_decode_tiny_kernel_traits
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
static constexpr index_t MAX_NUM_BLOCKS = MaxNumBlocks_;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
@@ -312,7 +310,7 @@ struct unified_attention_decode_tiny_kernel_traits
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits,
|
||||
MAX_NUM_BLOCKS>;
|
||||
CachePtrInt32OverflowPossible_>;
|
||||
|
||||
using unified_attention_pipeline =
|
||||
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
|
||||
@@ -335,7 +333,7 @@ template <unified_attention_args::data_type_enum DataType,
|
||||
index_t BlockM_ = 32,
|
||||
index_t NumQPerKV_ = 8,
|
||||
index_t BlockSize_ = 32,
|
||||
index_t MaxNumBlocks_ = -1> // -1 means no compile-time limit (runtime check)
|
||||
bool CachePtrInt32OverflowPossible_ = false>
|
||||
struct unified_attention_decode_bs32_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
@@ -344,7 +342,6 @@ struct unified_attention_decode_bs32_kernel_traits
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
static constexpr index_t MAX_NUM_BLOCKS = MaxNumBlocks_;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
@@ -377,7 +374,7 @@ struct unified_attention_decode_bs32_kernel_traits
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits,
|
||||
MAX_NUM_BLOCKS>;
|
||||
CachePtrInt32OverflowPossible_>;
|
||||
|
||||
using unified_attention_pipeline =
|
||||
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
|
||||
|
||||
@@ -66,7 +66,14 @@ struct UnifiedAttentionPipeline
|
||||
static constexpr ck_tile::index_t kPageBlockSize = UnifiedAttentionShape::kPageBlockSize;
|
||||
static constexpr ck_tile::index_t kHeadDim = UnifiedAttentionShape::kHeadDim;
|
||||
static constexpr ck_tile::index_t kHeadDimPadded = UnifiedAttentionShape::kHeadDimPadded;
|
||||
static constexpr ck_tile::index_t kMaxNumBlocks = Problem::kMaxNumBlocks;
|
||||
|
||||
// Overflow checking flag from Problem
|
||||
static constexpr bool kCachePtrInt32OverflowPossible = Problem::kCachePtrInt32OverflowPossible;
|
||||
// Set to true for large cache kernels (enables overflow check in loop)
|
||||
// Set to false for small cache kernels (compile-time eliminates check)
|
||||
|
||||
// Int32 overflow threshold for set_window_origin
|
||||
static constexpr long_index_t kInt32Max = 2147483647;
|
||||
|
||||
static_assert(kHeadDimPadded <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
@@ -364,36 +371,7 @@ struct UnifiedAttentionPipeline
|
||||
block_table_offset += num_blocks_start;
|
||||
index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_idx];
|
||||
|
||||
// Use pointer rebasing to avoid int32 overflow in tensor_coordinate::get_offset()
|
||||
// Overflow happens when: row_index * stride > INT32_MAX
|
||||
// Example for d64/GQA-8: max_row=4,799,968, stride=512, offset=2,457,583,616 > INT32_MAX
|
||||
//
|
||||
// Calculate overflow threshold using compile-time constants where possible
|
||||
// Assumption: kv_page_size_in_blocks is typically 1 (page_size == kPageBlockSize)
|
||||
// For configurations where this isn't true, we use runtime PageSize
|
||||
//
|
||||
// Compile-time threshold calculation (assuming page_size_in_blocks == 1):
|
||||
// threshold = INT32_MAX / (kPageBlockSize * kHeadDim)
|
||||
// For d64, block_size=32: threshold = 2147483647 / (32 * 64) = 1,048,575 blocks
|
||||
//
|
||||
// Only enabled when:
|
||||
// 1. Row strides provided from kernel (indicates we have stride info) - runtime
|
||||
// 2. Cache size exceeds overflow threshold - compile-time if kMaxNumBlocks != -1
|
||||
// 3. hdim <= 64 - compile-time (hdim=128 has different buffer layout)
|
||||
constexpr long_index_t kOverflowThresholdBlocks =
|
||||
(kHeadDim <= 64) ? (2147483647L / (kPageBlockSize * kHeadDim)) : 2147483647L;
|
||||
|
||||
// Compile-time overflow detection when kMaxNumBlocks is specified
|
||||
constexpr bool kNeedsRebasing = (kMaxNumBlocks != -1) && (kHeadDim <= 64) &&
|
||||
(static_cast<long_index_t>(kMaxNumBlocks) > kOverflowThresholdBlocks);
|
||||
|
||||
const bool need_overflow_check = (k_row_stride > 0 && v_row_stride > 0 && kHeadDim <= 64);
|
||||
const bool use_ptr_rebase = kNeedsRebasing ||
|
||||
(need_overflow_check && (kMaxNumBlocks == -1) &&
|
||||
(static_cast<long_index_t>(num_blocks) > kOverflowThresholdBlocks));
|
||||
|
||||
// Fast path: Create windows directly for small caches (no overflow risk)
|
||||
// Slow path: Use rebased pointers for large caches (overflow risk)
|
||||
// Create K/V DRAM windows
|
||||
auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{kv_blk_idx_initial * PageSize, 0},
|
||||
@@ -406,44 +384,6 @@ struct UnifiedAttentionPipeline
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
v_dram_window.init_raw();
|
||||
|
||||
// Variables for rebasing (only used if rebasing is possible)
|
||||
// When kMaxNumBlocks != -1 and kNeedsRebasing == false, compiler will eliminate this entirely
|
||||
using KPtrType = remove_cvref_t<decltype(k_dram_window.bottom_tensor_view_.buf_.p_data_)>;
|
||||
using VPtrType = remove_cvref_t<decltype(v_dram_window.bottom_tensor_view_.buf_.p_data_)>;
|
||||
[[maybe_unused]] KPtrType k_base_ptr = nullptr;
|
||||
[[maybe_unused]] VPtrType v_base_ptr = nullptr;
|
||||
[[maybe_unused]] long_index_t k_buf_size_orig = 0;
|
||||
[[maybe_unused]] long_index_t v_buf_size_orig = 0;
|
||||
|
||||
if constexpr(kNeedsRebasing || (kMaxNumBlocks == -1))
|
||||
{
|
||||
if(use_ptr_rebase)
|
||||
{
|
||||
// Save original pointers and sizes for lazy rebasing
|
||||
k_base_ptr = k_dram_window.bottom_tensor_view_.buf_.p_data_;
|
||||
v_base_ptr = v_dram_window.bottom_tensor_view_.buf_.p_data_;
|
||||
k_buf_size_orig = k_dram_window.bottom_tensor_view_.buf_.buffer_size_;
|
||||
v_buf_size_orig = v_dram_window.bottom_tensor_view_.buf_.buffer_size_;
|
||||
|
||||
// Initial rebase to first block
|
||||
long_index_t k_off =
|
||||
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * k_row_stride;
|
||||
k_dram_window.bottom_tensor_view_.buf_.p_data_ = k_base_ptr + k_off;
|
||||
auto new_k = k_buf_size_orig - k_off;
|
||||
k_dram_window.bottom_tensor_view_.buf_.buffer_size_ = new_k > 0 ? new_k : kPageBlockSize * kHeadDim;
|
||||
k_dram_window.init_raw();
|
||||
k_dram_window.set_window_origin({0, 0});
|
||||
|
||||
long_index_t v_off =
|
||||
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * v_row_stride;
|
||||
v_dram_window.bottom_tensor_view_.buf_.p_data_ = v_base_ptr + v_off;
|
||||
auto new_v = v_buf_size_orig - v_off;
|
||||
v_dram_window.bottom_tensor_view_.buf_.buffer_size_ = new_v > 0 ? new_v : kPageBlockSize * kHeadDim;
|
||||
v_dram_window.init_raw();
|
||||
v_dram_window.set_window_origin({0, 0});
|
||||
}
|
||||
}
|
||||
|
||||
// prefetch K tile
|
||||
constexpr index_t k0_loops = 1;
|
||||
constexpr index_t k1_loops = 1;
|
||||
@@ -545,20 +485,6 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
// Lazy rebasing: track which block we're currently rebased to
|
||||
// Only call rebase_window (expensive init_raw) when we drift too far from base
|
||||
// Threshold: rebase when offset from base would exceed 1 billion (half of int32_max)
|
||||
// For d64, block_size=32: threshold = 1B / (32 * 64) = ~488,281 blocks
|
||||
// This is compile-time constant, allowing compiler to optimize
|
||||
constexpr long_index_t kRebaseThreshold = 1000000000L / (kPageBlockSize * kHeadDim);
|
||||
[[maybe_unused]] index_t k_base_block = 0;
|
||||
[[maybe_unused]] index_t v_base_block = 0;
|
||||
if constexpr(kNeedsRebasing || (kMaxNumBlocks == -1))
|
||||
{
|
||||
if(use_ptr_rebase)
|
||||
{
|
||||
k_base_block = kv_blk_idx_initial;
|
||||
v_base_block = kv_blk_idx_initial;
|
||||
}
|
||||
}
|
||||
|
||||
// Page block index tracking
|
||||
// const index_t kv_page_size_in_blocks =
|
||||
@@ -573,41 +499,27 @@ struct UnifiedAttentionPipeline
|
||||
index_t k_page_blk_idx =
|
||||
block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)];
|
||||
|
||||
if constexpr(kNeedsRebasing || (kMaxNumBlocks == -1))
|
||||
{
|
||||
if(use_ptr_rebase)
|
||||
{
|
||||
// Lazy rebasing: only call expensive rebase_window when drifting too far from base
|
||||
long_index_t offset_from_base = static_cast<long_index_t>(k_page_blk_idx) - k_base_block;
|
||||
if(offset_from_base < 0) offset_from_base = -offset_from_base; // abs value
|
||||
// Calculate offset for this block
|
||||
index_t offset = k_page_blk_idx * PageSize +
|
||||
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
|
||||
if(offset_from_base > kRebaseThreshold)
|
||||
{
|
||||
// Too far from base, rebase to current block (expensive: calls init_raw)
|
||||
k_base_block = k_page_blk_idx;
|
||||
long_index_t k_row =
|
||||
static_cast<long_index_t>(k_page_blk_idx) * PageSize +
|
||||
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
rebase_window(k_dram_window, k_base_ptr, k_row * k_row_stride, k_buf_size_orig);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Close to base, just update window origin (cheap: no init_raw)
|
||||
long_index_t k_row =
|
||||
static_cast<long_index_t>(k_page_blk_idx) * PageSize +
|
||||
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
long_index_t base_row = static_cast<long_index_t>(k_base_block) * PageSize;
|
||||
k_dram_window.set_window_origin({static_cast<index_t>(k_row - base_row), 0});
|
||||
}
|
||||
// For large cache, check if we'd overflow int32 in set_window_origin
|
||||
if constexpr(kCachePtrInt32OverflowPossible)
|
||||
{
|
||||
if(offset > kInt32Max)
|
||||
{
|
||||
// Rebase: advance pointer by offset, then use origin {0, 0}
|
||||
auto& buf = k_dram_window.bottom_tensor_view_.buf_;
|
||||
auto stride_0 = k_dram_window.bottom_tensor_view_.desc_.calculate_offset(make_tuple(1, 0));
|
||||
buf.p_data_ = buf.p_data_ + (static_cast<long_index_t>(offset) * stride_0);
|
||||
k_dram_window.init_raw();
|
||||
k_dram_window.set_window_origin({0, 0});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Fast path when rebasing not needed (kMaxNumBlocks is small)
|
||||
k_dram_window.set_window_origin(
|
||||
{k_page_blk_idx * PageSize +
|
||||
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize,
|
||||
0});
|
||||
// Fast path: no overflow, just set window origin
|
||||
k_dram_window.set_window_origin({offset, 0});
|
||||
};
|
||||
|
||||
auto V_mem_load = [&](auto v_lds_write_idx) {
|
||||
@@ -617,41 +529,27 @@ struct UnifiedAttentionPipeline
|
||||
index_t v_page_blk_idx =
|
||||
block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)];
|
||||
|
||||
if constexpr(kNeedsRebasing || (kMaxNumBlocks == -1))
|
||||
{
|
||||
if(use_ptr_rebase)
|
||||
{
|
||||
// Lazy rebasing: only call expensive rebase_window when drifting too far from base
|
||||
long_index_t offset_from_base = static_cast<long_index_t>(v_page_blk_idx) - v_base_block;
|
||||
if(offset_from_base < 0) offset_from_base = -offset_from_base; // abs value
|
||||
// Calculate offset for this block
|
||||
index_t offset = v_page_blk_idx * PageSize +
|
||||
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
|
||||
if(offset_from_base > kRebaseThreshold)
|
||||
{
|
||||
// Too far from base, rebase to current block (expensive: calls init_raw)
|
||||
v_base_block = v_page_blk_idx;
|
||||
long_index_t v_row =
|
||||
static_cast<long_index_t>(v_page_blk_idx) * PageSize +
|
||||
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
rebase_window(v_dram_window, v_base_ptr, v_row * v_row_stride, v_buf_size_orig);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Close to base, just update window origin (cheap: no init_raw)
|
||||
long_index_t v_row =
|
||||
static_cast<long_index_t>(v_page_blk_idx) * PageSize +
|
||||
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
long_index_t base_row = static_cast<long_index_t>(v_base_block) * PageSize;
|
||||
v_dram_window.set_window_origin({static_cast<index_t>(v_row - base_row), 0});
|
||||
}
|
||||
// For large cache, check if we'd overflow int32 in set_window_origin
|
||||
if constexpr(kCachePtrInt32OverflowPossible)
|
||||
{
|
||||
if(offset > kInt32Max)
|
||||
{
|
||||
// Rebase: advance pointer by offset, then use origin {0, 0}
|
||||
auto& buf = v_dram_window.bottom_tensor_view_.buf_;
|
||||
auto stride_0 = v_dram_window.bottom_tensor_view_.desc_.calculate_offset(make_tuple(1, 0));
|
||||
buf.p_data_ = buf.p_data_ + (static_cast<long_index_t>(offset) * stride_0);
|
||||
v_dram_window.init_raw();
|
||||
v_dram_window.set_window_origin({0, 0});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Fast path when rebasing not needed (kMaxNumBlocks is small)
|
||||
v_dram_window.set_window_origin(
|
||||
{v_page_blk_idx * PageSize +
|
||||
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize,
|
||||
0});
|
||||
// Fast path: no overflow, just set window origin
|
||||
v_dram_window.set_window_origin({offset, 0});
|
||||
};
|
||||
|
||||
auto K_lds_load = [&](auto k_lds_read_idx) {
|
||||
|
||||
@@ -20,7 +20,7 @@ template <typename QDataType_,
|
||||
typename UnifiedAttentionShape_,
|
||||
typename FmhaMask_,
|
||||
typename Traits_,
|
||||
index_t MaxNumBlocks_ = -1>
|
||||
bool CachePtrInt32OverflowPossible_ = false> // TODO: Default false = no overflow expected
|
||||
struct UnifiedAttentionPipelineProblem
|
||||
{
|
||||
// TODO kM0 and KN1??
|
||||
@@ -42,11 +42,13 @@ struct UnifiedAttentionPipelineProblem
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
using FmhaMask = remove_cvref_t<FmhaMask_>;
|
||||
|
||||
static constexpr index_t kMaxNumBlocks = MaxNumBlocks_;
|
||||
static constexpr index_t kNumGemm0Warps = UnifiedAttentionShape::NumGemm0Warps;
|
||||
static constexpr index_t kNumGemm1Warps = UnifiedAttentionShape::NumGemm1Warps;
|
||||
static constexpr index_t kBlockSize = UnifiedAttentionShape::NumWarps * get_warp_size();
|
||||
|
||||
// TODO: Overflow check flag - controls whether to check for int32 overflow in loop
|
||||
static constexpr bool kCachePtrInt32OverflowPossible = CachePtrInt32OverflowPossible_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDim = Traits::kPadHeadDim;
|
||||
|
||||
Reference in New Issue
Block a user