From eca3cb3e0abdcb927a02f9b7b9ed786b9a9cdda2 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Fri, 24 Apr 2026 05:13:51 -0400 Subject: [PATCH] sparse_attn: add bm0 dispatch for sparge blockmap compatibility Add bm0 field to fmha_jenga_fwd_traits so callers can specify the preferred Q-tile size. Codegen now emits separate tile configs for bm0=64 (sparge blockmap) and bm0=128 (original), with CppConstraint guards to select the right kernel at runtime. End-to-end test passes for both jenga and vsa paths. Performance is known to be suboptimal at this stage; tile sizes and warp counts for the bm0=64 path have not been tuned. Co-Authored-By: Claude Opus 4.7 --- .../codegen/ops/fmha_fwd_jenga.py | 62 ++++++++++--------- .../codegen/ops/fmha_fwd_vsa.py | 62 ++++++++++--------- .../ck_tile/50_sparse_attn/fmha_fwd_trek.hpp | 2 +- .../ck_tile/50_sparse_attn/test_sparge.cpp | 2 + 4 files changed, 67 insertions(+), 61 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py index 1f0a78048d..fc4b8642dd 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -690,12 +690,12 @@ class KernelComponentFactory: # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128): [ - FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128 - 64, + FmhaFwdTileSize( # fmt: skip -- 128x128 tile (original, for old sparse attn test) 128, - 64, 128, - 64, + 32, + 128, + 32, 128, 4, 1, @@ -703,13 +703,36 @@ class KernelComponentFactory: 4, 1, 1, + 32, + 32, 16, - 16, - 16, - 16, - 16, + 32, + 32, 16, -1, + CppConstraint("t.bm0 == 0 || t.bm0 == 128"), + ), + FmhaFwdTileSize( # fmt: skip -- 64x128 tile (for sparge blockmap kM0=64) + 64, + 128, + 32, + 128, + 32, + 128, + 2, + 1, + 1, + 2, + 1, + 1, + 32, + 32, + 16, + 32, + 32, + 16, + -1, + CppConstraint("t.bm0 == 64"), ), FmhaFwdTileSize( # fmt: skip 16, @@ -774,27 +797,6 @@ class KernelComponentFactory: 16, -1, ), - FmhaFwdTileSize( # fmt: skip - 128, - 128, - 32, - 128, - 32, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 16, - 32, - 32, - 16, - -1, - ), ], # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], @@ -909,7 +911,7 @@ def get_fwd_blobs( for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): - if tile.F_bm0 != 64 or tile.F_bn0 != 128: + if tile.F_bm0 not in (64, 128) or tile.F_bn0 != 128: continue if pipeline.tag != "qr_async": continue diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py index 217cfcfe2a..208877037f 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -690,12 +690,12 @@ class KernelComponentFactory: # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128): [ - FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128 - 64, + FmhaFwdTileSize( # fmt: skip -- 128x128 tile (original, for old sparse attn test) 128, - 64, 128, - 64, + 32, + 128, + 32, 128, 4, 1, @@ -703,13 +703,36 @@ class KernelComponentFactory: 4, 1, 1, + 32, + 32, 16, - 16, - 16, - 16, - 16, + 32, + 32, 16, -1, + CppConstraint("t.bm0 == 0 || t.bm0 == 128"), + ), + FmhaFwdTileSize( # fmt: skip -- 64x128 tile (for sparge blockmap kM0=64) + 64, + 128, + 32, + 128, + 32, + 128, + 2, + 1, + 1, + 2, + 1, + 1, + 32, + 32, + 16, + 32, + 32, + 16, + -1, + CppConstraint("t.bm0 == 64"), ), FmhaFwdTileSize( # fmt: skip 16, @@ -774,27 +797,6 @@ class KernelComponentFactory: 16, -1, ), - FmhaFwdTileSize( # fmt: skip - 128, - 128, - 32, - 128, - 32, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 16, - 32, - 32, - 16, - -1, - ), ], # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], @@ -909,7 +911,7 @@ def get_fwd_blobs( for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): - if tile.F_bm0 != 64 or tile.F_bn0 != 128: + if tile.F_bm0 not in (64, 128) or tile.F_bn0 != 128: continue if pipeline.tag != "qr_async_vsa": continue diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index 350d1803f6..62d40ffbe0 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -272,7 +272,7 @@ struct fmha_jenga_fwd_traits std::string data_type; bool is_v_rowmajor; mask_enum mask_type; - // TODO: padding check is inside this api + int bm0 = 0; // preferred Q-tile size; 0 = don't care (dispatch picks largest) }; float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/50_sparse_attn/test_sparge.cpp b/example/ck_tile/50_sparse_attn/test_sparge.cpp index 7c30a10b06..81a49ca006 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge.cpp @@ -249,6 +249,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) attn_traits.data_type = std::is_same_v ? "fp16" : "bf16"; attn_traits.is_v_rowmajor = true; attn_traits.mask_type = mask_enum::no_mask; + attn_traits.bm0 = BLKQ; fmha_jenga_fwd_args attn_args; attn_args.q_ptr = q_dev.GetDeviceBuffer(); @@ -291,6 +292,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) attn_traits.data_type = std::is_same_v ? "fp16" : "bf16"; attn_traits.is_v_rowmajor = true; attn_traits.mask_type = mask_enum::no_mask; + attn_traits.bm0 = BLKQ; fmha_vsa_fwd_args attn_args; attn_args.q_ptr = q_dev.GetDeviceBuffer();