diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 17aa04f53d..cd26c09206 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/block/block_position_encoding.hpp" #include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" +#include "ck_tile/ops/fmha/block/block_tile_window_navigator.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp b/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp new file mode 100644 index 0000000000..ed39baa0e1 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" + +namespace ck_tile { + +template +struct SimpleTileWindowNavigator +{ + using DataType = DataType_; + + template + CK_TILE_DEVICE static constexpr auto + make_tile_window(const TensorView& tensor_view, + const WindowLengths& window_lengths, + const multi_index& window_origin) + { + return ck_tile::make_tile_window(tensor_view, window_lengths, window_origin); + } + + template + CK_TILE_DEVICE static constexpr auto + make_tile_window(const tile_window_with_static_lengths& tile_window, + const multi_index& window_origin) + { + return ck_tile::make_tile_window(tile_window, window_origin); + } + + template + CK_TILE_DEVICE constexpr auto + make_tile_window(const tile_window_with_static_lengths& tile_window, + const StaticTileDistribution& tile_distribution) + { + return ck_tile::make_tile_window(tile_window, tile_distribution); + } + + template + CK_TILE_DEVICE void + move_tile_window(TileWindow& tile_window, + const typename remove_cvref_t::BottomTensorIndex& step) + { + ck_tile::move_tile_window(tile_window, step); + } +}; + +template +struct PagedTileWindowNavigator +{ + using DataType = DataType_; + static constexpr index_t VirtualDim = VirtualDim_; + static_assert(VirtualDim == 0 || VirtualDim == 1); + + CK_TILE_DEVICE constexpr PagedTileWindowNavigator(copy_const_t* blocks_, + long_index_t block_stride_, + long_index_t head_stride_, + long_index_t row_stride_, + const int32_t* block_indices_, + index_t num_blocks_, + index_t page_block_size_) + : blocks(reinterpret_cast(blocks_)), + block_stride(block_stride_), + head_stride(head_stride_), + row_stride(row_stride_), + block_indices(block_indices_), + num_blocks(num_blocks_), + page_block_size(page_block_size_) + { + } + + template + CK_TILE_DEVICE auto + make_tile_window(const TensorView& tensor_view, + const WindowLengths& window_lengths, + const multi_index& window_origin) const + { + auto tile_window = ck_tile::make_tile_window(tensor_view, window_lengths, window_origin); + /// TODO: convert global window origin to local window origin + return tile_window; + } + + template + CK_TILE_DEVICE auto + make_tile_window(const tile_window_with_static_lengths& tile_window, + const StaticTileDistribution& tile_distribution) const + { + auto new_tile_window = ck_tile::make_tile_window(tile_window, tile_distribution); + /// TODO: convert global window origin to local window origin + return new_tile_window; + } + + template + CK_TILE_DEVICE auto + make_tile_window(const tile_window_with_static_lengths& tile_window, + const multi_index& window_origin) const + { + auto new_tile_window = ck_tile::make_tile_window(tile_window, window_origin); + /// TODO: convert global window origin to local window origin + return new_tile_window; + } + + template + CK_TILE_DEVICE void + move_tile_window(TileWindow& tile_window, + const typename remove_cvref_t::BottomTensorIndex& step) const + { + ck_tile::move_tile_window(tile_window, step); + } + + private: + DataType* get_block_base(index_t block_index) + { + return blocks + block_index * block_stride + head_stride; + } + + DataType* base(index_t i_virtual) { return get_block_base(); } + + DataType* blocks; + long_index_t block_stride; + long_index_t head_stride; + long_index_t row_stride; + + const int32_t* block_indices; + index_t num_blocks; + index_t page_block_size; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index fdb4407691..ca54a496d6 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -533,6 +533,29 @@ struct FmhaFwdSplitKVKernel const long_index_t batch_offset_o_acc = static_cast(i_batch) * kargs.batch_stride_o_acc; + auto k_tile_navigator = [&, i_batch_ = i_batch]() { + if constexpr(kIsPagedKV) + { + const auto* block_indices = + reinterpret_cast(kargs.block_table_ptr) + + i_batch_ * kargs.batch_stride_block_table; + const index_t num_blocks = + integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size); + + return PagedTileWindowNavigator(kargs.k_ptr, + kargs.batch_stride_k, + kargs.nhead_stride_k, + kargs.stride_k, + block_indices, + num_blocks, + kargs.page_block_size); + } + else + { + return SimpleTileWindowNavigator(); + } + }(); + if constexpr(kIsGroupMode) { // get starting offset for each batch @@ -582,22 +605,21 @@ struct FmhaFwdSplitKVKernel else { batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - - if(true || kargs.block_table_ptr == nullptr) + if constexpr(kIsPagedKV) { - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + const auto* block_indices = + reinterpret_cast(kargs.block_table_ptr) + + i_batch * kargs.batch_stride_block_table; + const index_t num_blocks = + integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size); + + // batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + // batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; } else { - const auto* block_table = reinterpret_cast(kargs.block_table_ptr) + - i_batch * kargs.batch_stride_block_table; - const auto i_block = - static_cast(block_table[i_n1 / kargs.page_block_size]); - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - // batch_offset_k = i_block * kargs.batch_stride_k; - batch_offset_v = i_block * kargs.batch_stride_v; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) @@ -615,14 +637,27 @@ struct FmhaFwdSplitKVKernel const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + static_cast(i_nhead) * kargs.nhead_stride_q + batch_offset_q; - const KDataType* k_ptr = - reinterpret_cast(kargs.k_ptr) + - static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + - batch_offset_k; + const KDataType* k_ptr = [&, i_nhead_ = i_nhead]() { + if constexpr(kIsPagedKV) + { + return reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead_ / kargs.nhead_ratio_qk) * + kargs.nhead_stride_k + + batch_offset_k; + } + else + { + return reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead_ / kargs.nhead_ratio_qk) * + kargs.nhead_stride_k + + batch_offset_k; + } + }(); const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + batch_offset_v; + OaccDataType* o_acc_ptr = reinterpret_cast(kargs.o_acc_ptr) + static_cast(i_nhead) * kargs.nhead_stride_o_acc + batch_offset_o_acc + i_split * kargs.split_stride_o_acc; @@ -899,7 +934,8 @@ struct FmhaFwdSplitKVKernel position_encoding, kargs.scale_s, smem_ptr, - dropout); + dropout, + k_tile_navigator); } else { @@ -915,7 +951,8 @@ struct FmhaFwdSplitKVKernel position_encoding, kargs.scale_s, smem_ptr, - dropout); + dropout, + k_tile_navigator); } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index ed14fa8bbe..5b54915237 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -120,7 +120,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, - typename PositionEncoding> + typename PositionEncoding, + typename KTileWindowNavigator> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -142,7 +143,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS PositionEncoding position_encoding, float scale_s, void* smem_ptr, - BlockDropout& dropout) const + BlockDropout& dropout, + KTileWindowNavigator& k_tile_navigator) const { static_assert( std::is_same_v> && @@ -239,9 +241,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } auto k_dram_block_window = - make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - k_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}); + k_tile_navigator.make_tile_window(k_dram_block_window_tmp, {seqlen_k_start, 0}); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window( @@ -272,15 +272,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS { // STAGE 1, QK gemm auto k_dram_window = make_tile_window( - k_dram_block_window.get_bottom_tensor_view(), - k_dram_block_window.get_window_lengths(), - k_dram_block_window.get_window_origin(), + k_dram_block_window, Policy::template MakeKDramTileDistribution()); // K DRAM tile window for // load auto k_block_tile = load_tile(k_dram_window); { - move_tile_window(k_dram_window, {0, kK0}); + k_tile_navigator.move_tile_window(k_dram_window, {0, kK0}); clear_tile(s_acc); // initialize C store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); k_block_tile = load_tile(k_dram_window); @@ -308,7 +306,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS sequence{}), k_lds_window); block_sync_lds(); - move_tile_window(k_dram_window, {0, kK0}); + k_tile_navigator.move_tile_window(k_dram_window, {0, kK0}); store_tile( k_lds_window, @@ -557,7 +555,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS }); } // move K tile windows - move_tile_window(k_dram_block_window, {kN0, 0}); + k_tile_navigator.move_tile_window(k_dram_block_window, {kN0, 0}); // tail { block_sync_lds(); @@ -624,7 +622,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, - typename PositionEncoding> + typename PositionEncoding, + typename KTileWindowNavigator> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -638,7 +637,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS PositionEncoding position_encoding, float scale_s, void* smem_ptr, - BlockDropout& dropout) const + BlockDropout& dropout, + KTileWindowNavigator& k_tile_navigator) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -660,7 +660,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS position_encoding, scale_s, smem_ptr, - dropout); + dropout, + k_tile_navigator); } };