From beb3036f8aff43227941ba066d4f5baafce1ca41 Mon Sep 17 00:00:00 2001 From: Damien Lejeune Date: Thu, 28 May 2026 14:02:26 +0000 Subject: [PATCH] UA: make sink parameters available to the traits --- .../unified_attention_impl.hpp | 47 +++++++++++++------ .../tile_unified_attention_traits.hpp | 9 +++- .../unified_attention_pipeline_problem.hpp | 6 +++ 3 files changed, 46 insertions(+), 16 deletions(-) diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index 402eea861d..53abdc4a81 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -248,7 +248,8 @@ template + bool IsLocal_ = false, + bool kHasSink_ = false> struct unified_attention_kernel_traits { using cfg = variant_config; @@ -257,6 +258,7 @@ struct unified_attention_kernel_traits static constexpr auto date_type = DataType; static constexpr bool is_masking = IsMasking; static constexpr bool is_local = IsLocal_; + static constexpr bool has_sink = kHasSink_; static constexpr KernelVariant variant = V; static constexpr ck_tile::index_t kPageSize = kPageSize_; @@ -315,9 +317,10 @@ struct unified_attention_kernel_traits unified_attention_warp_gemm_shape, true>; // IsVLayoutRowMajor - using unified_attention_traits = TileUnifiedAttentionTraits; // kBlockPerCu + using unified_attention_traits = TileUnifiedAttentionTraits; // kHasSink using unified_attention_mask = GenericAttentionMask; using unified_attention_pipeline_problem = @@ -441,21 +444,27 @@ std::pair unified_attention_kernel_dispatch(const unified_attention } // namespace ck_tile -// One-line instantiation per (V, DataType, IsMasking, PageSize, IsLocal) -// combination. Each instance .cpp consists of exactly one of these calls. -// PAGE_SIZE_ = 0 is the legacy runtime-page-size instance (catch-all -// fallback). IS_LOCAL_ = false is the non-SWA path (causal / no-mask); -// IS_LOCAL_ = true compiles the SWA-capable kernel that honours both the -// left and right window bounds inside the mask. -#define INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, \ - IS_LOCAL_) \ +// One-line instantiation per (V, DataType, IsMasking, PageSize, IsLocal, +// kHasSink) combination. Each instance .cpp consists of exactly one of +// these calls. PAGE_SIZE_ = 0 is the legacy runtime-page-size instance +// (catch-all fallback). IS_LOCAL_ = false is the non-SWA path (causal / +// no-mask); IS_LOCAL_ = true compiles the SWA-capable kernel that honours +// both the left and right window bounds inside the mask. HAS_SINK_ = true +// compiles the sink-aware kernel which seeds the online softmax with a +// per-Q-head virtual key (GPT-OSS / vLLM convention); HAS_SINK_ = false is +// the classic no-sink softmax. No instance file flips HAS_SINK_ yet — the +// trait knob exists so the pipeline can introduce an `if constexpr +// (kHasSink)` branch without changing codegen on the existing instances. +#define INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL_SINK(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_,\ + IS_LOCAL_, HAS_SINK_) \ template <> \ std::pair unified_attention_kernel_dispatch< \ unified_attention_kernel_traits>(const unified_attention_args& args, \ + IS_LOCAL_, \ + HAS_SINK_>>(const unified_attention_args& args, \ const stream_config& config) \ { \ using Traits = unified_attention_kernel_traits< \ @@ -463,14 +472,22 @@ std::pair unified_attention_kernel_dispatch(const unified_attention unified_attention_args::data_type_enum::DTYPE_, \ IS_MASK_, \ PAGE_SIZE_, \ - IS_LOCAL_>; \ + IS_LOCAL_, \ + HAS_SINK_>; \ return std::make_pair(true, \ unified_attention_kernel_launch(args, config)); \ } // Backward-compat wrappers — every existing instance .cpp uses one of these -// and defaults to `IsLocal = false` (the non-SWA path). +// and defaults to `IsLocal = false` (the non-SWA path) and `kHasSink = false` +// (the classic softmax). Each wrapper forwards through the canonical 6-arg +// _PS_LOCAL_SINK form, so adding a sink instance only requires the new macro. +#define INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, \ + IS_LOCAL_) \ + INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL_SINK( \ + VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, IS_LOCAL_, false) + #define INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_) \ INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, false) diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp index 8b01a5722d..198949c21b 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp @@ -9,11 +9,18 @@ namespace ck_tile { template + index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */, + bool kHasSink_ = false /* learnable per-Q-head attention sink */> struct TileUnifiedAttentionTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; static constexpr bool kPadHeadDim = kPadHeadDim_; static constexpr index_t kBlockPerCu = kBlockPerCu_; + // When true, the pipeline reads a per-Q-head sink scalar at init + // time and seeds the online softmax with the corresponding virtual + // key (GPT-OSS / vLLM convention). The kernel forwards the pointer + // via `kargs.sink_ptr`. Default `false` reproduces the classic + // no-sink softmax; no instance flips this yet. + static constexpr bool kHasSink = kHasSink_; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp index 2b655c74b3..ca44902a4c 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp @@ -53,5 +53,11 @@ struct UnifiedAttentionPipelineProblem static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; + // Learnable per-Q-head attention sink. False reproduces the classic + // no-sink softmax; true seeds the online softmax denominator with one + // virtual key per Q head (GPT-OSS / vLLM convention). Threaded through + // from `TileUnifiedAttentionTraits::kHasSink` so a single trait knob + // controls both the pipeline init and the kernel-side pointer arithmetic. + static constexpr bool kHasSink = Traits::kHasSink; }; } // namespace ck_tile