Simplify TileWindowNavigator interfaces

This commit is contained in:
PoYen, Chen
2024-08-05 16:31:31 +00:00
parent 1c9d77b606
commit ecaaa6f136
2 changed files with 77 additions and 53 deletions

View File

@@ -8,18 +8,17 @@
namespace ck_tile {
template <typename DataType_>
template <typename DataType_, typename TensorViewLengths_, typename TensorViewStrides_>
struct SimpleTileWindowNavigator
{
using DataType = DataType_;
using DataType = DataType_;
using TensorViewLengths = TensorViewLengths_;
using TensorViewStrides = TensorViewStrides_;
template <typename TensorView, typename WindowLengths>
CK_TILE_DEVICE static constexpr auto
make_tile_window(const TensorView& tensor_view,
const WindowLengths& window_lengths,
const multi_index<TensorView::get_num_of_dimension()>& window_origin)
CK_TILE_DEVICE constexpr SimpleTileWindowNavigator(const TensorViewLengths& lengths_,
const TensorViewStrides& strides_)
: lengths(lengths_), strides(strides_)
{
return ck_tile::make_tile_window(tensor_view, window_lengths, window_origin);
}
template <typename TensorView, typename WindowLengths>
@@ -30,14 +29,6 @@ struct SimpleTileWindowNavigator
return ck_tile::make_tile_window(tile_window, window_origin);
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE constexpr auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution)
{
return ck_tile::make_tile_window(tile_window, tile_distribution);
}
template <typename TileWindow>
CK_TILE_DEVICE void
move_tile_window(TileWindow& tile_window,
@@ -45,14 +36,30 @@ struct SimpleTileWindowNavigator
{
ck_tile::move_tile_window(tile_window, step);
}
TensorViewLengths lengths;
TensorViewStrides strides;
};
template <typename DataType_, index_t VirtualDim_>
template <typename DataType, typename TensorViewLengths, typename TensorViewStrides>
CK_TILE_DEVICE constexpr auto make_tile_window_navigator(const TensorViewLengths& lengths,
const TensorViewStrides& strides)
{
return SimpleTileWindowNavigator<DataType, TensorViewLengths, TensorViewStrides>(lengths,
strides);
}
template <typename DataType_,
index_t VirtualDim_,
typename TensorViewLengths_,
typename TensorViewStrides_>
struct PagedTileWindowNavigator
{
using DataType = DataType_;
static constexpr index_t VirtualDim = VirtualDim_;
static_assert(VirtualDim == 0 || VirtualDim == 1);
using TensorViewLengths = TensorViewLengths_;
using TensorViewStrides = TensorViewStrides_;
CK_TILE_DEVICE constexpr PagedTileWindowNavigator(copy_const_t<DataType, void>* blocks_,
long_index_t block_stride_,
@@ -60,46 +67,30 @@ struct PagedTileWindowNavigator
long_index_t row_stride_,
const int32_t* block_indices_,
index_t num_blocks_,
index_t page_block_size_)
index_t page_block_size_,
const TensorViewLengths& lengths_,
const TensorViewStrides& strides_)
: blocks(reinterpret_cast<DataType*>(blocks_)),
block_stride(block_stride_),
head_stride(head_stride_),
row_stride(row_stride_),
block_indices(block_indices_),
num_blocks(num_blocks_),
page_block_size(page_block_size_)
page_block_size(page_block_size_),
lengths(lengths_),
strides(strides_)
{
}
template <typename TensorView, typename WindowLengths>
CK_TILE_DEVICE auto
make_tile_window(const TensorView& tensor_view,
const WindowLengths& window_lengths,
const multi_index<TensorView::get_num_of_dimension()>& window_origin) const
{
auto tile_window = ck_tile::make_tile_window(tensor_view, window_lengths, window_origin);
/// TODO: convert global window origin to local window origin
return tile_window;
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution) const
{
auto new_tile_window = ck_tile::make_tile_window(tile_window, tile_distribution);
/// TODO: convert global window origin to local window origin
return new_tile_window;
}
template <typename TensorView, typename WindowLengths>
CK_TILE_DEVICE auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const multi_index<TensorView::get_num_of_dimension()>& window_origin) const
{
auto new_tile_window = ck_tile::make_tile_window(tile_window, window_origin);
/// TODO: convert global window origin to local window origin
return new_tile_window;
auto local_window_origin = window_origin;
return ck_tile::make_tile_window(tile_window, local_window_origin);
}
template <typename TileWindow>
@@ -107,10 +98,10 @@ struct PagedTileWindowNavigator
move_tile_window(TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
{
/// TODO: reset pointer and adjust local window origin
ck_tile::move_tile_window(tile_window, step);
}
private:
DataType* get_block_base(index_t block_index)
{
return blocks + block_index * block_stride + head_stride;
@@ -126,6 +117,35 @@ struct PagedTileWindowNavigator
const int32_t* block_indices;
index_t num_blocks;
index_t page_block_size;
TensorViewLengths lengths;
TensorViewStrides strides;
};
template <typename DataType,
index_t VirtualDim,
typename TensorViewLengths,
typename TensorViewStrides>
CK_TILE_DEVICE constexpr auto make_tile_window_navigator(copy_const_t<DataType, void>* blocks,
long_index_t block_stride,
long_index_t head_stride,
long_index_t row_stride,
const int32_t* block_indices,
index_t num_blocks,
index_t page_block_size,
const TensorViewLengths& lengths,
const TensorViewStrides& strides)
{
return PagedTileWindowNavigator<DataType, VirtualDim, TensorViewLengths, TensorViewStrides>(
blocks,
block_stride,
head_stride,
row_stride,
block_indices,
num_blocks,
page_block_size,
lengths,
strides);
}
} // namespace ck_tile

View File

@@ -542,17 +542,21 @@ struct FmhaFwdSplitKVKernel
const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
return PagedTileWindowNavigator<const KDataType, 0>(kargs.k_ptr,
kargs.batch_stride_k,
kargs.nhead_stride_k,
kargs.stride_k,
block_indices,
num_blocks,
kargs.page_block_size);
return make_tile_window_navigator<const KDataType, 0>(
kargs.k_ptr,
kargs.batch_stride_k,
kargs.nhead_stride_k,
kargs.stride_k,
block_indices,
num_blocks,
kargs.page_block_size,
make_tuple(kargs.seqlen_k, kargs.hdim_q),
make_tuple(kargs.stride_k, 1));
}
else
{
return SimpleTileWindowNavigator<KDataType>();
return make_tile_window_navigator<KDataType>(
make_tuple(kargs.seqlen_k, kargs.hdim_q), make_tuple(kargs.stride_k, 1));
}
}();
@@ -688,8 +692,8 @@ struct FmhaFwdSplitKVKernel
const auto k_dram = [&]() {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
k_tile_navigator.lengths,
k_tile_navigator.strides,
number<FmhaPipeline::kAlignmentK>{},
number<1>{});