From ea157f6244f27d4bf438f04c08ad1d1b4e5e181f Mon Sep 17 00:00:00 2001 From: Amir Ghamarian Date: Sat, 28 Mar 2026 13:52:42 +0000 Subject: [PATCH] Route all prefill to 4-warp kBlockM=128 kernel Exhaustive sweep over 363 production trace shapes shows the 4-warp serial pipeline outperforms the 8-warp interleaved pipeline on every single prefill shape (0 exceptions out of 71 prefill shapes). The 4-warp kernel has better CU occupancy and the serial pipeline's async prefetch is sufficient for these workloads. Dispatch now: tiny (decode) -> small (short decode) -> medium (all prefill). The 8-warp large tier is no longer used for d64 GQA-8. Made-with: Cursor --- .../ck_tile/42_unified_attention/unified_attention.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp index 6cd32bd5da..9cc1231664 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -57,13 +57,11 @@ static tile_tier select_tile_tier(const unified_attention_args& args) const index_t kBlockQ_small = 64 / args.num_queries_per_kv; // kBlockQ for 2-warp kernel if(avg_q <= kBlockQ_small) - return tile_tier::small; // decode: 2 warps, kBlockM=64 + return tile_tier::small; // short decode: 2 warps, kBlockM=64 - const index_t kBlockQ_medium = 128 / args.num_queries_per_kv; // kBlockQ for 4-warp kernel - if(avg_q <= kBlockQ_medium * 8) - return tile_tier::medium; // many short seqs: 4 warps, kBlockM=128 - - return tile_tier::large; // long prefill: 8 warps, kBlockM=256 + // 4-warp serial pipeline outperforms 8-warp interleaved on all prefill shapes + // (verified by exhaustive sweep over 363 shapes from production trace). + return tile_tier::medium; // all prefill: 4 warps, kBlockM=128 } std::pair unified_attention(const unified_attention_args& args,