UA: make sink parameters available to the traits

This commit is contained in:
Damien Lejeune
2026-05-28 14:02:26 +00:00
parent e7cb485d98
commit beb3036f8a
3 changed files with 46 additions and 16 deletions

View File

@@ -248,7 +248,8 @@ template <KernelVariant V,
unified_attention_args::data_type_enum DataType,
bool IsMasking,
ck_tile::index_t kPageSize_ = 0,
bool IsLocal_ = false>
bool IsLocal_ = false,
bool kHasSink_ = false>
struct unified_attention_kernel_traits
{
using cfg = variant_config<V>;
@@ -257,6 +258,7 @@ struct unified_attention_kernel_traits
static constexpr auto date_type = DataType;
static constexpr bool is_masking = IsMasking;
static constexpr bool is_local = IsLocal_;
static constexpr bool has_sink = kHasSink_;
static constexpr KernelVariant variant = V;
static constexpr ck_tile::index_t kPageSize = kPageSize_;
@@ -315,9 +317,10 @@ struct unified_attention_kernel_traits
unified_attention_warp_gemm_shape,
true>; // IsVLayoutRowMajor
using unified_attention_traits = TileUnifiedAttentionTraits<true, // kPadSeqLenQ_
false, // kPadHeadDimQ
-1>; // kBlockPerCu
using unified_attention_traits = TileUnifiedAttentionTraits<true, // kPadSeqLenQ_
false, // kPadHeadDimQ
-1, // kBlockPerCu
kHasSink_>; // kHasSink
using unified_attention_mask = GenericAttentionMask<IsMasking, IsLocal_>;
using unified_attention_pipeline_problem =
@@ -441,21 +444,27 @@ std::pair<bool, float> unified_attention_kernel_dispatch(const unified_attention
} // namespace ck_tile
// One-line instantiation per (V, DataType, IsMasking, PageSize, IsLocal)
// combination. Each instance .cpp consists of exactly one of these calls.
// PAGE_SIZE_ = 0 is the legacy runtime-page-size instance (catch-all
// fallback). IS_LOCAL_ = false is the non-SWA path (causal / no-mask);
// IS_LOCAL_ = true compiles the SWA-capable kernel that honours both the
// left and right window bounds inside the mask.
#define INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, \
IS_LOCAL_) \
// One-line instantiation per (V, DataType, IsMasking, PageSize, IsLocal,
// kHasSink) combination. Each instance .cpp consists of exactly one of
// these calls. PAGE_SIZE_ = 0 is the legacy runtime-page-size instance
// (catch-all fallback). IS_LOCAL_ = false is the non-SWA path (causal /
// no-mask); IS_LOCAL_ = true compiles the SWA-capable kernel that honours
// both the left and right window bounds inside the mask. HAS_SINK_ = true
// compiles the sink-aware kernel which seeds the online softmax with a
// per-Q-head virtual key (GPT-OSS / vLLM convention); HAS_SINK_ = false is
// the classic no-sink softmax. No instance file flips HAS_SINK_ yet — the
// trait knob exists so the pipeline can introduce an `if constexpr
// (kHasSink)` branch without changing codegen on the existing instances.
#define INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL_SINK(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_,\
IS_LOCAL_, HAS_SINK_) \
template <> \
std::pair<bool, float> unified_attention_kernel_dispatch< \
unified_attention_kernel_traits<KernelVariant::VARIANT_, \
unified_attention_args::data_type_enum::DTYPE_, \
IS_MASK_, \
PAGE_SIZE_, \
IS_LOCAL_>>(const unified_attention_args& args, \
IS_LOCAL_, \
HAS_SINK_>>(const unified_attention_args& args, \
const stream_config& config) \
{ \
using Traits = unified_attention_kernel_traits< \
@@ -463,14 +472,22 @@ std::pair<bool, float> unified_attention_kernel_dispatch(const unified_attention
unified_attention_args::data_type_enum::DTYPE_, \
IS_MASK_, \
PAGE_SIZE_, \
IS_LOCAL_>; \
IS_LOCAL_, \
HAS_SINK_>; \
return std::make_pair(true, \
unified_attention_kernel_launch<typename Traits::kernel, \
Traits::kUseDecodeGrid>(args, config)); \
}
// Backward-compat wrappers — every existing instance .cpp uses one of these
// and defaults to `IsLocal = false` (the non-SWA path).
// and defaults to `IsLocal = false` (the non-SWA path) and `kHasSink = false`
// (the classic softmax). Each wrapper forwards through the canonical 6-arg
// _PS_LOCAL_SINK form, so adding a sink instance only requires the new macro.
#define INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, \
IS_LOCAL_) \
INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL_SINK( \
VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, IS_LOCAL_, false)
#define INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_) \
INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, false)

View File

@@ -9,11 +9,18 @@ namespace ck_tile {
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDim_ /* paddding for hdim_v */,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */,
bool kHasSink_ = false /* learnable per-Q-head attention sink */>
struct TileUnifiedAttentionTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadHeadDim = kPadHeadDim_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
// When true, the pipeline reads a per-Q-head sink scalar at init
// time and seeds the online softmax with the corresponding virtual
// key (GPT-OSS / vLLM convention). The kernel forwards the pointer
// via `kargs.sink_ptr`. Default `false` reproduces the classic
// no-sink softmax; no instance flips this yet.
static constexpr bool kHasSink = kHasSink_;
};
} // namespace ck_tile

View File

@@ -53,5 +53,11 @@ struct UnifiedAttentionPipelineProblem
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
// Learnable per-Q-head attention sink. False reproduces the classic
// no-sink softmax; true seeds the online softmax denominator with one
// virtual key per Q head (GPT-OSS / vLLM convention). Threaded through
// from `TileUnifiedAttentionTraits::kHasSink` so a single trait knob
// controls both the pipeline init and the kernel-side pointer arithmetic.
static constexpr bool kHasSink = Traits::kHasSink;
};
} // namespace ck_tile