mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Add fmt: skip to FMHA codegen scripts for readability (#3057)
* fmt: skip for fmha_bwd.py * more fmt: skip * thank you, copilot * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -575,30 +575,8 @@ class KernelComponentFactory:
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
128: [
|
||||
FmhaFwdTileSize(
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
)
|
||||
],
|
||||
}
|
||||
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -618,40 +596,10 @@ class KernelComponentFactory:
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_async",
|
||||
"row",
|
||||
"t",
|
||||
"f",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
squant,
|
||||
mask,
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_async",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
squant,
|
||||
mask,
|
||||
)
|
||||
)
|
||||
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
|
||||
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
|
||||
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
@@ -663,33 +611,7 @@ class CustomFactory(KernelComponentFactory):
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
if 128 in result.keys():
|
||||
result[128].insert(
|
||||
0,
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint(
|
||||
"get_num_blocks(128) < num_cus * min_cu_util_rate"
|
||||
),
|
||||
),
|
||||
)
|
||||
result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -408,369 +408,29 @@ def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
|
||||
if dtype == "fp32" and tr_load == "f":
|
||||
return [
|
||||
# bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv,
|
||||
FmhaBwdDQDKDVTileSize(
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
64,
|
||||
32,
|
||||
32,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
1,
|
||||
),
|
||||
FmhaBwdDQDKDVTileSize(
|
||||
16,
|
||||
64,
|
||||
64,
|
||||
16,
|
||||
64,
|
||||
16,
|
||||
16,
|
||||
64,
|
||||
64,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
1,
|
||||
),
|
||||
FmhaBwdDQDKDVTileSize(
|
||||
16,
|
||||
64,
|
||||
128,
|
||||
16,
|
||||
128,
|
||||
16,
|
||||
16,
|
||||
128,
|
||||
128,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
1,
|
||||
),
|
||||
]
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
|
||||
] # fmt: skip
|
||||
elif (dtype == "fp16" or dtype == "bf16") and tr_load == "f":
|
||||
return [
|
||||
FmhaBwdDQDKDVTileSize(
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
64,
|
||||
32,
|
||||
32,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
1,
|
||||
),
|
||||
FmhaBwdDQDKDVTileSize(
|
||||
32,
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
64,
|
||||
32,
|
||||
32,
|
||||
64,
|
||||
64,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
1,
|
||||
),
|
||||
FmhaBwdDQDKDVTileSize(
|
||||
32,
|
||||
128,
|
||||
96,
|
||||
32,
|
||||
96,
|
||||
32,
|
||||
32,
|
||||
96,
|
||||
96,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
1,
|
||||
),
|
||||
FmhaBwdDQDKDVTileSize(
|
||||
16,
|
||||
128,
|
||||
128,
|
||||
16,
|
||||
128,
|
||||
16,
|
||||
32,
|
||||
128,
|
||||
128,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
1,
|
||||
),
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize(
|
||||
16,
|
||||
64,
|
||||
256,
|
||||
16,
|
||||
256,
|
||||
16,
|
||||
32,
|
||||
256,
|
||||
256,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
1,
|
||||
),
|
||||
]
|
||||
FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
] # fmt: skip
|
||||
elif (dtype == "fp16" or dtype == "bf16") and tr_load == "t":
|
||||
return [
|
||||
FmhaBwdDQDKDVTileSize(
|
||||
32,
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
64,
|
||||
32,
|
||||
32,
|
||||
64,
|
||||
64,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
1,
|
||||
),
|
||||
FmhaBwdDQDKDVTileSize(
|
||||
32,
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
32,
|
||||
128,
|
||||
128,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
1,
|
||||
),
|
||||
FmhaBwdDQDKDVTileSize(
|
||||
16,
|
||||
192,
|
||||
128,
|
||||
16,
|
||||
128,
|
||||
16,
|
||||
32,
|
||||
128,
|
||||
128,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
1,
|
||||
),
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# FmhaBwdDQDKDVTileSize( 32, 32, 64, 32, 64, 32, 32, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, 1, 32),
|
||||
FmhaBwdDQDKDVTileSize(
|
||||
32,
|
||||
16,
|
||||
64,
|
||||
32,
|
||||
64,
|
||||
32,
|
||||
16,
|
||||
64,
|
||||
64,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
2,
|
||||
32,
|
||||
),
|
||||
FmhaBwdDQDKDVTileSize( 32, 16, 64, 32, 64, 32, 16, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 32),
|
||||
# FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16),
|
||||
FmhaBwdDQDKDVTileSize(
|
||||
16,
|
||||
16,
|
||||
128,
|
||||
16,
|
||||
128,
|
||||
16,
|
||||
16,
|
||||
128,
|
||||
128,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
2,
|
||||
16,
|
||||
),
|
||||
]
|
||||
FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16),
|
||||
] # fmt: skip
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
@@ -635,578 +635,42 @@ class KernelComponentFactory:
|
||||
if dtype == "fp32":
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
(32, 32): [
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
64,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
)
|
||||
],
|
||||
(48, 48): [
|
||||
FmhaFwdTileSize(
|
||||
32,
|
||||
128,
|
||||
16,
|
||||
48,
|
||||
16,
|
||||
48,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize(
|
||||
128,
|
||||
64,
|
||||
16,
|
||||
48,
|
||||
32,
|
||||
48,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
],
|
||||
(64, 64): [
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
64,
|
||||
32,
|
||||
64,
|
||||
32,
|
||||
64,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
)
|
||||
],
|
||||
(96, 128): [
|
||||
FmhaFwdTileSize(
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
96,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
)
|
||||
],
|
||||
(128, 128): [
|
||||
FmhaFwdTileSize(
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
16,
|
||||
128,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize(
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
],
|
||||
(192, 192): [
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
64,
|
||||
32,
|
||||
192,
|
||||
32,
|
||||
192,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
)
|
||||
],
|
||||
(256, 256): [
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
64,
|
||||
32,
|
||||
256,
|
||||
32,
|
||||
256,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
)
|
||||
],
|
||||
}
|
||||
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
( 48, 48) : [FmhaFwdTileSize( 32, 128, 16, 48, 16, 48, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaFwdTileSize(128, 64, 16, 48, 32, 48, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
( 96, 128) : [FmhaFwdTileSize(128, 64, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 32, 128, 32, 128, 16, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
(32, 32): [
|
||||
FmhaFwdTileSize(
|
||||
128,
|
||||
64,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
)
|
||||
],
|
||||
(64, 64): [
|
||||
FmhaFwdTileSize(
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
64,
|
||||
32,
|
||||
64,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize(
|
||||
32,
|
||||
32,
|
||||
64,
|
||||
64,
|
||||
32,
|
||||
64,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize(
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
64,
|
||||
32,
|
||||
64,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
],
|
||||
(96, 128): [
|
||||
FmhaFwdTileSize(
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
96,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
)
|
||||
],
|
||||
(128, 128): [
|
||||
FmhaFwdTileSize(
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
32,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize(
|
||||
32,
|
||||
32,
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize(
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
128,
|
||||
16,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
FmhaFwdTileSize(
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
),
|
||||
],
|
||||
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(192, 128): [
|
||||
FmhaFwdTileSize(
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
192,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
)
|
||||
],
|
||||
(192, 192): [
|
||||
FmhaFwdTileSize(
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
192,
|
||||
32,
|
||||
192,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
1,
|
||||
)
|
||||
],
|
||||
(256, 256): [
|
||||
FmhaFwdTileSize(
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
256,
|
||||
32,
|
||||
256,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
)
|
||||
],
|
||||
}
|
||||
( 32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
FmhaFwdTileSize( 32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(192, 192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype == "fp8" or dtype == "fp8bf16":
|
||||
return {
|
||||
(64, 64): [
|
||||
FmhaFwdTileSize(
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
64,
|
||||
32,
|
||||
64,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
-1,
|
||||
)
|
||||
],
|
||||
(128, 128): [
|
||||
FmhaFwdTileSize(
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
-1,
|
||||
)
|
||||
],
|
||||
(256, 256): [
|
||||
FmhaFwdTileSize(
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
256,
|
||||
32,
|
||||
256,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
-1,
|
||||
)
|
||||
],
|
||||
}
|
||||
( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype == "fp8fp32":
|
||||
return {
|
||||
(128, 128): [
|
||||
FmhaFwdTileSize(
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
-1,
|
||||
)
|
||||
],
|
||||
}
|
||||
(128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -1229,60 +693,9 @@ class KernelComponentFactory:
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr",
|
||||
"row",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
squant,
|
||||
mask,
|
||||
skip,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr",
|
||||
"row",
|
||||
"f",
|
||||
"t",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
squant,
|
||||
mask,
|
||||
skip,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
squant,
|
||||
mask,
|
||||
skip,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
elif dtype in ["fp16", "bf16"]:
|
||||
squant = "f"
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(
|
||||
@@ -1294,137 +707,18 @@ class KernelComponentFactory:
|
||||
["t", "f"],
|
||||
):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr",
|
||||
"row",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
squant,
|
||||
mask,
|
||||
skip,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
squant,
|
||||
mask,
|
||||
skip,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
squant,
|
||||
mask,
|
||||
skip,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr",
|
||||
"row",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
squant,
|
||||
mask,
|
||||
skip,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
squant,
|
||||
mask,
|
||||
skip,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
else:
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_async",
|
||||
"row",
|
||||
"t",
|
||||
"f",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
squant,
|
||||
mask,
|
||||
skip,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_async",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
squant,
|
||||
mask,
|
||||
skip,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip
|
||||
if (
|
||||
(hdim, hdim_v) in [(64, 64), (128, 128)]
|
||||
and logits == "f"
|
||||
@@ -1433,103 +727,18 @@ class KernelComponentFactory:
|
||||
and lse == "f"
|
||||
and skip == "f"
|
||||
):
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_async_trload",
|
||||
"row",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
squant,
|
||||
mask,
|
||||
skip,
|
||||
"t",
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_async_trload",
|
||||
"row",
|
||||
"f",
|
||||
"f",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
squant,
|
||||
mask,
|
||||
skip,
|
||||
"t",
|
||||
)
|
||||
)
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip
|
||||
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
squant,
|
||||
mask,
|
||||
skip,
|
||||
"f",
|
||||
)
|
||||
) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline( "qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim
|
||||
elif dtype in ["fp8", "fp8bf16", "fp8fp32"]:
|
||||
# no need lse/dropout kernels
|
||||
for logits, squant, mask, bias in itertools.product(
|
||||
["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||
):
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr",
|
||||
"row",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
"f",
|
||||
"f",
|
||||
squant,
|
||||
mask,
|
||||
"f",
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
"f",
|
||||
"f",
|
||||
squant,
|
||||
mask,
|
||||
"f",
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip
|
||||
elif dtype in ["fp8fp16", "bf8"]:
|
||||
# TODO
|
||||
None
|
||||
@@ -1544,33 +753,7 @@ class CustomFactory(KernelComponentFactory):
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].insert(
|
||||
0,
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint(
|
||||
"get_num_blocks(128) < num_cus * min_cu_util_rate"
|
||||
),
|
||||
),
|
||||
)
|
||||
result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -349,43 +349,17 @@ def get_fwd_appendkv_blobs(
|
||||
# applying rotary embedding, so I just use 't' in inter/half pipelines
|
||||
for vlayout in ["row", "col"]:
|
||||
for pagedkv in ["t", "f"]:
|
||||
pipelines.append(
|
||||
FmhaFwdAppendKVPipeline(
|
||||
vlayout, "f", "t", "f", "f", "no", pagedkv
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdAppendKVPipeline(
|
||||
vlayout, "t", "t", "t", "t", "no", pagedkv
|
||||
)
|
||||
)
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "f", "f", "no", pagedkv)) # fmt: skip
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "no", pagedkv)) # fmt: skip
|
||||
|
||||
pipelines.append(
|
||||
FmhaFwdAppendKVPipeline(
|
||||
vlayout, "f", "t", "t", "f", "inter", pagedkv
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdAppendKVPipeline(
|
||||
vlayout, "t", "t", "t", "t", "inter", pagedkv
|
||||
)
|
||||
)
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "inter", pagedkv)) # fmt: skip
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "inter", pagedkv)) # fmt: skip
|
||||
|
||||
pipelines.append(
|
||||
FmhaFwdAppendKVPipeline(
|
||||
vlayout, "f", "t", "t", "f", "half", pagedkv
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdAppendKVPipeline(
|
||||
vlayout, "t", "t", "t", "t", "half", pagedkv
|
||||
)
|
||||
)
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "half", pagedkv)) # fmt: skip
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "half", pagedkv)) # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
# rope/paged-kv is not supported
|
||||
pipelines.append(
|
||||
FmhaFwdAppendKVPipeline("col", "t", "t", "t", "t", "no", "f")
|
||||
)
|
||||
pipelines.append(FmhaFwdAppendKVPipeline("col", "t", "t", "t", "t", "no", "f")) # fmt: skip
|
||||
elif dtype in ["fp8fp16", "fp8bf16"]:
|
||||
# TODO
|
||||
None
|
||||
|
||||
@@ -738,32 +738,18 @@ class FmhaFwdSplitKVCombineKernel:
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
"32": FmhaFwdTileSize(
|
||||
32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1
|
||||
),
|
||||
"64": FmhaFwdTileSize(
|
||||
64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1
|
||||
),
|
||||
"96": FmhaFwdTileSize(
|
||||
64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1
|
||||
),
|
||||
"128": FmhaFwdTileSize(
|
||||
64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1
|
||||
),
|
||||
# '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"256": FmhaFwdTileSize(
|
||||
64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1
|
||||
),
|
||||
}
|
||||
"32" : FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"96" : FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"128": FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# "160" : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"256": FmhaFwdTileSize( 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
} # fmt: skip
|
||||
elif dtype == "fp8" or dtype == "bf8":
|
||||
return {
|
||||
"64": FmhaFwdTileSize(
|
||||
128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1
|
||||
),
|
||||
"128": FmhaFwdTileSize(
|
||||
128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1
|
||||
),
|
||||
}
|
||||
"64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -807,157 +793,22 @@ def get_fwd_splitkv_blobs(
|
||||
for logits, mask, bias, pagedkv in itertools.product(
|
||||
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]
|
||||
):
|
||||
pipelines.append(
|
||||
Pipeline(
|
||||
"qr",
|
||||
"row",
|
||||
"f",
|
||||
"t",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
"t",
|
||||
squant,
|
||||
pagedkv,
|
||||
mask,
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
Pipeline(
|
||||
"qr",
|
||||
"col",
|
||||
"f",
|
||||
"t",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
"t",
|
||||
squant,
|
||||
pagedkv,
|
||||
mask,
|
||||
)
|
||||
)
|
||||
pipelines.append(Pipeline( "qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline( "qr", "col", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
|
||||
pipelines.append(
|
||||
Pipeline(
|
||||
"qr",
|
||||
"row",
|
||||
"t",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
"t",
|
||||
squant,
|
||||
pagedkv,
|
||||
mask,
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
Pipeline(
|
||||
"qr",
|
||||
"col",
|
||||
"t",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
"t",
|
||||
squant,
|
||||
pagedkv,
|
||||
mask,
|
||||
)
|
||||
)
|
||||
pipelines.append(Pipeline( "qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline( "qr", "col", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
|
||||
pipelines.append(
|
||||
Pipeline(
|
||||
"qr",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
"t",
|
||||
squant,
|
||||
pagedkv,
|
||||
mask,
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
Pipeline(
|
||||
"qr",
|
||||
"col",
|
||||
"t",
|
||||
"t",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
"t",
|
||||
squant,
|
||||
pagedkv,
|
||||
mask,
|
||||
)
|
||||
)
|
||||
pipelines.append(Pipeline( "qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline( "qr", "col", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
|
||||
pipelines.append(
|
||||
Pipeline(
|
||||
"qr",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
bias,
|
||||
"t",
|
||||
squant,
|
||||
pagedkv,
|
||||
mask,
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
Pipeline(
|
||||
"qr",
|
||||
"col",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
bias,
|
||||
"t",
|
||||
squant,
|
||||
pagedkv,
|
||||
mask,
|
||||
)
|
||||
)
|
||||
pipelines.append(Pipeline( "qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
pipelines.append(Pipeline( "qr", "col", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
for logits, mask, bias in itertools.product(
|
||||
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||
):
|
||||
pipelines.append(
|
||||
Pipeline(
|
||||
"qr",
|
||||
"col",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
"t",
|
||||
squant,
|
||||
"f",
|
||||
mask,
|
||||
)
|
||||
)
|
||||
pipelines.append(Pipeline( "qr", "col", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
|
||||
elif dtype in ["fp8fp16", "fp8bf16"]:
|
||||
# TODO
|
||||
None
|
||||
|
||||
@@ -524,27 +524,19 @@ class FmhaFwdKernel:
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
# '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
"128": FmhaFwdTileSize(
|
||||
128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1
|
||||
),
|
||||
# '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
# "32": FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# "64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# "96": FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# "192": FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# "256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
} # fmt: skip
|
||||
elif dtype == "fp8" or dtype == "bf8":
|
||||
return {
|
||||
"64": FmhaFwdTileSize(
|
||||
128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1
|
||||
),
|
||||
"128": FmhaFwdTileSize(
|
||||
128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1
|
||||
),
|
||||
"256": FmhaFwdTileSize(
|
||||
128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1
|
||||
),
|
||||
}
|
||||
"64": FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
"256": FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -569,82 +561,17 @@ def get_fwd_blobs(
|
||||
["t"],
|
||||
["f"],
|
||||
):
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_pagedkv",
|
||||
"row",
|
||||
"t",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
"f",
|
||||
pagedkv,
|
||||
squant,
|
||||
mask,
|
||||
skip,
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_pagedkv",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
"f",
|
||||
pagedkv,
|
||||
squant,
|
||||
mask,
|
||||
skip,
|
||||
)
|
||||
)
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(
|
||||
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||
):
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_pagedkv",
|
||||
"row",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
"f",
|
||||
"t",
|
||||
squant,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_pagedkv",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
"f",
|
||||
"t",
|
||||
squant,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
|
||||
elif dtype in ["fp8fp16", "fp8bf16"]:
|
||||
# TODO
|
||||
None
|
||||
pass # TODO
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
Reference in New Issue
Block a user