[CK_TILE] Add fmt: skip to FMHA codegen scripts for readability (#3057)

* fmt: skip for fmha_bwd.py

* more fmt: skip

* thank you, copilot

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Yi DING
2025-10-21 10:15:04 +08:00
committed by GitHub
parent 2570462ecf
commit e20923f384
6 changed files with 111 additions and 1594 deletions

View File

@@ -575,30 +575,8 @@ class KernelComponentFactory:
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
128: [
FmhaFwdTileSize(
128,
128,
32,
128,
32,
128,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
)
],
}
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
} # fmt: skip
else:
return None
@@ -618,40 +596,10 @@ class KernelComponentFactory:
["t", "f"],
["t", "f"],
):
pipelines.append(
FmhaFwdPipeline(
"qr_async",
"row",
"t",
"f",
"t",
"t",
logits,
bias,
lse,
dropout,
squant,
mask,
)
)
pipelines.append(
FmhaFwdPipeline(
"qr_async",
"row",
"t",
"t",
"t",
"t",
logits,
bias,
lse,
dropout,
squant,
mask,
)
)
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
else:
assert False
return pipelines
@@ -663,33 +611,7 @@ class CustomFactory(KernelComponentFactory):
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
if dtype == "fp16" or dtype == "bf16":
if 128 in result.keys():
result[128].insert(
0,
FmhaFwdTileSize(
64,
128,
64,
128,
64,
128,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
CppConstraint(
"get_num_blocks(128) < num_cus * min_cu_util_rate"
),
),
)
result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
return result

View File

@@ -408,369 +408,29 @@ def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
if dtype == "fp32" and tr_load == "f":
return [
# bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv,
FmhaBwdDQDKDVTileSize(
32,
128,
32,
32,
32,
32,
64,
32,
32,
1,
4,
1,
4,
1,
1,
2,
2,
1,
16,
16,
16,
16,
16,
16,
1,
),
FmhaBwdDQDKDVTileSize(
16,
64,
64,
16,
64,
16,
16,
64,
64,
1,
4,
1,
4,
1,
1,
1,
4,
1,
16,
16,
16,
16,
16,
16,
1,
),
FmhaBwdDQDKDVTileSize(
16,
64,
128,
16,
128,
16,
16,
128,
128,
1,
4,
1,
4,
1,
1,
1,
4,
1,
16,
16,
16,
16,
16,
16,
1,
),
]
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
] # fmt: skip
elif (dtype == "fp16" or dtype == "bf16") and tr_load == "f":
return [
FmhaBwdDQDKDVTileSize(
32,
128,
32,
32,
32,
32,
64,
32,
32,
1,
4,
1,
4,
1,
1,
2,
2,
1,
16,
16,
32,
16,
16,
16,
1,
),
FmhaBwdDQDKDVTileSize(
32,
128,
64,
32,
64,
32,
32,
64,
64,
1,
4,
1,
4,
1,
1,
1,
4,
1,
16,
16,
32,
16,
16,
16,
1,
),
FmhaBwdDQDKDVTileSize(
32,
128,
96,
32,
96,
32,
32,
96,
96,
1,
4,
1,
4,
1,
1,
2,
2,
1,
16,
16,
32,
16,
16,
16,
1,
),
FmhaBwdDQDKDVTileSize(
16,
128,
128,
16,
128,
16,
32,
128,
128,
1,
4,
1,
4,
1,
1,
1,
4,
1,
16,
16,
32,
16,
16,
16,
1,
),
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize(
16,
64,
256,
16,
256,
16,
32,
256,
256,
1,
4,
1,
4,
1,
1,
1,
4,
1,
16,
16,
32,
16,
16,
16,
1,
),
]
FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
] # fmt: skip
elif (dtype == "fp16" or dtype == "bf16") and tr_load == "t":
return [
FmhaBwdDQDKDVTileSize(
32,
128,
64,
32,
64,
32,
32,
64,
64,
1,
4,
1,
4,
1,
1,
1,
4,
1,
16,
16,
32,
16,
16,
32,
1,
),
FmhaBwdDQDKDVTileSize(
32,
128,
128,
32,
128,
32,
32,
128,
128,
1,
4,
1,
4,
1,
1,
1,
4,
1,
16,
16,
32,
16,
16,
32,
1,
),
FmhaBwdDQDKDVTileSize(
16,
192,
128,
16,
128,
16,
32,
128,
128,
1,
4,
1,
4,
1,
1,
1,
4,
1,
16,
16,
32,
16,
16,
16,
1,
),
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# FmhaBwdDQDKDVTileSize( 32, 32, 64, 32, 64, 32, 32, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, 1, 32),
FmhaBwdDQDKDVTileSize(
32,
16,
64,
32,
64,
32,
16,
64,
64,
1,
1,
1,
1,
1,
1,
1,
1,
1,
16,
16,
32,
16,
16,
16,
2,
32,
),
FmhaBwdDQDKDVTileSize( 32, 16, 64, 32, 64, 32, 16, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 32),
# FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16),
FmhaBwdDQDKDVTileSize(
16,
16,
128,
16,
128,
16,
16,
128,
128,
1,
1,
1,
1,
1,
1,
1,
1,
1,
16,
16,
32,
16,
16,
16,
2,
16,
),
]
FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16),
] # fmt: skip
else:
return []

