mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Fix CK-UA mixed batch: use max_seqlen_q for tier selection
Decode grid (num_kv_heads, num_seqs) assumes each seq has <= kBlockQ tokens. For mixed batches (decode + prefill), avg_q is low but some seqs have hundreds of tokens, causing truncation. Added max_seqlen_q to args and check it in select_tile_tier to force medium tier (1D grid with Q tile iteration) for mixed batches. 362/362 no-window shapes now pass. Made-with: Cursor
This commit is contained in:
@@ -69,18 +69,22 @@ enum class tile_tier { large, medium, small, tiny };
|
||||
static tile_tier select_tile_tier(const unified_attention_args& args)
|
||||
{
|
||||
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs : args.num_tokens;
|
||||
const index_t kBlockQ_tiny = 16 / args.num_queries_per_kv; // kBlockQ for 1-warp 16x16 kernel
|
||||
const index_t kBlockQ_tiny = 16 / args.num_queries_per_kv;
|
||||
const index_t kBlockQ_small = 64 / args.num_queries_per_kv;
|
||||
const index_t kBlockQ_medium = 128 / args.num_queries_per_kv;
|
||||
|
||||
if(avg_q <= kBlockQ_tiny)
|
||||
return tile_tier::tiny; // pure decode: 1 warp, 16x16 MFMA, kBlockM=16
|
||||
// Decode tiers use a 2D grid (num_kv_heads, num_seqs) that assumes each
|
||||
// seq has at most kBlockQ tokens. For mixed batches where some seqs have
|
||||
// many more tokens, we must use the medium tier (1D grid with Q tile iteration).
|
||||
const index_t max_q = args.max_seqlen_q > 0 ? args.max_seqlen_q : avg_q;
|
||||
|
||||
const index_t kBlockQ_small = 64 / args.num_queries_per_kv; // kBlockQ for 2-warp kernel
|
||||
if(avg_q <= kBlockQ_small)
|
||||
return tile_tier::small; // short decode: 2 warps, kBlockM=64
|
||||
if(avg_q <= kBlockQ_tiny && max_q <= kBlockQ_tiny)
|
||||
return tile_tier::tiny;
|
||||
|
||||
// 4-warp serial pipeline outperforms 8-warp interleaved on all prefill shapes
|
||||
// (verified by exhaustive sweep over 363 shapes from production trace).
|
||||
return tile_tier::medium; // all prefill: 4 warps, kBlockM=128
|
||||
if(avg_q <= kBlockQ_small && max_q <= kBlockQ_small)
|
||||
return tile_tier::small;
|
||||
|
||||
return tile_tier::medium;
|
||||
}
|
||||
|
||||
std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
|
||||
@@ -66,6 +66,7 @@ struct unified_attention_args
|
||||
const int32_t* query_start_len_ptr; // [num_seqs+1]
|
||||
|
||||
index_t num_seqs; // number of batches for q
|
||||
index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown)
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream,
|
||||
|
||||
Reference in New Issue
Block a user