Add IsLocal argument to trait

This commit is contained in:
Damien Lejeune
2026-05-27 13:02:39 +00:00
parent cea9adab59
commit 086512d842
2 changed files with 31 additions and 16 deletions

View File

@@ -247,7 +247,8 @@ struct variant_config<KernelVariant::decode_d64_m16>
template <KernelVariant V,
unified_attention_args::data_type_enum DataType,
bool IsMasking,
ck_tile::index_t kPageSize_ = 0>
ck_tile::index_t kPageSize_ = 0,
bool IsLocal_ = false>
struct unified_attention_kernel_traits
{
using cfg = variant_config<V>;
@@ -255,6 +256,7 @@ struct unified_attention_kernel_traits
static constexpr auto date_type = DataType;
static constexpr bool is_masking = IsMasking;
static constexpr bool is_local = IsLocal_;
static constexpr KernelVariant variant = V;
static constexpr ck_tile::index_t kPageSize = kPageSize_;
@@ -316,7 +318,7 @@ struct unified_attention_kernel_traits
using unified_attention_traits = TileUnifiedAttentionTraits<true, // kPadSeqLenQ_
false, // kPadHeadDimQ
-1>; // kBlockPerCu
using unified_attention_mask = GenericAttentionMask<IsMasking, /*IsLocal=*/false>;
using unified_attention_mask = GenericAttentionMask<IsMasking, IsLocal_>;
using unified_attention_pipeline_problem =
UnifiedAttentionPipelineProblem<typename dt::qkvp_dtype,
@@ -438,31 +440,38 @@ std::pair<bool, float> unified_attention_kernel_dispatch(const unified_attention
} // namespace ck_tile
// One-line instantiation per (V, DataType, IsMasking, PageSize) combination.
// Each instance .cpp consists of exactly one of these calls. PAGE_SIZE_ = 0
// is the legacy runtime-page-size instance (catch-all fallback). Non-zero
// values pin the runtime `page_size` argument to that literal — see the
// dispatch_variant<V> switch in unified_attention.cpp for routing.
#define INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_) \
// One-line instantiation per (V, DataType, IsMasking, PageSize, IsLocal)
// combination. Each instance .cpp consists of exactly one of these calls.
// PAGE_SIZE_ = 0 is the legacy runtime-page-size instance (catch-all
// fallback). IS_LOCAL_ = false is the non-SWA path (causal / no-mask);
// IS_LOCAL_ = true compiles the SWA-capable kernel that honours both the
// left and right window bounds inside the mask (used from Phase 3 on).
#define INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, \
IS_LOCAL_) \
template <> \
std::pair<bool, float> unified_attention_kernel_dispatch< \
unified_attention_kernel_traits<KernelVariant::VARIANT_, \
unified_attention_args::data_type_enum::DTYPE_, \
IS_MASK_, \
PAGE_SIZE_>>(const unified_attention_args& args, \
const stream_config& config) \
PAGE_SIZE_, \
IS_LOCAL_>>(const unified_attention_args& args, \
const stream_config& config) \
{ \
using Traits = unified_attention_kernel_traits< \
KernelVariant::VARIANT_, \
unified_attention_args::data_type_enum::DTYPE_, \
IS_MASK_, \
PAGE_SIZE_>; \
PAGE_SIZE_, \
IS_LOCAL_>; \
return std::make_pair(true, \
unified_attention_kernel_launch<typename Traits::kernel, \
Traits::kUseDecodeGrid>(args, config)); \
}
// Backward-compat shorthand for the existing one-liners — the default
// `PageSize = 0` instance is the catch-all runtime-page-size kernel.
// Backward-compat wrappers — every existing instance .cpp uses one of these
// and defaults to `IsLocal = false` (the non-SWA path).
#define INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_) \
INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, false)
#define INST_UNIFIED_ATTENTION_DISPATCH(VARIANT_, DTYPE_, IS_MASK_) \
INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, 0)

View File

@@ -539,14 +539,20 @@ struct UnifiedAttentionKernel
FmhaMask mask = [&]() {
if constexpr(kHasMask)
// Window args default to (-1, -1, false) on the host side, which
// make_generic_attention_mask_from_lr_window collapses to the
// previous hard-coded bottom-right causal layout (the `< 0`
// branches inside the helper). Once IsLocal=true instances are
// wired up in Phase 3 the same call site honours real SWA
// bounds.
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
-1,
0,
kargs.window_size_left,
kargs.window_size_right,
cur_batch_query_len, // y_total
seq_len, // x_total
num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv
// times along x dim of the tile
false);
kargs.is_top_left);
else
return FmhaMask{cur_batch_query_len, seq_len};
}();