mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-27 16:34:26 +00:00
WIP
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// d64 GQA-8 tiny+bs32 decode (kBlockM=16, kBlockQ=2, BlockSize=32), masked causal (non-SWA).
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 16, 8, 32>;
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ DECODE_B="-d=64 -h_k=1 -nqpkv=8 -b=4 -s=1 -s_k=512 -query_lens=1,1,1,1 -kv_lens
|
||||
|
||||
# Decode + page_blk_size=32 fixtures for d=64 GQA-8. These exercise the NEW
|
||||
# decode-tier IsLocal=true instances added in Phase 5:
|
||||
# - DECODE_BS32_Q1 (q=1) → tiny+bs32 local (kBlockM=32, kBlockQ=4)
|
||||
# - DECODE_BS32_Q1 (q=1) → tiny+bs32 local (kBlockM=16, kBlockQ=2)
|
||||
# - DECODE_BS32_QM (q in [256,1024]) → medium+bs32 local (kBlockM=128, kBlockQ=16)
|
||||
# Use the GPT-OSS-shaped window (left=127, right=0) to mirror the production
|
||||
# workload that motivated Phase 5.
|
||||
|
||||
@@ -43,9 +43,22 @@ std::ostream& operator<<(std::ostream& stream,
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL(DType, HSize, BM, NQPKV) \
|
||||
// #define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL(DType, HSize, BM, NQPKV) \
|
||||
// { \
|
||||
// using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, /*IsMasking=*/true, HSize, BM, NQPKV, 32, /*IsLocal=*/true>; \
|
||||
// return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
// }
|
||||
|
||||
// Tiny decode with page_blk_size=32: 1 warp, 16x16 MFMA, kBlockM=16, kBlockQ=2.
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, /*IsMasking=*/true, HSize, BM, NQPKV, 32, /*IsLocal=*/true>; \
|
||||
using kernel_traits = unified_attention_decode_tiny_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32_LOCAL(DType, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_tiny_kernel_traits<DType, /*IsMasking=*/true, HSize, BM, NQPKV, 32, /*IsLocal=*/true>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
@@ -81,11 +94,11 @@ std::ostream& operator<<(std::ostream& stream,
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
// #define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType, IsMask, HSize, BM, NQPKV) \
|
||||
// { \
|
||||
// using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
|
||||
// return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
// }
|
||||
|
||||
enum class tile_tier { large, medium, small, tiny };
|
||||
|
||||
@@ -133,7 +146,7 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
// ---------------+---------------+-----------------------------+---------------------
|
||||
// d128 MHA | >= 32 | tile_tier::large | d128_*_mask_local
|
||||
// d64 GQA-8 | >= 64 | tile_tier::large | d64_*_mask_gqa8_local
|
||||
// d64 GQA-8 | == 32, tiny | tile_tier::tiny | d64_*_mask_gqa8_bs32_narrow_local
|
||||
// d64 GQA-8 | == 32, tiny | tile_tier::tiny | d64_*_mask_gqa8_bs32_decode_t_local
|
||||
// d64 GQA-8 | == 32, med | tile_tier::medium | d64_*_mask_gqa8_bs32_decode_local
|
||||
// (small+bs32 SWA has no instance yet -> Triton fallback; GPT-OSS shows zero
|
||||
// such shapes in practice. Bumping it to medium is plausible but wastes a
|
||||
@@ -196,19 +209,19 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
if(tier == tile_tier::tiny)
|
||||
{
|
||||
if(use_bs32) {
|
||||
// bs32 narrow: 2 warps, 16x16 MFMA, kBlockM=32, kBlockQ=4.
|
||||
// Avoids 1-warp race condition; 2x less waste than small tier.
|
||||
// Tiny+bs32: 1 warp, 16x16 MFMA, kBlockM=16, kBlockQ=2 (matches Triton decode).
|
||||
// Decode grid requires max_seqlen_q <= kBlockQ (see select_tile_tier).
|
||||
if(args.data_type == unified_attention_args::data_type_enum::fp16)
|
||||
{
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, false, 64, 32, 8)
|
||||
else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL(unified_attention_args::data_type_enum::fp16, 64, 32, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, true, 64, 32, 8)
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32(unified_attention_args::data_type_enum::fp16, false, 64, 16, 8)
|
||||
else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32_LOCAL(unified_attention_args::data_type_enum::fp16, 64, 16, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32(unified_attention_args::data_type_enum::fp16, true, 64, 16, 8)
|
||||
}
|
||||
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
|
||||
{
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, false, 64, 32, 8)
|
||||
else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL(unified_attention_args::data_type_enum::bf16, 64, 32, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, true, 64, 32, 8)
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32(unified_attention_args::data_type_enum::bf16, false, 64, 16, 8)
|
||||
else if(is_local) DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32_LOCAL(unified_attention_args::data_type_enum::bf16, 64, 16, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32(unified_attention_args::data_type_enum::bf16, true, 64, 16, 8)
|
||||
}
|
||||
} else {
|
||||
// bs64 tiny: 1 warp, 16x16 MFMA, kBlockM=16, kBlockQ=2.
|
||||
@@ -297,6 +310,8 @@ std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
return std::make_pair(false, -1.f);
|
||||
}
|
||||
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32_LOCAL
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_TINY_BS32
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW_LOCAL
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32_LOCAL
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW
|
||||
|
||||
Reference in New Issue
Block a user