diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 04c85892ac..08a8d5bcf4 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -219,7 +219,6 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; - ck_tile::index_t batch_stride_lse_acc; }; using Kargs = std::conditional_t; @@ -297,7 +296,8 @@ struct FmhaFwdSplitKVKernel nhead_stride_v, nhead_stride_lse_acc, nhead_stride_o_acc, - batch_stride_lse_acc batch_stride_o_acc, + batch_stride_lse_acc, + batch_stride_o_acc, split_stride_lse_acc, split_stride_o_acc}, // args for common karg {}, // placeholder for bias @@ -377,7 +377,8 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_lse_acc ck_tile::index_t batch_stride_o_acc, + ck_tile::index_t batch_stride_lse_acc, + ck_tile::index_t batch_stride_o_acc, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left,