From bb7835326448d741631c28baba42ffc88a4c5433 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 5 Aug 2024 21:52:59 +0000 Subject: [PATCH] Remove ununnecessary data members --- .../block/block_tile_window_navigator.hpp | 70 ++----------------- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 38 ++++++---- 2 files changed, 31 insertions(+), 77 deletions(-) diff --git a/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp b/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp index 2357ee8586..49ff4db4ee 100644 --- a/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp +++ b/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp @@ -8,19 +8,11 @@ namespace ck_tile { -template +template struct SimpleTileWindowNavigator { - using DataType = DataType_; - using WindowOrigin = multi_index<2>; - using TensorViewLengths = TensorViewLengths_; - using TensorViewStrides = TensorViewStrides_; - - CK_TILE_DEVICE constexpr SimpleTileWindowNavigator(const TensorViewLengths& lengths_, - const TensorViewStrides& strides_) - : lengths(lengths_), strides(strides_) - { - } + using DataType = DataType_; + using WindowOrigin = multi_index<2>; template CK_TILE_DEVICE static constexpr auto @@ -60,31 +52,15 @@ struct SimpleTileWindowNavigator { return local_window_origin; } - - TensorViewLengths lengths; - TensorViewStrides strides; }; -template -CK_TILE_DEVICE constexpr auto make_tile_window_navigator(const TensorViewLengths& lengths, - const TensorViewStrides& strides) -{ - return SimpleTileWindowNavigator(lengths, - strides); -} - -template +template struct PagedTileWindowNavigator { using DataType = DataType_; static constexpr index_t VirtualDim = VirtualDim_; static_assert(VirtualDim == 0 || VirtualDim == 1, "only support 2d tile window"); - using WindowOrigin = multi_index<2>; - using TensorViewLengths = TensorViewLengths_; - using TensorViewStrides = TensorViewStrides_; + using WindowOrigin = multi_index<2>; CK_TILE_DEVICE constexpr PagedTileWindowNavigator( copy_const_t* physical_blocks_, @@ -92,17 +68,13 @@ struct PagedTileWindowNavigator long_index_t fixed_offset_, const int32_t* physical_block_indices_, index_t num_blocks_, - index_t page_block_size_, - const TensorViewLengths& lengths_, - const TensorViewStrides& strides_) + index_t page_block_size_) : physical_blocks(reinterpret_cast(physical_blocks_)), block_stride(block_stride_), fixed_offset(fixed_offset_), physical_block_indices(physical_block_indices_), num_blocks(num_blocks_), - page_block_size(page_block_size_), - lengths(lengths_), - strides(strides_) + page_block_size(page_block_size_) { } @@ -194,34 +166,6 @@ struct PagedTileWindowNavigator const int32_t* physical_block_indices; index_t num_blocks; index_t page_block_size; - - TensorViewLengths lengths; - TensorViewStrides strides; }; -template -CK_TILE_DEVICE constexpr auto -make_tile_window_navigator(copy_const_t* physical_blocks, - long_index_t block_stride, - long_index_t fixed_offset, - const int32_t* physical_block_indices, - index_t num_blocks, - index_t page_block_size, - const TensorViewLengths& lengths, - const TensorViewStrides& strides) -{ - return PagedTileWindowNavigator( - physical_blocks, - block_stride, - fixed_offset, - physical_block_indices, - num_blocks, - page_block_size, - lengths, - strides); -} - } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index e43b736778..379435860a 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -542,21 +542,20 @@ struct FmhaFwdSplitKVKernel const index_t num_blocks = integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size); - return make_tile_window_navigator( - kargs.k_ptr, - kargs.batch_stride_k, + const long_index_t fixed_offset = static_cast(i_nhead_ / kargs.nhead_ratio_qk) * - kargs.nhead_stride_k, - block_indices, - num_blocks, - kargs.page_block_size, - make_tuple(kargs.page_block_size, kargs.hdim_q), - make_tuple(kargs.stride_k, 1)); + kargs.nhead_stride_k; + + return PagedTileWindowNavigator(kargs.k_ptr, + kargs.batch_stride_k, + fixed_offset, + block_indices, + num_blocks, + kargs.page_block_size); } else { - return make_tile_window_navigator( - make_tuple(kargs.seqlen_k, kargs.hdim_q), make_tuple(kargs.stride_k, 1)); + return SimpleTileWindowNavigator(); } }(); @@ -672,10 +671,21 @@ struct FmhaFwdSplitKVKernel } }(); const auto k_dram = [&]() { + const auto lengths = [&]() { + if constexpr(kIsPagedKV) + { + return make_tuple(kargs.page_block_size, kargs.hdim_q); + } + else + { + return make_tuple(kargs.seqlen_k, kargs.hdim_q); + } + }(); + const auto k_dram_naive = make_naive_tensor_view( - k_ptr, - k_tile_navigator.lengths, - k_tile_navigator.strides, + k_ptr, // will update this pointer if using paged-kvcache + lengths, + make_tuple(kargs.stride_k, 1), number{}, number<1>{});