Fix CK-UA mixed batch: use max_seqlen_q for tier selection

Decode grid (num_kv_heads, num_seqs) assumes each seq has <= kBlockQ
tokens. For mixed batches (decode + prefill), avg_q is low but some
seqs have hundreds of tokens, causing truncation. Added max_seqlen_q
to args and check it in select_tile_tier to force medium tier (1D
grid with Q tile iteration) for mixed batches.

362/362 no-window shapes now pass.

Made-with: Cursor
This commit is contained in:
root
2026-04-01 18:09:48 +00:00
parent 07ba03bcbf
commit 65a3f88ad8
2 changed files with 14 additions and 9 deletions

View File

@@ -69,18 +69,22 @@ enum class tile_tier { large, medium, small, tiny };
static tile_tier select_tile_tier(const unified_attention_args& args)
{
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs : args.num_tokens;
const index_t kBlockQ_tiny = 16 / args.num_queries_per_kv; // kBlockQ for 1-warp 16x16 kernel
const index_t kBlockQ_tiny = 16 / args.num_queries_per_kv;
const index_t kBlockQ_small = 64 / args.num_queries_per_kv;
const index_t kBlockQ_medium = 128 / args.num_queries_per_kv;
if(avg_q <= kBlockQ_tiny)
return tile_tier::tiny; // pure decode: 1 warp, 16x16 MFMA, kBlockM=16
// Decode tiers use a 2D grid (num_kv_heads, num_seqs) that assumes each
// seq has at most kBlockQ tokens. For mixed batches where some seqs have
// many more tokens, we must use the medium tier (1D grid with Q tile iteration).
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
const index_t kBlockQ_small = 64 / args.num_queries_per_kv; // kBlockQ for 2-warp kernel
if(avg_q <= kBlockQ_small)
return tile_tier::small; // short decode: 2 warps, kBlockM=64
if(avg_q <= kBlockQ_tiny && max_q <= kBlockQ_tiny)
return tile_tier::tiny;
// 4-warp serial pipeline outperforms 8-warp interleaved on all prefill shapes
// (verified by exhaustive sweep over 363 shapes from production trace).
return tile_tier::medium; // all prefill: 4 warps, kBlockM=128
if(avg_q <= kBlockQ_small && max_q <= kBlockQ_small)
return tile_tier::small;
return tile_tier::medium;
}
std::pair<bool, float> unified_attention(const unified_attention_args& args,

View File

@@ -66,6 +66,7 @@ struct unified_attention_args
const int32_t* query_start_len_ptr; // [num_seqs+1]
index_t num_seqs; // number of batches for q
index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown)
};
std::ostream& operator<<(std::ostream& stream,