Add input fp8 and output bf16 attention (#2726)

* change host using fp16 to check

* fp8 to fp8 compare

* rewrite input parameters

* add not squant

* remove some output code

* for scale = 1

* format

* saturates only for fp8

* add fp8bf16 data type

* add fp8bf16 data type

* fix test fp8 code

* add run_fp8bf16_tests

* change fmha fwd example parameter(adding fp8bf16)

* Support fp8bf16 for Aiter

* Support aiter fp8bf16 in c++

* fix comment about fp8 in readme.md

* add fp8fp32

* add fp8fp32 test

* remove range_q etc.

* format

* fix test parameters about squant and fmha example input fp8bf16 fp8fp32 data type

* add fp8bf16 to data_type function

* change colmajor to rowmajor in test_ck_tile_fmha_fwd_fp8

* format

* reset atol for fp8

* fix bug for atol

---------

Co-authored-by: rocking <ChunYu.Lai@amd.com>
Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
ltqin
2025-09-19 14:26:43 +08:00
committed by GitHub
parent e469fee046
commit dd249f1cd6
12 changed files with 262 additions and 153 deletions

View File

@@ -7,7 +7,8 @@ FWD_DTYPE_MAP = {
"bf16" : "FmhaFwdBf16",
"fp8" : "FmhaFwdFp8",
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16"
"fp8bf16": "FmhaFwdFp8Bf16",
"fp8fp32": "FmhaFwdFp8Fp32"
}
BWD_DTYPE_MAP = {

View File

@@ -163,7 +163,7 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config&
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
const bool has_load_tr = ck_tile::is_load_tr_supported();
{F_dispatch}
@@ -248,11 +248,11 @@ class FmhaFwdApiTrait:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
@property
def seqtune(self) -> str:
if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true
else:
else:
return f'a.seqlen_q <= {self.bm0}'
@property
@@ -351,7 +351,7 @@ class FmhaFwdPipeline:
if self.F_squant == 't' : n += '_squant'
else: n += '_nsquant'
if self.F_trload == 't' : n += '_trload'
else: n += '_ntrload'
@@ -378,7 +378,7 @@ class FmhaFwdApiPool:
"t": "has_load_tr",
"f": "true"
}
per_tr_load =str()
for tr_load in ["t", "f"]:
per_dtypes=str()
@@ -550,12 +550,16 @@ class KernelComponentFactory:
(192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
}
elif dtype == 'fp8' or dtype == 'bf8':
elif dtype == 'fp8' or dtype == 'fp8bf16':
return {
(64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
}
elif dtype == 'fp8fp32':
return {
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
}
else:
return None
@@ -567,9 +571,9 @@ class KernelComponentFactory:
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
squant = 'f'
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
if hdim == 256 and hdim_v == 256:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
@@ -589,11 +593,12 @@ class KernelComponentFactory:
pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't'))
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
elif dtype in ['fp8', 'fp8bf16', 'fp8fp32']:
# no need lse/dropout kernels
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
elif dtype in ['fp8fp16', 'fp8bf16']:
for logits, squant, mask, bias in itertools.product(["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
elif dtype in ['fp8fp16', 'bf8']:
# TODO
None
else:
@@ -674,25 +679,34 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ['fp16', 'bf16']
cond = dtype in ['fp16', 'bf16', 'fp8bf16']
cond &= mode == 'batch'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if dtype == 'fp8bf16':
cond &= hdim == 128
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ['fp16', 'bf16']
cond = dtype in ['fp16', 'bf16', 'fp8bf16']
cond &= mode == 'group'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if dtype == 'fp8bf16':
cond &= hdim == 128
if not cond:
continue
# aiter::mha_fwd C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
cond = dtype in ['fp16', 'bf16', 'fp8bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if dtype == 'fp8bf16':
cond &= hdim == 128
if not cond:
continue
elif receipt == 888:
cond = dtype in ['fp8', 'fp8bf16', 'fp8fp32']
cond &= pipeline.F_vlayout == 'row'
cond &= hdim == 128
if not cond:
continue

View File

@@ -645,7 +645,6 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
}
else:
return None

View File

@@ -465,14 +465,14 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"]):
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
elif dtype in ['fp8', 'bf8']:
# TODO
None
# no need lse/dropout kernels
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f'))
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f'))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None