From 751d114b91f70b2f261560db2faa3282cae893d3 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Mon, 16 Jun 2025 00:52:12 +0000 Subject: [PATCH] test fp16 ok --- example/ck_tile/01_fmha/codegen/ops/fmha_batch_decode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_decode.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_decode.py index 75e9996b7a..0dd3862b1a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_decode.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_decode.py @@ -638,7 +638,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: return None def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: - if dtype == 'bf16': + if dtype == 'fp16' or dtype == 'bf16': return { ### '32' : FmhaFwdSplitKVCombineTileSize(32, -1), ### '64' : FmhaFwdSplitKVCombineTileSize(32, -1), @@ -736,7 +736,7 @@ def get_batch_decode_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> continue # Aiter(batch_decode) integration elif receipt == 200: - cond = dtype in [ 'bf16'] + cond = dtype in [ 'fp16','bf16'] cond &= mode == 'batch' cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_bias == 'no'