Add split-KV decode tiles (b16x32, b32x32) + fix num_splits heuristic

Decode tiles for split-KV hdim=64: bm0=16/1-warp and bm0=32/2-warp.
Fix num_splits to use num_heads_kv (not num_heads_q) and target 4x SMs.

Performance unchanged (0.056ms) because:
1. Split+combine overhead dominates for short KV (31 pages)
2. Triton 3D's single-kernel split avoids combine kernel entirely

Made-with: Cursor
This commit is contained in:
root
2026-04-01 18:49:16 +00:00
parent c5600bc8ae
commit 63821af1ff
2 changed files with 6 additions and 4 deletions

View File

@@ -821,7 +821,9 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
if dtype in ["fp16", "bf16"]:
return {
"32" : [FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
"64" : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"64" : [FmhaFwdTileSize( 16, 32, 32, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 32, 32, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
"96" : [FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
"128": [FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],

View File

@@ -154,15 +154,15 @@ int override_num_splits_if_necessary(
return num_splits;
}
// tile size should match the generate.py
const int kM0 = 64;
const int kM0 = 16; // smallest decode tile — use minimum for most parallelism
const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
if(num_splits < 1 && p_drop == 0.0f)
{
// Target 4x SMs for full GPU utilization (matching Triton 3D strategy)
return num_splits_heuristic(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128);
batch * nhead * num_m_blocks, props.multiProcessorCount * 4, 128);
}
return num_splits;