mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Add input fp8 and output bf16 attention (#2726)
* change host using fp16 to check * fp8 to fp8 compare * rewrite input parameters * add not squant * remove some output code * for scale = 1 * format * saturates only for fp8 * add fp8bf16 data type * add fp8bf16 data type * fix test fp8 code * add run_fp8bf16_tests * change fmha fwd example parameter(adding fp8bf16) * Support fp8bf16 for Aiter * Support aiter fp8bf16 in c++ * fix comment about fp8 in readme.md * add fp8fp32 * add fp8fp32 test * remove range_q etc. * format * fix test parameters about squant and fmha example input fp8bf16 fp8fp32 data type * add fp8bf16 to data_type function * change colmajor to rowmajor in test_ck_tile_fmha_fwd_fp8 * format * reset atol for fp8 * fix bug for atol --------- Co-authored-by: rocking <ChunYu.Lai@amd.com> Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
@@ -7,7 +7,8 @@ FWD_DTYPE_MAP = {
|
||||
"bf16" : "FmhaFwdBf16",
|
||||
"fp8" : "FmhaFwdFp8",
|
||||
"fp8fp16": "FmhaFwdFp8Fp16",
|
||||
"fp8bf16": "FmhaFwdFp8Bf16"
|
||||
"fp8bf16": "FmhaFwdFp8Bf16",
|
||||
"fp8fp32": "FmhaFwdFp8Fp32"
|
||||
}
|
||||
|
||||
BWD_DTYPE_MAP = {
|
||||
|
||||
@@ -163,7 +163,7 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config&
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
|
||||
const bool has_load_tr = ck_tile::is_load_tr_supported();
|
||||
|
||||
{F_dispatch}
|
||||
@@ -248,11 +248,11 @@ class FmhaFwdApiTrait:
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_q % {self.bm0} == 0'
|
||||
else: assert False
|
||||
|
||||
|
||||
@property
|
||||
def seqtune(self) -> str:
|
||||
if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true
|
||||
else:
|
||||
else:
|
||||
return f'a.seqlen_q <= {self.bm0}'
|
||||
|
||||
@property
|
||||
@@ -351,7 +351,7 @@ class FmhaFwdPipeline:
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
|
||||
|
||||
if self.F_trload == 't' : n += '_trload'
|
||||
else: n += '_ntrload'
|
||||
|
||||
@@ -378,7 +378,7 @@ class FmhaFwdApiPool:
|
||||
"t": "has_load_tr",
|
||||
"f": "true"
|
||||
}
|
||||
|
||||
|
||||
per_tr_load =str()
|
||||
for tr_load in ["t", "f"]:
|
||||
per_dtypes=str()
|
||||
@@ -550,12 +550,16 @@ class KernelComponentFactory:
|
||||
(192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
elif dtype == 'fp8' or dtype == 'fp8bf16':
|
||||
return {
|
||||
(64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
}
|
||||
elif dtype == 'fp8fp32':
|
||||
return {
|
||||
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -567,9 +571,9 @@ class KernelComponentFactory:
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
squant = 'f'
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
@@ -589,11 +593,12 @@ class KernelComponentFactory:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't'))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
elif dtype in ['fp8', 'fp8bf16', 'fp8fp32']:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
for logits, squant, mask, bias in itertools.product(["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
|
||||
elif dtype in ['fp8fp16', 'bf8']:
|
||||
# TODO
|
||||
None
|
||||
else:
|
||||
@@ -674,25 +679,34 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond = dtype in ['fp16', 'bf16', 'fp8bf16']
|
||||
cond &= mode == 'batch'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if dtype == 'fp8bf16':
|
||||
cond &= hdim == 128
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond = dtype in ['fp16', 'bf16', 'fp8bf16']
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if dtype == 'fp8bf16':
|
||||
cond &= hdim == 128
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond = dtype in ['fp16', 'bf16', 'fp8bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if dtype == 'fp8bf16':
|
||||
cond &= hdim == 128
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 888:
|
||||
cond = dtype in ['fp8', 'fp8bf16', 'fp8fp32']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= hdim == 128
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
|
||||
@@ -645,7 +645,6 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
return {
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -465,14 +465,14 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"]):
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# TODO
|
||||
None
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f'))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
|
||||
Reference in New Issue
Block a user