diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp index 6cd32bd5da..9cc1231664 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -57,13 +57,11 @@ static tile_tier select_tile_tier(const unified_attention_args& args) const index_t kBlockQ_small = 64 / args.num_queries_per_kv; // kBlockQ for 2-warp kernel if(avg_q <= kBlockQ_small) - return tile_tier::small; // decode: 2 warps, kBlockM=64 + return tile_tier::small; // short decode: 2 warps, kBlockM=64 - const index_t kBlockQ_medium = 128 / args.num_queries_per_kv; // kBlockQ for 4-warp kernel - if(avg_q <= kBlockQ_medium * 8) - return tile_tier::medium; // many short seqs: 4 warps, kBlockM=128 - - return tile_tier::large; // long prefill: 8 warps, kBlockM=256 + // 4-warp serial pipeline outperforms 8-warp interleaved on all prefill shapes + // (verified by exhaustive sweep over 363 shapes from production trace). + return tile_tier::medium; // all prefill: 4 warps, kBlockM=128 } std::pair unified_attention(const unified_attention_args& args,