mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Introduce 'TileWindowNavigator' types
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_tile_window_navigator.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
|
||||
|
||||
131
include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp
Normal file
131
include/ck_tile/ops/fmha/block/block_tile_window_navigator.hpp
Normal file
@@ -0,0 +1,131 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename DataType_>
|
||||
struct SimpleTileWindowNavigator
|
||||
{
|
||||
using DataType = DataType_;
|
||||
|
||||
template <typename TensorView, typename WindowLengths>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
make_tile_window(const TensorView& tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const multi_index<TensorView::get_num_of_dimension()>& window_origin)
|
||||
{
|
||||
return ck_tile::make_tile_window(tensor_view, window_lengths, window_origin);
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const multi_index<TensorView::get_num_of_dimension()>& window_origin)
|
||||
{
|
||||
return ck_tile::make_tile_window(tile_window, window_origin);
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const StaticTileDistribution& tile_distribution)
|
||||
{
|
||||
return ck_tile::make_tile_window(tile_window, tile_distribution);
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_DEVICE void
|
||||
move_tile_window(TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step)
|
||||
{
|
||||
ck_tile::move_tile_window(tile_window, step);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DataType_, index_t VirtualDim_>
|
||||
struct PagedTileWindowNavigator
|
||||
{
|
||||
using DataType = DataType_;
|
||||
static constexpr index_t VirtualDim = VirtualDim_;
|
||||
static_assert(VirtualDim == 0 || VirtualDim == 1);
|
||||
|
||||
CK_TILE_DEVICE constexpr PagedTileWindowNavigator(copy_const_t<DataType, void>* blocks_,
|
||||
long_index_t block_stride_,
|
||||
long_index_t head_stride_,
|
||||
long_index_t row_stride_,
|
||||
const int32_t* block_indices_,
|
||||
index_t num_blocks_,
|
||||
index_t page_block_size_)
|
||||
: blocks(reinterpret_cast<DataType*>(blocks_)),
|
||||
block_stride(block_stride_),
|
||||
head_stride(head_stride_),
|
||||
row_stride(row_stride_),
|
||||
block_indices(block_indices_),
|
||||
num_blocks(num_blocks_),
|
||||
page_block_size(page_block_size_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths>
|
||||
CK_TILE_DEVICE auto
|
||||
make_tile_window(const TensorView& tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const multi_index<TensorView::get_num_of_dimension()>& window_origin) const
|
||||
{
|
||||
auto tile_window = ck_tile::make_tile_window(tensor_view, window_lengths, window_origin);
|
||||
/// TODO: convert global window origin to local window origin
|
||||
return tile_window;
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
|
||||
CK_TILE_DEVICE auto
|
||||
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const StaticTileDistribution& tile_distribution) const
|
||||
{
|
||||
auto new_tile_window = ck_tile::make_tile_window(tile_window, tile_distribution);
|
||||
/// TODO: convert global window origin to local window origin
|
||||
return new_tile_window;
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths>
|
||||
CK_TILE_DEVICE auto
|
||||
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const multi_index<TensorView::get_num_of_dimension()>& window_origin) const
|
||||
{
|
||||
auto new_tile_window = ck_tile::make_tile_window(tile_window, window_origin);
|
||||
/// TODO: convert global window origin to local window origin
|
||||
return new_tile_window;
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_DEVICE void
|
||||
move_tile_window(TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
|
||||
{
|
||||
ck_tile::move_tile_window(tile_window, step);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType* get_block_base(index_t block_index)
|
||||
{
|
||||
return blocks + block_index * block_stride + head_stride;
|
||||
}
|
||||
|
||||
DataType* base(index_t i_virtual) { return get_block_base(); }
|
||||
|
||||
DataType* blocks;
|
||||
long_index_t block_stride;
|
||||
long_index_t head_stride;
|
||||
long_index_t row_stride;
|
||||
|
||||
const int32_t* block_indices;
|
||||
index_t num_blocks;
|
||||
index_t page_block_size;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -533,6 +533,29 @@ struct FmhaFwdSplitKVKernel
|
||||
const long_index_t batch_offset_o_acc =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
|
||||
auto k_tile_navigator = [&, i_batch_ = i_batch]() {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
const auto* block_indices =
|
||||
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
|
||||
i_batch_ * kargs.batch_stride_block_table;
|
||||
const index_t num_blocks =
|
||||
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
|
||||
|
||||
return PagedTileWindowNavigator<const KDataType, 0>(kargs.k_ptr,
|
||||
kargs.batch_stride_k,
|
||||
kargs.nhead_stride_k,
|
||||
kargs.stride_k,
|
||||
block_indices,
|
||||
num_blocks,
|
||||
kargs.page_block_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
return SimpleTileWindowNavigator<KDataType>();
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
@@ -582,22 +605,21 @@ struct FmhaFwdSplitKVKernel
|
||||
else
|
||||
{
|
||||
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
|
||||
if(true || kargs.block_table_ptr == nullptr)
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
const auto* block_indices =
|
||||
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
|
||||
i_batch * kargs.batch_stride_block_table;
|
||||
const index_t num_blocks =
|
||||
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
|
||||
|
||||
// batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
// batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto* block_table = reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
|
||||
i_batch * kargs.batch_stride_block_table;
|
||||
const auto i_block =
|
||||
static_cast<long_index_t>(block_table[i_n1 / kargs.page_block_size]);
|
||||
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
// batch_offset_k = i_block * kargs.batch_stride_k;
|
||||
batch_offset_v = i_block * kargs.batch_stride_v;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
}
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
@@ -615,14 +637,27 @@ struct FmhaFwdSplitKVKernel
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
|
||||
batch_offset_q;
|
||||
const KDataType* k_ptr =
|
||||
reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
|
||||
batch_offset_k;
|
||||
const KDataType* k_ptr = [&, i_nhead_ = i_nhead]() {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_k +
|
||||
batch_offset_k;
|
||||
}
|
||||
else
|
||||
{
|
||||
return reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_k +
|
||||
batch_offset_k;
|
||||
}
|
||||
}();
|
||||
const VDataType* v_ptr =
|
||||
reinterpret_cast<const VDataType*>(kargs.v_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
|
||||
batch_offset_v;
|
||||
|
||||
OaccDataType* o_acc_ptr = reinterpret_cast<OaccDataType*>(kargs.o_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc +
|
||||
batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
|
||||
@@ -899,7 +934,8 @@ struct FmhaFwdSplitKVKernel
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
dropout,
|
||||
k_tile_navigator);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -915,7 +951,8 @@ struct FmhaFwdSplitKVKernel
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
dropout,
|
||||
k_tile_navigator);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -120,7 +120,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding>
|
||||
typename PositionEncoding,
|
||||
typename KTileWindowNavigator>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
@@ -142,7 +143,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
BlockDropout& dropout,
|
||||
KTileWindowNavigator& k_tile_navigator) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -239,9 +241,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
}
|
||||
|
||||
auto k_dram_block_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0});
|
||||
k_tile_navigator.make_tile_window(k_dram_block_window_tmp, {seqlen_k_start, 0});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window = make_tile_window(
|
||||
@@ -272,15 +272,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
{
|
||||
// STAGE 1, QK gemm
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
k_dram_block_window,
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
{
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
k_tile_navigator.move_tile_window(k_dram_window, {0, kK0});
|
||||
clear_tile(s_acc); // initialize C
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
@@ -308,7 +306,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
k_lds_window);
|
||||
block_sync_lds();
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
k_tile_navigator.move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
store_tile(
|
||||
k_lds_window,
|
||||
@@ -557,7 +555,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
});
|
||||
}
|
||||
// move K tile windows
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
k_tile_navigator.move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
@@ -624,7 +622,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEaccDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
typename PositionEncoding,
|
||||
typename KTileWindowNavigator>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
@@ -638,7 +637,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr,
|
||||
BlockDropout& dropout) const
|
||||
BlockDropout& dropout,
|
||||
KTileWindowNavigator& k_tile_navigator) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -660,7 +660,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
position_encoding,
|
||||
scale_s,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
dropout,
|
||||
k_tile_navigator);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user