CK-UA: optional paging — contiguous (THD) K/V path, prefill_d128 fp8 -28%

Add a `bool kEnablePaging_` non-type template parameter on
UnifiedAttentionPipeline (default true preserves the paged behaviour).
When false, `refresh_*_offsets` collapses to a single per-row
`logical_token * row_stride` imad — no block_tables fetch, no
/ % page_size arithmetic, no Tier 0 scalar-promote, no Tier 2 LDS-cache
populate. The host selects between paths via a new
`args.kv_contiguous` runtime flag plumbed through dispatch_variant<V>.

Twelve new prefill instances pin EnablePaging=false:
  prefill_d{64,128} × {fp16, bf16, fp8} × {mask, nmask}

Decode variants stay on the paged path — callers without a KV cache
don't have decode workloads, and the binary-size cost isn't justified.

Measured impact on the same physical K/V memory (sq=1×4096, causal,
page_size=32 paged baseline, MI355, n=30 iters):

  variant            sk     paged    contig   Δ
  prefill_d64  bf16   4096   0.274    0.227   -17.1 %
  prefill_d64  bf16  16384   1.529    1.198   -21.6 %
  prefill_d64  bf16  32768   3.218    2.505   -22.1 %
  prefill_d64  fp8    4096   0.299    0.235   -21.4 %
  prefill_d64  fp8   16384   1.489    1.150   -22.7 %
  prefill_d64  fp8   32768   3.054    2.386   -21.9 %
  prefill_d128 bf16   4096   0.493    0.397   -19.3 %
  prefill_d128 bf16  16384   2.638    2.224   -15.7 %
  prefill_d128 bf16  32768   5.731    4.598   -19.8 %
  prefill_d128 fp8    4096   0.476    0.341   -28.3 %
  prefill_d128 fp8   16384   2.416    1.792   -25.8 %
  prefill_d128 fp8   32768   4.973    3.727   -25.0 %

prefill_d128 fp8 at -28 % is the single biggest UA optimisation
measured to date — bigger than Tier 0 (-12 %), Tier 2 (-5 %), and the
Tier-3 d=64 fp8 win (-16 %).

Correctness validated by bit-exact comparison against the paged
instance with page_size=32 and identity block_tables on 48 shape ×
dtype × mask combinations.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-05-19 13:15:31 +00:00
parent 06e1a70e7a
commit c9bc5350c8
16 changed files with 350 additions and 85 deletions

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d128, bf16, true)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d128, bf16, false)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d128, fp16, true)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d128, fp16, false)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d128, fp8, true)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d128, fp8, false)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d64, bf16, true)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d64, bf16, false)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d64, fp16, true)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d64, fp16, false)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d64, fp8, true)
} // namespace ck_tile

View File

@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
namespace ck_tile {
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d64, fp8, false)
} // namespace ck_tile

View File

