From ecaaa6f136fe5b327133a79fec04074bc60776e5 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 5 Aug 2024 16:31:31 +0000 Subject: [PATCH] Simplify TileWindowNavigator interfaces --- .../block/block_tile_window_navigator.hpp | 106 +++++++++++------- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 24 ++-- 2 files changed, 77 insertions(+), 53 deletions(-) diff --git a/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp b/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp index ed39baa0e1..1c1ef27b48 100644 --- a/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp +++ b/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp @@ -8,18 +8,17 @@ namespace ck_tile { -template +template struct SimpleTileWindowNavigator { - using DataType = DataType_; + using DataType = DataType_; + using TensorViewLengths = TensorViewLengths_; + using TensorViewStrides = TensorViewStrides_; - template - CK_TILE_DEVICE static constexpr auto - make_tile_window(const TensorView& tensor_view, - const WindowLengths& window_lengths, - const multi_index& window_origin) + CK_TILE_DEVICE constexpr SimpleTileWindowNavigator(const TensorViewLengths& lengths_, + const TensorViewStrides& strides_) + : lengths(lengths_), strides(strides_) { - return ck_tile::make_tile_window(tensor_view, window_lengths, window_origin); } template @@ -30,14 +29,6 @@ struct SimpleTileWindowNavigator return ck_tile::make_tile_window(tile_window, window_origin); } - template - CK_TILE_DEVICE constexpr auto - make_tile_window(const tile_window_with_static_lengths& tile_window, - const StaticTileDistribution& tile_distribution) - { - return ck_tile::make_tile_window(tile_window, tile_distribution); - } - template CK_TILE_DEVICE void move_tile_window(TileWindow& tile_window, @@ -45,14 +36,30 @@ struct SimpleTileWindowNavigator { ck_tile::move_tile_window(tile_window, step); } + + TensorViewLengths lengths; + TensorViewStrides strides; }; -template +template +CK_TILE_DEVICE constexpr auto make_tile_window_navigator(const TensorViewLengths& lengths, + const TensorViewStrides& strides) +{ + return SimpleTileWindowNavigator(lengths, + strides); +} + +template struct PagedTileWindowNavigator { using DataType = DataType_; static constexpr index_t VirtualDim = VirtualDim_; static_assert(VirtualDim == 0 || VirtualDim == 1); + using TensorViewLengths = TensorViewLengths_; + using TensorViewStrides = TensorViewStrides_; CK_TILE_DEVICE constexpr PagedTileWindowNavigator(copy_const_t* blocks_, long_index_t block_stride_, @@ -60,46 +67,30 @@ struct PagedTileWindowNavigator long_index_t row_stride_, const int32_t* block_indices_, index_t num_blocks_, - index_t page_block_size_) + index_t page_block_size_, + const TensorViewLengths& lengths_, + const TensorViewStrides& strides_) : blocks(reinterpret_cast(blocks_)), block_stride(block_stride_), head_stride(head_stride_), row_stride(row_stride_), block_indices(block_indices_), num_blocks(num_blocks_), - page_block_size(page_block_size_) + page_block_size(page_block_size_), + lengths(lengths_), + strides(strides_) { } - template - CK_TILE_DEVICE auto - make_tile_window(const TensorView& tensor_view, - const WindowLengths& window_lengths, - const multi_index& window_origin) const - { - auto tile_window = ck_tile::make_tile_window(tensor_view, window_lengths, window_origin); - /// TODO: convert global window origin to local window origin - return tile_window; - } - - template - CK_TILE_DEVICE auto - make_tile_window(const tile_window_with_static_lengths& tile_window, - const StaticTileDistribution& tile_distribution) const - { - auto new_tile_window = ck_tile::make_tile_window(tile_window, tile_distribution); - /// TODO: convert global window origin to local window origin - return new_tile_window; - } - template CK_TILE_DEVICE auto make_tile_window(const tile_window_with_static_lengths& tile_window, const multi_index& window_origin) const { - auto new_tile_window = ck_tile::make_tile_window(tile_window, window_origin); /// TODO: convert global window origin to local window origin - return new_tile_window; + auto local_window_origin = window_origin; + + return ck_tile::make_tile_window(tile_window, local_window_origin); } template @@ -107,10 +98,10 @@ struct PagedTileWindowNavigator move_tile_window(TileWindow& tile_window, const typename remove_cvref_t::BottomTensorIndex& step) const { + /// TODO: reset pointer and adjust local window origin ck_tile::move_tile_window(tile_window, step); } - private: DataType* get_block_base(index_t block_index) { return blocks + block_index * block_stride + head_stride; @@ -126,6 +117,35 @@ struct PagedTileWindowNavigator const int32_t* block_indices; index_t num_blocks; index_t page_block_size; + + TensorViewLengths lengths; + TensorViewStrides strides; }; +template +CK_TILE_DEVICE constexpr auto make_tile_window_navigator(copy_const_t* blocks, + long_index_t block_stride, + long_index_t head_stride, + long_index_t row_stride, + const int32_t* block_indices, + index_t num_blocks, + index_t page_block_size, + const TensorViewLengths& lengths, + const TensorViewStrides& strides) +{ + return PagedTileWindowNavigator( + blocks, + block_stride, + head_stride, + row_stride, + block_indices, + num_blocks, + page_block_size, + lengths, + strides); +} + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index ca54a496d6..6072f7d862 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -542,17 +542,21 @@ struct FmhaFwdSplitKVKernel const index_t num_blocks = integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size); - return PagedTileWindowNavigator(kargs.k_ptr, - kargs.batch_stride_k, - kargs.nhead_stride_k, - kargs.stride_k, - block_indices, - num_blocks, - kargs.page_block_size); + return make_tile_window_navigator( + kargs.k_ptr, + kargs.batch_stride_k, + kargs.nhead_stride_k, + kargs.stride_k, + block_indices, + num_blocks, + kargs.page_block_size, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1)); } else { - return SimpleTileWindowNavigator(); + return make_tile_window_navigator( + make_tuple(kargs.seqlen_k, kargs.hdim_q), make_tuple(kargs.stride_k, 1)); } }(); @@ -688,8 +692,8 @@ struct FmhaFwdSplitKVKernel const auto k_dram = [&]() { const auto k_dram_naive = make_naive_tensor_view( k_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_q), - make_tuple(kargs.stride_k, 1), + k_tile_navigator.lengths, + k_tile_navigator.strides, number{}, number<1>{});