Add comment for why I just 't' for all padding flags

This commit is contained in:
PoYen, Chen
2024-07-24 05:13:16 +00:00
parent 59e1d9b84f
commit c7b7b44883

View File

@@ -287,9 +287,8 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
pipelines = []
if dtype in ['fp16', 'bf16']:
for rope in ROPE_MAP.keys():
# pipelines.append(FmhaFwdAppendKVPipeline('row', 'f', 'f', 'f', 'f', rope))
# pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f', rope))
# FIXME: it will be very complicated if we consider all the padding cases,
# so I just use 't' on all the dimensions
pipelines.append(FmhaFwdAppendKVPipeline('row', 't', 't', 't', 't', rope))
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', rope))
elif dtype in ['fp8', 'bf8']: