mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
UA: make sink parameters available to the traits
This commit is contained in:
@@ -248,7 +248,8 @@ template <KernelVariant V,
|
||||
unified_attention_args::data_type_enum DataType,
|
||||
bool IsMasking,
|
||||
ck_tile::index_t kPageSize_ = 0,
|
||||
bool IsLocal_ = false>
|
||||
bool IsLocal_ = false,
|
||||
bool kHasSink_ = false>
|
||||
struct unified_attention_kernel_traits
|
||||
{
|
||||
using cfg = variant_config<V>;
|
||||
@@ -257,6 +258,7 @@ struct unified_attention_kernel_traits
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
static constexpr bool is_local = IsLocal_;
|
||||
static constexpr bool has_sink = kHasSink_;
|
||||
static constexpr KernelVariant variant = V;
|
||||
static constexpr ck_tile::index_t kPageSize = kPageSize_;
|
||||
|
||||
@@ -315,9 +317,10 @@ struct unified_attention_kernel_traits
|
||||
unified_attention_warp_gemm_shape,
|
||||
true>; // IsVLayoutRowMajor
|
||||
|
||||
using unified_attention_traits = TileUnifiedAttentionTraits<true, // kPadSeqLenQ_
|
||||
false, // kPadHeadDimQ
|
||||
-1>; // kBlockPerCu
|
||||
using unified_attention_traits = TileUnifiedAttentionTraits<true, // kPadSeqLenQ_
|
||||
false, // kPadHeadDimQ
|
||||
-1, // kBlockPerCu
|
||||
kHasSink_>; // kHasSink
|
||||
using unified_attention_mask = GenericAttentionMask<IsMasking, IsLocal_>;
|
||||
|
||||
using unified_attention_pipeline_problem =
|
||||
@@ -441,21 +444,27 @@ std::pair<bool, float> unified_attention_kernel_dispatch(const unified_attention
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
// One-line instantiation per (V, DataType, IsMasking, PageSize, IsLocal)
|
||||
// combination. Each instance .cpp consists of exactly one of these calls.
|
||||
// PAGE_SIZE_ = 0 is the legacy runtime-page-size instance (catch-all
|
||||
// fallback). IS_LOCAL_ = false is the non-SWA path (causal / no-mask);
|
||||
// IS_LOCAL_ = true compiles the SWA-capable kernel that honours both the
|
||||
// left and right window bounds inside the mask.
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, \
|
||||
IS_LOCAL_) \
|
||||
// One-line instantiation per (V, DataType, IsMasking, PageSize, IsLocal,
|
||||
// kHasSink) combination. Each instance .cpp consists of exactly one of
|
||||
// these calls. PAGE_SIZE_ = 0 is the legacy runtime-page-size instance
|
||||
// (catch-all fallback). IS_LOCAL_ = false is the non-SWA path (causal /
|
||||
// no-mask); IS_LOCAL_ = true compiles the SWA-capable kernel that honours
|
||||
// both the left and right window bounds inside the mask. HAS_SINK_ = true
|
||||
// compiles the sink-aware kernel which seeds the online softmax with a
|
||||
// per-Q-head virtual key (GPT-OSS / vLLM convention); HAS_SINK_ = false is
|
||||
// the classic no-sink softmax. No instance file flips HAS_SINK_ yet — the
|
||||
// trait knob exists so the pipeline can introduce an `if constexpr
|
||||
// (kHasSink)` branch without changing codegen on the existing instances.
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL_SINK(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_,\
|
||||
IS_LOCAL_, HAS_SINK_) \
|
||||
template <> \
|
||||
std::pair<bool, float> unified_attention_kernel_dispatch< \
|
||||
unified_attention_kernel_traits<KernelVariant::VARIANT_, \
|
||||
unified_attention_args::data_type_enum::DTYPE_, \
|
||||
IS_MASK_, \
|
||||
PAGE_SIZE_, \
|
||||
IS_LOCAL_>>(const unified_attention_args& args, \
|
||||
IS_LOCAL_, \
|
||||
HAS_SINK_>>(const unified_attention_args& args, \
|
||||
const stream_config& config) \
|
||||
{ \
|
||||
using Traits = unified_attention_kernel_traits< \
|
||||
@@ -463,14 +472,22 @@ std::pair<bool, float> unified_attention_kernel_dispatch(const unified_attention
|
||||
unified_attention_args::data_type_enum::DTYPE_, \
|
||||
IS_MASK_, \
|
||||
PAGE_SIZE_, \
|
||||
IS_LOCAL_>; \
|
||||
IS_LOCAL_, \
|
||||
HAS_SINK_>; \
|
||||
return std::make_pair(true, \
|
||||
unified_attention_kernel_launch<typename Traits::kernel, \
|
||||
Traits::kUseDecodeGrid>(args, config)); \
|
||||
}
|
||||
|
||||
// Backward-compat wrappers — every existing instance .cpp uses one of these
|
||||
// and defaults to `IsLocal = false` (the non-SWA path).
|
||||
// and defaults to `IsLocal = false` (the non-SWA path) and `kHasSink = false`
|
||||
// (the classic softmax). Each wrapper forwards through the canonical 6-arg
|
||||
// _PS_LOCAL_SINK form, so adding a sink instance only requires the new macro.
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, \
|
||||
IS_LOCAL_) \
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL_SINK( \
|
||||
VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, IS_LOCAL_, false)
|
||||
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_) \
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, false)
|
||||
|
||||
|
||||
@@ -9,11 +9,18 @@ namespace ck_tile {
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadHeadDim_ /* paddding for hdim_v */,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */,
|
||||
bool kHasSink_ = false /* learnable per-Q-head attention sink */>
|
||||
struct TileUnifiedAttentionTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadHeadDim = kPadHeadDim_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
// When true, the pipeline reads a per-Q-head sink scalar at init
|
||||
// time and seeds the online softmax with the corresponding virtual
|
||||
// key (GPT-OSS / vLLM convention). The kernel forwards the pointer
|
||||
// via `kargs.sink_ptr`. Default `false` reproduces the classic
|
||||
// no-sink softmax; no instance flips this yet.
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -53,5 +53,11 @@ struct UnifiedAttentionPipelineProblem
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
// Learnable per-Q-head attention sink. False reproduces the classic
|
||||
// no-sink softmax; true seeds the online softmax denominator with one
|
||||
// virtual key per Q head (GPT-OSS / vLLM convention). Threaded through
|
||||
// from `TileUnifiedAttentionTraits::kHasSink` so a single trait knob
|
||||
// controls both the pipeline init and the kernel-side pointer arithmetic.
|
||||
static constexpr bool kHasSink = Traits::kHasSink;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user