mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Simplify TileWindowNavigator interfaces
This commit is contained in:
@@ -8,18 +8,17 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename DataType_>
|
||||
template <typename DataType_, typename TensorViewLengths_, typename TensorViewStrides_>
|
||||
struct SimpleTileWindowNavigator
|
||||
{
|
||||
using DataType = DataType_;
|
||||
using DataType = DataType_;
|
||||
using TensorViewLengths = TensorViewLengths_;
|
||||
using TensorViewStrides = TensorViewStrides_;
|
||||
|
||||
template <typename TensorView, typename WindowLengths>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
make_tile_window(const TensorView& tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const multi_index<TensorView::get_num_of_dimension()>& window_origin)
|
||||
CK_TILE_DEVICE constexpr SimpleTileWindowNavigator(const TensorViewLengths& lengths_,
|
||||
const TensorViewStrides& strides_)
|
||||
: lengths(lengths_), strides(strides_)
|
||||
{
|
||||
return ck_tile::make_tile_window(tensor_view, window_lengths, window_origin);
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths>
|
||||
@@ -30,14 +29,6 @@ struct SimpleTileWindowNavigator
|
||||
return ck_tile::make_tile_window(tile_window, window_origin);
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const StaticTileDistribution& tile_distribution)
|
||||
{
|
||||
return ck_tile::make_tile_window(tile_window, tile_distribution);
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_DEVICE void
|
||||
move_tile_window(TileWindow& tile_window,
|
||||
@@ -45,14 +36,30 @@ struct SimpleTileWindowNavigator
|
||||
{
|
||||
ck_tile::move_tile_window(tile_window, step);
|
||||
}
|
||||
|
||||
TensorViewLengths lengths;
|
||||
TensorViewStrides strides;
|
||||
};
|
||||
|
||||
template <typename DataType_, index_t VirtualDim_>
|
||||
template <typename DataType, typename TensorViewLengths, typename TensorViewStrides>
|
||||
CK_TILE_DEVICE constexpr auto make_tile_window_navigator(const TensorViewLengths& lengths,
|
||||
const TensorViewStrides& strides)
|
||||
{
|
||||
return SimpleTileWindowNavigator<DataType, TensorViewLengths, TensorViewStrides>(lengths,
|
||||
strides);
|
||||
}
|
||||
|
||||
template <typename DataType_,
|
||||
index_t VirtualDim_,
|
||||
typename TensorViewLengths_,
|
||||
typename TensorViewStrides_>
|
||||
struct PagedTileWindowNavigator
|
||||
{
|
||||
using DataType = DataType_;
|
||||
static constexpr index_t VirtualDim = VirtualDim_;
|
||||
static_assert(VirtualDim == 0 || VirtualDim == 1);
|
||||
using TensorViewLengths = TensorViewLengths_;
|
||||
using TensorViewStrides = TensorViewStrides_;
|
||||
|
||||
CK_TILE_DEVICE constexpr PagedTileWindowNavigator(copy_const_t<DataType, void>* blocks_,
|
||||
long_index_t block_stride_,
|
||||
@@ -60,46 +67,30 @@ struct PagedTileWindowNavigator
|
||||
long_index_t row_stride_,
|
||||
const int32_t* block_indices_,
|
||||
index_t num_blocks_,
|
||||
index_t page_block_size_)
|
||||
index_t page_block_size_,
|
||||
const TensorViewLengths& lengths_,
|
||||
const TensorViewStrides& strides_)
|
||||
: blocks(reinterpret_cast<DataType*>(blocks_)),
|
||||
block_stride(block_stride_),
|
||||
head_stride(head_stride_),
|
||||
row_stride(row_stride_),
|
||||
block_indices(block_indices_),
|
||||
num_blocks(num_blocks_),
|
||||
page_block_size(page_block_size_)
|
||||
page_block_size(page_block_size_),
|
||||
lengths(lengths_),
|
||||
strides(strides_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths>
|
||||
CK_TILE_DEVICE auto
|
||||
make_tile_window(const TensorView& tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const multi_index<TensorView::get_num_of_dimension()>& window_origin) const
|
||||
{
|
||||
auto tile_window = ck_tile::make_tile_window(tensor_view, window_lengths, window_origin);
|
||||
/// TODO: convert global window origin to local window origin
|
||||
return tile_window;
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
|
||||
CK_TILE_DEVICE auto
|
||||
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const StaticTileDistribution& tile_distribution) const
|
||||
{
|
||||
auto new_tile_window = ck_tile::make_tile_window(tile_window, tile_distribution);
|
||||
/// TODO: convert global window origin to local window origin
|
||||
return new_tile_window;
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths>
|
||||
CK_TILE_DEVICE auto
|
||||
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const multi_index<TensorView::get_num_of_dimension()>& window_origin) const
|
||||
{
|
||||
auto new_tile_window = ck_tile::make_tile_window(tile_window, window_origin);
|
||||
/// TODO: convert global window origin to local window origin
|
||||
return new_tile_window;
|
||||
auto local_window_origin = window_origin;
|
||||
|
||||
return ck_tile::make_tile_window(tile_window, local_window_origin);
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
@@ -107,10 +98,10 @@ struct PagedTileWindowNavigator
|
||||
move_tile_window(TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
|
||||
{
|
||||
/// TODO: reset pointer and adjust local window origin
|
||||
ck_tile::move_tile_window(tile_window, step);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType* get_block_base(index_t block_index)
|
||||
{
|
||||
return blocks + block_index * block_stride + head_stride;
|
||||
@@ -126,6 +117,35 @@ struct PagedTileWindowNavigator
|
||||
const int32_t* block_indices;
|
||||
index_t num_blocks;
|
||||
index_t page_block_size;
|
||||
|
||||
TensorViewLengths lengths;
|
||||
TensorViewStrides strides;
|
||||
};
|
||||
|
||||
template <typename DataType,
|
||||
index_t VirtualDim,
|
||||
typename TensorViewLengths,
|
||||
typename TensorViewStrides>
|
||||
CK_TILE_DEVICE constexpr auto make_tile_window_navigator(copy_const_t<DataType, void>* blocks,
|
||||
long_index_t block_stride,
|
||||
long_index_t head_stride,
|
||||
long_index_t row_stride,
|
||||
const int32_t* block_indices,
|
||||
index_t num_blocks,
|
||||
index_t page_block_size,
|
||||
const TensorViewLengths& lengths,
|
||||
const TensorViewStrides& strides)
|
||||
{
|
||||
return PagedTileWindowNavigator<DataType, VirtualDim, TensorViewLengths, TensorViewStrides>(
|
||||
blocks,
|
||||
block_stride,
|
||||
head_stride,
|
||||
row_stride,
|
||||
block_indices,
|
||||
num_blocks,
|
||||
page_block_size,
|
||||
lengths,
|
||||
strides);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -542,17 +542,21 @@ struct FmhaFwdSplitKVKernel
|
||||
const index_t num_blocks =
|
||||
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
|
||||
|
||||
return PagedTileWindowNavigator<const KDataType, 0>(kargs.k_ptr,
|
||||
kargs.batch_stride_k,
|
||||
kargs.nhead_stride_k,
|
||||
kargs.stride_k,
|
||||
block_indices,
|
||||
num_blocks,
|
||||
kargs.page_block_size);
|
||||
return make_tile_window_navigator<const KDataType, 0>(
|
||||
kargs.k_ptr,
|
||||
kargs.batch_stride_k,
|
||||
kargs.nhead_stride_k,
|
||||
kargs.stride_k,
|
||||
block_indices,
|
||||
num_blocks,
|
||||
kargs.page_block_size,
|
||||
make_tuple(kargs.seqlen_k, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_k, 1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return SimpleTileWindowNavigator<KDataType>();
|
||||
return make_tile_window_navigator<KDataType>(
|
||||
make_tuple(kargs.seqlen_k, kargs.hdim_q), make_tuple(kargs.stride_k, 1));
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -688,8 +692,8 @@ struct FmhaFwdSplitKVKernel
|
||||
const auto k_dram = [&]() {
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.seqlen_k, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_k, 1),
|
||||
k_tile_navigator.lengths,
|
||||
k_tile_navigator.strides,
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user