diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 3b0f4aaf4b..ab6d320ab4 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -541,6 +541,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm cond = dtype in ['fp16', 'bf16'] cond &= mode == 'group' cond &= pipeline.F_vlayout == 'row' + cond &= ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f') cond &= pipeline.F_squant == 'f' if not cond: continue diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 26753f5ea4..f362ebd7a6 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -738,6 +738,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond = dtype in ['fp16', 'bf16'] cond &= mode == "group" cond &= pipeline.F_vlayout == 'row' + cond &= ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f') cond &= pipeline.F_squant == 'f' if not cond: continue