mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
fp8 fmha async pipeline (#3339)
* replace qr with async pipeline * Add fp8fp32 to DTYPE_BITS * Add kAlignmentRandVal to avoid compile fail * format --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -29,7 +29,15 @@ from codegen.cpp_symbol_map import (
|
||||
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
|
||||
|
||||
|
||||
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
|
||||
DTYPE_BITS = {
|
||||
"fp32": 32,
|
||||
"fp16": 16,
|
||||
"bf16": 16,
|
||||
"fp8": 8,
|
||||
"fp8bf16": 8,
|
||||
"fp8fp32": 8,
|
||||
"bf8": 8,
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 48: 48, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
|
||||
|
||||
@@ -678,6 +686,7 @@ class KernelComponentFactoryGfx9:
|
||||
return {
|
||||
( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
(192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype in ["fp8fp32"]:
|
||||
@@ -742,8 +751,8 @@ class KernelComponentFactoryGfx9:
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip
|
||||
elif dtype in ["fp8", "fp8fp16", "bf8"]:
|
||||
# TODO
|
||||
None
|
||||
@@ -958,7 +967,7 @@ def get_fwd_blobs(
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if dtype == "fp8bf16":
|
||||
cond &= hdim == 128 or hdim == 256
|
||||
cond &= hdim == 128 or hdim == 192
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
@@ -967,7 +976,7 @@ def get_fwd_blobs(
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if dtype == "fp8bf16":
|
||||
cond &= hdim == 128 or hdim == 256
|
||||
cond &= hdim == 128 or hdim == 192
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
@@ -975,13 +984,13 @@ def get_fwd_blobs(
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if dtype == "fp8bf16":
|
||||
cond &= hdim == 128 or hdim == 256
|
||||
cond &= hdim == 128 or hdim == 192
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 888:
|
||||
cond = dtype in ["fp8bf16", "fp8fp32"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= hdim == 128 or hdim == 256
|
||||
cond &= hdim == 128 or hdim == 192
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
|
||||
Reference in New Issue
Block a user