diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_nopage.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_nopage.cpp new file mode 100644 index 0000000000..c79db60e64 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask_nopage.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d128, bf16, true) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_nopage.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_nopage.cpp new file mode 100644 index 0000000000..ccb127dec9 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask_nopage.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d128, bf16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_nopage.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_nopage.cpp new file mode 100644 index 0000000000..130f2e4560 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask_nopage.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d128, fp16, true) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_nopage.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_nopage.cpp new file mode 100644 index 0000000000..dd122a3e7a --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask_nopage.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d128, fp16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp8_mask_nopage.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp8_mask_nopage.cpp new file mode 100644 index 0000000000..7fbb1f8e3e --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp8_mask_nopage.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d128, fp8, true) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp8_nmask_nopage.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp8_nmask_nopage.cpp new file mode 100644 index 0000000000..a34d77dd60 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp8_nmask_nopage.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d128, fp8, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_nopage.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_nopage.cpp new file mode 100644 index 0000000000..dff1796551 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_mask_nopage.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d64, bf16, true) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_nopage.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_nopage.cpp new file mode 100644 index 0000000000..389c34590f --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_bf16_nmask_nopage.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d64, bf16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_nopage.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_nopage.cpp new file mode 100644 index 0000000000..eef034ebae --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_mask_nopage.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d64, fp16, true) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_nopage.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_nopage.cpp new file mode 100644 index 0000000000..3b60689578 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp16_nmask_nopage.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d64, fp16, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp8_mask_nopage.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp8_mask_nopage.cpp new file mode 100644 index 0000000000..a02c5a0724 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp8_mask_nopage.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d64, fp8, true) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp8_nmask_nopage.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp8_nmask_nopage.cpp new file mode 100644 index 0000000000..7e0539519a --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d64_fp8_nmask_nopage.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d64, fp8, false) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp index c3ef2e5657..bbcaabb7ec 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -115,12 +115,13 @@ namespace { template + index_t PageSize, + bool EnablePaging = true> std::pair dispatch_one(const unified_attention_args& args, const stream_config& config) { return unified_attention_kernel_dispatch< - unified_attention_kernel_traits>(args, config); + unified_attention_kernel_traits>(args, config); } // --------------------------------------------------------------------------- @@ -132,7 +133,13 @@ std::pair dispatch_one(const unified_attention_args& args, // to the matching one; anything else (or any non-prefill variant) goes to // the PageSize=0 catch-all that keeps the legacy runtime-page-size code. // -// The fast-path payoff is two-fold: +// When args.kv_contiguous is set, we instead route to the contiguous-K/V +// (THD-layout) instance — a third compile-time variant per (prefill, +// dtype, mask) triple that skips the block_tables fetch entirely. Decode +// variants don't have a contiguous instance (callers with no KV cache +// don't have decode workloads); they fall back to the paged catch-all. +// +// The paged fast-path payoff is two-fold: // 1. Compile-time `page_size` lets the compiler strength-reduce every // `/ page_size`, `* page_size`, and `% page_size` inside the per-tile // address chain. div-by-32 collapses to `shr 5`, etc. @@ -150,6 +157,11 @@ std::pair dispatch_page_size(const unified_attention_args& args, if constexpr(V == KernelVariant::prefill_d128 || V == KernelVariant::prefill_d64) { + if(args.kv_contiguous) + { + // Contiguous (THD) instance — PageSize is irrelevant, EnablePaging=false. + return dispatch_one(args, config); + } switch(args.page_blk_size) { case 16: return dispatch_one(args, config); diff --git a/example/ck_tile/42_unified_attention/unified_attention.hpp b/example/ck_tile/42_unified_attention/unified_attention.hpp index 9f99281ba5..85ff84a925 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention.hpp @@ -81,6 +81,26 @@ struct unified_attention_args index_t num_seqs; // number of batches for q index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown) + // Layout selector for K/V: + // false (default) : paged KV cache — K/V are + // [num_blks, page_blk_size, num_kv_heads, head_size] + // and `block_tables_ptr` resolves logical → physical + // page. Used by vLLM / SGLang inference servers. + // true : contiguous (THD) — K/V are + // [num_kv_tokens, num_kv_heads, head_size] for the + // current request and `block_tables_ptr` is ignored. + // The kernel skips the per-tile page-table fetch and + // the / % page_size arithmetic entirely. Used by + // pretraining / flash-attention-style callers that + // don't have a shared KV cache. When this is true, + // `page_blk_size` is ignored; treat `num_blks` as a + // virtual page count: the K/V tensor view still has + // shape (num_blks * page_blk_size, head_dim), so the + // caller can either set num_blks=num_kv_tokens with + // page_blk_size=1 or any equivalent factorisation + // that yields the right `num_kv_tokens` total. + bool kv_contiguous = false; + // Set to true when the K/V cache is large enough that an int32 byte // offset into it can overflow (i.e. when // num_blocks * page_size * num_kv_heads * head_dim * sizeof(T) > INT32_MAX diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index e8061fa528..918b6d5b07 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -114,11 +114,18 @@ struct unified_attention_problem_traits struct variant_config; -// Each variant_config exposes `Pipeline` so the traits -// can pin the page size at compile time. PageSize=0 means "runtime page -// size" (legacy behaviour); the host dispatcher selects a non-zero value -// when it can prove the runtime `args.page_blk_size` matches one of the -// instances we compiled. +// Each variant_config exposes `Pipeline` +// so the traits can pin the page size and the paged-vs-contiguous toggle +// at compile time: +// - PageSize=0 : runtime page size (legacy paged behaviour). +// - PageSize>0 : compile-time-pinned paged behaviour. +// - EnablePaging=true (default): K/V are a paged KV cache. +// - EnablePaging=false : K/V are contiguous (THD) tensors — +// refresh_*_offsets skips block_tables +// entirely, PageSize is ignored. +// The host dispatcher (dispatch_variant in unified_attention.cpp) picks +// the matching instance from `args.page_blk_size` / `args.kv_contiguous` +// at launch time. template <> struct variant_config { @@ -127,9 +134,11 @@ struct variant_config static constexpr index_t BlockSize = 32; using BlockWarps = sequence<8, 1, 1>; using WarpGemmShape = sequence<32, 32, 16>; - template - using Pipeline = - UnifiedAttentionPipeline; + template + using Pipeline = UnifiedAttentionPipeline; static constexpr bool kUseDecodeGrid = false; }; @@ -141,9 +150,11 @@ struct variant_config static constexpr index_t BlockSize = 32; using BlockWarps = sequence<4, 1, 1>; using WarpGemmShape = sequence<32, 32, 16>; - template - using Pipeline = - UnifiedAttentionPipeline; + template + using Pipeline = UnifiedAttentionPipeline; static constexpr bool kUseDecodeGrid = false; }; @@ -155,9 +166,11 @@ struct variant_config static constexpr index_t BlockSize = 32; using BlockWarps = sequence<1, 1, 1>; using WarpGemmShape = sequence<32, 32, 16>; - template - using Pipeline = - UnifiedAttentionPipeline; + template + using Pipeline = UnifiedAttentionPipeline; static constexpr bool kUseDecodeGrid = true; }; @@ -169,9 +182,11 @@ struct variant_config static constexpr index_t BlockSize = 32; using BlockWarps = sequence<1, 1, 1>; using WarpGemmShape = sequence<16, 16, 32>; - template - using Pipeline = - UnifiedAttentionPipeline; + template + using Pipeline = UnifiedAttentionPipeline; static constexpr bool kUseDecodeGrid = true; }; @@ -183,9 +198,11 @@ struct variant_config static constexpr index_t BlockSize = 64; using BlockWarps = sequence<8, 1, 1>; using WarpGemmShape = sequence<32, 32, 16>; - template - using Pipeline = - UnifiedAttentionPipeline; + template + using Pipeline = UnifiedAttentionPipeline; static constexpr bool kUseDecodeGrid = false; }; @@ -197,9 +214,11 @@ struct variant_config static constexpr index_t BlockSize = 64; using BlockWarps = sequence<4, 1, 1>; using WarpGemmShape = sequence<32, 32, 16>; - template - using Pipeline = - UnifiedAttentionPipeline; + template + using Pipeline = UnifiedAttentionPipeline; static constexpr bool kUseDecodeGrid = false; }; @@ -211,9 +230,11 @@ struct variant_config static constexpr index_t BlockSize = 64; using BlockWarps = sequence<2, 1, 1>; using WarpGemmShape = sequence<32, 32, 16>; - template - using Pipeline = - UnifiedAttentionPipeline; + template + using Pipeline = UnifiedAttentionPipeline; static constexpr bool kUseDecodeGrid = true; }; @@ -225,9 +246,11 @@ struct variant_config static constexpr index_t BlockSize = 64; using BlockWarps = sequence<1, 1, 1>; using WarpGemmShape = sequence<16, 16, 32>; - template - using Pipeline = - UnifiedAttentionPipeline; + template + using Pipeline = UnifiedAttentionPipeline; static constexpr bool kUseDecodeGrid = true; }; @@ -237,25 +260,35 @@ struct variant_config // Single templated trait. Pulls per-variant knobs from variant_config and // per-dtype element types from unified_attention_problem_traits. // ============================================================================= -// kPageSize: optional compile-time pin of the runtime `page_size`. Default -// 0 keeps the legacy runtime-page-size behaviour; a non-zero value lets the -// pipeline strength-reduce the per-tile arithmetic *and* widen the Tier 0 / -// Tier 2 gate from the conservative `KY0_step_N <= 16` hedge to the real -// `KY0_step_N <= kPageSize` condition. The host dispatcher (dispatch_variant -// in unified_attention.cpp) picks the matching instance at launch time. +// kPageSize : optional compile-time pin of the runtime `page_size`. +// Default 0 keeps the legacy runtime-page-size behaviour; +// a non-zero value lets the pipeline strength-reduce the +// per-tile arithmetic *and* widen the Tier 0 / Tier 2 gate +// from the conservative `KY0_step_N <= 16` hedge to the +// real `KY0_step_N <= kPageSize` condition. +// kEnablePaging : true (default) — paged KV cache (vLLM/SGLang). false — +// contiguous (THD) K/V tensors that skip the entire +// block_tables fetch chain. Mutually exclusive use cases; +// only one path is emitted per instance. PageSize is +// ignored when kEnablePaging == false. +// The host dispatcher (dispatch_variant in unified_attention.cpp) picks +// the matching instance from `args.page_blk_size` / `args.kv_contiguous` +// at launch time. template + ck_tile::index_t kPageSize_ = 0, + bool kEnablePaging_ = true> struct unified_attention_kernel_traits { using cfg = variant_config; using dt = unified_attention_problem_traits; - static constexpr auto date_type = DataType; - static constexpr bool is_masking = IsMasking; - static constexpr KernelVariant variant = V; - static constexpr ck_tile::index_t kPageSize = kPageSize_; + static constexpr auto date_type = DataType; + static constexpr bool is_masking = IsMasking; + static constexpr KernelVariant variant = V; + static constexpr ck_tile::index_t kPageSize = kPageSize_; + static constexpr bool kEnablePaging = kEnablePaging_; static constexpr index_t HEAD_SIZE = cfg::HeadSize; static constexpr index_t kBlockM = cfg::BlockM; @@ -301,7 +334,9 @@ struct unified_attention_kernel_traits unified_attention_traits>; using unified_attention_pipeline = - typename cfg::template Pipeline; + typename cfg::template Pipeline; using epilogue = Default2DEpilogue unified_attention_kernel_dispatch(const unified_attention } // namespace ck_tile -// One-line instantiation per (V, DataType, IsMasking, PageSize) combination. -// Each instance .cpp consists of exactly one of these calls. PAGE_SIZE_ = 0 -// is the legacy runtime-page-size instance (catch-all fallback). Non-zero -// values pin the runtime `page_size` argument to that literal — see the -// dispatch_variant switch in unified_attention.cpp for routing. -#define INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_) \ - template <> \ - std::pair unified_attention_kernel_dispatch< \ - unified_attention_kernel_traits>(const unified_attention_args& args, \ - const stream_config& config) \ - { \ - using Traits = unified_attention_kernel_traits< \ - KernelVariant::VARIANT_, \ - unified_attention_args::data_type_enum::DTYPE_, \ - IS_MASK_, \ - PAGE_SIZE_>; \ - return std::make_pair(true, \ - unified_attention_kernel_launch(args, config)); \ +// One-line instantiation per (V, DataType, IsMasking, PageSize, EnablePaging) +// combination. Each instance .cpp consists of exactly one of these calls. +// - PAGE_SIZE_ = 0 + ENABLE_PAGING_ = true : legacy runtime-page-size +// instance (catch-all paged +// fallback). +// - PAGE_SIZE_ > 0 + ENABLE_PAGING_ = true : constexpr-pinned paged. +// - PAGE_SIZE_ = 0 + ENABLE_PAGING_ = false : contiguous (THD) instance. +// See dispatch_variant in unified_attention.cpp for runtime routing. +#define INST_UNIFIED_ATTENTION_DISPATCH_PS_NP( \ + VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, ENABLE_PAGING_) \ + template <> \ + std::pair unified_attention_kernel_dispatch< \ + unified_attention_kernel_traits>(const unified_attention_args& args, \ + const stream_config& config) \ + { \ + using Traits = unified_attention_kernel_traits< \ + KernelVariant::VARIANT_, \ + unified_attention_args::data_type_enum::DTYPE_, \ + IS_MASK_, \ + PAGE_SIZE_, \ + ENABLE_PAGING_>; \ + return std::make_pair(true, \ + unified_attention_kernel_launch(args, config)); \ } +// Backward-compat shorthand: PageSize-only specialization on the paged path. +#define INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_) \ + INST_UNIFIED_ATTENTION_DISPATCH_PS_NP(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, true) + // Backward-compat shorthand for the existing one-liners — the default -// `PageSize = 0` instance is the catch-all runtime-page-size kernel. -#define INST_UNIFIED_ATTENTION_DISPATCH(VARIANT_, DTYPE_, IS_MASK_) \ - INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, 0) +// PageSize = 0, paged, catch-all runtime-page-size kernel. +#define INST_UNIFIED_ATTENTION_DISPATCH(VARIANT_, DTYPE_, IS_MASK_) \ + INST_UNIFIED_ATTENTION_DISPATCH_PS_NP(VARIANT_, DTYPE_, IS_MASK_, 0, true) + +// Contiguous (THD) K/V instance. PageSize is forced to 0 (ignored anyway +// when EnablePaging = false). One instance per (variant, dtype, mask) +// triple. +#define INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(VARIANT_, DTYPE_, IS_MASK_) \ + INST_UNIFIED_ATTENTION_DISPATCH_PS_NP(VARIANT_, DTYPE_, IS_MASK_, 0, false) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index b79f1190bd..3a6b112e9a 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -48,22 +48,38 @@ namespace ck_tile { -// kPageSize_ : non-type template parameter that pins the runtime -// `page_size` argument to a compile-time constant when > 0. The host -// dispatcher selects an instance whose kPageSize_ matches `args.page_blk_size` -// and routes execution there; instances compiled with kPageSize_ == 0 keep -// the legacy runtime-page-size path and serve as the catch-all fallback for -// uncommon page sizes. Having the value at compile time: -// 1. lets the compiler strength-reduce every `/ page_size`, `* page_size`, -// `% page_size` into shift / multiply-by-magic-constant on the literal -// (e.g. div-by-32 → shr 5); -// 2. lets the Tier 0 / Tier 2 gate use the real `KY0_step_N <= kPageSize` -// condition instead of the conservative `KY0_step_N <= 16` hedge, so -// prefill_d128 bf16, prefill_d64 bf16, and prefill_d64 fp8 also gain -// the scalar-promote + LDS-cache fast path on their natural page sizes. +// kPageSize_ : non-type template parameter that pins the runtime +// `page_size` argument to a compile-time constant when > 0. +// The host dispatcher selects an instance whose kPageSize_ +// matches `args.page_blk_size` and routes execution there; +// instances compiled with kPageSize_ == 0 keep the legacy +// runtime-page-size path and serve as the catch-all +// fallback for uncommon page sizes. Having the value at +// compile time: +// 1. lets the compiler strength-reduce every / * % +// page_size into shift / multiply-by-magic; +// 2. lets the Tier 0 / Tier 2 gate use the real +// `KY0_step_N <= kPageSize` condition instead of +// the conservative `<= 16` hedge. +// Only meaningful when kEnablePaging_ == true; ignored +// (and conventionally set to 0) when paging is disabled. +// kEnablePaging_ : true (default) is the paged-KV-cache layout used by +// vLLM/SGLang inference servers — K/V are stored as +// [num_blocks, page_size, num_heads, head_dim] and a +// `block_tables` index resolves logical → physical pages. +// false is the contiguous "THD" layout used by +// pretraining / flash-attention-style callers — +// K/V are stored as [num_kv_tokens, num_heads, head_dim] +// and `refresh_*_offsets` just multiplies the logical +// token by the row stride. The contiguous path skips +// the entire block_tables fetch, the per-tile / % +// page_size arithmetic, and the Tier 0 / Tier 2 +// LDS-cache machinery — `block_tables_ptr` and +// `page_size_runtime` are ignored in that mode. template + ck_tile::index_t kPageSize_ = 0, + bool kEnablePaging_ = true> struct UnifiedAttentionPipeline { using Problem = ck_tile::remove_cvref_t; @@ -72,6 +88,8 @@ struct UnifiedAttentionPipeline // Compile-time page size (0 = runtime). See class-level comment above. static constexpr ck_tile::index_t kPageSize = kPageSize_; static constexpr bool kHasCePageSize = (kPageSize_ > 0); + // Paged KV cache vs contiguous THD layout. See class-level comment above. + static constexpr bool kEnablePaging = kEnablePaging_; using QDataType = ck_tile::remove_cvref_t; using KDataType = ck_tile::remove_cvref_t; using VDataType = ck_tile::remove_cvref_t; @@ -147,6 +165,10 @@ struct UnifiedAttentionPipeline CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetPageTableLdsBytes() { + // Contiguous (THD) layout doesn't go through block_tables at all, + // so the LDS cache is dead — skip the allocation entirely. + if constexpr(!kEnablePaging) return 0; + // Allocate the cache only for the kernel instances where Tier 0's // constexpr gate fires (otherwise the lambdas wouldn't read it and // the LDS would sit idle, hurting occupancy for nothing). Mirror the @@ -677,12 +699,16 @@ struct UnifiedAttentionPipeline // (newly ON — biggest win) // prefill_d64 bf16 KY0_step_N=64 @ ps=64 74.4 → 73.4 ms (-1.3%) // (newly ON; small win) + // Tier 0 + Tier 2 only make sense on the paged path — the + // contiguous (THD) layout has no block_tables to scalar-promote. constexpr index_t kKPageSizeCap = kHasCePageSize ? kPageSize : index_t{16}; constexpr index_t kVPageSizeCap = kHasCePageSize ? kPageSize : index_t{16}; constexpr bool kScalarPromoteKPageIdx = + kEnablePaging && (Problem::kBlockSize >= 8 * ck_tile::get_warp_size()) && (KNRepeat >= 2) && (KY0_step_N <= kKPageSizeCap); constexpr bool kScalarPromoteVPageIdx = + kEnablePaging && (Problem::kBlockSize >= 8 * ck_tile::get_warp_size()) && (VNRepeat >= 2) && (VY0_step_N <= kVPageSizeCap); @@ -742,7 +768,22 @@ struct UnifiedAttentionPipeline auto refresh_k_offsets = [&](index_t k_tile_idx) { static_for<0, KNRepeat, 1>{}([&](auto i) { - if constexpr(kScalarPromoteKPageIdx) + if constexpr(!kEnablePaging) + { + // Contiguous (THD) layout: K is one flat + // [num_kv_tokens, head_dim] tensor for the current + // kv-head, so the offset is just `token * row_stride`. + // No block_tables fetch, no / % page_size — the entire + // per-tile address-comp chain collapses to a single + // imad on a per-lane token index. Frees up the cycles + // Tier 0/2 were spending paying down the indirection. + const index_t logical_token = split_token_offset + + k_tile_idx * kPageBlockSize + k_thread_n_pos + + static_cast(i.value) * KY0_step_N; + k_page_offsets(i) = + static_cast(logical_token) * k_row_stride; + } + else if constexpr(kScalarPromoteKPageIdx) { // Compute the uniform per-`i` base in scalar; force the // resulting page-table index into an SGPR. Tier 2 reads @@ -787,7 +828,16 @@ struct UnifiedAttentionPipeline }; auto refresh_v_offsets = [&](index_t v_tile_idx) { static_for<0, VNRepeat, 1>{}([&](auto i) { - if constexpr(kScalarPromoteVPageIdx) + if constexpr(!kEnablePaging) + { + // Contiguous (THD) layout: see refresh_k_offsets above. + const index_t logical_token = split_token_offset + + v_tile_idx * kPageBlockSize + v_thread_n_pos + + static_cast(i.value) * VY0_step_N; + v_page_offsets(i) = + static_cast(logical_token) * v_row_stride; + } + else if constexpr(kScalarPromoteVPageIdx) { const index_t i_base_token = split_token_offset + v_tile_idx * kPageBlockSize +