From 63821af1ff6600ef9dbf422645e5dedd6800ab13 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 1 Apr 2026 18:49:16 +0000 Subject: [PATCH] Add split-KV decode tiles (b16x32, b32x32) + fix num_splits heuristic Decode tiles for split-KV hdim=64: bm0=16/1-warp and bm0=32/2-warp. Fix num_splits to use num_heads_kv (not num_heads_q) and target 4x SMs. Performance unchanged (0.056ms) because: 1. Split+combine overhead dominates for short KV (31 pages) 2. Triton 3D's single-kernel split avoids combine kernel entirely Made-with: Cursor --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py | 4 +++- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 8b580ed921..627bd27118 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -821,7 +821,9 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase): if dtype in ["fp16", "bf16"]: return { "32" : [FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - "64" : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "64" : [FmhaFwdTileSize( 16, 32, 32, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 32, 32, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), FmhaFwdTileSize( 64, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], "96" : [FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], "128": [FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 40b8006381..81b3384962 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -154,15 +154,15 @@ int override_num_splits_if_necessary( return num_splits; } - // tile size should match the generate.py - const int kM0 = 64; + const int kM0 = 16; // smallest decode tile — use minimum for most parallelism const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0); if(num_splits < 1 && p_drop == 0.0f) { + // Target 4x SMs for full GPU utilization (matching Triton 3D strategy) return num_splits_heuristic( - batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128); + batch * nhead * num_m_blocks, props.multiProcessorCount * 4, 128); } return num_splits;