mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
sparse_attn: add bm0 dispatch for sparge blockmap compatibility
Add bm0 field to fmha_jenga_fwd_traits so callers can specify the preferred Q-tile size. Codegen now emits separate tile configs for bm0=64 (sparge blockmap) and bm0=128 (original), with CppConstraint guards to select the right kernel at runtime. End-to-end test passes for both jenga and vsa paths. Performance is known to be suboptimal at this stage; tile sizes and warp counts for the bm0=64 path have not been tuned. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -690,12 +690,12 @@ class KernelComponentFactory:
|
||||
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128): [
|
||||
FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128
|
||||
64,
|
||||
FmhaFwdTileSize( # fmt: skip -- 128x128 tile (original, for old sparse attn test)
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
@@ -703,13 +703,36 @@ class KernelComponentFactory:
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint("t.bm0 == 0 || t.bm0 == 128"),
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip -- 64x128 tile (for sparge blockmap kM0=64)
|
||||
64,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint("t.bm0 == 64"),
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
16,
|
||||
@@ -774,27 +797,6 @@ class KernelComponentFactory:
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
],
|
||||
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
@@ -909,7 +911,7 @@ def get_fwd_blobs(
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
|
||||
):
|
||||
if tile.F_bm0 != 64 or tile.F_bn0 != 128:
|
||||
if tile.F_bm0 not in (64, 128) or tile.F_bn0 != 128:
|
||||
continue
|
||||
if pipeline.tag != "qr_async":
|
||||
continue
|
||||
|
||||
@@ -690,12 +690,12 @@ class KernelComponentFactory:
|
||||
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128): [
|
||||
FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128
|
||||
64,
|
||||
FmhaFwdTileSize( # fmt: skip -- 128x128 tile (original, for old sparse attn test)
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
@@ -703,13 +703,36 @@ class KernelComponentFactory:
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint("t.bm0 == 0 || t.bm0 == 128"),
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip -- 64x128 tile (for sparge blockmap kM0=64)
|
||||
64,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint("t.bm0 == 64"),
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
16,
|
||||
@@ -774,27 +797,6 @@ class KernelComponentFactory:
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
],
|
||||
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
@@ -909,7 +911,7 @@ def get_fwd_blobs(
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
|
||||
):
|
||||
if tile.F_bm0 != 64 or tile.F_bn0 != 128:
|
||||
if tile.F_bm0 not in (64, 128) or tile.F_bn0 != 128:
|
||||
continue
|
||||
if pipeline.tag != "qr_async_vsa":
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user