diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_decode.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_decode.py index 75e9996b7a..0dd3862b1a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_decode.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_decode.py @@ -638,7 +638,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: return None def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: - if dtype == 'bf16': + if dtype == 'fp16' or dtype == 'bf16': return { ### '32' : FmhaFwdSplitKVCombineTileSize(32, -1), ### '64' : FmhaFwdSplitKVCombineTileSize(32, -1), @@ -736,7 +736,7 @@ def get_batch_decode_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> continue # Aiter(batch_decode) integration elif receipt == 200: - cond = dtype in [ 'bf16'] + cond = dtype in [ 'fp16','bf16'] cond &= mode == 'batch' cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_bias == 'no'