@@ -115,12 +115,13 @@ namespace {
template <KernelVariant V,
unified_attention_args::data_type_enum DType,
bool IsMask,
index_t PageSize>
index_t PageSize,
bool EnablePaging = true>
std::pair<bool, float> dispatch_one(const unified_attention_args& args,
const stream_config& config)
{
return unified_attention_kernel_dispatch<
unified_attention_kernel_traits<V, DType, IsMask, PageSize>>(args, config);
unified_attention_kernel_traits<V, DType, IsMask, PageSize, EnablePaging>>(args, config);
}
// ---------------------------------------------------------------------------
@@ -132,7 +133,13 @@ std::pair<bool, float> dispatch_one(const unified_attention_args& args,
// to the matching one; anything else (or any non-prefill variant) goes to
// the PageSize=0 catch-all that keeps the legacy runtime-page-size code.
//
// The fast-path payoff is two-fold:
// When args.kv_contiguous is set, we instead route to the contiguous-K/V
// (THD-layout) instance — a third compile-time variant per (prefill,
// dtype, mask) triple that skips the block_tables fetch entirely. Decode
// variants don't have a contiguous instance (callers with no KV cache
// don't have decode workloads); they fall back to the paged catch-all.
//
// The paged fast-path payoff is two-fold:
// 1. Compile-time `page_size` lets the compiler strength-reduce every
// `/ page_size`, `* page_size`, and `% page_size` inside the per-tile
// address chain. div-by-32 collapses to `shr 5`, etc.
@@ -150,6 +157,11 @@ std::pair<bool, float> dispatch_page_size(const unified_attention_args& args,
if constexpr(V == KernelVariant::prefill_d128 ||
V == KernelVariant::prefill_d64)
{
if(args.kv_contiguous)
{
// Contiguous (THD) instance — PageSize is irrelevant, EnablePaging=false.
return dispatch_one<V, DType, IsMask, 0, /*EnablePaging=*/false>(args, config);
}
switch(args.page_blk_size)
{
case 16: return dispatch_one<V, DType, IsMask, 16>(args, config);

View File

@@ -81,6 +81,26 @@ struct unified_attention_args
index_t num_seqs; // number of batches for q
index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown)
// Layout selector for K/V:
// false (default) : paged KV cache — K/V are
// [num_blks, page_blk_size, num_kv_heads, head_size]
// and `block_tables_ptr` resolves logical → physical
// page. Used by vLLM / SGLang inference servers.
// true : contiguous (THD) — K/V are
// [num_kv_tokens, num_kv_heads, head_size] for the
// current request and `block_tables_ptr` is ignored.
// The kernel skips the per-tile page-table fetch and
// the / % page_size arithmetic entirely. Used by
// pretraining / flash-attention-style callers that
// don't have a shared KV cache. When this is true,
// `page_blk_size` is ignored; treat `num_blks` as a
// virtual page count: the K/V tensor view still has
// shape (num_blks * page_blk_size, head_dim), so the
// caller can either set num_blks=num_kv_tokens with
// page_blk_size=1 or any equivalent factorisation
// that yields the right `num_kv_tokens` total.
bool kv_contiguous = false;
// Set to true when the K/V cache is large enough that an int32 byte
// offset into it can overflow (i.e. when
// num_blocks * page_size * num_kv_heads * head_dim * sizeof(T) > INT32_MAX

View File

@@ -114,11 +114,18 @@ struct unified_attention_problem_traits<unified_attention_args::data_type_enum::
template <KernelVariant V>
struct variant_config;
// Each variant_config exposes `Pipeline<Problem, PageSize>` so the traits
// can pin the page size at compile time. PageSize=0 means "runtime page
// size" (legacy behaviour); the host dispatcher selects a non-zero value
// when it can prove the runtime `args.page_blk_size` matches one of the
// instances we compiled.
// Each variant_config exposes `Pipeline<Problem, PageSize, EnablePaging>`
// so the traits can pin the page size and the paged-vs-contiguous toggle
// at compile time:
// - PageSize=0 : runtime page size (legacy paged behaviour).
// - PageSize>0 : compile-time-pinned paged behaviour.
// - EnablePaging=true (default): K/V are a paged KV cache.
// - EnablePaging=false : K/V are contiguous (THD) tensors —
// refresh_*_offsets skips block_tables
// entirely, PageSize is ignored.
// The host dispatcher (dispatch_variant in unified_attention.cpp) picks
// the matching instance from `args.page_blk_size` / `args.kv_contiguous`
// at launch time.
template <>
struct variant_config<KernelVariant::prefill_d128>
{
@@ -127,9 +134,11 @@ struct variant_config<KernelVariant::prefill_d128>
static constexpr index_t BlockSize = 32;
using BlockWarps = sequence<8, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
template <typename Problem, index_t PageSize = 0>
using Pipeline =
UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineDefaultPolicy, PageSize>;
template <typename Problem, index_t PageSize = 0, bool EnablePaging = true>
using Pipeline = UnifiedAttentionPipeline<Problem,
UnifiedAttentionPipelineDefaultPolicy,
PageSize,
EnablePaging>;
static constexpr bool kUseDecodeGrid = false;
};
@@ -141,9 +150,11 @@ struct variant_config<KernelVariant::decode_d128_m128>
static constexpr index_t BlockSize = 32;
using BlockWarps = sequence<4, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
template <typename Problem, index_t PageSize = 0>
using Pipeline =
UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineDefaultPolicy, PageSize>;
template <typename Problem, index_t PageSize = 0, bool EnablePaging = true>
using Pipeline = UnifiedAttentionPipeline<Problem,
UnifiedAttentionPipelineDefaultPolicy,
PageSize,
EnablePaging>;
static constexpr bool kUseDecodeGrid = false;
};
@@ -155,9 +166,11 @@ struct variant_config<KernelVariant::decode_d128_m32>
static constexpr index_t BlockSize = 32;
using BlockWarps = sequence<1, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
template <typename Problem, index_t PageSize = 0>
using Pipeline =
UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineTinyDecodePolicy, PageSize>;
template <typename Problem, index_t PageSize = 0, bool EnablePaging = true>
using Pipeline = UnifiedAttentionPipeline<Problem,
UnifiedAttentionPipelineTinyDecodePolicy,
PageSize,
EnablePaging>;
static constexpr bool kUseDecodeGrid = true;
};
@@ -169,9 +182,11 @@ struct variant_config<KernelVariant::decode_d128_m16>
static constexpr index_t BlockSize = 32;
using BlockWarps = sequence<1, 1, 1>;
using WarpGemmShape = sequence<16, 16, 32>;
template <typename Problem, index_t PageSize = 0>
using Pipeline =
UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineTinyDecodePolicy, PageSize>;
template <typename Problem, index_t PageSize = 0, bool EnablePaging = true>
using Pipeline = UnifiedAttentionPipeline<Problem,
UnifiedAttentionPipelineTinyDecodePolicy,
PageSize,
EnablePaging>;
static constexpr bool kUseDecodeGrid = true;
};
@@ -183,9 +198,11 @@ struct variant_config<KernelVariant::prefill_d64>
static constexpr index_t BlockSize = 64;
using BlockWarps = sequence<8, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
template <typename Problem, index_t PageSize = 0>
using Pipeline =
UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineDefaultPolicy, PageSize>;
template <typename Problem, index_t PageSize = 0, bool EnablePaging = true>
using Pipeline = UnifiedAttentionPipeline<Problem,
UnifiedAttentionPipelineDefaultPolicy,
PageSize,
EnablePaging>;
static constexpr bool kUseDecodeGrid = false;
};
@@ -197,9 +214,11 @@ struct variant_config<KernelVariant::decode_d64_m128>
static constexpr index_t BlockSize = 64;
using BlockWarps = sequence<4, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
template <typename Problem, index_t PageSize = 0>
using Pipeline =
UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineDefaultPolicy, PageSize>;
template <typename Problem, index_t PageSize = 0, bool EnablePaging = true>
using Pipeline = UnifiedAttentionPipeline<Problem,
UnifiedAttentionPipelineDefaultPolicy,
PageSize,
EnablePaging>;
static constexpr bool kUseDecodeGrid = false;
};
@@ -211,9 +230,11 @@ struct variant_config<KernelVariant::decode_d64_m64>
static constexpr index_t BlockSize = 64;
using BlockWarps = sequence<2, 1, 1>;
using WarpGemmShape = sequence<32, 32, 16>;
template <typename Problem, index_t PageSize = 0>
using Pipeline =
UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineDecodePolicy, PageSize>;
template <typename Problem, index_t PageSize = 0, bool EnablePaging = true>
using Pipeline = UnifiedAttentionPipeline<Problem,
UnifiedAttentionPipelineDecodePolicy,
PageSize,
EnablePaging>;
static constexpr bool kUseDecodeGrid = true;
};
@@ -225,9 +246,11 @@ struct variant_config<KernelVariant::decode_d64_m16>
static constexpr index_t BlockSize = 64;
using BlockWarps = sequence<1, 1, 1>;
using WarpGemmShape = sequence<16, 16, 32>;
template <typename Problem, index_t PageSize = 0>
using Pipeline =
UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineTinyDecodePolicy, PageSize>;
template <typename Problem, index_t PageSize = 0, bool EnablePaging = true>
using Pipeline = UnifiedAttentionPipeline<Problem,
UnifiedAttentionPipelineTinyDecodePolicy,
PageSize,
EnablePaging>;
static constexpr bool kUseDecodeGrid = true;
};
@@ -237,25 +260,35 @@ struct variant_config<KernelVariant::decode_d64_m16>
// Single templated trait. Pulls per-variant knobs from variant_config<V> and
// per-dtype element types from unified_attention_problem_traits<DataType>.
// =============================================================================
// kPageSize: optional compile-time pin of the runtime `page_size`. Default
// 0 keeps the legacy runtime-page-size behaviour; a non-zero value lets the
// pipeline strength-reduce the per-tile arithmetic *and* widen the Tier 0 /
// Tier 2 gate from the conservative `KY0_step_N <= 16` hedge to the real
// `KY0_step_N <= kPageSize` condition. The host dispatcher (dispatch_variant
// in unified_attention.cpp) picks the matching instance at launch time.
// kPageSize : optional compile-time pin of the runtime `page_size`.
// Default 0 keeps the legacy runtime-page-size behaviour;
// a non-zero value lets the pipeline strength-reduce the
// per-tile arithmetic *and* widen the Tier 0 / Tier 2 gate
// from the conservative `KY0_step_N <= 16` hedge to the
// real `KY0_step_N <= kPageSize` condition.
// kEnablePaging : true (default) — paged KV cache (vLLM/SGLang). false —
// contiguous (THD) K/V tensors that skip the entire
// block_tables fetch chain. Mutually exclusive use cases;
// only one path is emitted per instance. PageSize is
// ignored when kEnablePaging == false.
// The host dispatcher (dispatch_variant in unified_attention.cpp) picks
// the matching instance from `args.page_blk_size` / `args.kv_contiguous`
// at launch time.
template <KernelVariant V,
unified_attention_args::data_type_enum DataType,
bool IsMasking,
ck_tile::index_t kPageSize_ = 0>
ck_tile::index_t kPageSize_ = 0,
bool kEnablePaging_ = true>
struct unified_attention_kernel_traits
{
using cfg = variant_config<V>;
using dt = unified_attention_problem_traits<DataType>;
static constexpr auto date_type = DataType;
static constexpr bool is_masking = IsMasking;
static constexpr KernelVariant variant = V;
static constexpr ck_tile::index_t kPageSize = kPageSize_;
static constexpr auto date_type = DataType;
static constexpr bool is_masking = IsMasking;
static constexpr KernelVariant variant = V;
static constexpr ck_tile::index_t kPageSize = kPageSize_;
static constexpr bool kEnablePaging = kEnablePaging_;
static constexpr index_t HEAD_SIZE = cfg::HeadSize;
static constexpr index_t kBlockM = cfg::BlockM;
@@ -301,7 +334,9 @@ struct unified_attention_kernel_traits
unified_attention_traits>;
using unified_attention_pipeline =
typename cfg::template Pipeline<unified_attention_pipeline_problem, kPageSize_>;
typename cfg::template Pipeline<unified_attention_pipeline_problem,
kPageSize_,
kEnablePaging_>;
using epilogue =
Default2DEpilogue<Default2DEpilogueProblem<typename dt::acc_dtype,
@@ -402,31 +437,47 @@ std::pair<bool, float> unified_attention_kernel_dispatch(const unified_attention
} // namespace ck_tile
// One-line instantiation per (V, DataType, IsMasking, PageSize) combination.
// Each instance .cpp consists of exactly one of these calls. PAGE_SIZE_ = 0
// is the legacy runtime-page-size instance (catch-all fallback). Non-zero
// values pin the runtime `page_size` argument to that literal — see the
// dispatch_variant<V> switch in unified_attention.cpp for routing.
#define INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_) \
template <> \
std::pair<bool, float> unified_attention_kernel_dispatch< \
unified_attention_kernel_traits<KernelVariant::VARIANT_, \
unified_attention_args::data_type_enum::DTYPE_, \
IS_MASK_, \
PAGE_SIZE_>>(const unified_attention_args& args, \
const stream_config& config) \
{ \
using Traits = unified_attention_kernel_traits< \
KernelVariant::VARIANT_, \
unified_attention_args::data_type_enum::DTYPE_, \
IS_MASK_, \
PAGE_SIZE_>; \
return std::make_pair(true, \
unified_attention_kernel_launch<typename Traits::kernel, \
Traits::kUseDecodeGrid>(args, config)); \
// One-line instantiation per (V, DataType, IsMasking, PageSize, EnablePaging)
// combination. Each instance .cpp consists of exactly one of these calls.
// - PAGE_SIZE_ = 0 + ENABLE_PAGING_ = true : legacy runtime-page-size
// instance (catch-all paged
// fallback).
// - PAGE_SIZE_ > 0 + ENABLE_PAGING_ = true : constexpr-pinned paged.
// - PAGE_SIZE_ = 0 + ENABLE_PAGING_ = false : contiguous (THD) instance.
// See dispatch_variant<V> in unified_attention.cpp for runtime routing.
#define INST_UNIFIED_ATTENTION_DISPATCH_PS_NP( \
VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, ENABLE_PAGING_) \
template <> \
std::pair<bool, float> unified_attention_kernel_dispatch< \
unified_attention_kernel_traits<KernelVariant::VARIANT_, \
unified_attention_args::data_type_enum::DTYPE_, \
IS_MASK_, \
PAGE_SIZE_, \
ENABLE_PAGING_>>(const unified_attention_args& args, \
const stream_config& config) \
{ \
using Traits = unified_attention_kernel_traits< \
KernelVariant::VARIANT_, \
unified_attention_args::data_type_enum::DTYPE_, \
IS_MASK_, \
PAGE_SIZE_, \
ENABLE_PAGING_>; \
return std::make_pair(true, \
unified_attention_kernel_launch<typename Traits::kernel, \
Traits::kUseDecodeGrid>(args, config)); \
}
// Backward-compat shorthand: PageSize-only specialization on the paged path.
#define INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_) \
INST_UNIFIED_ATTENTION_DISPATCH_PS_NP(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, true)
// Backward-compat shorthand for the existing one-liners — the default
// `PageSize = 0` instance is the catch-all runtime-page-size kernel.
#define INST_UNIFIED_ATTENTION_DISPATCH(VARIANT_, DTYPE_, IS_MASK_) \
INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, 0)
// PageSize = 0, paged, catch-all runtime-page-size kernel.
#define INST_UNIFIED_ATTENTION_DISPATCH(VARIANT_, DTYPE_, IS_MASK_) \
INST_UNIFIED_ATTENTION_DISPATCH_PS_NP(VARIANT_, DTYPE_, IS_MASK_, 0, true)
// Contiguous (THD) K/V instance. PageSize is forced to 0 (ignored anyway
// when EnablePaging = false). One instance per (variant, dtype, mask)
// triple.
#define INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(VARIANT_, DTYPE_, IS_MASK_) \
INST_UNIFIED_ATTENTION_DISPATCH_PS_NP(VARIANT_, DTYPE_, IS_MASK_, 0, false)

