Introduce 'TileWindowNavigator' types

This commit is contained in:
PoYen, Chen
2024-08-05 15:58:41 +00:00
parent 55b77cf962
commit 1c9d77b606
4 changed files with 201 additions and 31 deletions

View File

@@ -8,6 +8,7 @@
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
#include "ck_tile/ops/fmha/block/block_tile_window_navigator.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"

View File

@@ -0,0 +1,131 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
namespace ck_tile {
template <typename DataType_>
struct SimpleTileWindowNavigator
{
using DataType = DataType_;
template <typename TensorView, typename WindowLengths>
CK_TILE_DEVICE static constexpr auto
make_tile_window(const TensorView& tensor_view,
const WindowLengths& window_lengths,
const multi_index<TensorView::get_num_of_dimension()>& window_origin)
{
return ck_tile::make_tile_window(tensor_view, window_lengths, window_origin);
}
template <typename TensorView, typename WindowLengths>
CK_TILE_DEVICE static constexpr auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const multi_index<TensorView::get_num_of_dimension()>& window_origin)
{
return ck_tile::make_tile_window(tile_window, window_origin);
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE constexpr auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution)
{
return ck_tile::make_tile_window(tile_window, tile_distribution);
}
template <typename TileWindow>
CK_TILE_DEVICE void
move_tile_window(TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step)
{
ck_tile::move_tile_window(tile_window, step);
}
};
template <typename DataType_, index_t VirtualDim_>
struct PagedTileWindowNavigator
{
using DataType = DataType_;
static constexpr index_t VirtualDim = VirtualDim_;
static_assert(VirtualDim == 0 || VirtualDim == 1);
CK_TILE_DEVICE constexpr PagedTileWindowNavigator(copy_const_t<DataType, void>* blocks_,
long_index_t block_stride_,
long_index_t head_stride_,
long_index_t row_stride_,
const int32_t* block_indices_,
index_t num_blocks_,
index_t page_block_size_)
: blocks(reinterpret_cast<DataType*>(blocks_)),
block_stride(block_stride_),
head_stride(head_stride_),
row_stride(row_stride_),
block_indices(block_indices_),
num_blocks(num_blocks_),
page_block_size(page_block_size_)
{
}
template <typename TensorView, typename WindowLengths>
CK_TILE_DEVICE auto
make_tile_window(const TensorView& tensor_view,
const WindowLengths& window_lengths,
const multi_index<TensorView::get_num_of_dimension()>& window_origin) const
{
auto tile_window = ck_tile::make_tile_window(tensor_view, window_lengths, window_origin);
/// TODO: convert global window origin to local window origin
return tile_window;
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution) const
{
auto new_tile_window = ck_tile::make_tile_window(tile_window, tile_distribution);
/// TODO: convert global window origin to local window origin
return new_tile_window;
}
template <typename TensorView, typename WindowLengths>
CK_TILE_DEVICE auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const multi_index<TensorView::get_num_of_dimension()>& window_origin) const
{
auto new_tile_window = ck_tile::make_tile_window(tile_window, window_origin);
/// TODO: convert global window origin to local window origin
return new_tile_window;
}
template <typename TileWindow>
CK_TILE_DEVICE void
move_tile_window(TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
{
ck_tile::move_tile_window(tile_window, step);
}
private:
DataType* get_block_base(index_t block_index)
{
return blocks + block_index * block_stride + head_stride;
}
DataType* base(index_t i_virtual) { return get_block_base(); }
DataType* blocks;
long_index_t block_stride;
long_index_t head_stride;
long_index_t row_stride;
const int32_t* block_indices;
index_t num_blocks;
index_t page_block_size;
};
} // namespace ck_tile

View File

