diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp index aa54396740..bdeb56aed9 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -69,18 +69,22 @@ enum class tile_tier { large, medium, small, tiny }; static tile_tier select_tile_tier(const unified_attention_args& args) { const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs : args.num_tokens; - const index_t kBlockQ_tiny = 16 / args.num_queries_per_kv; // kBlockQ for 1-warp 16x16 kernel + const index_t kBlockQ_tiny = 16 / args.num_queries_per_kv; + const index_t kBlockQ_small = 64 / args.num_queries_per_kv; + const index_t kBlockQ_medium = 128 / args.num_queries_per_kv; - if(avg_q <= kBlockQ_tiny) - return tile_tier::tiny; // pure decode: 1 warp, 16x16 MFMA, kBlockM=16 + // Decode tiers use a 2D grid (num_kv_heads, num_seqs) that assumes each + // seq has at most kBlockQ tokens. For mixed batches where some seqs have + // many more tokens, we must use the medium tier (1D grid with Q tile iteration). + const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q; - const index_t kBlockQ_small = 64 / args.num_queries_per_kv; // kBlockQ for 2-warp kernel - if(avg_q <= kBlockQ_small) - return tile_tier::small; // short decode: 2 warps, kBlockM=64 + if(avg_q <= kBlockQ_tiny && max_q <= kBlockQ_tiny) + return tile_tier::tiny; - // 4-warp serial pipeline outperforms 8-warp interleaved on all prefill shapes - // (verified by exhaustive sweep over 363 shapes from production trace). - return tile_tier::medium; // all prefill: 4 warps, kBlockM=128 + if(avg_q <= kBlockQ_small && max_q <= kBlockQ_small) + return tile_tier::small; + + return tile_tier::medium; } std::pair unified_attention(const unified_attention_args& args, diff --git a/example/ck_tile/42_unified_attention/unified_attention.hpp b/example/ck_tile/42_unified_attention/unified_attention.hpp index 64f340c556..8b645387a4 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention.hpp @@ -66,6 +66,7 @@ struct unified_attention_args const int32_t* query_start_len_ptr; // [num_seqs+1] index_t num_seqs; // number of batches for q + index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown) }; std::ostream& operator<<(std::ostream& stream,