From bd0d2f397598b140d15175fe3a7312e94c0e8bd3 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 6 Aug 2024 08:02:43 +0000 Subject: [PATCH] Add batch_stride_k/batch_stride_v in group mode --- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 27bd9305a5..54f7992c97 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -237,6 +237,9 @@ struct FmhaFwdSplitKVKernel const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; const int32_t* seqlen_k_ptr; + + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; }; using Kargs = std::conditional_t; @@ -408,6 +411,8 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, ck_tile::index_t split_stride_lse_acc, @@ -460,7 +465,9 @@ struct FmhaFwdSplitKVKernel {}, // placeholder for dropout reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), - reinterpret_cast(seqlen_k_ptr)}; + reinterpret_cast(seqlen_k_ptr), + batch_stride_k, + batch_stride_v}; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -582,16 +589,8 @@ struct FmhaFwdSplitKVKernel else { batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - if constexpr(kIsPagedKV) - { - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - } - else - { - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - } + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) {