generate async pipelin code for fp8

This commit is contained in:
ltqin
2025-09-23 06:48:42 +00:00
parent 4363a82bd6
commit 43fa6ccaf7

View File

@@ -20,7 +20,8 @@ DTYPE_BITS = {
"fp16": 16,
"bf16": 16,
"fp8" : 8,
"bf8" : 8
"bf8" : 8,
"fp8bf16" : 8
}
K0_MAX_SUBMAX_MAP = {
@@ -552,9 +553,9 @@ class KernelComponentFactory:
}
elif dtype == 'fp8' or dtype == 'fp8bf16':
return {
(64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
# (64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
# (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
}
elif dtype == 'fp8fp32':
return {
@@ -593,12 +594,12 @@ class KernelComponentFactory:
pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't'))
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'fp8bf16', 'fp8fp32']:
elif dtype in ['fp8', 'fp8bf16']:
# no need lse/dropout kernels
for logits, squant, mask, bias in itertools.product(["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
elif dtype in ['fp8fp16', 'bf8']:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
elif dtype in ['fp8fp16', 'bf8', 'fp8fp32']:
# TODO
None
else: