From 7789b53e153f301e9156eeb2ef88045390a5efe2 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 7 Aug 2024 04:51:21 +0000 Subject: [PATCH] Add tile navigators to the appendkv kernel --- .../fmha/kernel/fmha_fwd_appendkv_kernel.hpp | 100 +++++++++++++++++- .../block_fmha_fwd_appendkv_pipeline.hpp | 36 ++++++- 2 files changed, 128 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index b9de3f044e..ee272cd494 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -139,6 +139,9 @@ struct FmhaFwdAppendKVKernel const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; const int32_t* seqlen_k_ptr; + + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; }; using Kargs = std::conditional_t; @@ -361,6 +364,58 @@ struct FmhaFwdAppendKVKernel batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; } + auto k_tile_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(kIsPagedKV) + { + const auto* block_indices = + reinterpret_cast(kargs.block_table_ptr) + + i_batch_ * kargs.batch_stride_block_table; + const index_t num_blocks = + integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size); + + const long_index_t fixed_offset = + static_cast(i_nhead_ / kargs.nhead_ratio_qk) * + kargs.nhead_stride_k; + + return PagedTileWindowNavigator(kargs.k_ptr, + kargs.batch_stride_k, + fixed_offset, + block_indices, + num_blocks, + kargs.page_block_size); + } + else + { + return SimpleTileWindowNavigator(); + } + }(); + + auto v_tile_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(kIsPagedKV) + { + const auto* block_indices = + reinterpret_cast(kargs.block_table_ptr) + + i_batch_ * kargs.batch_stride_block_table; + const index_t num_blocks = + integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size); + + const long_index_t fixed_offset = + static_cast(i_nhead_ / kargs.nhead_ratio_qk) * + kargs.nhead_stride_v; + + return PagedTileWindowNavigator(kargs.v_ptr, + kargs.batch_stride_v, + fixed_offset, + block_indices, + num_blocks, + kargs.page_block_size); + } + else + { + return SimpleTileWindowNavigator(); + } + }(); + // for simplicity, batch stride we just modify the pointer QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + static_cast(i_nhead) * kargs.nhead_stride_q + @@ -397,9 +452,20 @@ struct FmhaFwdAppendKVKernel sequence{}); }(); const auto k_dram = [&]() { + const auto lengths = [&]() { + if constexpr(kIsPagedKV) + { + return make_tuple(kargs.page_block_size, kargs.hdim_q); + } + else + { + return make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.hdim_q); + } + }(); + const auto k_dram_naive = make_naive_tensor_view( k_ptr, - make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.hdim_q), + lengths, make_tuple(kargs.stride_k, 1), number{}, number<1>{}); @@ -425,9 +491,20 @@ struct FmhaFwdAppendKVKernel const auto v_dram = [&]() { if constexpr(std::is_same_v) { + const auto lengths = [&]() { + if constexpr(kIsPagedKV) + { + return make_tuple(kargs.page_block_size, kargs.hdim_v); + } + else + { + return make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.hdim_v); + } + }(); + const auto v_dram_naive = make_naive_tensor_view( v_ptr, - make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.hdim_v), + lengths, make_tuple(kargs.stride_v, 1), number{}, number<1>{}); @@ -446,9 +523,20 @@ struct FmhaFwdAppendKVKernel } else { + const auto lengths = [&]() { + if constexpr(kIsPagedKV) + { + return make_tuple(kargs.hdim_v, kargs.page_block_size); + } + else + { + return make_tuple(kargs.hdim_v, kargs.seqlen_k + kargs.seqlen_knew); + } + }(); + const auto v_dram_naive = make_naive_tensor_view( v_ptr, - make_tuple(kargs.hdim_v, kargs.seqlen_k + kargs.seqlen_knew), + lengths, make_tuple(kargs.stride_v, 1), number{}, number<1>{}); @@ -644,6 +732,8 @@ struct FmhaFwdAppendKVKernel knew_rotary_cos_dram_window, knew_rotary_sin_dram_window, kargs.rotary_dim, + k_tile_navigator, + v_tile_navigator, kargs.seqlen_q <= i_m0, kargs.seqlen_knew <= i_n0); } @@ -658,7 +748,9 @@ struct FmhaFwdAppendKVKernel q_rotary_sin_dram_window, knew_rotary_cos_dram_window, knew_rotary_sin_dram_window, - 0, + 0, // rotary_dim not used + k_tile_navigator, + v_tile_navigator, kargs.seqlen_q <= i_m0, kargs.seqlen_knew <= i_n0); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index d6d071f65c..53467c5e93 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -83,7 +83,9 @@ struct BlockFmhaFwdAppendKVPipeline typename QRotaryCosDramBlockWindow, typename QRotarySinDramBlockWindow, typename KnewRotaryCosDramBlockWindow, - typename KnewRotarySinDramBlockWindow> + typename KnewRotarySinDramBlockWindow, + typename KTileWindowNavigator, + typename VTileWindowNavigator> CK_TILE_HOST_DEVICE auto operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile const QElementFunction& q_element_func, @@ -98,6 +100,8 @@ struct BlockFmhaFwdAppendKVPipeline const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window, const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window, index_t rotary_dim, + const KTileWindowNavigator& k_tile_navigator, + const VTileWindowNavigator& v_tile_navigator, bool skip_transform_q, bool skip_append_kv) const { @@ -141,7 +145,16 @@ struct BlockFmhaFwdAppendKVPipeline rotary_dim, thread_end); } - store_tile(k_dram_block_window, knew_tile); + + if constexpr(kIsPagedKV) + { + /// TODO: handle cross-page-block write + store_tile(k_dram_block_window, knew_tile); + } + else + { + store_tile(k_dram_block_window, knew_tile); + } // append Vnew to V auto vnew_window = make_tile_window( @@ -151,7 +164,16 @@ struct BlockFmhaFwdAppendKVPipeline auto vnew = load_tile(vnew_window); return tile_elementwise_in(vnew_element_func, vnew); }(); - store_tile(v_dram_block_window, vnew_tile); + + if constexpr(kIsPagedKV) + { + /// TODO: handle cross-page-block write + store_tile(v_dram_block_window, vnew_tile); + } + else + { + store_tile(v_dram_block_window, vnew_tile); + } } if(!skip_transform_q) @@ -201,7 +223,9 @@ struct BlockFmhaFwdAppendKVPipeline typename QRotaryCosDramBlockWindow, typename QRotarySinDramBlockWindow, typename KnewRotaryCosDramBlockWindow, - typename KnewRotarySinDramBlockWindow> + typename KnewRotarySinDramBlockWindow, + typename KTileWindowNavigator, + typename VTileWindowNavigator> CK_TILE_HOST_DEVICE auto operator()(QDramBlockWindow& q_dram_block_window, KDramBlockWindow& k_dram_block_window, @@ -213,6 +237,8 @@ struct BlockFmhaFwdAppendKVPipeline const KnewRotaryCosDramBlockWindow& knew_rotary_cos_dram_block_window, const KnewRotarySinDramBlockWindow& knew_rotary_sin_dram_block_window, index_t rotary_dim, + const KTileWindowNavigator& k_tile_navigator, + const VTileWindowNavigator& v_tile_navigator, bool skip_transform_q, bool skip_append_kv) const { @@ -229,6 +255,8 @@ struct BlockFmhaFwdAppendKVPipeline knew_rotary_cos_dram_block_window, knew_rotary_sin_dram_block_window, rotary_dim, + k_tile_navigator, + v_tile_navigator, skip_transform_q, skip_append_kv); }