fp8 fmha async pipeline (#3339)

* replace qr with async pipeline

* Add fp8fp32 to DTYPE_BITS

* Add kAlignmentRandVal to avoid compile fail

* format

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
rocking
2025-12-04 12:18:25 +08:00
committed by GitHub
parent 4baa4c9fae
commit eb7f617713
2 changed files with 18 additions and 7 deletions

View File

@@ -29,7 +29,15 @@ from codegen.cpp_symbol_map import (
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
DTYPE_BITS = {
"fp32": 32,
"fp16": 16,
"bf16": 16,
"fp8": 8,
"fp8bf16": 8,
"fp8fp32": 8,
"bf8": 8,
}
K0_MAX_SUBMAX_MAP = {32: 32, 48: 48, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
@@ -678,6 +686,7 @@ class KernelComponentFactoryGfx9:
return {
( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
} # fmt: skip
elif dtype in ["fp8fp32"]:
@@ -742,8 +751,8 @@ class KernelComponentFactoryGfx9:
get_mask_map(mask_impl).keys(),
["no"],
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
elif dtype in ["fp8", "fp8fp16", "bf8"]:
# TODO
None
@@ -958,7 +967,7 @@ def get_fwd_blobs(
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
if dtype == "fp8bf16":
cond &= hdim == 128 or hdim == 256
cond &= hdim == 128 or hdim == 192
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
@@ -967,7 +976,7 @@ def get_fwd_blobs(
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
if dtype == "fp8bf16":
cond &= hdim == 128 or hdim == 256
cond &= hdim == 128 or hdim == 192
if not cond:
continue
# aiter::mha_fwd C++ api integration
@@ -975,13 +984,13 @@ def get_fwd_blobs(
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= pipeline.F_vlayout == "row"
if dtype == "fp8bf16":
cond &= hdim == 128 or hdim == 256
cond &= hdim == 128 or hdim == 192
if not cond:
continue
elif receipt == 888:
cond = dtype in ["fp8bf16", "fp8fp32"]
cond &= pipeline.F_vlayout == "row"
cond &= hdim == 128 or hdim == 256
cond &= hdim == 128 or hdim == 192
if not cond:
continue