From 086512d842532704a3824d3a7305d7c4c190734d Mon Sep 17 00:00:00 2001 From: Damien Lejeune Date: Wed, 27 May 2026 13:02:39 +0000 Subject: [PATCH] Add IsLocal argument to trait --- .../unified_attention_impl.hpp | 35 ++++++++++++------- .../kernel/unified_attention_kernel.hpp | 12 +++++-- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index 7088423cc3..a4088fed91 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -247,7 +247,8 @@ struct variant_config template + ck_tile::index_t kPageSize_ = 0, + bool IsLocal_ = false> struct unified_attention_kernel_traits { using cfg = variant_config; @@ -255,6 +256,7 @@ struct unified_attention_kernel_traits static constexpr auto date_type = DataType; static constexpr bool is_masking = IsMasking; + static constexpr bool is_local = IsLocal_; static constexpr KernelVariant variant = V; static constexpr ck_tile::index_t kPageSize = kPageSize_; @@ -316,7 +318,7 @@ struct unified_attention_kernel_traits using unified_attention_traits = TileUnifiedAttentionTraits; // kBlockPerCu - using unified_attention_mask = GenericAttentionMask; + using unified_attention_mask = GenericAttentionMask; using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem unified_attention_kernel_dispatch(const unified_attention } // namespace ck_tile -// One-line instantiation per (V, DataType, IsMasking, PageSize) combination. -// Each instance .cpp consists of exactly one of these calls. PAGE_SIZE_ = 0 -// is the legacy runtime-page-size instance (catch-all fallback). Non-zero -// values pin the runtime `page_size` argument to that literal — see the -// dispatch_variant switch in unified_attention.cpp for routing. -#define INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_) \ +// One-line instantiation per (V, DataType, IsMasking, PageSize, IsLocal) +// combination. Each instance .cpp consists of exactly one of these calls. +// PAGE_SIZE_ = 0 is the legacy runtime-page-size instance (catch-all +// fallback). IS_LOCAL_ = false is the non-SWA path (causal / no-mask); +// IS_LOCAL_ = true compiles the SWA-capable kernel that honours both the +// left and right window bounds inside the mask (used from Phase 3 on). +#define INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, \ + IS_LOCAL_) \ template <> \ std::pair unified_attention_kernel_dispatch< \ unified_attention_kernel_traits>(const unified_attention_args& args, \ - const stream_config& config) \ + PAGE_SIZE_, \ + IS_LOCAL_>>(const unified_attention_args& args, \ + const stream_config& config) \ { \ using Traits = unified_attention_kernel_traits< \ KernelVariant::VARIANT_, \ unified_attention_args::data_type_enum::DTYPE_, \ IS_MASK_, \ - PAGE_SIZE_>; \ + PAGE_SIZE_, \ + IS_LOCAL_>; \ return std::make_pair(true, \ unified_attention_kernel_launch(args, config)); \ } -// Backward-compat shorthand for the existing one-liners — the default -// `PageSize = 0` instance is the catch-all runtime-page-size kernel. +// Backward-compat wrappers — every existing instance .cpp uses one of these +// and defaults to `IsLocal = false` (the non-SWA path). +#define INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_) \ + INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, false) + #define INST_UNIFIED_ATTENTION_DISPATCH(VARIANT_, DTYPE_, IS_MASK_) \ INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, 0) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index e35d5eb288..2bda5c51fb 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -539,14 +539,20 @@ struct UnifiedAttentionKernel FmhaMask mask = [&]() { if constexpr(kHasMask) + // Window args default to (-1, -1, false) on the host side, which + // make_generic_attention_mask_from_lr_window collapses to the + // previous hard-coded bottom-right causal layout (the `< 0` + // branches inside the helper). Once IsLocal=true instances are + // wired up in Phase 3 the same call site honours real SWA + // bounds. return ck_tile::make_generic_attention_mask_from_lr_window( - -1, - 0, + kargs.window_size_left, + kargs.window_size_right, cur_batch_query_len, // y_total seq_len, // x_total num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv // times along x dim of the tile - false); + kargs.is_top_left); else return FmhaMask{cur_batch_query_len, seq_len}; }();