mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
Add IsLocal argument to trait
This commit is contained in:
@@ -247,7 +247,8 @@ struct variant_config<KernelVariant::decode_d64_m16>
|
||||
template <KernelVariant V,
|
||||
unified_attention_args::data_type_enum DataType,
|
||||
bool IsMasking,
|
||||
ck_tile::index_t kPageSize_ = 0>
|
||||
ck_tile::index_t kPageSize_ = 0,
|
||||
bool IsLocal_ = false>
|
||||
struct unified_attention_kernel_traits
|
||||
{
|
||||
using cfg = variant_config<V>;
|
||||
@@ -255,6 +256,7 @@ struct unified_attention_kernel_traits
|
||||
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
static constexpr bool is_local = IsLocal_;
|
||||
static constexpr KernelVariant variant = V;
|
||||
static constexpr ck_tile::index_t kPageSize = kPageSize_;
|
||||
|
||||
@@ -316,7 +318,7 @@ struct unified_attention_kernel_traits
|
||||
using unified_attention_traits = TileUnifiedAttentionTraits<true, // kPadSeqLenQ_
|
||||
false, // kPadHeadDimQ
|
||||
-1>; // kBlockPerCu
|
||||
using unified_attention_mask = GenericAttentionMask<IsMasking, /*IsLocal=*/false>;
|
||||
using unified_attention_mask = GenericAttentionMask<IsMasking, IsLocal_>;
|
||||
|
||||
using unified_attention_pipeline_problem =
|
||||
UnifiedAttentionPipelineProblem<typename dt::qkvp_dtype,
|
||||
@@ -438,31 +440,38 @@ std::pair<bool, float> unified_attention_kernel_dispatch(const unified_attention
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
// One-line instantiation per (V, DataType, IsMasking, PageSize) combination.
|
||||
// Each instance .cpp consists of exactly one of these calls. PAGE_SIZE_ = 0
|
||||
// is the legacy runtime-page-size instance (catch-all fallback). Non-zero
|
||||
// values pin the runtime `page_size` argument to that literal — see the
|
||||
// dispatch_variant<V> switch in unified_attention.cpp for routing.
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_) \
|
||||
// One-line instantiation per (V, DataType, IsMasking, PageSize, IsLocal)
|
||||
// combination. Each instance .cpp consists of exactly one of these calls.
|
||||
// PAGE_SIZE_ = 0 is the legacy runtime-page-size instance (catch-all
|
||||
// fallback). IS_LOCAL_ = false is the non-SWA path (causal / no-mask);
|
||||
// IS_LOCAL_ = true compiles the SWA-capable kernel that honours both the
|
||||
// left and right window bounds inside the mask (used from Phase 3 on).
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, \
|
||||
IS_LOCAL_) \
|
||||
template <> \
|
||||
std::pair<bool, float> unified_attention_kernel_dispatch< \
|
||||
unified_attention_kernel_traits<KernelVariant::VARIANT_, \
|
||||
unified_attention_args::data_type_enum::DTYPE_, \
|
||||
IS_MASK_, \
|
||||
PAGE_SIZE_>>(const unified_attention_args& args, \
|
||||
const stream_config& config) \
|
||||
PAGE_SIZE_, \
|
||||
IS_LOCAL_>>(const unified_attention_args& args, \
|
||||
const stream_config& config) \
|
||||
{ \
|
||||
using Traits = unified_attention_kernel_traits< \
|
||||
KernelVariant::VARIANT_, \
|
||||
unified_attention_args::data_type_enum::DTYPE_, \
|
||||
IS_MASK_, \
|
||||
PAGE_SIZE_>; \
|
||||
PAGE_SIZE_, \
|
||||
IS_LOCAL_>; \
|
||||
return std::make_pair(true, \
|
||||
unified_attention_kernel_launch<typename Traits::kernel, \
|
||||
Traits::kUseDecodeGrid>(args, config)); \
|
||||
}
|
||||
|
||||
// Backward-compat shorthand for the existing one-liners — the default
|
||||
// `PageSize = 0` instance is the catch-all runtime-page-size kernel.
|
||||
// Backward-compat wrappers — every existing instance .cpp uses one of these
|
||||
// and defaults to `IsLocal = false` (the non-SWA path).
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_) \
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_PS_LOCAL(VARIANT_, DTYPE_, IS_MASK_, PAGE_SIZE_, false)
|
||||
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH(VARIANT_, DTYPE_, IS_MASK_) \
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_PS(VARIANT_, DTYPE_, IS_MASK_, 0)
|
||||
|
||||
@@ -539,14 +539,20 @@ struct UnifiedAttentionKernel
|
||||
|
||||
FmhaMask mask = [&]() {
|
||||
if constexpr(kHasMask)
|
||||
// Window args default to (-1, -1, false) on the host side, which
|
||||
// make_generic_attention_mask_from_lr_window collapses to the
|
||||
// previous hard-coded bottom-right causal layout (the `< 0`
|
||||
// branches inside the helper). Once IsLocal=true instances are
|
||||
// wired up in Phase 3 the same call site honours real SWA
|
||||
// bounds.
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
-1,
|
||||
0,
|
||||
kargs.window_size_left,
|
||||
kargs.window_size_right,
|
||||
cur_batch_query_len, // y_total
|
||||
seq_len, // x_total
|
||||
num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv
|
||||
// times along x dim of the tile
|
||||
false);
|
||||
kargs.is_top_left);
|
||||
else
|
||||
return FmhaMask{cur_batch_query_len, seq_len};
|
||||
}();
|
||||
|
||||
Reference in New Issue
Block a user