Adding SWA decode dispatcher to support GPT-OSS shape + update smoke test

This commit is contained in:
Damien Lejeune
2026-05-08 14:38:16 +00:00
parent e36693c4dc
commit b686143624
6 changed files with 188 additions and 27 deletions

View File

@@ -34,6 +34,21 @@ std::ostream& operator<<(std::ostream& stream,
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
}
// SWA-aware decode dispatchers for bs32. These mirror the non-local *_BS32 macros
// but flip IsLocal=true on the 7th template arg, so the kernel uses the
// sliding-window mask AND the per-Q-tile KV-block iteration clipping.
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LOCAL(DType, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_decode_kernel_traits<DType, /*IsMasking=*/true, HSize, BM, NQPKV, 32, /*IsLocal=*/true>; \
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
}
#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL(DType, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, /*IsMasking=*/true, HSize, BM, NQPKV, 32, /*IsLocal=*/true>; \
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}
// Dispatch macros for three tile tiers (default block_size).
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType, IsMask, HSize, BM, NQPKV) \
{ \
@@ -108,22 +123,52 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
const bool is_local =
is_mask && (args.window_size_left >= 0 || args.window_size_right > 0);
auto tier = select_tile_tier(args);
// SWA instances currently only exist at the large prefill tier (kBlockM=256,
// 8 warps). Each requires args.page_blk_size >= kBlockN of the instance —
// otherwise the kernel hits a device-side `kv_page_size_in_blocks >= 1`
// assertion. When SWA is requested on an unsupported (shape, page_blk_size)
// pair we return {false, 0} so the caller (e.g. _try_ck_unified_attention)
// can fall back to a backend that handles it (Triton). Falling through to
// the IsLocal=false path would silently ignore window_size_left and produce
// SWA-instance availability matrix (IsLocal=true). Anything not listed here
// returns {false, 0} so the caller (e.g. _try_ck_unified_attention) falls
// back to a backend that handles it (Triton). Falling through to the
// IsLocal=false path would silently ignore window_size_left and produce
// wrong outputs, so we reject explicitly.
//
// shape | page_blk_size | tier we route to | instance
// ---------------+---------------+-----------------------------+---------------------
// d128 MHA | >= 32 | tile_tier::large | d128_*_mask_local
// d64 GQA-8 | >= 64 | tile_tier::large | d64_*_mask_gqa8_local
// d64 GQA-8 | == 32, tiny | tile_tier::tiny | d64_*_mask_gqa8_bs32_narrow_local
// d64 GQA-8 | == 32, med | tile_tier::medium | d64_*_mask_gqa8_bs32_decode_local
// (small+bs32 SWA has no instance yet -> Triton fallback; GPT-OSS shows zero
// such shapes in practice. Bumping it to medium is plausible but wastes a
// full kBlockM=128 tile when only kBlockM=64 is needed -- revisit if the
// workload changes.)
if(is_local)
{
const bool d128_mha = (args.hdim == 128 && args.num_queries_per_kv == 1);
const bool d64_gqa8 = (args.hdim == 64 && args.num_queries_per_kv == 8);
const index_t kBN_req = d128_mha ? 32 : (d64_gqa8 ? 64 : 0);
if(kBN_req == 0 || args.page_blk_size < kBN_req)
const bool d128_mha = (args.hdim == 128 && args.num_queries_per_kv == 1);
const bool d64_gqa8 = (args.hdim == 64 && args.num_queries_per_kv == 8);
if(d128_mha)
{
if(args.page_blk_size < 32) return {false, 0.f};
tier = tile_tier::large;
}
else if(d64_gqa8)
{
if(args.page_blk_size >= 64)
{
tier = tile_tier::large;
}
else if(args.page_blk_size >= 32)
{
if(tier != tile_tier::tiny && tier != tile_tier::medium)
return {false, 0.f};
// keep selected tier as-is
}
else
{
return {false, 0.f};
}
}
else
{
return {false, 0.f};
tier = tile_tier::large;
}
}
// d128, MHA (num_queries_per_kv == 1)
@@ -155,13 +200,15 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
// Avoids 1-warp race condition; 2x less waste than small tier.
if(args.data_type == unified_attention_args::data_type_enum::fp16)
{
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, false, 64, 32, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, true, 64, 32, 8)
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, false, 64, 32, 8)
else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL(unified_attention_args::data_type_enum::fp16, 64, 32, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, true, 64, 32, 8)
}
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
{
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, false, 64, 32, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, true, 64, 32, 8)
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, false, 64, 32, 8)
else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL(unified_attention_args::data_type_enum::bf16, 64, 32, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, true, 64, 32, 8)
}
} else {
// bs64 tiny: 1 warp, 16x16 MFMA, kBlockM=16, kBlockQ=2.
@@ -205,8 +252,9 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
if(args.data_type == unified_attention_args::data_type_enum::fp16)
{
if(use_bs32) {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8)
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8)
else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LOCAL(unified_attention_args::data_type_enum::fp16, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8)
} else {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8)
@@ -215,8 +263,9 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
{
if(use_bs32) {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8)
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8)
else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LOCAL(unified_attention_args::data_type_enum::bf16, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8)
} else {
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8)
@@ -248,6 +297,8 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
return std::make_pair(false, -1.f);
}
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LOCAL
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32