mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
sparse_attn: add bm0 dispatch for sparge blockmap compatibility
Add bm0 field to fmha_jenga_fwd_traits so callers can specify the preferred Q-tile size. Codegen now emits separate tile configs for bm0=64 (sparge blockmap) and bm0=128 (original), with CppConstraint guards to select the right kernel at runtime. End-to-end test passes for both jenga and vsa paths. Performance is known to be suboptimal at this stage; tile sizes and warp counts for the bm0=64 path have not been tuned. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -690,12 +690,12 @@ class KernelComponentFactory:
|
||||
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128): [
|
||||
FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128
|
||||
64,
|
||||
FmhaFwdTileSize( # fmt: skip -- 128x128 tile (original, for old sparse attn test)
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
@@ -703,13 +703,36 @@ class KernelComponentFactory:
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint("t.bm0 == 0 || t.bm0 == 128"),
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip -- 64x128 tile (for sparge blockmap kM0=64)
|
||||
64,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint("t.bm0 == 64"),
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
16,
|
||||
@@ -774,27 +797,6 @@ class KernelComponentFactory:
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
],
|
||||
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
@@ -909,7 +911,7 @@ def get_fwd_blobs(
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
|
||||
):
|
||||
if tile.F_bm0 != 64 or tile.F_bn0 != 128:
|
||||
if tile.F_bm0 not in (64, 128) or tile.F_bn0 != 128:
|
||||
continue
|
||||
if pipeline.tag != "qr_async":
|
||||
continue
|
||||
|
||||
@@ -690,12 +690,12 @@ class KernelComponentFactory:
|
||||
# FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128): [
|
||||
FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128
|
||||
64,
|
||||
FmhaFwdTileSize( # fmt: skip -- 128x128 tile (original, for old sparse attn test)
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
@@ -703,13 +703,36 @@ class KernelComponentFactory:
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint("t.bm0 == 0 || t.bm0 == 128"),
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip -- 64x128 tile (for sparge blockmap kM0=64)
|
||||
64,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint("t.bm0 == 64"),
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
16,
|
||||
@@ -774,27 +797,6 @@ class KernelComponentFactory:
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize( # fmt: skip
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
],
|
||||
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
@@ -909,7 +911,7 @@ def get_fwd_blobs(
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)
|
||||
):
|
||||
if tile.F_bm0 != 64 or tile.F_bn0 != 128:
|
||||
if tile.F_bm0 not in (64, 128) or tile.F_bn0 != 128:
|
||||
continue
|
||||
if pipeline.tag != "qr_async_vsa":
|
||||
continue
|
||||
|
||||
@@ -272,7 +272,7 @@ struct fmha_jenga_fwd_traits
|
||||
std::string data_type;
|
||||
bool is_v_rowmajor;
|
||||
mask_enum mask_type;
|
||||
// TODO: padding check is inside this api
|
||||
int bm0 = 0; // preferred Q-tile size; 0 = don't care (dispatch picks largest)
|
||||
};
|
||||
|
||||
float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
@@ -249,6 +249,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
attn_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
|
||||
attn_traits.is_v_rowmajor = true;
|
||||
attn_traits.mask_type = mask_enum::no_mask;
|
||||
attn_traits.bm0 = BLKQ;
|
||||
|
||||
fmha_jenga_fwd_args attn_args;
|
||||
attn_args.q_ptr = q_dev.GetDeviceBuffer();
|
||||
@@ -291,6 +292,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
attn_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
|
||||
attn_traits.is_v_rowmajor = true;
|
||||
attn_traits.mask_type = mask_enum::no_mask;
|
||||
attn_traits.bm0 = BLKQ;
|
||||
|
||||
fmha_vsa_fwd_args attn_args;
|
||||
attn_args.q_ptr = q_dev.GetDeviceBuffer();
|
||||
|
||||
Reference in New Issue
Block a user