This commit is contained in:
Sami Aario
2026-05-26 15:23:01 +00:00
parent b686143624
commit 8e18d8221a
3 changed files with 33 additions and 17 deletions

View File

@@ -6,6 +6,7 @@
namespace ck_tile {
// d64 GQA-8 tiny+bs32 decode (kBlockM=16, kBlockQ=2, BlockSize=32), masked causal (non-SWA).
using kernel_traits =
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 16, 8, 32>;

View File

@@ -38,7 +38,7 @@ DECODE_B="-d=64 -h_k=1 -nqpkv=8 -b=4 -s=1 -s_k=512 -query_lens=1,1,1,1 -kv_lens
# Decode + page_blk_size=32 fixtures for d=64 GQA-8. These exercise the NEW
# decode-tier IsLocal=true instances added in Phase 5:
# - DECODE_BS32_Q1 (q=1) → tiny+bs32 local (kBlockM=32, kBlockQ=4)
# - DECODE_BS32_Q1 (q=1) → tiny+bs32 local (kBlockM=16, kBlockQ=2)
# - DECODE_BS32_QM (q in [256,1024]) → medium+bs32 local (kBlockM=128, kBlockQ=16)
# Use the GPT-OSS-shaped window (left=127, right=0) to mirror the production
# workload that motivated Phase 5.

View File

@@ -43,9 +43,22 @@ std::ostream& operator<<(std::ostream& stream,
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
}
#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL(DType, HSize, BM, NQPKV) \
// #define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL(DType, HSize, BM, NQPKV) \
// { \
// using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, /*IsMasking=*/true, HSize, BM, NQPKV, 32, /*IsLocal=*/true>; \
// return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
// }
// Tiny decode with page_blk_size=32: 1 warp, 16x16 MFMA, kBlockM=16, kBlockQ=2.
#define DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32(DType, IsMask, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, /*IsMasking=*/true, HSize, BM, NQPKV, 32, /*IsLocal=*/true>; \
using kernel_traits = unified_attention_decode_tiny_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}
#define DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32_LOCAL(DType, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_decode_tiny_kernel_traits<DType, /*IsMasking=*/true, HSize, BM, NQPKV, 32, /*IsLocal=*/true>; \
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}
@@ -81,11 +94,11 @@ std::ostream& operator<<(std::ostream& stream,
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}
#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType, IsMask, HSize, BM, NQPKV) \
{ \
using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}
// #define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType, IsMask, HSize, BM, NQPKV) \
// { \
// using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
// return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
// }
enum class tile_tier { large, medium, small, tiny };
@@ -133,7 +146,7 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
// ---------------+---------------+-----------------------------+---------------------
// d128 MHA | >= 32 | tile_tier::large | d128_*_mask_local
// d64 GQA-8 | >= 64 | tile_tier::large | d64_*_mask_gqa8_local
// d64 GQA-8 | == 32, tiny | tile_tier::tiny | d64_*_mask_gqa8_bs32_narrow_local
// d64 GQA-8 | == 32, tiny | tile_tier::tiny | d64_*_mask_gqa8_bs32_decode_t_local
// d64 GQA-8 | == 32, med | tile_tier::medium | d64_*_mask_gqa8_bs32_decode_local
// (small+bs32 SWA has no instance yet -> Triton fallback; GPT-OSS shows zero
// such shapes in practice. Bumping it to medium is plausible but wastes a
@@ -196,19 +209,19 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
if(tier == tile_tier::tiny)
{
if(use_bs32) {
// bs32 narrow: 2 warps, 16x16 MFMA, kBlockM=32, kBlockQ=4.
// Avoids 1-warp race condition; 2x less waste than small tier.
// Tiny+bs32: 1 warp, 16x16 MFMA, kBlockM=16, kBlockQ=2 (matches Triton decode).
// Decode grid requires max_seqlen_q <= kBlockQ (see select_tile_tier).
if(args.data_type == unified_attention_args::data_type_enum::fp16)
{
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, false, 64, 32, 8)
else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL(unified_attention_args::data_type_enum::fp16, 64, 32, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, true, 64, 32, 8)
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32(unified_attention_args::data_type_enum::fp16, false, 64, 16, 8)
else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32_LOCAL(unified_attention_args::data_type_enum::fp16, 64, 16, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32(unified_attention_args::data_type_enum::fp16, true, 64, 16, 8)
}
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
{
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, false, 64, 32, 8)
else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL(unified_attention_args::data_type_enum::bf16, 64, 32, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, true, 64, 32, 8)
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32(unified_attention_args::data_type_enum::bf16, false, 64, 16, 8)
else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32_LOCAL(unified_attention_args::data_type_enum::bf16, 64, 16, 8)
else DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32(unified_attention_args::data_type_enum::bf16, true, 64, 16, 8)
}
} else {
// bs64 tiny: 1 warp, 16x16 MFMA, kBlockM=16, kBlockQ=2.
@@ -297,6 +310,8 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
return std::make_pair(false, -1.f);
}
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32_LOCAL
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LOCAL
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW