mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
CK-UA: optional paging — contiguous (THD) K/V path, prefill_d128 fp8 -28%
Add a `bool kEnablePaging_` non-type template parameter on
UnifiedAttentionPipeline (default true preserves the paged behaviour).
When false, `refresh_*_offsets` collapses to a single per-row
`logical_token * row_stride` imad — no block_tables fetch, no
/ % page_size arithmetic, no Tier 0 scalar-promote, no Tier 2 LDS-cache
populate. The host selects between paths via a new
`args.kv_contiguous` runtime flag plumbed through dispatch_variant<V>.
Twelve new prefill instances pin EnablePaging=false:
prefill_d{64,128} × {fp16, bf16, fp8} × {mask, nmask}
Decode variants stay on the paged path — callers without a KV cache
don't have decode workloads, and the binary-size cost isn't justified.
Measured impact on the same physical K/V memory (sq=1×4096, causal,
page_size=32 paged baseline, MI355, n=30 iters):
variant sk paged contig Δ
prefill_d64 bf16 4096 0.274 0.227 -17.1 %
prefill_d64 bf16 16384 1.529 1.198 -21.6 %
prefill_d64 bf16 32768 3.218 2.505 -22.1 %
prefill_d64 fp8 4096 0.299 0.235 -21.4 %
prefill_d64 fp8 16384 1.489 1.150 -22.7 %
prefill_d64 fp8 32768 3.054 2.386 -21.9 %
prefill_d128 bf16 4096 0.493 0.397 -19.3 %
prefill_d128 bf16 16384 2.638 2.224 -15.7 %
prefill_d128 bf16 32768 5.731 4.598 -19.8 %
prefill_d128 fp8 4096 0.476 0.341 -28.3 %
prefill_d128 fp8 16384 2.416 1.792 -25.8 %
prefill_d128 fp8 32768 4.973 3.727 -25.0 %
prefill_d128 fp8 at -28 % is the single biggest UA optimisation
measured to date — bigger than Tier 0 (-12 %), Tier 2 (-5 %), and the
Tier-3 d=64 fp8 win (-16 %).
Correctness validated by bit-exact comparison against the paged
instance with page_size=32 and identity block_tables on 48 shape ×
dtype × mask combinations.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d128, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d128, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d128, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d128, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d128, fp8, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d128, fp8, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d64, bf16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d64, bf16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d64, fp16, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d64, fp16, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d64, fp8, true)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(prefill_d64, fp8, false)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -115,12 +115,13 @@ namespace {
|
||||
template <KernelVariant V,
|
||||
unified_attention_args::data_type_enum DType,
|
||||
bool IsMask,
|
||||
index_t PageSize>
|
||||
index_t PageSize,
|
||||
bool EnablePaging = true>
|
||||
std::pair<bool, float> dispatch_one(const unified_attention_args& args,
|
||||
const stream_config& config)
|
||||
{
|
||||
return unified_attention_kernel_dispatch<
|
||||
unified_attention_kernel_traits<V, DType, IsMask, PageSize>>(args, config);
|
||||
unified_attention_kernel_traits<V, DType, IsMask, PageSize, EnablePaging>>(args, config);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -132,7 +133,13 @@ std::pair<bool, float> dispatch_one(const unified_attention_args& args,
|
||||
// to the matching one; anything else (or any non-prefill variant) goes to
|
||||
// the PageSize=0 catch-all that keeps the legacy runtime-page-size code.
|
||||
//
|
||||
// The fast-path payoff is two-fold:
|
||||
// When args.kv_contiguous is set, we instead route to the contiguous-K/V
|
||||
// (THD-layout) instance — a third compile-time variant per (prefill,
|
||||
// dtype, mask) triple that skips the block_tables fetch entirely. Decode
|
||||
// variants don't have a contiguous instance (callers with no KV cache
|
||||
// don't have decode workloads); they fall back to the paged catch-all.
|
||||
//
|
||||
// The paged fast-path payoff is two-fold:
|
||||
// 1. Compile-time `page_size` lets the compiler strength-reduce every
|
||||
// `/ page_size`, `* page_size`, and `% page_size` inside the per-tile
|
||||
// address chain. div-by-32 collapses to `shr 5`, etc.
|
||||
@@ -150,6 +157,11 @@ std::pair<bool, float> dispatch_page_size(const unified_attention_args& args,
|
||||
if constexpr(V == KernelVariant::prefill_d128 ||
|
||||
V == KernelVariant::prefill_d64)
|
||||
{
|
||||
if(args.kv_contiguous)
|
||||
{
|
||||
// Contiguous (THD) instance — PageSize is irrelevant, EnablePaging=false.
|
||||
return dispatch_one<V, DType, IsMask, 0, /*EnablePaging=*/false>(args, config);
|
||||
}
|
||||
switch(args.page_blk_size)
|
||||
{
|
||||
case 16: return dispatch_one<V, DType, IsMask, 16>(args, config);
|
||||
|
||||
@@ -81,6 +81,26 @@ struct unified_attention_args
|
||||
index_t num_seqs; // number of batches for q
|
||||
index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown)
|
||||
|
||||
// Layout selector for K/V:
|
||||
// false (default) : paged KV cache — K/V are
|
||||
// [num_blks, page_blk_size, num_kv_heads, head_size]
|
||||
// and `block_tables_ptr` resolves logical → physical
|
||||
// page. Used by vLLM / SGLang inference servers.
|
||||
// true : contiguous (THD) — K/V are
|
||||
// [num_kv_tokens, num_kv_heads, head_size] for the
|
||||
// current request and `block_tables_ptr` is ignored.
|
||||
// The kernel skips the per-tile page-table fetch and
|
||||
// the / % page_size arithmetic entirely. Used by
|
||||
// pretraining / flash-attention-style callers that
|
||||
// don't have a shared KV cache. When this is true,
|
||||
// `page_blk_size` is ignored; treat `num_blks` as a
|
||||
// virtual page count: the K/V tensor view still has
|
||||
// shape (num_blks * page_blk_size, head_dim), so the
|
||||
// caller can either set num_blks=num_kv_tokens with
|
||||
// page_blk_size=1 or any equivalent factorisation
|
||||
// that yields the right `num_kv_tokens` total.
|
||||
bool kv_contiguous = false;
|
||||
|
||||
// Set to true when the K/V cache is large enough that an int32 byte
|
||||
// offset into it can overflow (i.e. when
|
||||
// num_blocks * page_size * num_kv_heads * head_dim * sizeof(T) > INT32_MAX
|
||||
|
||||
@@ -114,11 +114,18 @@ struct unified_attention_problem_traits<unified_attention_args::data_type_enum::
|
||||
template <KernelVariant V>
|
||||
struct variant_config;
|
||||
|
||||
// Each variant_config exposes `Pipeline<Problem, PageSize>` so the traits
|
||||
// can pin the page size at compile time. PageSize=0 means "runtime page
|
||||
// size" (legacy behaviour); the host dispatcher selects a non-zero value
|
||||
// when it can prove the runtime `args.page_blk_size` matches one of the
|
||||
// instances we compiled.
|
||||
// Each variant_config exposes `Pipeline<Problem, PageSize, EnablePaging>`
|
||||
// so the traits can pin the page size and the paged-vs-contiguous toggle
|
||||
// at compile time:
|
||||
// - PageSize=0 : runtime page size (legacy paged behaviour).
|
||||
// - PageSize>0 : compile-time-pinned paged behaviour.
|
||||
// - EnablePaging=true (default): K/V are a paged KV cache.
|
||||
// - EnablePaging=false : K/V are contiguous (THD) tensors —
|
||||
// refresh_*_offsets skips block_tables
|
||||
// entirely, PageSize is ignored.
|
||||
// The host dispatcher (dispatch_variant in unified_attention.cpp) picks
|
||||
// the matching instance from `args.page_blk_size` / `args.kv_contiguous`
|
||||
// at launch time.
|
||||
template <>
|
||||
struct variant_config<KernelVariant::prefill_d128>
|
||||
{
|
||||
@@ -127,9 +134,11 @@ struct variant_config<KernelVariant::prefill_d128>
|
||||
static constexpr index_t BlockSize = 32;
|
||||
using BlockWarps = sequence<8, 1, 1>;
|
||||
using WarpGemmShape = sequence<32, 32, 16>;
|
||||
template <typename Problem, index_t PageSize = 0>
|
||||
using Pipeline =
|
||||
UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineDefaultPolicy, PageSize>;
|
||||
template <typename Problem, index_t PageSize = 0, bool EnablePaging = true>
|
||||
using Pipeline = UnifiedAttentionPipeline<Problem,
|
||||
UnifiedAttentionPipelineDefaultPolicy,
|
||||
PageSize,
|
||||
EnablePaging>;
|
||||
static constexpr bool kUseDecodeGrid = false;
|
||||
};
|
||||
|
||||
@@ -141,9 +150,11 @@ struct variant_config<KernelVariant::decode_d128_m128>
|
||||
static constexpr index_t BlockSize = 32;
|
||||
using BlockWarps = sequence<4, 1, 1>;
|
||||
using WarpGemmShape = sequence<32, 32, 16>;
|
||||
template <typename Problem, index_t PageSize = 0>
|
||||
using Pipeline =
|
||||
UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineDefaultPolicy, PageSize>;
|
||||
template <typename Problem, index_t PageSize = 0, bool EnablePaging = true>
|
||||
using Pipeline = UnifiedAttentionPipeline<Problem,
|
||||
UnifiedAttentionPipelineDefaultPolicy,
|
||||
PageSize,
|
||||
EnablePaging>;
|
||||
static constexpr bool kUseDecodeGrid = false;
|
||||
};
|
||||
|
||||
@@ -155,9 +166,11 @@ struct variant_config<KernelVariant::decode_d128_m32>
|
||||
static constexpr index_t BlockSize = 32;
|
||||
using BlockWarps = sequence<1, 1, 1>;
|
||||
using WarpGemmShape = sequence<32, 32, 16>;
|
||||
template <typename Problem, index_t PageSize = 0>
|
||||
using Pipeline =
|
||||
UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineTinyDecodePolicy, PageSize>;
|
||||
template <typename Problem, index_t PageSize = 0, bool EnablePaging = true>
|
||||
using Pipeline = UnifiedAttentionPipeline<Problem,
|
||||
UnifiedAttentionPipelineTinyDecodePolicy,
|
||||
PageSize,
|
||||
EnablePaging>;
|
||||
static constexpr bool kUseDecodeGrid = true;
|
||||
};
|
||||
|
||||
@@ -169,9 +182,11 @@ struct variant_config<KernelVariant::decode_d128_m16>
|
||||
static constexpr index_t BlockSize = 32;
|
||||
using BlockWarps = sequence<1, 1, 1>;
|
||||
using WarpGemmShape = sequence<16, 16, 32>;
|
||||
template <typename Problem, index_t PageSize = 0>
|
||||
using Pipeline =
|
||||
UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineTinyDecodePolicy, PageSize>;
|
||||
template <typename Problem, index_t PageSize = 0, bool EnablePaging = true>
|
||||
using Pipeline = UnifiedAttentionPipeline<Problem,
|
||||
UnifiedAttentionPipelineTinyDecodePolicy,
|
||||
PageSize,
|
||||
EnablePaging>;
|
||||
static constexpr bool kUseDecodeGrid = true;
|
||||
};
|
||||
|
||||
@@ -183,9 +198,11 @@ struct variant_config<KernelVariant::prefill_d64>
|
||||
static constexpr index_t BlockSize = 64;
|
||||
using BlockWarps = sequence<8, 1, 1>;
|
||||
using WarpGemmShape = sequence<32, 32, 16>;
|
||||
template <typename Problem, index_t PageSize = 0>
|
||||
using Pipeline =
|
||||
UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineDefaultPolicy, PageSize>;
|
||||
template <typename Problem, index_t PageSize = 0, bool EnablePaging = true>
|
||||
using Pipeline = UnifiedAttentionPipeline<Problem,
|
||||
UnifiedAttentionPipelineDefaultPolicy,
|
||||
PageSize,
|
||||
EnablePaging>;
|
||||
static constexpr bool kUseDecodeGrid = false;
|
||||
};
|
||||
|
||||
@@ -197,9 +214,11 @@ struct variant_config<KernelVariant::decode_d64_m128>
|
||||
static constexpr index_t BlockSize = 64;
|
||||
using BlockWarps = sequence<4, 1, 1>;
|
||||
using WarpGemmShape = sequence<32, 32, 16>;
|
||||
template <typename Problem, index_t PageSize = 0>
|
||||
using Pipeline =
|
||||
UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineDefaultPolicy, PageSize>;
|
||||
template <typename Problem, index_t PageSize = 0, bool EnablePaging = true>
|
||||
using Pipeline = UnifiedAttentionPipeline<Problem,
|
||||
UnifiedAttentionPipelineDefaultPolicy,
|
||||
PageSize,
|
||||
EnablePaging>;
|
||||
static constexpr bool kUseDecodeGrid = false;
|
||||
};
|
||||
|
||||
@@ -211,9 +230,11 @@ struct variant_config<KernelVariant::decode_d64_m64>
|
||||
static constexpr index_t BlockSize = 64;
|
||||
using BlockWarps = sequence<2, 1, 1>;
|
||||
using WarpGemmShape = sequence<32, 32, 16>;
|
||||
template <typename Problem, index_t PageSize = 0>
|
||||
using Pipeline =
|
||||
UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineDecodePolicy, PageSize>;
|
||||
template <typename Problem, index_t PageSize = 0, bool EnablePaging = true>
|
||||
using Pipeline = UnifiedAttentionPipeline<Problem,
|
||||
UnifiedAttentionPipelineDecodePolicy,
|
||||
PageSize,
|
||||
EnablePaging>;
|
||||
static constexpr bool kUseDecodeGrid = true;
|
||||
};
|
||||
|
||||
@@ -225,9 +246,11 @@ struct variant_config<KernelVariant::decode_d64_m16>
|
||||
static constexpr index_t BlockSize = 64;
|
||||
using BlockWarps = sequence<1, 1, 1>;
|
||||
using WarpGemmShape = sequence<16, 16, 32>;
|
||||
template <typename Problem, index_t PageSize = 0>
|
||||
using Pipeline =
|
||||
UnifiedAttentionPipeline<Problem, UnifiedAttentionPipelineTinyDecodePolicy, PageSize>;
|
||||
template <typename Problem, index_t PageSize = 0, bool EnablePaging = true>
|
||||
using Pipeline = UnifiedAttentionPipeline<Problem,
|
||||
UnifiedAttentionPipelineTinyDecodePolicy,
|
||||
PageSize,
|
||||
EnablePaging>;
|
||||
static constexpr bool kUseDecodeGrid = true;
|
||||
};
|
||||
|
||||
@@ -237,25 +260,35 @@ struct variant_config<KernelVariant::decode_d64_m16>
|
||||
// Single templated trait. Pulls per-variant knobs from variant_config<V> and
|
||||
// per-dtype element types from unified_attention_problem_traits<DataType>.
|
||||
// =============================================================================
|
||||
// kPageSize: optional compile-time pin of the runtime `page_size`. Default
|
||||
// 0 keeps the legacy runtime-page-size behaviour; a non-zero value lets the
|
||||
// pipeline strength-reduce the per-tile arithmetic *and* widen the Tier 0 /
|
||||
// Tier 2 gate from the conservative `KY0_step_N <= 16` hedge to the real
|
||||
// `KY0_step_N <= kPageSize` condition. The host dispatcher (dispatch_variant
|
||||
// in unified_attention.cpp) picks the matching instance at launch time.
|
||||
// kPageSize : optional compile-time pin of the runtime `page_size`.
|
||||
// Default 0 keeps the legacy runtime-page-size behaviour;
|
||||
// a non-zero value lets the pipeline strength-reduce the
|
||||
// per-tile arithmetic *and* widen the Tier 0 / Tier 2 gate
|
||||
// from the conservative `KY0_step_N <= 16` hedge to the
|
||||
// real `KY0_step_N <= kPageSize` condition.
|
||||
// kEnablePaging : true (default) — paged KV cache (vLLM/SGLang). false —
|
||||
// contiguous (THD) K/V tensors that skip the entire
|
||||
// block_tables fetch chain. Mutually exclusive use cases;
|
||||
// only one path is emitted per instance. PageSize is
|
||||
// ignored when kEnablePaging == false.
|
||||
// The host dispatcher (dispatch_variant in unified_attention.cpp) picks
|
||||
// the matching instance from `args.page_blk_size` / `args.kv_contiguous`
|
||||
// at launch time.
|
||||
template <KernelVariant V,
|
||||
unified_attention_args::data_type_enum DataType,
|
||||
bool IsMasking,
|
||||
ck_tile::index_t kPageSize_ = 0>
|
||||
ck_tile::index_t kPageSize_ = 0,
|
||||
bool kEnablePaging_ = true>
|
||||
struct unified_attention_kernel_traits
|
||||
{
|
||||
using cfg = variant_config<V>;
|
||||
using dt = unified_attention_problem_traits<DataType>;
|
||||
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
static constexpr KernelVariant variant = V;
|
||||
static constexpr ck_tile::index_t kPageSize = kPageSize_;
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
static constexpr KernelVariant variant = V;
|
||||
static constexpr ck_tile::index_t kPageSize = kPageSize_;
|
||||
static constexpr bool kEnablePaging = kEnablePaging_;
|
||||
|
||||
static constexpr index_t HEAD_SIZE = cfg::HeadSize;
|
||||
static constexpr index_t kBlockM = cfg::BlockM;
|
||||
@@ -301,7 +334,9 @@ struct unified_attention_kernel_traits
|
||||
unified_attention_traits>;
|
||||
|
||||
using unified_attention_pipeline =
|
||||
typename cfg::template Pipeline<unified_attention_pipeline_problem, kPageSize_>;
|
||||
typename cfg::template Pipeline<unified_attention_pipeline_problem,
|
||||
kPageSize_,
|
||||
kEnablePaging_>;
|
||||
|
||||
using epilogue =
|
||||
Default2DEpilogue<Default2DEpilogueProblem<typename dt::acc_dtype,
|
||||
@@ -402,31 +437,47 @@ std::pair<bool, float> unified_attention_kernel_dispatch(const unified_attention
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
// One-line instantiation per (V, DataType, IsMasking, PageSize) combination.
|
||||
// Each instance .cpp consists of exactly one of these calls. PAGE_SIZE_ = 0
|
||||
// is the legacy runtime-page-size instance (catch-all fallback). Non-zero
|
||||
// values pin the runtime `page_size` argument to that literal — see the
|
||||
// dispatch_variant<V> switch in unified_attention.cpp for routing.
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_) \
|
||||
template <> \
|
||||
std::pair<bool, float> unified_attention_kernel_dispatch< \
|
||||
unified_attention_kernel_traits<KernelVariant::VARIANT_, \
|
||||
unified_attention_args::data_type_enum::DTYPE_, \
|
||||
IS_MASK_, \
|
||||
PAGE_SIZE_>>(const unified_attention_args& args, \
|
||||
const stream_config& config) \
|
||||
{ \
|
||||
using Traits = unified_attention_kernel_traits< \
|
||||
KernelVariant::VARIANT_, \
|
||||
unified_attention_args::data_type_enum::DTYPE_, \
|
||||
IS_MASK_, \
|
||||
PAGE_SIZE_>; \
|
||||
return std::make_pair(true, \
|
||||
unified_attention_kernel_launch<typename Traits::kernel, \
|
||||
Traits::kUseDecodeGrid>(args, config)); \
|
||||
// One-line instantiation per (V, DataType, IsMasking, PageSize, EnablePaging)
|
||||
// combination. Each instance .cpp consists of exactly one of these calls.
|
||||
// - PAGE_SIZE_ = 0 + ENABLE_PAGING_ = true : legacy runtime-page-size
|
||||
// instance (catch-all paged
|
||||
// fallback).
|
||||
// - PAGE_SIZE_ > 0 + ENABLE_PAGING_ = true : constexpr-pinned paged.
|
||||
// - PAGE_SIZE_ = 0 + ENABLE_PAGING_ = false : contiguous (THD) instance.
|
||||
// See dispatch_variant<V> in unified_attention.cpp for runtime routing.
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH_PS_NP( \
|
||||
VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, ENABLE_PAGING_) \
|
||||
template <> \
|
||||
std::pair<bool, float> unified_attention_kernel_dispatch< \
|
||||
unified_attention_kernel_traits<KernelVariant::VARIANT_, \
|
||||
unified_attention_args::data_type_enum::DTYPE_, \
|
||||
IS_MASK_, \
|
||||
PAGE_SIZE_, \
|
||||
ENABLE_PAGING_>>(const unified_attention_args& args, \
|
||||
const stream_config& config) \
|
||||
{ \
|
||||
using Traits = unified_attention_kernel_traits< \
|
||||
KernelVariant::VARIANT_, \
|
||||
unified_attention_args::data_type_enum::DTYPE_, \
|
||||
IS_MASK_, \
|
||||
PAGE_SIZE_, \
|
||||
ENABLE_PAGING_>; \
|
||||
return std::make_pair(true, \
|
||||
unified_attention_kernel_launch<typename Traits::kernel, \
|
||||
Traits::kUseDecodeGrid>(args, config)); \
|
||||
}
|
||||
|
||||
// Backward-compat shorthand: PageSize-only specialization on the paged path.
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_) \
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_PS_NP(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, true)
|
||||
|
||||
// Backward-compat shorthand for the existing one-liners — the default
|
||||
// `PageSize = 0` instance is the catch-all runtime-page-size kernel.
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH(VARIANT_, DTYPE_, IS_MASK_) \
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, 0)
|
||||
// PageSize = 0, paged, catch-all runtime-page-size kernel.
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH(VARIANT_, DTYPE_, IS_MASK_) \
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_PS_NP(VARIANT_, DTYPE_, IS_MASK_, 0, true)
|
||||
|
||||
// Contiguous (THD) K/V instance. PageSize is forced to 0 (ignored anyway
|
||||
// when EnablePaging = false). One instance per (variant, dtype, mask)
|
||||
// triple.
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH_NOPAGE(VARIANT_, DTYPE_, IS_MASK_) \
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_PS_NP(VARIANT_, DTYPE_, IS_MASK_, 0, false)
|
||||
|
||||
@@ -48,22 +48,38 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// kPageSize_ : non-type template parameter that pins the runtime
|
||||
// `page_size` argument to a compile-time constant when > 0. The host
|
||||
// dispatcher selects an instance whose kPageSize_ matches `args.page_blk_size`
|
||||
// and routes execution there; instances compiled with kPageSize_ == 0 keep
|
||||
// the legacy runtime-page-size path and serve as the catch-all fallback for
|
||||
// uncommon page sizes. Having the value at compile time:
|
||||
// 1. lets the compiler strength-reduce every `/ page_size`, `* page_size`,
|
||||
// `% page_size` into shift / multiply-by-magic-constant on the literal
|
||||
// (e.g. div-by-32 → shr 5);
|
||||
// 2. lets the Tier 0 / Tier 2 gate use the real `KY0_step_N <= kPageSize`
|
||||
// condition instead of the conservative `KY0_step_N <= 16` hedge, so
|
||||
// prefill_d128 bf16, prefill_d64 bf16, and prefill_d64 fp8 also gain
|
||||
// the scalar-promote + LDS-cache fast path on their natural page sizes.
|
||||
// kPageSize_ : non-type template parameter that pins the runtime
|
||||
// `page_size` argument to a compile-time constant when > 0.
|
||||
// The host dispatcher selects an instance whose kPageSize_
|
||||
// matches `args.page_blk_size` and routes execution there;
|
||||
// instances compiled with kPageSize_ == 0 keep the legacy
|
||||
// runtime-page-size path and serve as the catch-all
|
||||
// fallback for uncommon page sizes. Having the value at
|
||||
// compile time:
|
||||
// 1. lets the compiler strength-reduce every / * %
|
||||
// page_size into shift / multiply-by-magic;
|
||||
// 2. lets the Tier 0 / Tier 2 gate use the real
|
||||
// `KY0_step_N <= kPageSize` condition instead of
|
||||
// the conservative `<= 16` hedge.
|
||||
// Only meaningful when kEnablePaging_ == true; ignored
|
||||
// (and conventionally set to 0) when paging is disabled.
|
||||
// kEnablePaging_ : true (default) is the paged-KV-cache layout used by
|
||||
// vLLM/SGLang inference servers — K/V are stored as
|
||||
// [num_blocks, page_size, num_heads, head_dim] and a
|
||||
// `block_tables` index resolves logical → physical pages.
|
||||
// false is the contiguous "THD" layout used by
|
||||
// pretraining / flash-attention-style callers —
|
||||
// K/V are stored as [num_kv_tokens, num_heads, head_dim]
|
||||
// and `refresh_*_offsets` just multiplies the logical
|
||||
// token by the row stride. The contiguous path skips
|
||||
// the entire block_tables fetch, the per-tile / %
|
||||
// page_size arithmetic, and the Tier 0 / Tier 2
|
||||
// LDS-cache machinery — `block_tables_ptr` and
|
||||
// `page_size_runtime` are ignored in that mode.
|
||||
template <typename Problem_,
|
||||
typename Policy_ = UnifiedAttentionPipelineDefaultPolicy,
|
||||
ck_tile::index_t kPageSize_ = 0>
|
||||
ck_tile::index_t kPageSize_ = 0,
|
||||
bool kEnablePaging_ = true>
|
||||
struct UnifiedAttentionPipeline
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
@@ -72,6 +88,8 @@ struct UnifiedAttentionPipeline
|
||||
// Compile-time page size (0 = runtime). See class-level comment above.
|
||||
static constexpr ck_tile::index_t kPageSize = kPageSize_;
|
||||
static constexpr bool kHasCePageSize = (kPageSize_ > 0);
|
||||
// Paged KV cache vs contiguous THD layout. See class-level comment above.
|
||||
static constexpr bool kEnablePaging = kEnablePaging_;
|
||||
using QDataType = ck_tile::remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = ck_tile::remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = ck_tile::remove_cvref_t<typename Problem::VDataType>;
|
||||
@@ -147,6 +165,10 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetPageTableLdsBytes()
|
||||
{
|
||||
// Contiguous (THD) layout doesn't go through block_tables at all,
|
||||
// so the LDS cache is dead — skip the allocation entirely.
|
||||
if constexpr(!kEnablePaging) return 0;
|
||||
|
||||
// Allocate the cache only for the kernel instances where Tier 0's
|
||||
// constexpr gate fires (otherwise the lambdas wouldn't read it and
|
||||
// the LDS would sit idle, hurting occupancy for nothing). Mirror the
|
||||
@@ -677,12 +699,16 @@ struct UnifiedAttentionPipeline
|
||||
// (newly ON — biggest win)
|
||||
// prefill_d64 bf16 KY0_step_N=64 @ ps=64 74.4 → 73.4 ms (-1.3%)
|
||||
// (newly ON; small win)
|
||||
// Tier 0 + Tier 2 only make sense on the paged path — the
|
||||
// contiguous (THD) layout has no block_tables to scalar-promote.
|
||||
constexpr index_t kKPageSizeCap = kHasCePageSize ? kPageSize : index_t{16};
|
||||
constexpr index_t kVPageSizeCap = kHasCePageSize ? kPageSize : index_t{16};
|
||||
constexpr bool kScalarPromoteKPageIdx =
|
||||
kEnablePaging &&
|
||||
(Problem::kBlockSize >= 8 * ck_tile::get_warp_size()) &&
|
||||
(KNRepeat >= 2) && (KY0_step_N <= kKPageSizeCap);
|
||||
constexpr bool kScalarPromoteVPageIdx =
|
||||
kEnablePaging &&
|
||||
(Problem::kBlockSize >= 8 * ck_tile::get_warp_size()) &&
|
||||
(VNRepeat >= 2) && (VY0_step_N <= kVPageSizeCap);
|
||||
|
||||
@@ -742,7 +768,22 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
auto refresh_k_offsets = [&](index_t k_tile_idx) {
|
||||
static_for<0, KNRepeat, 1>{}([&](auto i) {
|
||||
if constexpr(kScalarPromoteKPageIdx)
|
||||
if constexpr(!kEnablePaging)
|
||||
{
|
||||
// Contiguous (THD) layout: K is one flat
|
||||
// [num_kv_tokens, head_dim] tensor for the current
|
||||
// kv-head, so the offset is just `token * row_stride`.
|
||||
// No block_tables fetch, no / % page_size — the entire
|
||||
// per-tile address-comp chain collapses to a single
|
||||
// imad on a per-lane token index. Frees up the cycles
|
||||
// Tier 0/2 were spending paying down the indirection.
|
||||
const index_t logical_token = split_token_offset +
|
||||
k_tile_idx * kPageBlockSize + k_thread_n_pos +
|
||||
static_cast<index_t>(i.value) * KY0_step_N;
|
||||
k_page_offsets(i) =
|
||||
static_cast<long_index_t>(logical_token) * k_row_stride;
|
||||
}
|
||||
else if constexpr(kScalarPromoteKPageIdx)
|
||||
{
|
||||
// Compute the uniform per-`i` base in scalar; force the
|
||||
// resulting page-table index into an SGPR. Tier 2 reads
|
||||
@@ -787,7 +828,16 @@ struct UnifiedAttentionPipeline
|
||||
};
|
||||
auto refresh_v_offsets = [&](index_t v_tile_idx) {
|
||||
static_for<0, VNRepeat, 1>{}([&](auto i) {
|
||||
if constexpr(kScalarPromoteVPageIdx)
|
||||
if constexpr(!kEnablePaging)
|
||||
{
|
||||
// Contiguous (THD) layout: see refresh_k_offsets above.
|
||||
const index_t logical_token = split_token_offset +
|
||||
v_tile_idx * kPageBlockSize + v_thread_n_pos +
|
||||
static_cast<index_t>(i.value) * VY0_step_N;
|
||||
v_page_offsets(i) =
|
||||
static_cast<long_index_t>(logical_token) * v_row_stride;
|
||||
}
|
||||
else if constexpr(kScalarPromoteVPageIdx)
|
||||
{
|
||||
const index_t i_base_token = split_token_offset +
|
||||
v_tile_idx * kPageBlockSize +
|
||||
|
||||
Reference in New Issue
Block a user