diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index da1165f9db..d5d5db4d28 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -115,8 +115,8 @@ class FmhaFwdAppendKVApiTrait: @property def skcheck(self) -> str: if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true - if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bsk} != 0*/' - else : return f'a.seqlen_k % {self.bsk} == 0' + if self.skpad == 't' : return f'true /*a.seqlen_knew % {self.bsk} != 0*/' + else : return f'a.seqlen_knew % {self.bsk} == 0' @property def dcheck(self) -> str: @@ -286,11 +286,17 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for rope in ROPE_MAP.keys(): - # FIXME: it will be very complicated if we consider all the padding cases, - # so I just use 't' for the padding flags - pipelines.append(FmhaFwdAppendKVPipeline('row', 't', 't', 't', 't', rope)) - pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', rope)) + # NOTICE: it will be very complicated if we consider all the hdim_q padding cases while + # applying rotary embedding, so I just use 't' in inter/half pipelines + for vlayout in ['row', 'col']: + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 'f', 'f', 'f', 'no')) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no')) + + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 'f', 't', 'f', 'inter')) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter')) + + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 'f', 't', 'f', 'half')) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half')) elif dtype in ['fp8', 'bf8']: # rope is not supported pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no'))