mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Add split-KV decode tiles (b16x32, b32x32) + fix num_splits heuristic
Decode tiles for split-KV hdim=64: bm0=16/1-warp and bm0=32/2-warp. Fix num_splits to use num_heads_kv (not num_heads_q) and target 4x SMs. Performance unchanged (0.056ms) because: 1. Split+combine overhead dominates for short KV (31 pages) 2. Triton 3D's single-kernel split avoids combine kernel entirely Made-with: Cursor
This commit is contained in:
@@ -821,7 +821,9 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
"32" : [FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
"64" : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"64" : [FmhaFwdTileSize( 16, 32, 32, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaFwdTileSize( 32, 32, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaFwdTileSize( 64, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
"96" : [FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
"128": [FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
|
||||
@@ -154,15 +154,15 @@ int override_num_splits_if_necessary(
|
||||
return num_splits;
|
||||
}
|
||||
|
||||
// tile size should match the generate.py
|
||||
const int kM0 = 64;
|
||||
const int kM0 = 16; // smallest decode tile — use minimum for most parallelism
|
||||
|
||||
const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
|
||||
|
||||
if(num_splits < 1 && p_drop == 0.0f)
|
||||
{
|
||||
// Target 4x SMs for full GPU utilization (matching Triton 3D strategy)
|
||||
return num_splits_heuristic(
|
||||
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128);
|
||||
batch * nhead * num_m_blocks, props.multiProcessorCount * 4, 128);
|
||||
}
|
||||
|
||||
return num_splits;
|
||||
|
||||
Reference in New Issue
Block a user