From 4cc514f803acb4f0eb251fa60e22d75faf8e625b Mon Sep 17 00:00:00 2001 From: danyao12 Date: Wed, 7 Aug 2024 11:00:33 +0000 Subject: [PATCH] fix unpadded lse issue in fwd splitkv --- .../fmha_fwd_splitkv_combine_kernel.hpp | 29 +++++++++---------- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 22 +++++--------- 2 files changed, 21 insertions(+), 30 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index b90e04f63f..e2c7db3e1b 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -99,10 +99,9 @@ struct FmhaFwdSplitKVCombineKernel struct CommonLSEKargs { - void* lse_ptr = nullptr; - ck_tile::index_t nhead_stride_lse = 0; - ck_tile::index_t batch_stride_lse_acc = 0; - ck_tile::index_t batch_stride_lse = 0; + void* lse_ptr = nullptr; + ck_tile::index_t nhead_stride_lse = 0; + ck_tile::index_t batch_stride_lse = 0; }; struct Fp8StaticQuantKargs @@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel std::conditional_t> { ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_lse_acc; }; struct GroupModeKargs @@ -171,14 +171,14 @@ struct FmhaFwdSplitKVCombineKernel split_stride_o_acc}, // args for common karg {}, // placeholder for lse {}, // placeholder for fp8_static_quant args - batch_stride_o}; + batch_stride_o, + batch_stride_lse_acc}; if constexpr(kStoreLSE) { - kargs.lse_ptr = lse_ptr; - kargs.nhead_stride_lse = nhead_stride_lse; - kargs.batch_stride_lse_acc = batch_stride_lse_acc; - kargs.batch_stride_lse = batch_stride_lse; + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; } if constexpr(kDoFp8StaticQuant) { @@ -282,12 +282,12 @@ struct FmhaFwdSplitKVCombineKernel // get starting offset for each batch const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - batch_offset_o = query_start * kargs.row_stride_o; + batch_offset_o = query_start * kargs.row_stride_o; + batch_offset_lse_acc = query_start; if constexpr(kStoreLSE) { - batch_offset_lse_acc = query_start; - batch_offset_lse = query_start; + batch_offset_lse = query_start; } // get real # queries & # keys under group mode @@ -303,12 +303,11 @@ struct FmhaFwdSplitKVCombineKernel } else { - batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + batch_offset_lse_acc = static_cast(i_batch) * kargs.batch_stride_lse_acc; if constexpr(kStoreLSE) { - batch_offset_lse_acc = - static_cast(i_batch) * kargs.batch_stride_lse_acc; batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; } } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 4cbc93aaad..36c10db79c 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -47,7 +47,6 @@ struct FmhaFwdSplitKVKernel static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; - static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; @@ -520,8 +519,9 @@ struct FmhaFwdSplitKVKernel const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + batch_offset_lse_acc = query_start; if constexpr(std::is_same_v) { batch_offset_v = key_start * kargs.stride_v; @@ -538,10 +538,6 @@ struct FmhaFwdSplitKVKernel { batch_offset_randval = query_start * kargs.stride_randval; } - if constexpr(kStoreLSE) - { - batch_offset_lse_acc = query_start; - } // get real # queries & # keys under group mode const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; @@ -566,9 +562,10 @@ struct FmhaFwdSplitKVKernel } else { - batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + batch_offset_lse_acc = static_cast(i_batch) * kargs.batch_stride_lse_acc; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; @@ -578,11 +575,6 @@ struct FmhaFwdSplitKVKernel batch_offset_randval = static_cast(i_batch) * kargs.batch_stride_randval; } - if constexpr(kStoreLSE) - { - batch_offset_lse_acc = - static_cast(i_batch) * kargs.batch_stride_lse_acc; - } } // for simplicity, batch stride we just modify the pointer