From a4c6029a3dc1ca99aaebce60f9ebe1bf1dbe10f2 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Fri, 16 Aug 2024 10:08:01 +0000 Subject: [PATCH] Fix skcheck logic --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index e975a10c04..a5d02c92cf 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -114,9 +114,8 @@ class FmhaFwdAppendKVApiTrait: @property def skcheck(self) -> str: - if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true - if self.skpad == 't' : return f'true /*a.seqlen_knew % {self.bsk} != 0*/' - else : return f'a.seqlen_knew % {self.bsk} == 0' + # we do not check all the values in a.seqlen_k_ptr + return 't' @property def dcheck(self) -> str: @@ -291,13 +290,13 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> # applying rotary embedding, so I just use 't' in inter/half pipelines for vlayout in ['row', 'col']: for pagedkv in ["t", "f"]: - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 'f', 'f', 'f', 'no', pagedkv)) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 'f', 'f', 'no', pagedkv)) pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no', pagedkv)) - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 'f', 't', 'f', 'inter', pagedkv)) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'inter', pagedkv)) pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter', pagedkv)) - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 'f', 't', 'f', 'half', pagedkv)) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'half', pagedkv)) pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half', pagedkv)) elif dtype in ['fp8', 'bf8']: # rope/paged-kv is not supported