Refine pipeline padding settings

This commit is contained in:
PoYen, Chen
2024-07-24 11:37:56 +00:00
parent f053ae2b5b
commit 4280a07d2a

View File

@@ -115,8 +115,8 @@ class FmhaFwdAppendKVApiTrait:
@property
def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bsk} != 0*/'
else : return f'a.seqlen_k % {self.bsk} == 0'
if self.skpad == 't' : return f'true /*a.seqlen_knew % {self.bsk} != 0*/'
else : return f'a.seqlen_knew % {self.bsk} == 0'
@property
def dcheck(self) -> str:
@@ -286,11 +286,17 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for rope in ROPE_MAP.keys():
# FIXME: it will be very complicated if we consider all the padding cases,
# so I just use 't' for the padding flags
pipelines.append(FmhaFwdAppendKVPipeline('row', 't', 't', 't', 't', rope))
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', rope))
# NOTICE: it will be very complicated if we consider all the hdim_q padding cases while
# applying rotary embedding, so I just use 't' in inter/half pipelines
for vlayout in ['row', 'col']:
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 'f', 'f', 'f', 'no'))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no'))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 'f', 't', 'f', 'inter'))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter'))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 'f', 't', 'f', 'half'))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half'))
elif dtype in ['fp8', 'bf8']:
# rope is not supported
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no'))