Fix skcheck logic

This commit is contained in:
PoYen, Chen
2024-08-16 10:08:01 +00:00
parent 5805f5aa73
commit a4c6029a3d

View File

@@ -114,9 +114,8 @@ class FmhaFwdAppendKVApiTrait:
@property
def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.skpad == 't' : return f'true /*a.seqlen_knew % {self.bsk} != 0*/'
else : return f'a.seqlen_knew % {self.bsk} == 0'
# we do not check all the values in a.seqlen_k_ptr
return 't'
@property
def dcheck(self) -> str:
@@ -291,13 +290,13 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
# applying rotary embedding, so I just use 't' in inter/half pipelines
for vlayout in ['row', 'col']:
for pagedkv in ["t", "f"]:
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 'f', 'f', 'f', 'no', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 'f', 'f', 'no', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 'f', 't', 'f', 'inter', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'inter', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 'f', 't', 'f', 'half', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'half', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half', pagedkv))
elif dtype in ['fp8', 'bf8']:
# rope/paged-kv is not supported