[CK_TILE][FMHA] Add FP8 support for batch_prefill kernel (#3425)

* Add fp8bf16 support for batch_prefill

* Fix wrong scale_s re-compute logic in batch_prefill

* Fix wrong scale_s re-compute logic in fmha fwd

* Fix batch_prefill codegen error

* Remove no-longer used GetName() function

* Add fp8 logits=True instances

* Update CHANGELOG.md
This commit is contained in:
Po Yen Chen
2025-12-24 10:34:06 +08:00
committed by GitHub
parent c0797c1671
commit 1c3151963b
6 changed files with 175 additions and 90 deletions

View File

@@ -24,8 +24,15 @@ from codegen.cpp_symbol_map import (
)
from codegen.utils import update_file
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
DTYPE_BITS = {
"fp32": 32,
"fp16": 16,
"bf16": 16,
"fp8": 8,
"fp8bf16": 8,
"fp8fp32": 8,
"bf8": 8,
}
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
@@ -108,7 +115,7 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
std::cout << ", {F_kname}" << std::flush;
auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
@@ -494,6 +501,7 @@ class FmhaFwdKernel:
@property
def template(self) -> str:
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
F_kname=self.name,
F_idx=self.F_idx,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
@@ -576,10 +584,14 @@ class FmhaFwdKernel:
class KernelComponentFactory:
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
if dtype in ["fp16", "bf16"]:
return {
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
} # fmt: skip
elif dtype in ["fp8bf16"]:
return {
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
} # fmt: skip
else:
return None
@@ -589,9 +601,9 @@ class KernelComponentFactory:
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
qscale = "no"
pipelines = []
if dtype in ["fp16", "bf16"]:
qscale = "no"
for logits, mask, bias, lse, dropout in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
@@ -599,10 +611,16 @@ class KernelComponentFactory:
["t", "f"],
["t", "f"],
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
elif dtype in ["fp8bf16"]:
# no need lse/dropout kernels
for logits, qscale, mask, bias in itertools.product(
["t", "f"],
["pertensor"],
get_mask_map(mask_impl).keys(),
["no"],
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask)) # fmt: skip
else:
assert False
return pipelines
@@ -612,7 +630,7 @@ class CustomFactory(KernelComponentFactory):
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
if dtype == "fp16" or dtype == "bf16":
if dtype in ["fp16", "bf16"]:
if 128 in result.keys():
result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
return result
@@ -695,15 +713,14 @@ def get_fwd_blobs(
continue
# Aiter(mha_batch_prefill) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16"]
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# aiter::mha_batch_prefill C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16"]
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_qscale == "no"

View File

@@ -1017,7 +1017,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32:
# no need lse/dropout kernels
for logits, qscale, mask, bias, sink in itertools.product(
["f"],
["t", "f"],
["no", "pertensor"],
get_mask_map(mask_impl).keys(),
["no"],