@@ -533,6 +533,29 @@ struct FmhaFwdSplitKVKernel
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
auto k_tile_navigator = [&, i_batch_ = i_batch]() {
if constexpr(kIsPagedKV)
{
const auto* block_indices =
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch_ * kargs.batch_stride_block_table;
const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
return PagedTileWindowNavigator<const KDataType, 0>(kargs.k_ptr,
kargs.batch_stride_k,
kargs.nhead_stride_k,
kargs.stride_k,
block_indices,
num_blocks,
kargs.page_block_size);
}
else
{
return SimpleTileWindowNavigator<KDataType>();
}
}();
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
@@ -582,22 +605,21 @@ struct FmhaFwdSplitKVKernel
else
{
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
if(true || kargs.block_table_ptr == nullptr)
if constexpr(kIsPagedKV)
{
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
const auto* block_indices =
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch * kargs.batch_stride_block_table;
const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
// batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
// batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
}
else
{
const auto* block_table = reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch * kargs.batch_stride_block_table;
const auto i_block =
static_cast<long_index_t>(block_table[i_n1 / kargs.page_block_size]);
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
// batch_offset_k = i_block * kargs.batch_stride_k;
batch_offset_v = i_block * kargs.batch_stride_v;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
@@ -615,14 +637,27 @@ struct FmhaFwdSplitKVKernel
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
batch_offset_q;
const KDataType* k_ptr =
reinterpret_cast<const KDataType*>(kargs.k_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
batch_offset_k;
const KDataType* k_ptr = [&, i_nhead_ = i_nhead]() {
if constexpr(kIsPagedKV)
{
return reinterpret_cast<const KDataType*>(kargs.k_ptr) +
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_k +
batch_offset_k;
}
else
{
return reinterpret_cast<const KDataType*>(kargs.k_ptr) +
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_k +
batch_offset_k;
}
}();
const VDataType* v_ptr =
reinterpret_cast<const VDataType*>(kargs.v_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
batch_offset_v;
OaccDataType* o_acc_ptr = reinterpret_cast<OaccDataType*>(kargs.o_acc_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc +
batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
@@ -899,7 +934,8 @@ struct FmhaFwdSplitKVKernel
position_encoding,
kargs.scale_s,
smem_ptr,
dropout);
dropout,
k_tile_navigator);
}
else
{
@@ -915,7 +951,8 @@ struct FmhaFwdSplitKVKernel
position_encoding,
kargs.scale_s,
smem_ptr,
dropout);
dropout,
k_tile_navigator);
}
}();

View File

@@ -120,7 +120,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding>
typename PositionEncoding,
typename KTileWindowNavigator>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
@@ -142,7 +143,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
BlockDropout& dropout,
KTileWindowNavigator& k_tile_navigator) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -239,9 +241,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
k_tile_navigator.make_tile_window(k_dram_block_window_tmp, {seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window(
@@ -272,15 +272,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{
// STAGE 1, QK gemm
auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
k_dram_block_window,
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
auto k_block_tile = load_tile(k_dram_window);
{
move_tile_window(k_dram_window, {0, kK0});
k_tile_navigator.move_tile_window(k_dram_window, {0, kK0});
clear_tile(s_acc); // initialize C
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
k_block_tile = load_tile(k_dram_window);
@@ -308,7 +306,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window);
block_sync_lds();
move_tile_window(k_dram_window, {0, kK0});
k_tile_navigator.move_tile_window(k_dram_window, {0, kK0});
store_tile(
k_lds_window,
@@ -557,7 +555,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
});
}
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
k_tile_navigator.move_tile_window(k_dram_block_window, {kN0, 0});
// tail
{
block_sync_lds();
@@ -624,7 +622,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename PositionEncoding>
typename PositionEncoding,
typename KTileWindowNavigator>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
@@ -638,7 +637,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
BlockDropout& dropout,
KTileWindowNavigator& k_tile_navigator) const
{
return operator()(q_dram_block_window_tmp,
identity{},
@@ -660,7 +660,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
position_encoding,
scale_s,
smem_ptr,
dropout);
dropout,
k_tile_navigator);
}
};