Add tile navigators to the appendkv kernel

This commit is contained in:
PoYen, Chen
2024-08-07 04:51:21 +00:00
parent 443a528adc
commit 7789b53e15
2 changed files with 128 additions and 8 deletions

View File

@@ -139,6 +139,9 @@ struct FmhaFwdAppendKVKernel
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
};
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
@@ -361,6 +364,58 @@ struct FmhaFwdAppendKVKernel
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
}
auto k_tile_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
if constexpr(kIsPagedKV)
{
const auto* block_indices =
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch_ * kargs.batch_stride_block_table;
const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_k;
return PagedTileWindowNavigator<const KDataType, 0>(kargs.k_ptr,
kargs.batch_stride_k,
fixed_offset,
block_indices,
num_blocks,
kargs.page_block_size);
}
else
{
return SimpleTileWindowNavigator<const KDataType>();
}
}();
auto v_tile_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
if constexpr(kIsPagedKV)
{
const auto* block_indices =
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch_ * kargs.batch_stride_block_table;
const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_v;
return PagedTileWindowNavigator<const VDataType, 1>(kargs.v_ptr,
kargs.batch_stride_v,
fixed_offset,
block_indices,
num_blocks,
kargs.page_block_size);
}
else
{
return SimpleTileWindowNavigator<const VDataType>();
}
}();
// for simplicity, batch stride we just modify the pointer
QDataType* q_ptr = reinterpret_cast<QDataType*>(kargs.q_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
@@ -397,9 +452,20 @@ struct FmhaFwdAppendKVKernel
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
const auto k_dram = [&]() {
const auto lengths = [&]() {
if constexpr(kIsPagedKV)
{
return make_tuple(kargs.page_block_size, kargs.hdim_q);
}
else
{
return make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.hdim_q);
}
}();
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.hdim_q),
lengths,
make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
@@ -425,9 +491,20 @@ struct FmhaFwdAppendKVKernel
const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
const auto lengths = [&]() {
if constexpr(kIsPagedKV)
{
return make_tuple(kargs.page_block_size, kargs.hdim_v);
}
else
{
return make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.hdim_v);
}
}();
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.hdim_v),
lengths,
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
@@ -446,9 +523,20 @@ struct FmhaFwdAppendKVKernel
}
else
{
const auto lengths = [&]() {
if constexpr(kIsPagedKV)
{
return make_tuple(kargs.hdim_v, kargs.page_block_size);
}
else
{
return make_tuple(kargs.hdim_v, kargs.seqlen_k + kargs.seqlen_knew);
}
}();
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.hdim_v, kargs.seqlen_k + kargs.seqlen_knew),
lengths,
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
@@ -644,6 +732,8 @@ struct FmhaFwdAppendKVKernel
knew_rotary_cos_dram_window,
knew_rotary_sin_dram_window,
kargs.rotary_dim,
k_tile_navigator,
v_tile_navigator,
kargs.seqlen_q <= i_m0,
kargs.seqlen_knew <= i_n0);
}
@@ -658,7 +748,9 @@ struct FmhaFwdAppendKVKernel
q_rotary_sin_dram_window,
knew_rotary_cos_dram_window,
knew_rotary_sin_dram_window,
0,
0, // rotary_dim not used
k_tile_navigator,
v_tile_navigator,
kargs.seqlen_q <= i_m0,
kargs.seqlen_knew <= i_n0);
}

View File

@@ -83,7 +83,9 @@ struct BlockFmhaFwdAppendKVPipeline
typename QRotaryCosDramBlockWindow,
typename QRotarySinDramBlockWindow,
typename KnewRotaryCosDramBlockWindow,
typename KnewRotarySinDramBlockWindow>
typename KnewRotarySinDramBlockWindow,
typename KTileWindowNavigator,
typename VTileWindowNavigator>
CK_TILE_HOST_DEVICE auto
operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile
const QElementFunction& q_element_func,
@@ -98,6 +100,8 @@ struct BlockFmhaFwdAppendKVPipeline
const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window,
const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window,
index_t rotary_dim,
const KTileWindowNavigator& k_tile_navigator,
const VTileWindowNavigator& v_tile_navigator,
bool skip_transform_q,
bool skip_append_kv) const
{
@@ -141,7 +145,16 @@ struct BlockFmhaFwdAppendKVPipeline
rotary_dim,
thread_end);
}
store_tile(k_dram_block_window, knew_tile);
if constexpr(kIsPagedKV)
{
/// TODO: handle cross-page-block write
store_tile(k_dram_block_window, knew_tile);
}
else
{
store_tile(k_dram_block_window, knew_tile);
}
// append Vnew to V
auto vnew_window = make_tile_window(
@@ -151,7 +164,16 @@ struct BlockFmhaFwdAppendKVPipeline
auto vnew = load_tile(vnew_window);
return tile_elementwise_in(vnew_element_func, vnew);
}();
store_tile(v_dram_block_window, vnew_tile);
if constexpr(kIsPagedKV)
{
/// TODO: handle cross-page-block write
store_tile(v_dram_block_window, vnew_tile);
}
else
{
store_tile(v_dram_block_window, vnew_tile);
}
}
if(!skip_transform_q)
@@ -201,7 +223,9 @@ struct BlockFmhaFwdAppendKVPipeline
typename QRotaryCosDramBlockWindow,
typename QRotarySinDramBlockWindow,
typename KnewRotaryCosDramBlockWindow,
typename KnewRotarySinDramBlockWindow>
typename KnewRotarySinDramBlockWindow,
typename KTileWindowNavigator,
typename VTileWindowNavigator>
CK_TILE_HOST_DEVICE auto
operator()(QDramBlockWindow& q_dram_block_window,
KDramBlockWindow& k_dram_block_window,
@@ -213,6 +237,8 @@ struct BlockFmhaFwdAppendKVPipeline
const KnewRotaryCosDramBlockWindow& knew_rotary_cos_dram_block_window,
const KnewRotarySinDramBlockWindow& knew_rotary_sin_dram_block_window,
index_t rotary_dim,
const KTileWindowNavigator& k_tile_navigator,
const VTileWindowNavigator& v_tile_navigator,
bool skip_transform_q,
bool skip_append_kv) const
{
@@ -229,6 +255,8 @@ struct BlockFmhaFwdAppendKVPipeline
knew_rotary_cos_dram_block_window,
knew_rotary_sin_dram_block_window,
rotary_dim,
k_tile_navigator,
v_tile_navigator,
skip_transform_q,
skip_append_kv);
}