[CK_TILE] Add paged-kvcache support in group mode fmha fwd splitkv kernels (#1678)

* Generate group mode paged-attn kernel

* Enable paged-kvcache + group mode support

* Add missing header: fused_moe.hpp

* Add comment to explain kernel arg usage

* Make error message more clear

* Add comment for confusing data member names

* Add more comment for confusing variable names

* Fix typo in option description
This commit is contained in:
Po Yen Chen
2024-11-21 14:53:10 +08:00
committed by GitHub
parent 6916d8cc03
commit fb1ccfa9df
6 changed files with 95 additions and 43 deletions

View File

@@ -655,9 +655,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
if pipeline.F_pagedkv == 't':
# we only use batch mode kernels to handle (paged-) kvcache problems
continue
k = Kernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,