diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode.cpp index 112efe1222..b4e0adf619 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_decode_kernel_traits; + unified_attention_decode_kernel_traits; // Large cache: overflow checks enabled INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_m_small_cache.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_m_small_cache.cpp index c414497bd0..c62f37e186 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_m_small_cache.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_m_small_cache.cpp @@ -6,9 +6,9 @@ namespace ck_tile { -// Medium-tier small-cache optimized variant: MaxNumBlocks=100000 (zero rebasing overhead) +// Medium-tier small-cache optimized variant: MaxNumBlocks=false (zero rebasing overhead) using kernel_traits = - unified_attention_decode_kernel_traits; + unified_attention_decode_kernel_traits; INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_s.cpp index ef17fc1971..c52fc16642 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_s.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_s.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_decode_small_kernel_traits; + unified_attention_decode_small_kernel_traits; // Large cache: overflow checks enabled INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_s_small_cache.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_s_small_cache.cpp index e5fd136271..e4c390fdcb 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_s_small_cache.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_s_small_cache.cpp @@ -6,9 +6,9 @@ namespace ck_tile { -// Small-cache optimized variant: MaxNumBlocks=100000 (zero rebasing overhead) +// Small-cache optimized variant: MaxNumBlocks=false (zero rebasing overhead) using kernel_traits = - unified_attention_decode_small_kernel_traits; + unified_attention_decode_small_kernel_traits; INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_t.cpp index 2c6531c835..92b4e4897a 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_t.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_decode_t.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_decode_tiny_kernel_traits; + unified_attention_decode_tiny_kernel_traits; // Large cache: overflow checks enabled INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_narrow.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_narrow.cpp index 204319568f..5af36b1142 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_narrow.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_narrow.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_decode_bs32_kernel_traits; + unified_attention_decode_bs32_kernel_traits; // Large cache: overflow checks enabled INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_narrow_small_cache.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_narrow_small_cache.cpp new file mode 100644 index 0000000000..317b3b478c --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_gqa8_bs32_narrow_small_cache.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_decode_bs32_kernel_traits; // Small cache: overflow checks enabled + +INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode.cpp index f0c3617a52..1353e38380 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_decode_kernel_traits; + unified_attention_decode_kernel_traits; // Large cache: overflow checks enabled INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_m_small_cache.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_m_small_cache.cpp index 77c25e050a..cce92502db 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_m_small_cache.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_m_small_cache.cpp @@ -6,9 +6,9 @@ namespace ck_tile { -// Medium-tier small-cache optimized variant: MaxNumBlocks=100000 (zero rebasing overhead) +// Medium-tier small-cache optimized variant: MaxNumBlocks=false (zero rebasing overhead) using kernel_traits = - unified_attention_decode_kernel_traits; + unified_attention_decode_kernel_traits; INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_s.cpp index ead32cf0bf..9b93635c51 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_s.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_s.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_decode_small_kernel_traits; + unified_attention_decode_small_kernel_traits; // Large cache: overflow checks enabled INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_s_small_cache.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_s_small_cache.cpp index 5847c20378..c4d35d99ec 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_s_small_cache.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_s_small_cache.cpp @@ -6,9 +6,9 @@ namespace ck_tile { -// Small-cache optimized variant: MaxNumBlocks=100000 (zero rebasing overhead) +// Small-cache optimized variant: CachePtrInt32OverflowPossible=false (no overflow checks) using kernel_traits = - unified_attention_decode_small_kernel_traits; + unified_attention_decode_small_kernel_traits; INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_t.cpp index cc77cf7726..43170ad37b 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_t.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_decode_t.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_decode_tiny_kernel_traits; + unified_attention_decode_tiny_kernel_traits; // Large cache: overflow checks enabled INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_narrow.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_narrow.cpp index 27ccae7b06..e1702c7c75 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_narrow.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_narrow.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_decode_bs32_kernel_traits; + unified_attention_decode_bs32_kernel_traits; // Large cache: overflow checks enabled INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_narrow_small_cache.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_narrow_small_cache.cpp new file mode 100644 index 0000000000..5ac1b55d1f --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_gqa8_bs32_narrow_small_cache.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_decode_bs32_kernel_traits; // Small cache: overflow checks enabled + +INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode.cpp index b79fe8eeb2..7859f67129 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_decode_kernel_traits; + unified_attention_decode_kernel_traits; // Large cache: overflow checks enabled INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_m_small_cache.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_m_small_cache.cpp new file mode 100644 index 0000000000..42860869f7 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_m_small_cache.cpp @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +// Medium-tier small-cache optimized variant: MaxNumBlocks=false (zero rebasing overhead) +using kernel_traits = + unified_attention_decode_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_s.cpp index 272439ecb0..21804b7585 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_s.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_s.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_decode_small_kernel_traits; + unified_attention_decode_small_kernel_traits; // Large cache: overflow checks enabled INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_s_small_cache.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_s_small_cache.cpp new file mode 100644 index 0000000000..3164397c00 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_s_small_cache.cpp @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +// Small-cache optimized variant: MaxNumBlocks=false (zero rebasing overhead) +using kernel_traits = + unified_attention_decode_small_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_t.cpp index 1420e3fa40..530787487a 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_t.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_decode_t.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_decode_tiny_kernel_traits; + unified_attention_decode_tiny_kernel_traits; // Large cache: overflow checks enabled INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_narrow.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_narrow.cpp index 22d1f71e5b..c2cb747457 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_narrow.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_narrow.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_decode_bs32_kernel_traits; + unified_attention_decode_bs32_kernel_traits; // Large cache: overflow checks enabled INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_narrow_small_cache.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_narrow_small_cache.cpp new file mode 100644 index 0000000000..a4f6775971 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_gqa8_bs32_narrow_small_cache.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_decode_bs32_kernel_traits; // Small cache: overflow checks enabled + +INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode.cpp index c883749ac2..cb43ed76fe 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_decode_kernel_traits; + unified_attention_decode_kernel_traits; // Large cache: overflow checks enabled INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_m_small_cache.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_m_small_cache.cpp new file mode 100644 index 0000000000..00f6ad8be5 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_m_small_cache.cpp @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +// Medium-tier small-cache optimized variant: MaxNumBlocks=false (zero rebasing overhead) +using kernel_traits = + unified_attention_decode_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_s.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_s.cpp index b76f03fe0c..c42c20c844 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_s.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_s.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_decode_small_kernel_traits; + unified_attention_decode_small_kernel_traits; // Large cache: overflow checks enabled INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_s_small_cache.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_s_small_cache.cpp new file mode 100644 index 0000000000..2a78cbc2f7 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_s_small_cache.cpp @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +// Small-cache optimized variant: CachePtrInt32OverflowPossible=false (no overflow checks) +using kernel_traits = + unified_attention_decode_small_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_t.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_t.cpp index 134ab386b5..be4c86fb01 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_t.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_decode_t.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_decode_tiny_kernel_traits; + unified_attention_decode_tiny_kernel_traits; // Large cache: overflow checks enabled INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_narrow.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_narrow.cpp index 47a8dd7939..814abcef05 100644 --- a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_narrow.cpp +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_narrow.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_decode_bs32_kernel_traits; + unified_attention_decode_bs32_kernel_traits; // Large cache: overflow checks enabled INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_narrow_small_cache.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_narrow_small_cache.cpp new file mode 100644 index 0000000000..12a21a1bda --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_gqa8_bs32_narrow_small_cache.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_decode_bs32_kernel_traits; // Small cache: overflow checks enabled + +INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp index f0c0bbee1a..f673cc2fca 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -64,24 +64,43 @@ std::ostream& operator<<(std::ostream& stream, return unified_attention_kernel_dispatch_decode(args, config); \ } -// Small-cache variants (7th template arg = MaxNumBlocks for compile-time overflow elimination). -// For d64/GQA-8/bs32: overflow threshold = 2^31 / (32 * 64) = 1,048,575 blocks. -// Set MaxNumBlocks = 100,000 (conservative, safe for ~98K blocks) to guarantee no overflow. +// Small-cache variants (7th template arg = CachePtrInt32OverflowPossible=false). +// For small caches (<100K blocks), we can guarantee no int32 overflow, so compile-time eliminate overflow checks. #define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_SMALL_CACHE(DType, IsMask, HSize, BM, NQPKV) \ { \ - using kernel_traits = unified_attention_decode_kernel_traits; \ + using kernel_traits = unified_attention_decode_kernel_traits; \ return unified_attention_kernel_dispatch(args, config); \ } #define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE(DType, IsMask, HSize, BM, NQPKV) \ { \ - using kernel_traits = unified_attention_decode_small_kernel_traits; \ + using kernel_traits = unified_attention_decode_small_kernel_traits; \ return unified_attention_kernel_dispatch_decode(args, config); \ } #define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_SMALL_CACHE(DType, IsMask, HSize, BM, NQPKV) \ { \ - using kernel_traits = unified_attention_decode_bs32_kernel_traits; \ + using kernel_traits = unified_attention_decode_bs32_kernel_traits; \ + return unified_attention_kernel_dispatch_decode(args, config); \ + } + +// Large-cache variants (7th template arg = CachePtrInt32OverflowPossible=true). +// For large caches (>=100K blocks), enable runtime overflow checking with pointer rebasing. +#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LARGE_CACHE(DType, IsMask, HSize, BM, NQPKV) \ + { \ + using kernel_traits = unified_attention_decode_kernel_traits; \ + return unified_attention_kernel_dispatch(args, config); \ + } + +#define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_LARGE_CACHE(DType, IsMask, HSize, BM, NQPKV) \ + { \ + using kernel_traits = unified_attention_decode_small_kernel_traits; \ + return unified_attention_kernel_dispatch_decode(args, config); \ + } + +#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LARGE_CACHE(DType, IsMask, HSize, BM, NQPKV) \ + { \ + using kernel_traits = unified_attention_decode_bs32_kernel_traits; \ return unified_attention_kernel_dispatch_decode(args, config); \ } @@ -108,26 +127,14 @@ static tile_tier select_tile_tier(const unified_attention_args& args) return tile_tier::medium; } -// Select between small-cache (compile-time overflow elimination) and large-cache variants. -// For d64/bs32: overflow threshold = 2^31 / (32 * 64) = 1,048,575 blocks -// We use 100,000 as the small-cache limit (conservative, safe for ~98K blocks) -static bool use_small_cache_variant(const unified_attention_args& args) -{ - // Only optimize for d64 with block_size < 64 (bs32 variants) - if(args.hdim != 64 || args.page_blk_size >= 64) - return false; - - // Conservative threshold: 100,000 blocks (~98K) - // This guarantees no int32 overflow for d64/bs32 - constexpr index_t kSmallCacheThreshold = 100000; - return args.num_blks <= kSmallCacheThreshold; -} std::pair unified_attention(const unified_attention_args& args, const stream_config& config) { const bool is_mask = (args.mask_type != static_cast(mask_enum::no_mask)); const auto tier = select_tile_tier(args); + // Python calculates overflow possibility and passes it directly + const bool use_small_cache = !args.cache_ptr_int32_overflow_possible; // d128, MHA (num_queries_per_kv == 1) if(args.hdim == 128 && args.num_queries_per_kv == 1) @@ -148,7 +155,6 @@ std::pair unified_attention(const unified_attention_args& args, if(args.hdim == 64 && args.num_queries_per_kv == 8) { const bool use_bs32 = (args.page_blk_size < 64); - const bool use_small_cache = use_small_cache_variant(args); if(tier == tile_tier::tiny) { @@ -157,13 +163,23 @@ std::pair unified_attention(const unified_attention_args& args, // Avoids 1-warp race condition; 2x less waste than small tier. if(args.data_type == unified_attention_args::data_type_enum::fp16) { - if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, false, 64, 32, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, true, 64, 32, 8) + if(use_small_cache) { + if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_SMALL_CACHE(unified_attention_args::data_type_enum::fp16, false, 64, 32, 8) + else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_SMALL_CACHE(unified_attention_args::data_type_enum::fp16, true, 64, 32, 8) + } else { + if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LARGE_CACHE(unified_attention_args::data_type_enum::fp16, false, 64, 32, 8) + else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LARGE_CACHE(unified_attention_args::data_type_enum::fp16, true, 64, 32, 8) + } } else if(args.data_type == unified_attention_args::data_type_enum::bf16) { - if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, false, 64, 32, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, true, 64, 32, 8) + if(use_small_cache) { + if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 32, 8) + else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 32, 8) + } else { + if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LARGE_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 32, 8) + else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LARGE_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 32, 8) + } } } else { // bs64 tiny: 1 warp, 16x16 MFMA, kBlockM=16, kBlockQ=2. @@ -184,8 +200,13 @@ std::pair unified_attention(const unified_attention_args& args, if(args.data_type == unified_attention_args::data_type_enum::fp16) { if(use_bs32) { - if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::fp16, false, 64, 64, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::fp16, true, 64, 64, 8) + if(use_small_cache) { + if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::fp16, false, 64, 64, 8) + else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::fp16, true, 64, 64, 8) + } else { + if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::fp16, false, 64, 64, 8) + else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::fp16, true, 64, 64, 8) + } } else { if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(unified_attention_args::data_type_enum::fp16, false, 64, 64, 8) else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(unified_attention_args::data_type_enum::fp16, true, 64, 64, 8) @@ -198,8 +219,8 @@ std::pair unified_attention(const unified_attention_args& args, if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8) else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 64, 8) } else { - if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::bf16, true, 64, 64, 8) + if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8) + else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 64, 8) } } else { if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8) @@ -212,8 +233,13 @@ std::pair unified_attention(const unified_attention_args& args, if(args.data_type == unified_attention_args::data_type_enum::fp16) { if(use_bs32) { - if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8) + if(use_small_cache) { + if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8) + else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8) + } else { + if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8) + else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8) + } } else { if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8) else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8) @@ -222,8 +248,13 @@ std::pair unified_attention(const unified_attention_args& args, else if(args.data_type == unified_attention_args::data_type_enum::bf16) { if(use_bs32) { - if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8) - else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8) + if(use_small_cache) { + if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8) + else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_SMALL_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8) + } else { + if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8) + else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LARGE_CACHE(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8) + } } else { if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8) else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8) diff --git a/example/ck_tile/42_unified_attention/unified_attention.hpp b/example/ck_tile/42_unified_attention/unified_attention.hpp index 8b645387a4..4bacd847dd 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention.hpp @@ -67,6 +67,8 @@ struct unified_attention_args index_t num_seqs; // number of batches for q index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown) + + bool cache_ptr_int32_overflow_possible = false; // true = use large cache variant with overflow checks }; std::ostream& operator<<(std::ostream& stream, diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index 6b9109b4cf..5f701a362c 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -67,7 +67,8 @@ template + index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32, + bool CachePtrInt32OverflowPossible_ = false> struct unified_attention_kernel_traits { static constexpr auto date_type = DataType; @@ -116,7 +117,7 @@ struct unified_attention_kernel_traits unified_attention_shape, unified_attention_mask, unified_attention_traits, - -1>; // MaxNumBlocks = -1 (runtime check) for prefill/large tiles + CachePtrInt32OverflowPossible_>; using unified_attention_pipeline = UnifiedAttentionPipeline; @@ -139,7 +140,7 @@ template // -1 means no compile-time limit (runtime check) + bool CachePtrInt32OverflowPossible_ = false> // Default false = no overflow expected struct unified_attention_decode_kernel_traits { static constexpr auto date_type = DataType; @@ -148,7 +149,6 @@ struct unified_attention_decode_kernel_traits static constexpr index_t kBlockM = BlockM_; static constexpr index_t HEAD_SIZE = HeadSize_; static constexpr index_t BLOCK_SIZE = BlockSize_; - static constexpr index_t MAX_NUM_BLOCKS = MaxNumBlocks_; static constexpr index_t num_queries_per_kv = NumQPerKV_; static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv; @@ -183,7 +183,7 @@ struct unified_attention_decode_kernel_traits unified_attention_shape, unified_attention_mask, unified_attention_traits, - -1>; // MaxNumBlocks = -1 (runtime check) for prefill/large tiles + CachePtrInt32OverflowPossible_>; using unified_attention_pipeline = UnifiedAttentionPipeline; @@ -203,7 +203,7 @@ template + bool CachePtrInt32OverflowPossible_ = false> struct unified_attention_decode_small_kernel_traits { static constexpr auto date_type = DataType; @@ -212,7 +212,6 @@ struct unified_attention_decode_small_kernel_traits static constexpr index_t kBlockM = BlockM_; static constexpr index_t HEAD_SIZE = HeadSize_; static constexpr index_t BLOCK_SIZE = BlockSize_; - static constexpr index_t MAX_NUM_BLOCKS = MaxNumBlocks_; static constexpr index_t num_queries_per_kv = NumQPerKV_; static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv; @@ -246,7 +245,7 @@ struct unified_attention_decode_small_kernel_traits unified_attention_shape, unified_attention_mask, unified_attention_traits, - MAX_NUM_BLOCKS>; + CachePtrInt32OverflowPossible_>; using unified_attention_pipeline = UnifiedAttentionPipeline // -1 means no compile-time limit (runtime check) + bool CachePtrInt32OverflowPossible_ = false> struct unified_attention_decode_tiny_kernel_traits { static constexpr auto date_type = DataType; @@ -278,7 +277,6 @@ struct unified_attention_decode_tiny_kernel_traits static constexpr index_t kBlockM = BlockM_; static constexpr index_t HEAD_SIZE = HeadSize_; static constexpr index_t BLOCK_SIZE = BlockSize_; - static constexpr index_t MAX_NUM_BLOCKS = MaxNumBlocks_; static constexpr index_t num_queries_per_kv = NumQPerKV_; static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv; @@ -312,7 +310,7 @@ struct unified_attention_decode_tiny_kernel_traits unified_attention_shape, unified_attention_mask, unified_attention_traits, - MAX_NUM_BLOCKS>; + CachePtrInt32OverflowPossible_>; using unified_attention_pipeline = UnifiedAttentionPipeline // -1 means no compile-time limit (runtime check) + bool CachePtrInt32OverflowPossible_ = false> struct unified_attention_decode_bs32_kernel_traits { static constexpr auto date_type = DataType; @@ -344,7 +342,6 @@ struct unified_attention_decode_bs32_kernel_traits static constexpr index_t kBlockM = BlockM_; static constexpr index_t HEAD_SIZE = HeadSize_; static constexpr index_t BLOCK_SIZE = BlockSize_; - static constexpr index_t MAX_NUM_BLOCKS = MaxNumBlocks_; static constexpr index_t num_queries_per_kv = NumQPerKV_; static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv; @@ -377,7 +374,7 @@ struct unified_attention_decode_bs32_kernel_traits unified_attention_shape, unified_attention_mask, unified_attention_traits, - MAX_NUM_BLOCKS>; + CachePtrInt32OverflowPossible_>; using unified_attention_pipeline = UnifiedAttentionPipeline INT32_MAX - // Example for d64/GQA-8: max_row=4,799,968, stride=512, offset=2,457,583,616 > INT32_MAX - // - // Calculate overflow threshold using compile-time constants where possible - // Assumption: kv_page_size_in_blocks is typically 1 (page_size == kPageBlockSize) - // For configurations where this isn't true, we use runtime PageSize - // - // Compile-time threshold calculation (assuming page_size_in_blocks == 1): - // threshold = INT32_MAX / (kPageBlockSize * kHeadDim) - // For d64, block_size=32: threshold = 2147483647 / (32 * 64) = 1,048,575 blocks - // - // Only enabled when: - // 1. Row strides provided from kernel (indicates we have stride info) - runtime - // 2. Cache size exceeds overflow threshold - compile-time if kMaxNumBlocks != -1 - // 3. hdim <= 64 - compile-time (hdim=128 has different buffer layout) - constexpr long_index_t kOverflowThresholdBlocks = - (kHeadDim <= 64) ? (2147483647L / (kPageBlockSize * kHeadDim)) : 2147483647L; - - // Compile-time overflow detection when kMaxNumBlocks is specified - constexpr bool kNeedsRebasing = (kMaxNumBlocks != -1) && (kHeadDim <= 64) && - (static_cast(kMaxNumBlocks) > kOverflowThresholdBlocks); - - const bool need_overflow_check = (k_row_stride > 0 && v_row_stride > 0 && kHeadDim <= 64); - const bool use_ptr_rebase = kNeedsRebasing || - (need_overflow_check && (kMaxNumBlocks == -1) && - (static_cast(num_blocks) > kOverflowThresholdBlocks)); - - // Fast path: Create windows directly for small caches (no overflow risk) - // Slow path: Use rebased pointers for large caches (overflow risk) + // Create K/V DRAM windows auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), {kv_blk_idx_initial * PageSize, 0}, @@ -406,44 +384,6 @@ struct UnifiedAttentionPipeline Policy::template MakeVDramTileDistribution()); v_dram_window.init_raw(); - // Variables for rebasing (only used if rebasing is possible) - // When kMaxNumBlocks != -1 and kNeedsRebasing == false, compiler will eliminate this entirely - using KPtrType = remove_cvref_t; - using VPtrType = remove_cvref_t; - [[maybe_unused]] KPtrType k_base_ptr = nullptr; - [[maybe_unused]] VPtrType v_base_ptr = nullptr; - [[maybe_unused]] long_index_t k_buf_size_orig = 0; - [[maybe_unused]] long_index_t v_buf_size_orig = 0; - - if constexpr(kNeedsRebasing || (kMaxNumBlocks == -1)) - { - if(use_ptr_rebase) - { - // Save original pointers and sizes for lazy rebasing - k_base_ptr = k_dram_window.bottom_tensor_view_.buf_.p_data_; - v_base_ptr = v_dram_window.bottom_tensor_view_.buf_.p_data_; - k_buf_size_orig = k_dram_window.bottom_tensor_view_.buf_.buffer_size_; - v_buf_size_orig = v_dram_window.bottom_tensor_view_.buf_.buffer_size_; - - // Initial rebase to first block - long_index_t k_off = - static_cast(kv_blk_idx_initial) * PageSize * k_row_stride; - k_dram_window.bottom_tensor_view_.buf_.p_data_ = k_base_ptr + k_off; - auto new_k = k_buf_size_orig - k_off; - k_dram_window.bottom_tensor_view_.buf_.buffer_size_ = new_k > 0 ? new_k : kPageBlockSize * kHeadDim; - k_dram_window.init_raw(); - k_dram_window.set_window_origin({0, 0}); - - long_index_t v_off = - static_cast(kv_blk_idx_initial) * PageSize * v_row_stride; - v_dram_window.bottom_tensor_view_.buf_.p_data_ = v_base_ptr + v_off; - auto new_v = v_buf_size_orig - v_off; - v_dram_window.bottom_tensor_view_.buf_.buffer_size_ = new_v > 0 ? new_v : kPageBlockSize * kHeadDim; - v_dram_window.init_raw(); - v_dram_window.set_window_origin({0, 0}); - } - } - // prefetch K tile constexpr index_t k0_loops = 1; constexpr index_t k1_loops = 1; @@ -545,20 +485,6 @@ struct UnifiedAttentionPipeline // Lazy rebasing: track which block we're currently rebased to // Only call rebase_window (expensive init_raw) when we drift too far from base - // Threshold: rebase when offset from base would exceed 1 billion (half of int32_max) - // For d64, block_size=32: threshold = 1B / (32 * 64) = ~488,281 blocks - // This is compile-time constant, allowing compiler to optimize - constexpr long_index_t kRebaseThreshold = 1000000000L / (kPageBlockSize * kHeadDim); - [[maybe_unused]] index_t k_base_block = 0; - [[maybe_unused]] index_t v_base_block = 0; - if constexpr(kNeedsRebasing || (kMaxNumBlocks == -1)) - { - if(use_ptr_rebase) - { - k_base_block = kv_blk_idx_initial; - v_base_block = kv_blk_idx_initial; - } - } // Page block index tracking // const index_t kv_page_size_in_blocks = @@ -573,41 +499,27 @@ struct UnifiedAttentionPipeline index_t k_page_blk_idx = block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)]; - if constexpr(kNeedsRebasing || (kMaxNumBlocks == -1)) - { - if(use_ptr_rebase) - { - // Lazy rebasing: only call expensive rebase_window when drifting too far from base - long_index_t offset_from_base = static_cast(k_page_blk_idx) - k_base_block; - if(offset_from_base < 0) offset_from_base = -offset_from_base; // abs value + // Calculate offset for this block + index_t offset = k_page_blk_idx * PageSize + + (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize; - if(offset_from_base > kRebaseThreshold) - { - // Too far from base, rebase to current block (expensive: calls init_raw) - k_base_block = k_page_blk_idx; - long_index_t k_row = - static_cast(k_page_blk_idx) * PageSize + - (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize; - rebase_window(k_dram_window, k_base_ptr, k_row * k_row_stride, k_buf_size_orig); - } - else - { - // Close to base, just update window origin (cheap: no init_raw) - long_index_t k_row = - static_cast(k_page_blk_idx) * PageSize + - (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize; - long_index_t base_row = static_cast(k_base_block) * PageSize; - k_dram_window.set_window_origin({static_cast(k_row - base_row), 0}); - } + // For large cache, check if we'd overflow int32 in set_window_origin + if constexpr(kCachePtrInt32OverflowPossible) + { + if(offset > kInt32Max) + { + // Rebase: advance pointer by offset, then use origin {0, 0} + auto& buf = k_dram_window.bottom_tensor_view_.buf_; + auto stride_0 = k_dram_window.bottom_tensor_view_.desc_.calculate_offset(make_tuple(1, 0)); + buf.p_data_ = buf.p_data_ + (static_cast(offset) * stride_0); + k_dram_window.init_raw(); + k_dram_window.set_window_origin({0, 0}); return; } } - // Fast path when rebasing not needed (kMaxNumBlocks is small) - k_dram_window.set_window_origin( - {k_page_blk_idx * PageSize + - (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize, - 0}); + // Fast path: no overflow, just set window origin + k_dram_window.set_window_origin({offset, 0}); }; auto V_mem_load = [&](auto v_lds_write_idx) { @@ -617,41 +529,27 @@ struct UnifiedAttentionPipeline index_t v_page_blk_idx = block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)]; - if constexpr(kNeedsRebasing || (kMaxNumBlocks == -1)) - { - if(use_ptr_rebase) - { - // Lazy rebasing: only call expensive rebase_window when drifting too far from base - long_index_t offset_from_base = static_cast(v_page_blk_idx) - v_base_block; - if(offset_from_base < 0) offset_from_base = -offset_from_base; // abs value + // Calculate offset for this block + index_t offset = v_page_blk_idx * PageSize + + (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize; - if(offset_from_base > kRebaseThreshold) - { - // Too far from base, rebase to current block (expensive: calls init_raw) - v_base_block = v_page_blk_idx; - long_index_t v_row = - static_cast(v_page_blk_idx) * PageSize + - (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize; - rebase_window(v_dram_window, v_base_ptr, v_row * v_row_stride, v_buf_size_orig); - } - else - { - // Close to base, just update window origin (cheap: no init_raw) - long_index_t v_row = - static_cast(v_page_blk_idx) * PageSize + - (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize; - long_index_t base_row = static_cast(v_base_block) * PageSize; - v_dram_window.set_window_origin({static_cast(v_row - base_row), 0}); - } + // For large cache, check if we'd overflow int32 in set_window_origin + if constexpr(kCachePtrInt32OverflowPossible) + { + if(offset > kInt32Max) + { + // Rebase: advance pointer by offset, then use origin {0, 0} + auto& buf = v_dram_window.bottom_tensor_view_.buf_; + auto stride_0 = v_dram_window.bottom_tensor_view_.desc_.calculate_offset(make_tuple(1, 0)); + buf.p_data_ = buf.p_data_ + (static_cast(offset) * stride_0); + v_dram_window.init_raw(); + v_dram_window.set_window_origin({0, 0}); return; } } - // Fast path when rebasing not needed (kMaxNumBlocks is small) - v_dram_window.set_window_origin( - {v_page_blk_idx * PageSize + - (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize, - 0}); + // Fast path: no overflow, just set window origin + v_dram_window.set_window_origin({offset, 0}); }; auto K_lds_load = [&](auto k_lds_read_idx) { diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp index 7f2b7a5f5c..aa653384e6 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp @@ -20,7 +20,7 @@ template + bool CachePtrInt32OverflowPossible_ = false> // TODO: Default false = no overflow expected struct UnifiedAttentionPipelineProblem { // TODO kM0 and KN1?? @@ -42,11 +42,13 @@ struct UnifiedAttentionPipelineProblem using Traits = remove_cvref_t; using FmhaMask = remove_cvref_t; - static constexpr index_t kMaxNumBlocks = MaxNumBlocks_; static constexpr index_t kNumGemm0Warps = UnifiedAttentionShape::NumGemm0Warps; static constexpr index_t kNumGemm1Warps = UnifiedAttentionShape::NumGemm1Warps; static constexpr index_t kBlockSize = UnifiedAttentionShape::NumWarps * get_warp_size(); + // TODO: Overflow check flag - controls whether to check for int32 overflow in loop + static constexpr bool kCachePtrInt32OverflowPossible = CachePtrInt32OverflowPossible_; + // attributes from traits static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadHeadDim = Traits::kPadHeadDim;