From b98985262d8c06e7d3bb914054eba7972cd5e57d Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 6 Aug 2024 14:54:07 +0000 Subject: [PATCH] Add missing kernel arguments for group mode --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py | 7 ++++--- example/ck_tile/01_fmha/fmha_fwd.hpp | 2 ++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 613f3b6e66..c531da1ccf 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -108,8 +108,8 @@ template<> void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_args a) {{ if constexpr({F_mode} == false) {{ // batch mode - if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0 - && a.seqlen_k % (a.num_splits * {F_bk1}) == 0) {{ + // make sure F_bn0 is divisible by F_bk1 + if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ kernel_runner::run(s, a); }} else {{ kernel_runner::run(s, a); @@ -527,7 +527,8 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if dtype in ['fp16', 'bf16']: # splitkv kernel donot support dropout for mask, bias, lse, dropout, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"], ["t", "f"]): - if hdim == 256 or hdim == 32 or hdim == 64 or hdim == 128: + # TODO: use async pipeline when compiler is more stable + if hdim == 256 or hdim in [32, 64, 128]: # if True: pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, pagedkv, mask)) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 546b70a4b0..c4535aac5e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -336,6 +336,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_randval, args.nhead_stride_lse_acc, args.nhead_stride_o_acc, + args.batch_stride_k, + args.batch_stride_v, args.batch_stride_lse_acc, args.batch_stride_o_acc, args.split_stride_lse_acc,