View File

@@ -635,578 +635,42 @@ class KernelComponentFactory:
if dtype == "fp32":
return {
# bm0, bn0, bk0, bn1, bk1,
(32, 32): [
FmhaFwdTileSize(
64,
64,
16,
32,
32,
32,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
)
],
(48, 48): [
FmhaFwdTileSize(
32,
128,
16,
48,
16,
48,
2,
1,
1,
2,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
),
FmhaFwdTileSize(
128,
64,
16,
48,
32,
48,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
),
],
(64, 64): [
FmhaFwdTileSize(
64,
64,
32,
64,
32,
64,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
)
],
(96, 128): [
FmhaFwdTileSize(
128,
64,
32,
128,
32,
96,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
)
],
(128, 128): [
FmhaFwdTileSize(
32,
128,
32,
128,
16,
128,
2,
1,
1,
2,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
),
FmhaFwdTileSize(
128,
64,
32,
128,
32,
128,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
),
],
(192, 192): [
FmhaFwdTileSize(
64,
64,
32,
192,
32,
192,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
)
],
(256, 256): [
FmhaFwdTileSize(
64,
64,
32,
256,
32,
256,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
)
],
}
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 48, 48) : [FmhaFwdTileSize( 32, 128, 16, 48, 16, 48, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize(128, 64, 16, 48, 32, 48, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 96, 128) : [FmhaFwdTileSize(128, 64, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 32, 128, 32, 128, 16, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
} # fmt: skip
elif dtype == "fp16" or dtype == "bf16":
return {
(32, 32): [
FmhaFwdTileSize(
128,
64,
16,
32,
32,
32,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
)
],
(64, 64): [
FmhaFwdTileSize(
16,
32,
64,
64,
32,
64,
1,
1,
1,
1,
1,
1,
16,
16,
32,
16,
16,
32,
-1,
),
FmhaFwdTileSize(
32,
32,
64,
64,
32,
64,
1,
1,
1,
1,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
FmhaFwdTileSize(
128,
64,
32,
64,
32,
64,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
],
(96, 128): [
FmhaFwdTileSize(
128,
128,
32,
128,
32,
96,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
)
],
(128, 128): [
FmhaFwdTileSize(
16,
32,
64,
128,
32,
128,
1,
1,
1,
1,
1,
1,
16,
16,
32,
16,
16,
32,
-1,
),
FmhaFwdTileSize(
32,
32,
128,
128,
32,
128,
1,
1,
1,
1,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
FmhaFwdTileSize(
128,
64,
32,
128,
16,
128,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
FmhaFwdTileSize(
128,
128,
32,
128,
32,
128,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
),
],
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
(192, 128): [
FmhaFwdTileSize(
128,
128,
32,
128,
32,
192,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
)
],
(192, 192): [
FmhaFwdTileSize(
128,
128,
32,
192,
32,
192,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
1,
)
],
(256, 256): [
FmhaFwdTileSize(
128,
128,
32,
256,
32,
256,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
)
],
}
( 32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
FmhaFwdTileSize( 32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
(192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(192, 192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
(256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
} # fmt: skip
elif dtype == "fp8" or dtype == "fp8bf16":
return {
(64, 64): [
FmhaFwdTileSize(
128,
64,
32,
64,
32,
64,
2,
1,
1,
2,
1,
1,
32,
32,
32,
32,
32,
32,
-1,
)
],
(128, 128): [
FmhaFwdTileSize(
128,
128,
32,
128,
32,
128,
4,
1,
1,
4,
1,
1,
32,
32,
32,
32,
32,
32,
-1,
)
],
(256, 256): [
FmhaFwdTileSize(
128,
128,
32,
256,
32,
256,
4,
1,
1,
4,
1,
1,
32,
32,
32,
32,
32,
32,
-1,
)
],
}
( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
} # fmt: skip
elif dtype == "fp8fp32":
return {
(128, 128): [
FmhaFwdTileSize(
128,
128,
32,
128,
32,
128,
4,
1,
1,
4,
1,
1,
32,
32,
32,
32,
32,
32,
-1,
)
],
}
(128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
} # fmt: skip
else:
return None
@@ -1229,60 +693,9 @@ class KernelComponentFactory:
["t", "f"],
["t", "f"],
):
pipelines.append(
FmhaFwdPipeline(
"qr",
"row",
"f",
"f",
"f",
"f",
logits,
bias,
lse,
dropout,
squant,
mask,
skip,
"f",
)
)
pipelines.append(
FmhaFwdPipeline(
"qr",
"row",
"f",
"t",
"f",
"f",
logits,
bias,
lse,
dropout,
squant,
mask,
skip,
"f",
)
)
pipelines.append(
FmhaFwdPipeline(
"qr",
"row",
"t",
"t",
"t",
"t",
logits,
bias,
lse,
dropout,
squant,
mask,
skip,
"f",
)
)
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
elif dtype in ["fp16", "bf16"]:
squant = "f"
for logits, mask, bias, lse, dropout, skip in itertools.product(
@@ -1294,137 +707,18 @@ class KernelComponentFactory:
["t", "f"],
):
if hdim == 256 and hdim_v == 256:
pipelines.append(
FmhaFwdPipeline(
"qr",
"row",
"f",
"f",
"f",
"f",
logits,
bias,
lse,
dropout,
squant,
mask,
skip,
"f",
)
)
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
# the below two is used for hdim vectorize load
pipelines.append(
FmhaFwdPipeline(
"qr",
"row",
"t",
"t",
"f",
"f",
logits,
bias,
lse,
dropout,
squant,
mask,
skip,
"f",
)
)
pipelines.append(
FmhaFwdPipeline(
"qr",
"row",
"t",
"t",
"t",
"t",
logits,
bias,
lse,
dropout,
squant,
mask,
skip,
"f",
)
)
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
else:
if bias == "bias":
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines.append(
FmhaFwdPipeline(
"qr",
"row",
"f",
"f",
"f",
"f",
logits,
bias,
lse,
dropout,
squant,
mask,
skip,
"f",
)
)
pipelines.append(
FmhaFwdPipeline(
"qr",
"row",
"t",
"t",
"t",
"t",
logits,
bias,
lse,
dropout,
squant,
mask,
skip,
"f",
)
)
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
else:
pipelines.append(
FmhaFwdPipeline(
"qr_async",
"row",
"t",
"f",
"t",
"t",
logits,
bias,
lse,
dropout,
squant,
mask,
skip,
"f",
)
)
pipelines.append(
FmhaFwdPipeline(
"qr_async",
"row",
"t",
"t",
"t",
"t",
logits,
bias,
lse,
dropout,
squant,
mask,
skip,
"f",
)
)
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
if (
(hdim, hdim_v) in [(64, 64), (128, 128)]
and logits == "f"
@@ -1433,103 +727,18 @@ class KernelComponentFactory:
and lse == "f"
and skip == "f"
):
pipelines.append(
FmhaFwdPipeline(
"qr_async_trload",
"row",
"f",
"f",
"f",
"f",
logits,
bias,
lse,
dropout,
squant,
mask,
skip,
"t",
)
)
pipelines.append(
FmhaFwdPipeline(
"qr_async_trload",
"row",
"f",
"f",
"t",
"t",
logits,
bias,
lse,
dropout,
squant,
mask,
skip,
"t",
)
)
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
if receipt == 1 and bias != "bias":
pipelines.append(
FmhaFwdPipeline(
"qr",
"row",
"t",
"t",
"t",
"t",
logits,
bias,
lse,
dropout,
squant,
mask,
skip,
"f",
)
) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline( "qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim
elif dtype in ["fp8", "fp8bf16", "fp8fp32"]:
# no need lse/dropout kernels
for logits, squant, mask, bias in itertools.product(
["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
):
pipelines.append(
FmhaFwdPipeline(
"qr",
"row",
"f",
"f",
"f",
"f",
logits,
bias,
"f",
"f",
squant,
mask,
"f",
"f",
)
)
pipelines.append(
FmhaFwdPipeline(
"qr",
"row",
"t",
"t",
"f",
"f",
logits,
bias,
"f",
"f",
squant,
mask,
"f",
"f",
)
)
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
elif dtype in ["fp8fp16", "bf8"]:
# TODO
None
@@ -1544,33 +753,7 @@ class CustomFactory(KernelComponentFactory):
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
if dtype == "fp16" or dtype == "bf16":
if (128, 128) in result.keys():
result[(128, 128)].insert(
0,
FmhaFwdTileSize(
64,
128,
64,
128,
64,
128,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
CppConstraint(
"get_num_blocks(128) < num_cus * min_cu_util_rate"
),
),
)
result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
return result

View File

@@ -349,43 +349,17 @@ def get_fwd_appendkv_blobs(
# applying rotary embedding, so I just use 't' in inter/half pipelines
for vlayout in ["row", "col"]:
for pagedkv in ["t", "f"]:
pipelines.append(
FmhaFwdAppendKVPipeline(
vlayout, "f", "t", "f", "f", "no", pagedkv
)
)
pipelines.append(
FmhaFwdAppendKVPipeline(
vlayout, "t", "t", "t", "t", "no", pagedkv
)
)
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "f", "f", "no", pagedkv)) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "no", pagedkv)) # fmt: skip
pipelines.append(
FmhaFwdAppendKVPipeline(
vlayout, "f", "t", "t", "f", "inter", pagedkv
)
)
pipelines.append(
FmhaFwdAppendKVPipeline(
vlayout, "t", "t", "t", "t", "inter", pagedkv
)
)
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "inter", pagedkv)) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "inter", pagedkv)) # fmt: skip
pipelines.append(
FmhaFwdAppendKVPipeline(
vlayout, "f", "t", "t", "f", "half", pagedkv
)
)
pipelines.append(
FmhaFwdAppendKVPipeline(
vlayout, "t", "t", "t", "t", "half", pagedkv
)
)
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "half", pagedkv)) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "half", pagedkv)) # fmt: skip
elif dtype in ["fp8", "bf8"]:
# rope/paged-kv is not supported
pipelines.append(
FmhaFwdAppendKVPipeline("col", "t", "t", "t", "t", "no", "f")
)
pipelines.append(FmhaFwdAppendKVPipeline("col", "t", "t", "t", "t", "no", "f")) # fmt: skip
elif dtype in ["fp8fp16", "fp8bf16"]:
# TODO
None

View File

@@ -738,32 +738,18 @@ class FmhaFwdSplitKVCombineKernel:
def get_fmha_fwd_tile_dict_from_dtype(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
"32": FmhaFwdTileSize(
32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1
),
"64": FmhaFwdTileSize(
64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1
),
"96": FmhaFwdTileSize(
64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1
),
"128": FmhaFwdTileSize(
64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1
),
# '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"256": FmhaFwdTileSize(
64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1
),
}
"32" : FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"96" : FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"128": FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
# "160" : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"256": FmhaFwdTileSize( 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
} # fmt: skip
elif dtype == "fp8" or dtype == "bf8":
return {
"64": FmhaFwdTileSize(
128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1
),
"128": FmhaFwdTileSize(
128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1
),
}
"64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
} # fmt: skip
else:
return None
@@ -807,157 +793,22 @@ def get_fwd_splitkv_blobs(
for logits, mask, bias, pagedkv in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]
):
pipelines.append(
Pipeline(
"qr",
"row",
"f",
"t",
"f",
"f",
logits,
bias,
"t",
squant,
pagedkv,
mask,
)
)
pipelines.append(
Pipeline(
"qr",
"col",
"f",
"t",
"f",
"f",
logits,
bias,
"t",
squant,
pagedkv,
mask,
)
)
pipelines.append(Pipeline( "qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline( "qr", "col", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(
Pipeline(
"qr",
"row",
"t",
"f",
"f",
"f",
logits,
bias,
"t",
squant,
pagedkv,
mask,
)
)
pipelines.append(
Pipeline(
"qr",
"col",
"t",
"f",
"f",
"f",
logits,
bias,
"t",
squant,
pagedkv,
mask,
)
)
pipelines.append(Pipeline( "qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline( "qr", "col", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(
Pipeline(
"qr",
"row",
"t",
"t",
"f",
"f",
logits,
bias,
"t",
squant,
pagedkv,
mask,
)
)
pipelines.append(
Pipeline(
"qr",
"col",
"t",
"t",
"f",
"f",
logits,
bias,
"t",
squant,
pagedkv,
mask,
)
)
pipelines.append(Pipeline( "qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline( "qr", "col", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(
Pipeline(
"qr",
"row",
"t",
"t",
"t",
"t",
logits,
bias,
"t",
squant,
pagedkv,
mask,
)
)
pipelines.append(
Pipeline(
"qr",
"col",
"t",
"t",
"t",
"t",
logits,
bias,
"t",
squant,
pagedkv,
mask,
)
)
pipelines.append(Pipeline( "qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline( "qr", "col", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
elif dtype in ["fp8", "bf8"]:
for logits, mask, bias in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
):
pipelines.append(
Pipeline(
"qr",
"col",
"f",
"f",
"f",
"f",
logits,
bias,
"t",
squant,
"f",
mask,
)
)
pipelines.append(Pipeline( "qr", "col", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
elif dtype in ["fp8fp16", "fp8bf16"]:
# TODO
None

View File

@@ -524,27 +524,19 @@ class FmhaFwdKernel:
def get_fmha_fwd_tile_dict_from_dtype(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
# '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
"128": FmhaFwdTileSize(
128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1
),
# '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
}
# "32": FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# "64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# "96": FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# "192": FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# "256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
} # fmt: skip
elif dtype == "fp8" or dtype == "bf8":
return {
"64": FmhaFwdTileSize(
128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1
),
"128": FmhaFwdTileSize(
128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1
),
"256": FmhaFwdTileSize(
128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1
),
}
"64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
"256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
} # fmt: skip
else:
return None
@@ -569,82 +561,17 @@ def get_fwd_blobs(
["t"],
["f"],
):
pipelines.append(
FmhaFwdPipeline(
"qr_pagedkv",
"row",
"t",
"f",
"f",
"f",
logits,
bias,
"f",
pagedkv,
squant,
mask,
skip,
)
)
pipelines.append(
FmhaFwdPipeline(
"qr_pagedkv",
"row",
"t",
"t",
"f",
"f",
logits,
bias,
"f",
pagedkv,
squant,
mask,
skip,
)
)
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
elif dtype in ["fp8", "bf8"]:
# no need lse/dropout kernels
for logits, mask, bias in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
):
pipelines.append(
FmhaFwdPipeline(
"qr_pagedkv",
"row",
"f",
"f",
"f",
"f",
logits,
bias,
"f",
"t",
squant,
mask,
"f",
)
)
pipelines.append(
FmhaFwdPipeline(
"qr_pagedkv",
"row",
"t",
"t",
"f",
"f",
logits,
bias,
"f",
"t",
squant,
mask,
"f",
)
)
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
elif dtype in ["fp8fp16", "fp8bf16"]:
# TODO
None
pass # TODO
else:
assert False
return pipelines