From c7b7b4488375c30f89eae421dfbd1ae3a0cebf67 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 24 Jul 2024 05:13:16 +0000 Subject: [PATCH] Add comment for why I just 't' for all padding flags --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index ab572debbd..63442eef0b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -287,9 +287,8 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> pipelines = [] if dtype in ['fp16', 'bf16']: for rope in ROPE_MAP.keys(): - # pipelines.append(FmhaFwdAppendKVPipeline('row', 'f', 'f', 'f', 'f', rope)) - # pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f', rope)) - + # FIXME: it will be very complicated if we consider all the padding cases, + # so I just use 't' on all the dimensions pipelines.append(FmhaFwdAppendKVPipeline('row', 't', 't', 't', 't', rope)) pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', rope)) elif dtype in ['fp8', 'bf8']: