sparse_attn: add bm0 dispatch for sparge blockmap compatibility

Add bm0 field to fmha_jenga_fwd_traits so callers can specify the
preferred Q-tile size. Codegen now emits separate tile configs for
bm0=64 (sparge blockmap) and bm0=128 (original), with CppConstraint
guards to select the right kernel at runtime.

End-to-end test passes for both jenga and vsa paths. Performance is
known to be suboptimal at this stage; tile sizes and warp counts for
the bm0=64 path have not been tuned.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Gino Lu
2026-04-24 05:13:51 -04:00
parent ab44b83566
commit eca3cb3e0a
4 changed files with 67 additions and 61 deletions

View File

@@ -690,12 +690,12 @@ class KernelComponentFactory:
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128): [
FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128
64,
FmhaFwdTileSize( # fmt: skip -- 128x128 tile (original, for old sparse attn test)
128,
64,
128,
64,
32,
128,
32,
128,
4,
1,
@@ -703,13 +703,36 @@ class KernelComponentFactory:
4,
1,
1,
32,
32,
16,
16,
16,
16,
16,
32,
32,
16,
-1,
CppConstraint("t.bm0 == 0 || t.bm0 == 128"),
),
FmhaFwdTileSize( # fmt: skip -- 64x128 tile (for sparge blockmap kM0=64)
64,
128,
32,
128,
32,
128,
2,
1,
1,
2,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
CppConstraint("t.bm0 == 64"),
),
FmhaFwdTileSize( # fmt: skip
16,
@@ -774,27 +797,6 @@ class KernelComponentFactory:
16,
-1,
),
FmhaFwdTileSize( # fmt: skip
128,
128,
32,
128,
32,
128,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
],
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
@@ -909,7 +911,7 @@ def get_fwd_blobs(
for tile, pipeline in itertools.product(
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
):
if tile.F_bm0 != 64 or tile.F_bn0 != 128:
if tile.F_bm0 not in (64, 128) or tile.F_bn0 != 128:
continue
if pipeline.tag != "qr_async":
continue

View File

@@ -690,12 +690,12 @@ class KernelComponentFactory:
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128): [
FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128
64,
FmhaFwdTileSize( # fmt: skip -- 128x128 tile (original, for old sparse attn test)
128,
64,
128,
64,
32,
128,
32,
128,
4,
1,
@@ -703,13 +703,36 @@ class KernelComponentFactory:
4,
1,
1,
32,
32,
16,
16,
16,
16,
16,
32,
32,
16,
-1,
CppConstraint("t.bm0 == 0 || t.bm0 == 128"),
),
FmhaFwdTileSize( # fmt: skip -- 64x128 tile (for sparge blockmap kM0=64)
64,
128,
32,
128,
32,
128,
2,
1,
1,
2,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
CppConstraint("t.bm0 == 64"),
),
FmhaFwdTileSize( # fmt: skip
16,
@@ -774,27 +797,6 @@ class KernelComponentFactory:
16,
-1,
),
FmhaFwdTileSize( # fmt: skip
128,
128,
32,
128,
32,
128,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
],
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
@@ -909,7 +911,7 @@ def get_fwd_blobs(
for tile, pipeline in itertools.product(
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
):
if tile.F_bm0 != 64 or tile.F_bn0 != 128:
if tile.F_bm0 not in (64, 128) or tile.F_bn0 != 128:
continue
if pipeline.tag != "qr_async_vsa":
continue

View File

@@ -272,7 +272,7 @@ struct fmha_jenga_fwd_traits
std::string data_type;
bool is_v_rowmajor;
mask_enum mask_type;
// TODO: padding check is inside this api
int bm0 = 0; // preferred Q-tile size; 0 = don't care (dispatch picks largest)
};
float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&);

View File

@@ -249,6 +249,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
attn_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
attn_traits.is_v_rowmajor = true;
attn_traits.mask_type = mask_enum::no_mask;
attn_traits.bm0 = BLKQ;
fmha_jenga_fwd_args attn_args;
attn_args.q_ptr = q_dev.GetDeviceBuffer();
@@ -291,6 +292,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
attn_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
attn_traits.is_v_rowmajor = true;
attn_traits.mask_type = mask_enum::no_mask;
attn_traits.bm0 = BLKQ;
fmha_vsa_fwd_args attn_args;
attn_args.q_ptr = q_dev.GetDeviceBuffer();