View File

@@ -48,22 +48,38 @@
namespace ck_tile {
// kPageSize_ : non-type template parameter that pins the runtime
// `page_size` argument to a compile-time constant when > 0. The host
// dispatcher selects an instance whose kPageSize_ matches `args.page_blk_size`
// and routes execution there; instances compiled with kPageSize_ == 0 keep
// the legacy runtime-page-size path and serve as the catch-all fallback for
// uncommon page sizes. Having the value at compile time:
// 1. lets the compiler strength-reduce every `/ page_size`, `* page_size`,
// `% page_size` into shift / multiply-by-magic-constant on the literal
// (e.g. div-by-32 → shr 5);
// 2. lets the Tier 0 / Tier 2 gate use the real `KY0_step_N <= kPageSize`
// condition instead of the conservative `KY0_step_N <= 16` hedge, so
// prefill_d128 bf16, prefill_d64 bf16, and prefill_d64 fp8 also gain
// the scalar-promote + LDS-cache fast path on their natural page sizes.
// kPageSize_ : non-type template parameter that pins the runtime
// `page_size` argument to a compile-time constant when > 0.
// The host dispatcher selects an instance whose kPageSize_
// matches `args.page_blk_size` and routes execution there;
// instances compiled with kPageSize_ == 0 keep the legacy
// runtime-page-size path and serve as the catch-all
// fallback for uncommon page sizes. Having the value at
// compile time:
// 1. lets the compiler strength-reduce every / * %
// page_size into shift / multiply-by-magic;
// 2. lets the Tier 0 / Tier 2 gate use the real
// `KY0_step_N <= kPageSize` condition instead of
// the conservative `<= 16` hedge.
// Only meaningful when kEnablePaging_ == true; ignored
// (and conventionally set to 0) when paging is disabled.
// kEnablePaging_ : true (default) is the paged-KV-cache layout used by
// vLLM/SGLang inference servers — K/V are stored as
// [num_blocks, page_size, num_heads, head_dim] and a
// `block_tables` index resolves logical → physical pages.
// false is the contiguous "THD" layout used by
// pretraining / flash-attention-style callers —
// K/V are stored as [num_kv_tokens, num_heads, head_dim]
// and `refresh_*_offsets` just multiplies the logical
// token by the row stride. The contiguous path skips
// the entire block_tables fetch, the per-tile / %
// page_size arithmetic, and the Tier 0 / Tier 2
// LDS-cache machinery — `block_tables_ptr` and
// `page_size_runtime` are ignored in that mode.
template <typename Problem_,
typename Policy_ = UnifiedAttentionPipelineDefaultPolicy,
ck_tile::index_t kPageSize_ = 0>
ck_tile::index_t kPageSize_ = 0,
bool kEnablePaging_ = true>
struct UnifiedAttentionPipeline
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
@@ -72,6 +88,8 @@ struct UnifiedAttentionPipeline
// Compile-time page size (0 = runtime). See class-level comment above.
static constexpr ck_tile::index_t kPageSize = kPageSize_;
static constexpr bool kHasCePageSize = (kPageSize_ > 0);
// Paged KV cache vs contiguous THD layout. See class-level comment above.
static constexpr bool kEnablePaging = kEnablePaging_;
using QDataType = ck_tile::remove_cvref_t<typename Problem::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename Problem::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename Problem::VDataType>;
@@ -147,6 +165,10 @@ struct UnifiedAttentionPipeline
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetPageTableLdsBytes()
{
// Contiguous (THD) layout doesn't go through block_tables at all,
// so the LDS cache is dead — skip the allocation entirely.
if constexpr(!kEnablePaging) return 0;
// Allocate the cache only for the kernel instances where Tier 0's
// constexpr gate fires (otherwise the lambdas wouldn't read it and
// the LDS would sit idle, hurting occupancy for nothing). Mirror the
@@ -677,12 +699,16 @@ struct UnifiedAttentionPipeline
// (newly ON — biggest win)
// prefill_d64 bf16 KY0_step_N=64 @ ps=64 74.4 → 73.4 ms (-1.3%)
// (newly ON; small win)
// Tier 0 + Tier 2 only make sense on the paged path — the
// contiguous (THD) layout has no block_tables to scalar-promote.
constexpr index_t kKPageSizeCap = kHasCePageSize ? kPageSize : index_t{16};
constexpr index_t kVPageSizeCap = kHasCePageSize ? kPageSize : index_t{16};
constexpr bool kScalarPromoteKPageIdx =
kEnablePaging &&
(Problem::kBlockSize >= 8 * ck_tile::get_warp_size()) &&
(KNRepeat >= 2) && (KY0_step_N <= kKPageSizeCap);
constexpr bool kScalarPromoteVPageIdx =
kEnablePaging &&
(Problem::kBlockSize >= 8 * ck_tile::get_warp_size()) &&
(VNRepeat >= 2) && (VY0_step_N <= kVPageSizeCap);
@@ -742,7 +768,22 @@ struct UnifiedAttentionPipeline
auto refresh_k_offsets = [&](index_t k_tile_idx) {
static_for<0, KNRepeat, 1>{}([&](auto i) {
if constexpr(kScalarPromoteKPageIdx)
if constexpr(!kEnablePaging)
{
// Contiguous (THD) layout: K is one flat
// [num_kv_tokens, head_dim] tensor for the current
// kv-head, so the offset is just `token * row_stride`.
// No block_tables fetch, no / % page_size — the entire
// per-tile address-comp chain collapses to a single
// imad on a per-lane token index. Frees up the cycles
// Tier 0/2 were spending paying down the indirection.
const index_t logical_token = split_token_offset +
k_tile_idx * kPageBlockSize + k_thread_n_pos +
static_cast<index_t>(i.value) * KY0_step_N;
k_page_offsets(i) =
static_cast<long_index_t>(logical_token) * k_row_stride;
}
else if constexpr(kScalarPromoteKPageIdx)
{
// Compute the uniform per-`i` base in scalar; force the
// resulting page-table index into an SGPR. Tier 2 reads
@@ -787,7 +828,16 @@ struct UnifiedAttentionPipeline
};
auto refresh_v_offsets = [&](index_t v_tile_idx) {
static_for<0, VNRepeat, 1>{}([&](auto i) {
if constexpr(kScalarPromoteVPageIdx)
if constexpr(!kEnablePaging)
{
// Contiguous (THD) layout: see refresh_k_offsets above.
const index_t logical_token = split_token_offset +
v_tile_idx * kPageBlockSize + v_thread_n_pos +
static_cast<index_t>(i.value) * VY0_step_N;
v_page_offsets(i) =
static_cast<long_index_t>(logical_token) * v_row_stride;
}
else if constexpr(kScalarPromoteVPageIdx)
{
const index_t i_base_token = split_token_offset +
v_tile_idx * kPageBlockSize +