From 43fa6ccaf730e0a1c26497b4528e0bacef234ea2 Mon Sep 17 00:00:00 2001 From: ltqin Date: Tue, 23 Sep 2025 06:48:42 +0000 Subject: [PATCH] generate async pipelin code for fp8 --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index cfb96b7d53..ed7c8f5464 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -20,7 +20,8 @@ DTYPE_BITS = { "fp16": 16, "bf16": 16, "fp8" : 8, - "bf8" : 8 + "bf8" : 8, + "fp8bf16" : 8 } K0_MAX_SUBMAX_MAP = { @@ -552,9 +553,9 @@ class KernelComponentFactory: } elif dtype == 'fp8' or dtype == 'fp8bf16': return { - (64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + # (64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], - (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } elif dtype == 'fp8fp32': return { @@ -593,12 +594,12 @@ class KernelComponentFactory: pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim - elif dtype in ['fp8', 'fp8bf16', 'fp8fp32']: + elif dtype in ['fp8', 'fp8bf16']: # no need lse/dropout kernels for logits, squant, mask, bias in itertools.product(["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) - elif dtype in ['fp8fp16', 'bf8']: + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + elif dtype in ['fp8fp16', 'bf8', 'fp8fp32']: # TODO None else: