Route all prefill to 4-warp kBlockM=128 kernel

Exhaustive sweep over 363 production trace shapes shows the 4-warp
serial pipeline outperforms the 8-warp interleaved pipeline on every
single prefill shape (0 exceptions out of 71 prefill shapes).

The 4-warp kernel has better CU occupancy and the serial pipeline's
async prefetch is sufficient for these workloads.

Dispatch now: tiny (decode) -> small (short decode) -> medium (all prefill).
The 8-warp large tier is no longer used for d64 GQA-8.

Made-with: Cursor
This commit is contained in:
Amir Ghamarian
2026-03-28 13:52:42 +00:00
parent 33b2015939
commit ea157f6244

View File

@@ -57,13 +57,11 @@ static tile_tier select_tile_tier(const unified_attention_args& args)
const index_t kBlockQ_small = 64 / args.num_queries_per_kv; // kBlockQ for 2-warp kernel
if(avg_q <= kBlockQ_small)
return tile_tier::small; // decode: 2 warps, kBlockM=64
return tile_tier::small; // short decode: 2 warps, kBlockM=64
const index_t kBlockQ_medium = 128 / args.num_queries_per_kv; // kBlockQ for 4-warp kernel
if(avg_q <= kBlockQ_medium * 8)
return tile_tier::medium; // many short seqs: 4 warps, kBlockM=128
return tile_tier::large; // long prefill: 8 warps, kBlockM=256
// 4-warp serial pipeline outperforms 8-warp interleaved on all prefill shapes
// (verified by exhaustive sweep over 363 shapes from production trace).
return tile_tier::medium; // all prefill: 4 warps, kBlockM=128
}
std::pair<bool, float> unified_attention(const unified_attention_args& args,