diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 958a6350b2..6ebd0dcb04 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -374,9 +374,16 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) args.hdim_v, args.num_splits, args.scale_o, + args.stride_o_acc, args.stride_o, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, args.nhead_stride_lse, - args.nhead_stride_o); + args.nhead_stride_o, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.split_stride_lse_acc, + args.split_stride_o_acc); } else { // create batch mode kernel arguments @@ -391,11 +398,18 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) args.hdim_v, args.num_splits, args.scale_o, + args.stride_o_acc, args.stride_o, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, args.nhead_stride_lse, args.nhead_stride_o, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, args.batch_stride_lse, - args.batch_stride_o); + args.batch_stride_o, + args.split_stride_lse_acc, + args.split_stride_o_acc); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index c511613fff..fe012ab1ab 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -85,8 +85,18 @@ struct FmhaFwdSplitKVCombineKernel ck_tile::index_t hdim_v; ck_tile::index_t num_splits; + ck_tile::index_t row_stride_o_acc; ck_tile::index_t row_stride_o; + + ck_tile::index_t nhead_stride_lse_acc; + ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o; + + ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; + + ck_tile::index_t split_stride_lse_acc; + ck_tile::index_t split_stride_o_acc; }; struct CommonLSEKargs @@ -132,11 +142,18 @@ struct FmhaFwdSplitKVCombineKernel ck_tile::index_t hdim_v, ck_tile::index_t num_splits, float scale_o, + ck_tile::index_t row_stride_o_acc, ck_tile::index_t row_stride_o, + ck_tile::index_t nhead_stride_lse_acc, + ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_lse_acc, + ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_lse, - ck_tile::index_t batch_stride_o) + ck_tile::index_t batch_stride_o, + ck_tile::index_t split_stride_lse_acc, + ck_tile::index_t split_stride_o_acc) { Kargs kargs{{lse_acc_ptr, o_acc_ptr, @@ -147,10 +164,17 @@ struct FmhaFwdSplitKVCombineKernel seqlen_q, hdim_v, num_splits, + row_stride_o_acc, row_stride_o, - nhead_stride_o}, // args for common karg - {}, // placeholder for lse - {}, // placeholder for fp8_static_quant args + nhead_stride_lse_acc, + nhead_stride_o_acc, + nhead_stride_o, + batch_stride_lse_acc, + batch_stride_o_acc, + split_stride_lse_acc, + split_stride_o_acc}, // args for common karg + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args batch_stride_o}; if constexpr(kStoreLSE) @@ -180,9 +204,16 @@ struct FmhaFwdSplitKVCombineKernel ck_tile::index_t hdim_v, ck_tile::index_t num_splits, float scale_o, + ck_tile::index_t row_stride_o_acc, ck_tile::index_t row_stride_o, + ck_tile::index_t nhead_stride_lse_acc, + ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_lse, - ck_tile::index_t nhead_stride_o) + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_lse_acc, + ck_tile::index_t batch_stride_o_acc, + ck_tile::index_t split_stride_lse_acc, + ck_tile::index_t split_stride_o_acc) { Kargs kargs{{lse_acc_ptr, o_acc_ptr, @@ -193,10 +224,17 @@ struct FmhaFwdSplitKVCombineKernel -1, // seqlen will be updated by another pointer hdim_v, num_splits, + row_stride_o_acc, row_stride_o, - nhead_stride_o}, // args for common karg - {}, // placeholder for lse - {}, // placeholder for fp8_static_quant args + nhead_stride_lse_acc, + nhead_stride_o_acc, + nhead_stride_o, + batch_stride_lse_acc, + batch_stride_o_acc, + split_stride_lse_acc, + split_stride_o_acc}, // args for common karg + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args reinterpret_cast(seqstart_q_ptr)}; if constexpr(kStoreLSE) @@ -239,20 +277,18 @@ struct FmhaFwdSplitKVCombineKernel const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - long_index_t batch_offset_lse_acc = 0; - long_index_t batch_offset_o_acc = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; + const long_index_t batch_offset_lse_acc = + static_cast(i_batch) * kargs.batch_stride_lse_acc; + const long_index_t batch_offset_o_acc = + static_cast(i_batch) * kargs.batch_stride_o_acc; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; if constexpr(kIsGroupMode) { // get starting offset for each batch const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - batch_offset_lse_acc = - static_cast(i_batch) * (kargs.nhead * kargs.max_seqlen_q); - batch_offset_o_acc = static_cast(i_batch) * - (kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v); if constexpr(kStoreLSE) { batch_offset_lse = @@ -273,10 +309,6 @@ struct FmhaFwdSplitKVCombineKernel } else { - batch_offset_lse_acc = - static_cast(i_batch) * (kargs.nhead * kargs.max_seqlen_q); - batch_offset_o_acc = static_cast(i_batch) * - (kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v); if constexpr(kStoreLSE) { batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; @@ -285,13 +317,12 @@ struct FmhaFwdSplitKVCombineKernel } // for simplicity, batch stride we just modify the pointer - const LSEDataType* lse_acc_ptr = reinterpret_cast(kargs.lse_acc_ptr) + - static_cast(i_nhead) * (kargs.max_seqlen_q) + - batch_offset_lse_acc; + const LSEDataType* lse_acc_ptr = + reinterpret_cast(kargs.lse_acc_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_lse_acc + batch_offset_lse_acc; const OaccDataType* o_acc_ptr = reinterpret_cast(kargs.o_acc_ptr) + - static_cast(i_nhead) * (kargs.max_seqlen_q * kargs.hdim_v) + - batch_offset_o_acc; + static_cast(i_nhead) * kargs.nhead_stride_o_acc + batch_offset_o_acc; ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + static_cast(i_nhead) * kargs.nhead_stride_o + batch_offset_o; @@ -301,7 +332,7 @@ struct FmhaFwdSplitKVCombineKernel const auto lse_acc_dram_naive = make_naive_tensor_view( lse_acc_ptr, make_tuple(kargs.num_splits, kargs.seqlen_q), - make_tuple(kargs.batch * kargs.nhead * kargs.max_seqlen_q, 1), + make_tuple(kargs.split_stride_lse_acc, 1), number<8>{}, number<1>{}); @@ -315,8 +346,7 @@ struct FmhaFwdSplitKVCombineKernel const auto o_acc_dram_naive = make_naive_tensor_view( o_acc_ptr, make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v), - make_tuple( - kargs.batch * kargs.nhead * kargs.max_seqlen_q * kargs.hdim_v, kargs.hdim_v, 1), + make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1), number{}, number<1>{}); @@ -390,8 +420,8 @@ struct FmhaFwdSplitKVCombineKernel lse_dram_window, identity{}, // lse_element_func composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func - smem_ptr, kargs.num_splits, + smem_ptr, kargs.seqlen_q, kargs.max_seqlen_q); } @@ -400,8 +430,8 @@ struct FmhaFwdSplitKVCombineKernel return FmhaPipeline{}(lse_acc_dram_window, o_acc_dram_window, lse_dram_window, - smem_ptr, kargs.num_splits, + smem_ptr, kargs.seqlen_q, kargs.max_seqlen_q); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index e1b8ae12cb..340df1a094 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -82,8 +82,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline LSEDramBlockWindowTmp& lse_dram_window_tmp, const LSEElementFunction& lse_element_func, const OaccElementFunction& o_acc_element_func, - void* smem_ptr, index_t num_splits, + void* smem_ptr, index_t real_seqlen_q, index_t max_seqlen_q) const { @@ -311,8 +311,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline CK_TILE_HOST_DEVICE auto operator()(const LSEaccDramBlockWindow& lse_acc_dram_block_window, const OaccDramBlockWindow& o_acc_dram_block_window, LSEDramBlockWindow& lse_dram_block_window, - void* smem_ptr, index_t num_splits, + void* smem_ptr, index_t real_seqlen_q, index_t max_seqlen_q) const { @@ -321,8 +321,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline lse_dram_block_window, identity{}, identity{}, - smem_ptr, num_splits, + smem_ptr, real_seqlen_q, max_seqlen_q); }