mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Route all prefill to 4-warp kBlockM=128 kernel
Exhaustive sweep over 363 production trace shapes shows the 4-warp serial pipeline outperforms the 8-warp interleaved pipeline on every single prefill shape (0 exceptions out of 71 prefill shapes). The 4-warp kernel has better CU occupancy and the serial pipeline's async prefetch is sufficient for these workloads. Dispatch now: tiny (decode) -> small (short decode) -> medium (all prefill). The 8-warp large tier is no longer used for d64 GQA-8. Made-with: Cursor
This commit is contained in:
@@ -57,13 +57,11 @@ static tile_tier select_tile_tier(const unified_attention_args& args)
|
||||
|
||||
const index_t kBlockQ_small = 64 / args.num_queries_per_kv; // kBlockQ for 2-warp kernel
|
||||
if(avg_q <= kBlockQ_small)
|
||||
return tile_tier::small; // decode: 2 warps, kBlockM=64
|
||||
return tile_tier::small; // short decode: 2 warps, kBlockM=64
|
||||
|
||||
const index_t kBlockQ_medium = 128 / args.num_queries_per_kv; // kBlockQ for 4-warp kernel
|
||||
if(avg_q <= kBlockQ_medium * 8)
|
||||
return tile_tier::medium; // many short seqs: 4 warps, kBlockM=128
|
||||
|
||||
return tile_tier::large; // long prefill: 8 warps, kBlockM=256
|
||||
// 4-warp serial pipeline outperforms 8-warp interleaved on all prefill shapes
|
||||
// (verified by exhaustive sweep over 363 shapes from production trace).
|
||||
return tile_tier::medium; // all prefill: 4 warps, kBlockM=128
|
||||
}
|
||||
|
||||
std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
|
||||
Reference in New Issue
Block a user