Remove ununnecessary data members

This commit is contained in:
PoYen, Chen
2024-08-05 21:52:59 +00:00
parent 3fc7279519
commit bb78353264
2 changed files with 31 additions and 77 deletions

View File

@@ -8,19 +8,11 @@
namespace ck_tile {
template <typename DataType_, typename TensorViewLengths_, typename TensorViewStrides_>
template <typename DataType_>
struct SimpleTileWindowNavigator
{
using DataType = DataType_;
using WindowOrigin = multi_index<2>;
using TensorViewLengths = TensorViewLengths_;
using TensorViewStrides = TensorViewStrides_;
CK_TILE_DEVICE constexpr SimpleTileWindowNavigator(const TensorViewLengths& lengths_,
const TensorViewStrides& strides_)
: lengths(lengths_), strides(strides_)
{
}
using DataType = DataType_;
using WindowOrigin = multi_index<2>;
template <typename TensorView, typename WindowLengths>
CK_TILE_DEVICE static constexpr auto
@@ -60,31 +52,15 @@ struct SimpleTileWindowNavigator
{
return local_window_origin;
}
TensorViewLengths lengths;
TensorViewStrides strides;
};
template <typename DataType, typename TensorViewLengths, typename TensorViewStrides>
CK_TILE_DEVICE constexpr auto make_tile_window_navigator(const TensorViewLengths& lengths,
const TensorViewStrides& strides)
{
return SimpleTileWindowNavigator<DataType, TensorViewLengths, TensorViewStrides>(lengths,
strides);
}
template <typename DataType_,
index_t VirtualDim_,
typename TensorViewLengths_,
typename TensorViewStrides_>
template <typename DataType_, index_t VirtualDim_>
struct PagedTileWindowNavigator
{
using DataType = DataType_;
static constexpr index_t VirtualDim = VirtualDim_;
static_assert(VirtualDim == 0 || VirtualDim == 1, "only support 2d tile window");
using WindowOrigin = multi_index<2>;
using TensorViewLengths = TensorViewLengths_;
using TensorViewStrides = TensorViewStrides_;
using WindowOrigin = multi_index<2>;
CK_TILE_DEVICE constexpr PagedTileWindowNavigator(
copy_const_t<DataType, void>* physical_blocks_,
@@ -92,17 +68,13 @@ struct PagedTileWindowNavigator
long_index_t fixed_offset_,
const int32_t* physical_block_indices_,
index_t num_blocks_,
index_t page_block_size_,
const TensorViewLengths& lengths_,
const TensorViewStrides& strides_)
index_t page_block_size_)
: physical_blocks(reinterpret_cast<DataType*>(physical_blocks_)),
block_stride(block_stride_),
fixed_offset(fixed_offset_),
physical_block_indices(physical_block_indices_),
num_blocks(num_blocks_),
page_block_size(page_block_size_),
lengths(lengths_),
strides(strides_)
page_block_size(page_block_size_)
{
}
@@ -194,34 +166,6 @@ struct PagedTileWindowNavigator
const int32_t* physical_block_indices;
index_t num_blocks;
index_t page_block_size;
TensorViewLengths lengths;
TensorViewStrides strides;
};
template <typename DataType,
index_t VirtualDim,
typename TensorViewLengths,
typename TensorViewStrides>
CK_TILE_DEVICE constexpr auto
make_tile_window_navigator(copy_const_t<DataType, void>* physical_blocks,
long_index_t block_stride,
long_index_t fixed_offset,
const int32_t* physical_block_indices,
index_t num_blocks,
index_t page_block_size,
const TensorViewLengths& lengths,
const TensorViewStrides& strides)
{
return PagedTileWindowNavigator<DataType, VirtualDim, TensorViewLengths, TensorViewStrides>(
physical_blocks,
block_stride,
fixed_offset,
physical_block_indices,
num_blocks,
page_block_size,
lengths,
strides);
}
} // namespace ck_tile

View File

@@ -542,21 +542,20 @@ struct FmhaFwdSplitKVKernel
const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
return make_tile_window_navigator<const KDataType, 0>(
kargs.k_ptr,
kargs.batch_stride_k,
const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_k,
block_indices,
num_blocks,
kargs.page_block_size,
make_tuple(kargs.page_block_size, kargs.hdim_q),
make_tuple(kargs.stride_k, 1));
kargs.nhead_stride_k;
return PagedTileWindowNavigator<const KDataType, 0>(kargs.k_ptr,
kargs.batch_stride_k,
fixed_offset,
block_indices,
num_blocks,
kargs.page_block_size);
}
else
{
return make_tile_window_navigator<const KDataType>(
make_tuple(kargs.seqlen_k, kargs.hdim_q), make_tuple(kargs.stride_k, 1));
return SimpleTileWindowNavigator<const KDataType>();
}
}();
@@ -672,10 +671,21 @@ struct FmhaFwdSplitKVKernel
}
}();
const auto k_dram = [&]() {
const auto lengths = [&]() {
if constexpr(kIsPagedKV)
{
return make_tuple(kargs.page_block_size, kargs.hdim_q);
}
else
{
return make_tuple(kargs.seqlen_k, kargs.hdim_q);
}
}();
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
k_tile_navigator.lengths,
k_tile_navigator.strides,
k_ptr, // will update this pointer if using paged-kvcache
lengths,
make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});